Compare commits
95 Commits
dockerfile
...
fused-moe-
| Author | SHA1 | Date | |
|---|---|---|---|
| 94e7c6dac7 | |||
| 13729ad0af | |||
| 550f8a052c | |||
| 8ce3cad72f | |||
| 270d05d9fd | |||
| 6e2c176e1f | |||
| a86754a12b | |||
| c2a2f19aba | |||
| 2c11a738b3 | |||
| b639327ad9 | |||
| 4afe687a82 | |||
| 5de8d9f111 | |||
| c1c8ca57ff | |||
| a3a5a47e48 | |||
| fb25e95688 | |||
| 0d4891cd03 | |||
| f56d2996ca | |||
| 147afb448b | |||
| 3c7d942da8 | |||
| 890323dc1b | |||
| 01cae37713 | |||
| 11c0198615 | |||
| b1235c3e10 | |||
| 44d02f54db | |||
| a8593237c0 | |||
| fc0f41d10a | |||
| 7b828e30d5 | |||
| 5f0af36af5 | |||
| 0d21b2664c | |||
| 9907fc4494 | |||
| d47661f0cd | |||
| 53fa457391 | |||
| 6fb162447b | |||
| 66177189c5 | |||
| b4f0b5f9aa | |||
| cbd14ed561 | |||
| 7bd4c37ae7 | |||
| 8020e98c9f | |||
| 762be26a8e | |||
| 6a9e6b2abf | |||
| 5d09152ff1 | |||
| 31d5c1797f | |||
| 35514b682a | |||
| e2de455c34 | |||
| 5b032352cc | |||
| 922f316441 | |||
| 5923ab9524 | |||
| 0cf893cae1 | |||
| cf75cd2098 | |||
| b854321ffe | |||
| 5b6fe23d05 | |||
| f0c98cae27 | |||
| 574ad60db9 | |||
| fdadb6f43a | |||
| 41060c6e08 | |||
| 3de2ed767f | |||
| 299252ea82 | |||
| d6902ce79f | |||
| 5e53c89a74 | |||
| c66e38ea4c | |||
| 251595368f | |||
| 4bed167768 | |||
| b140416abf | |||
| 5b8366b61a | |||
| c7753a9809 | |||
| 4b9a9435bb | |||
| 3482fd7e4e | |||
| 77f77a951e | |||
| 1a4f35e2ea | |||
| be1e128dfb | |||
| 65393ee064 | |||
| dc221ad72d | |||
| 7571a4a7e5 | |||
| f67d986dd1 | |||
| cc876d0f29 | |||
| fdfd409f8f | |||
| ffbcc9e757 | |||
| 59389c927b | |||
| 8f2720def9 | |||
| ad6c2e1a0b | |||
| 49e8c7ea25 | |||
| 805d62ca88 | |||
| b7d9e9416f | |||
| 7c12a765aa | |||
| cd587c93ef | |||
| 332d4cb17b | |||
| bf03ff3575 | |||
| 47043eb678 | |||
| 31b96d1c64 | |||
| e59ba9e142 | |||
| 403b481573 | |||
| 138709f8d1 | |||
| 0bbac1c1b4 | |||
| a3e4e85ece | |||
| eb58f5953d |
@ -46,6 +46,6 @@ while getopts "m:b:l:f:t:" OPT; do
|
||||
done
|
||||
|
||||
lm_eval --model vllm \
|
||||
--model_args "pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,distributed_executor_backend=ray,trust_remote_code=true,max_model_len=4096" \
|
||||
--model_args "pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,add_bos_token=true,trust_remote_code=true,max_model_len=4096" \
|
||||
--tasks gsm8k --num_fewshot "$FEWSHOT" --limit "$LIMIT" \
|
||||
--batch_size "$BATCH_SIZE"
|
||||
|
||||
@ -18,12 +18,14 @@ RTOL = 0.08
|
||||
|
||||
def launch_lm_eval(eval_config, tp_size):
|
||||
trust_remote_code = eval_config.get("trust_remote_code", False)
|
||||
max_model_len = eval_config.get("max_model_len", 4096)
|
||||
model_args = (
|
||||
f"pretrained={eval_config['model_name']},"
|
||||
f"tensor_parallel_size={tp_size},"
|
||||
f"enforce_eager=true,"
|
||||
f"add_bos_token=true,"
|
||||
f"trust_remote_code={trust_remote_code}"
|
||||
f"trust_remote_code={trust_remote_code},"
|
||||
f"max_model_len={max_model_len}"
|
||||
)
|
||||
results = lm_eval.simple_evaluate(
|
||||
model="vllm",
|
||||
|
||||
@ -22,16 +22,6 @@ trap remove_docker_container EXIT
|
||||
# Remove the container that might not be cleaned up in the previous run.
|
||||
remove_docker_container
|
||||
|
||||
# Build docker image.
|
||||
# TODO: build the image outside the script and share the image with other
|
||||
# tpu test if building time is too long.
|
||||
DOCKER_BUILDKIT=1 docker build \
|
||||
--build-arg max_jobs=16 \
|
||||
--build-arg USE_SCCACHE=1 \
|
||||
--build-arg GIT_REPO_CHECK=0 \
|
||||
--tag vllm/vllm-tpu-bm \
|
||||
--progress plain -f docker/Dockerfile.tpu .
|
||||
|
||||
LOG_ROOT=$(mktemp -d)
|
||||
# If mktemp fails, set -e will cause the script to exit.
|
||||
echo "Results will be stored in: $LOG_ROOT"
|
||||
|
||||
@ -282,7 +282,7 @@ steps:
|
||||
- python3 offline_inference/llm_engine_example.py
|
||||
- python3 offline_inference/audio_language.py --seed 0
|
||||
- python3 offline_inference/vision_language.py --seed 0
|
||||
- python3 offline_inference/vision_language_embedding.py --seed 0
|
||||
- python3 offline_inference/vision_language_pooling.py --seed 0
|
||||
- python3 offline_inference/vision_language_multi_image.py --seed 0
|
||||
- VLLM_USE_V1=0 python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
|
||||
- python3 offline_inference/encoder_decoder.py
|
||||
|
||||
2
.github/mergify.yml
vendored
2
.github/mergify.yml
vendored
@ -86,8 +86,6 @@ pull_request_rules:
|
||||
- and:
|
||||
- files~=^vllm/model_executor/models/
|
||||
- files=vllm/model_executor/models/registry.py
|
||||
- files=tests/models/registry.py
|
||||
- files=docs/models/supported_models.md
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@ -146,6 +146,7 @@ venv.bak/
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
docs/argparse
|
||||
docs/examples
|
||||
|
||||
# mypy
|
||||
|
||||
@ -171,7 +171,6 @@ if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}")
|
||||
endif()
|
||||
|
||||
|
||||
#
|
||||
# Use FetchContent for C++ dependencies that are compiled as part of vLLM's build process.
|
||||
# setup.py will override FETCHCONTENT_BASE_DIR to play nicely with sccache.
|
||||
@ -232,7 +231,6 @@ endif()
|
||||
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/mamba/mamba_ssm/selective_scan_fwd.cu"
|
||||
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
|
||||
"csrc/cache_kernels.cu"
|
||||
"csrc/attention/paged_attention_v1.cu"
|
||||
"csrc/attention/paged_attention_v2.cu"
|
||||
@ -393,7 +391,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
|
||||
# CUDA 12.0 or later
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_ARCHS)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu"
|
||||
@ -409,7 +407,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
|
||||
message(STATUS "Building scaled_mm_c3x_sm90 for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_ARCHS)
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c3x_sm90 as CUDA Compiler version is "
|
||||
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
|
||||
"later if you intend on running FP8 quantized models on "
|
||||
@ -424,7 +422,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# The cutlass_scaled_mm kernels for Geforce Blackwell SM120 (c3x, i.e. CUTLASS 3.x) require
|
||||
# CUDA 12.8 or later
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0;12.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu"
|
||||
@ -438,7 +436,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
|
||||
message(STATUS "Building scaled_mm_c3x_sm120 for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS)
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c3x_sm120 as CUDA Compiler version is "
|
||||
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
|
||||
"later if you intend on running FP8 quantized models on "
|
||||
@ -453,7 +451,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# The cutlass_scaled_mm kernels for Blackwell SM100 (c3x, i.e. CUTLASS 3.x)
|
||||
# require CUDA 12.8 or later
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu"
|
||||
@ -468,7 +466,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
|
||||
message(STATUS "Building scaled_mm_c3x_sm100 for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS)
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c3x_sm100 as CUDA Compiler version is "
|
||||
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
|
||||
"later if you intend on running FP8 quantized models on "
|
||||
@ -511,7 +509,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
|
||||
# require CUDA 12.2 or later (and only work on Hopper).
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_ARCHS)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.2 AND SCALED_MM_ARCHS)
|
||||
set(SRCS "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
@ -520,7 +518,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SPARSE_SCALED_MM_C3X=1")
|
||||
message(STATUS "Building sparse_scaled_mm_c3x for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_ARCHS)
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.2 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building sparse_scaled_mm_c3x kernels as CUDA Compiler version is "
|
||||
"not >= 12.2, we recommend upgrading to CUDA 12.2 or later "
|
||||
"if you intend on running FP8 sparse quantized models on Hopper.")
|
||||
@ -532,7 +530,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
# FP4 Archs and flags
|
||||
cuda_archs_loose_intersection(FP4_ARCHS "10.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND FP4_ARCHS)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
|
||||
"csrc/quantization/fp4/nvfp4_experts_quant.cu"
|
||||
@ -553,7 +551,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
# CUTLASS MLA Archs and flags
|
||||
cuda_archs_loose_intersection(MLA_ARCHS "10.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND MLA_ARCHS)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND MLA_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/attention/mla/cutlass_mla_kernels.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
@ -642,7 +640,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# The machete kernels only work on hopper and require CUDA 12.0 or later.
|
||||
# Only build Machete kernels if we are building for something compatible with sm90a
|
||||
cuda_archs_loose_intersection(MACHETE_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND MACHETE_ARCHS)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND MACHETE_ARCHS)
|
||||
#
|
||||
# For the Machete kernels we automatically generate sources for various
|
||||
# preselected input type pairs and schedules.
|
||||
@ -694,7 +692,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
message(STATUS "Building Machete kernels for archs: ${MACHETE_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0
|
||||
AND MACHETE_ARCHS)
|
||||
message(STATUS "Not building Machete kernels as CUDA Compiler version is "
|
||||
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
|
||||
|
||||
20
Dockerfile
Normal file
20
Dockerfile
Normal file
@ -0,0 +1,20 @@
|
||||
ARG CUDA_VERSION=12.8.1
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04
|
||||
|
||||
RUN apt update && apt install git -y && apt install curl -y
|
||||
|
||||
WORKDIR /workspace
|
||||
RUN git clone https://github.com/vllm-project/vllm.git
|
||||
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
||||
|
||||
# Install vllm.
|
||||
WORKDIR /workspace/vllm
|
||||
RUN uv venv .vllm --python 3.12
|
||||
RUN . .vllm/bin/activate && VLLM_USE_PRECOMPILED=1 uv pip install -e .
|
||||
|
||||
# Checkout a specific commit.
|
||||
ENV VLLM_SHA=550f8a052cae03c7e14a46767f689ab09c1cc28d
|
||||
RUN git fetch && git checkout ${VLLM_SHA}
|
||||
|
||||
ENTRYPOINT ["/bin/bash"]
|
||||
@ -69,7 +69,7 @@ vLLM is flexible and easy to use with:
|
||||
|
||||
- Seamless integration with popular Hugging Face models
|
||||
- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more
|
||||
- Tensor parallelism and pipeline parallelism support for distributed inference
|
||||
- Tensor, pipeline, data and expert parallelism support for distributed inference
|
||||
- Streaming outputs
|
||||
- OpenAI-compatible API server
|
||||
- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, TPU, and AWS Neuron
|
||||
|
||||
@ -324,6 +324,9 @@ class RandomDataset(BenchmarkDataset):
|
||||
input_low = int(real_input_len * (1 - range_ratio))
|
||||
input_high = int(real_input_len * (1 + range_ratio))
|
||||
output_low = int(output_len * (1 - range_ratio))
|
||||
# Ensure the lower bound for output length is at least 1 to prevent
|
||||
# sampling 0 tokens, which can cause request failures.
|
||||
output_low = max(output_low, 1)
|
||||
output_high = int(output_len * (1 + range_ratio))
|
||||
|
||||
# Add logging for debugging
|
||||
|
||||
63
benchmarks/kernels/Justfile
Normal file
63
benchmarks/kernels/Justfile
Normal file
@ -0,0 +1,63 @@
|
||||
all:
|
||||
just llama-scout-bf16 && \
|
||||
just llama-scout-fp8 && \
|
||||
just llama-maverick && \
|
||||
just qwen-30b && \
|
||||
just qwen-30b-fp8 && \
|
||||
just qwen-235b && \
|
||||
just deepseek-r1
|
||||
|
||||
|
||||
llama-scout-bf16:
|
||||
python3 benchmark_moe.py \
|
||||
--model meta-llama/Llama-4-Scout-17B-16E-Instruct \
|
||||
--tp-size 1 \
|
||||
--ep-size 8 \
|
||||
--tune
|
||||
|
||||
llama-scout-fp8:
|
||||
python3 benchmark_moe.py \
|
||||
--model meta-llama/Llama-4-Scout-17B-16E-Instruct \
|
||||
--tp-size 1 \
|
||||
--ep-size 8 \
|
||||
--dtype fp8_w8a8 \
|
||||
--tune
|
||||
|
||||
llama-maverick:
|
||||
python3 benchmark_moe.py \
|
||||
--model meta-llama/Llama-4-Maverick-17B-128E-Instruct \
|
||||
--tp-size 1 \
|
||||
--ep-size 8 \
|
||||
--dtype fp8_w8a8 \
|
||||
--tune
|
||||
|
||||
qwen-30b:
|
||||
python3 benchmark_moe.py \
|
||||
--model Qwen/Qwen3-30B-A3B \
|
||||
--tp-size 1 \
|
||||
--ep-size 8 \
|
||||
--tune
|
||||
|
||||
qwen-30b-fp8:
|
||||
python3 benchmark_moe.py \
|
||||
--model Qwen/Qwen3-30B-A3B-FP8 \
|
||||
--tp-size 1 \
|
||||
--ep-size 8 \
|
||||
--dtype fp8_w8a8 \
|
||||
--tune
|
||||
|
||||
qwen-235b:
|
||||
python3 benchmark_moe.py \
|
||||
--model Qwen/Qwen3-235B-A22B \
|
||||
--tp-size 1 \
|
||||
--ep-size 8 \
|
||||
--dtype fp8_w8a8 \
|
||||
--tune
|
||||
|
||||
deepseek-r1:
|
||||
python3 benchmark_moe.py \
|
||||
--model deepseek-ai/DeepSeek-R1-0528 \
|
||||
--tp-size 1 \
|
||||
--ep-size 8 \
|
||||
--dtype fp8_w8a8 \
|
||||
--tune
|
||||
141
benchmarks/kernels/bench_nvfp4_gemm.py
Normal file
141
benchmarks/kernels/bench_nvfp4_gemm.py
Normal file
@ -0,0 +1,141 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
from weight_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
if not current_platform.has_device_capability(100):
|
||||
raise RuntimeError("NVFP4 requires compute capability of 10.0 (Blackwell)")
|
||||
|
||||
|
||||
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
||||
PROVIDER_CFGS = {
|
||||
"torch-bf16": dict(enabled=True),
|
||||
"nvfp4": dict(no_a_quant=False, enabled=True),
|
||||
"nvfp4-noquant": dict(no_a_quant=True, enabled=True),
|
||||
}
|
||||
|
||||
_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]]
|
||||
|
||||
|
||||
def _quant_weight_nvfp4(b: torch.Tensor, device: str):
|
||||
# Compute global scale for weight
|
||||
b_amax = torch.abs(b).max().to(torch.float32)
|
||||
b_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
|
||||
b_fp4, scale_b_fp4 = ops.scaled_fp4_quant(b, b_global_scale)
|
||||
return b_fp4, scale_b_fp4, b_global_scale
|
||||
|
||||
|
||||
def build_nvfp4_runner(cfg, a, b, dtype, device):
|
||||
b_fp4, scale_b_fp4, b_global_scale = _quant_weight_nvfp4(b, device)
|
||||
|
||||
# Compute global scale for activation
|
||||
# NOTE: This is generally provided ahead-of-time by the model checkpoint.
|
||||
a_amax = torch.abs(a).max().to(torch.float32)
|
||||
a_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax
|
||||
|
||||
# Alpha for the GEMM operation
|
||||
alpha = 1.0 / (a_global_scale * b_global_scale)
|
||||
|
||||
if cfg["no_a_quant"]:
|
||||
# Pre-quantize activation
|
||||
a_fp4, scale_a_fp4 = ops.scaled_fp4_quant(a, a_global_scale)
|
||||
|
||||
def run():
|
||||
return ops.cutlass_scaled_fp4_mm(
|
||||
a_fp4, b_fp4, scale_a_fp4, scale_b_fp4, alpha, dtype
|
||||
)
|
||||
|
||||
return run
|
||||
|
||||
# Quantize activation on-the-fly
|
||||
def run():
|
||||
a_fp4, scale_a_fp4 = ops.scaled_fp4_quant(a, a_global_scale)
|
||||
return ops.cutlass_scaled_fp4_mm(
|
||||
a_fp4, b_fp4, scale_a_fp4, scale_b_fp4, alpha, dtype
|
||||
)
|
||||
|
||||
return run
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384],
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=_enabled,
|
||||
line_names=_enabled,
|
||||
ylabel="TFLOP/s (larger is better)",
|
||||
plot_name="BF16 vs NVFP4 GEMMs",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider, N, K):
|
||||
M = batch_size
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
a = torch.randn((M, K), device=device, dtype=dtype)
|
||||
b = torch.randn((N, K), device=device, dtype=dtype)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "torch-bf16":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: torch.nn.functional.linear(a, b), quantiles=quantiles
|
||||
)
|
||||
else:
|
||||
cfg = PROVIDER_CFGS[provider]
|
||||
run_quant = build_nvfp4_runner(cfg, a, b, dtype, device)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: run_quant(), quantiles=quantiles
|
||||
)
|
||||
|
||||
to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3)
|
||||
return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms)
|
||||
|
||||
|
||||
def prepare_shapes(args):
|
||||
out = []
|
||||
for model, tp_size in itertools.product(args.models, args.tp_sizes):
|
||||
for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
|
||||
KN[tp_dim] //= tp_size
|
||||
KN.append(model)
|
||||
out.append(KN)
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=["meta-llama/Llama-3.1-8B-Instruct"],
|
||||
choices=list(WEIGHT_SHAPES.keys()),
|
||||
)
|
||||
parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1])
|
||||
args = parser.parse_args()
|
||||
|
||||
for K, N, model in prepare_shapes(args):
|
||||
print(f"{model}, N={N} K={K}, BF16 vs NVFP4 GEMMs TFLOP/s:")
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
show_plots=True,
|
||||
save_path=f"bench_nvfp4_res_n{N}_k{K}",
|
||||
N=N,
|
||||
K=K,
|
||||
)
|
||||
|
||||
print("Benchmark finished!")
|
||||
98
benchmarks/kernels/bench_per_token_quant_fp8.py
Normal file
98
benchmarks/kernels/bench_per_token_quant_fp8.py
Normal file
@ -0,0 +1,98 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import itertools
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
|
||||
# TODO(luka): use standalone_compile utility
|
||||
def with_dyn_arg(fn: Callable, arg_index: int, dim_index: int):
|
||||
def inner(*args):
|
||||
torch._dynamo.mark_dynamic(args[arg_index], dim_index)
|
||||
return fn(*args)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
torch._dynamo.config.recompile_limit = 8888
|
||||
compilation_config = CompilationConfig(custom_ops=["none"])
|
||||
with set_current_vllm_config(VllmConfig(compilation_config=compilation_config)):
|
||||
torch_per_token_quant_fp8 = torch.compile(
|
||||
QuantFP8(False, GroupShape.PER_TOKEN),
|
||||
fullgraph=True,
|
||||
dynamic=False, # recompile for different shapes
|
||||
)
|
||||
|
||||
# First dim is explicitly dynamic to simulate vLLM usage
|
||||
torch_per_token_quant_fp8 = with_dyn_arg(torch_per_token_quant_fp8, 0, 0)
|
||||
|
||||
|
||||
def cuda_per_token_quant_fp8(
|
||||
input: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return ops.scaled_fp8_quant(input)
|
||||
|
||||
|
||||
def calculate_diff(batch_size: int, seq_len: int):
|
||||
"""Calculate difference between Triton and CUDA implementations."""
|
||||
device = torch.device("cuda")
|
||||
x = torch.rand((batch_size * seq_len, 4096), dtype=torch.float16, device=device)
|
||||
|
||||
torch_out, torch_scale = torch_per_token_quant_fp8(x)
|
||||
cuda_out, cuda_scale = cuda_per_token_quant_fp8(x)
|
||||
|
||||
if torch.allclose(
|
||||
cuda_out.to(torch.float32), torch_out.to(torch.float32), rtol=1e-3, atol=1e-5
|
||||
) and torch.allclose(cuda_scale, torch_scale, rtol=1e-3, atol=1e-5):
|
||||
print("✅ All implementations match")
|
||||
else:
|
||||
print("❌ Implementations differ")
|
||||
|
||||
|
||||
batch_size_range = [1, 16, 32, 64, 128]
|
||||
seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]
|
||||
|
||||
configs = list(itertools.product(batch_size_range, seq_len_range))
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size", "seq_len"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["torch", "cuda"],
|
||||
line_names=["Torch", "CUDA"],
|
||||
styles=[("blue", "-"), ("green", "-")],
|
||||
ylabel="us",
|
||||
plot_name="per-token-dynamic-quant-fp8-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark_quantization(batch_size, seq_len, provider):
|
||||
dtype = torch.float16
|
||||
device = torch.device("cuda")
|
||||
|
||||
x = torch.randn(batch_size * seq_len, 4096, device=device, dtype=dtype)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "torch":
|
||||
fn = lambda: torch_per_token_quant_fp8(x.clone())
|
||||
elif provider == "cuda":
|
||||
fn = lambda: cuda_per_token_quant_fp8(x.clone())
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
calculate_diff(batch_size=4, seq_len=4096)
|
||||
benchmark_quantization.run(print_data=True)
|
||||
@ -86,6 +86,9 @@ def benchmark_config(
|
||||
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32
|
||||
)
|
||||
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
|
||||
if use_deep_gemm:
|
||||
# we use the default block shape for deepgemm
|
||||
block_quant_shape = [128, 128]
|
||||
if use_fp8_w8a8:
|
||||
if block_quant_shape:
|
||||
block_n, block_k = block_quant_shape[0], block_quant_shape[1]
|
||||
@ -592,6 +595,13 @@ def main(args: argparse.Namespace):
|
||||
intermediate_size = config.intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
|
||||
# Expert parallelism
|
||||
if E % args.ep_size != 0:
|
||||
raise ValueError(
|
||||
f"Number of experts {E} must be divisible by expert parallel size {args.ep_size}"
|
||||
)
|
||||
E = E // args.ep_size
|
||||
|
||||
hidden_size = config.hidden_size
|
||||
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
|
||||
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
||||
@ -721,7 +731,10 @@ if __name__ == "__main__":
|
||||
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp-size", "-tp", "--tensor-parallel-size", type=int, default=2
|
||||
"--tp-size", "-tp", "--tensor-parallel-size", type=int, default=1
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ep-size", "-ep", "--expert-parallel-size", type=int, default=1
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto"
|
||||
|
||||
240
benchmarks/kernels/benchmark_trtllm_attention.py
Normal file
240
benchmarks/kernels/benchmark_trtllm_attention.py
Normal file
@ -0,0 +1,240 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import csv
|
||||
import os
|
||||
import random
|
||||
from datetime import datetime
|
||||
|
||||
import flashinfer
|
||||
import torch
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
|
||||
# KV Cache Layout for TRT-LLM
|
||||
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
|
||||
|
||||
|
||||
def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||
finfo = torch.finfo(dtype)
|
||||
min_val, max_val = x.aminmax()
|
||||
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
||||
scale = finfo.max / amax * 0.1
|
||||
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
|
||||
return x_scl_sat.to(dtype), scale.float().reciprocal()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def benchmark_decode(
|
||||
num_seqs,
|
||||
max_seq_len,
|
||||
page_size=16,
|
||||
dtype=torch.bfloat16,
|
||||
kv_layout="HND",
|
||||
num_kv_heads=8,
|
||||
kv_cache_dtype="auto",
|
||||
head_dim=128,
|
||||
warmup=10,
|
||||
trials=20,
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
device = "cuda"
|
||||
torch.manual_seed(0)
|
||||
|
||||
# Currently only HEAD_GRP_SIZE == 8 is supported
|
||||
HEAD_GRP_SIZE = 8
|
||||
MAX_SEQ_LEN = max_seq_len
|
||||
|
||||
# large number to reduce kv_cache reuse
|
||||
NUM_BLOCKS = int(256000 / page_size)
|
||||
|
||||
workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8, device=device)
|
||||
|
||||
# For decode, batch_size is num_decode_token
|
||||
num_qo_heads = num_kv_heads * HEAD_GRP_SIZE
|
||||
sm_scale = float(1.0 / (head_dim**0.5))
|
||||
q = torch.randn(num_seqs, num_qo_heads, head_dim, device=device, dtype=dtype)
|
||||
kv_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
||||
|
||||
max_kv_len = max(kv_lens)
|
||||
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int, device=device)
|
||||
max_num_blocks_per_seq = (max_kv_len + page_size - 1) // page_size
|
||||
|
||||
block_tables = torch.randint(
|
||||
0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
|
||||
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim)
|
||||
kv_cache = torch.randn(size=kv_cache_shape, device=device, dtype=dtype)
|
||||
k_scale = v_scale = 1.0
|
||||
|
||||
if kv_cache_dtype.startswith("fp8"):
|
||||
kv_cache, _ = to_float8(kv_cache)
|
||||
|
||||
# Benchmark TRT decode
|
||||
def trt_decode():
|
||||
return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
|
||||
q,
|
||||
kv_cache,
|
||||
workspace_buffer,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
sm_scale,
|
||||
block_tables,
|
||||
kv_lens_tensor,
|
||||
page_size,
|
||||
max_kv_len,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
def time_fn(fn, warmup=10, trials=20):
|
||||
torch.cuda.synchronize()
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
times = []
|
||||
for i in range(warmup):
|
||||
fn()
|
||||
for i in range(trials):
|
||||
start.record()
|
||||
fn()
|
||||
end.record()
|
||||
torch.cuda.synchronize()
|
||||
times.append(start.elapsed_time(end)) # ms
|
||||
return sum(times) / len(times), torch.std(torch.tensor(times))
|
||||
|
||||
# TRT Decode
|
||||
trt_mean, trt_std = time_fn(trt_decode)
|
||||
|
||||
kv_indptr = [0]
|
||||
kv_indices = []
|
||||
kv_last_page_lens = []
|
||||
for i in range(num_seqs):
|
||||
seq_len = kv_lens[i]
|
||||
assert seq_len > 0
|
||||
num_blocks = (seq_len + page_size - 1) // page_size
|
||||
kv_indices.extend(block_tables[i, :num_blocks])
|
||||
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
||||
kv_last_page_len = seq_len % page_size
|
||||
if kv_last_page_len == 0:
|
||||
kv_last_page_len = page_size
|
||||
kv_last_page_lens.append(kv_last_page_len)
|
||||
|
||||
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
||||
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||
|
||||
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer,
|
||||
kv_layout,
|
||||
use_tensor_cores=((num_qo_heads // num_kv_heads) > 4),
|
||||
)
|
||||
|
||||
wrapper.plan(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
page_size,
|
||||
"NONE",
|
||||
q_data_type=dtype,
|
||||
kv_data_type=torch.float8_e4m3fn if kv_cache_dtype.startswith("fp8") else dtype,
|
||||
)
|
||||
|
||||
def baseline_decode():
|
||||
return wrapper.run(q, kv_cache, sm_scale, k_scale, v_scale)
|
||||
|
||||
baseline_mean, baseline_std = time_fn(baseline_decode)
|
||||
|
||||
# Calculate percentage speedup (positive means TRT is faster)
|
||||
speedup_percent = (baseline_mean - trt_mean) / baseline_mean
|
||||
|
||||
print(
|
||||
f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.3f}\t{trt_std.item():.3f}"
|
||||
f"\t{baseline_mean:.3f}\t{baseline_std.item():.3f}\t{speedup_percent:.3f}"
|
||||
)
|
||||
|
||||
# Return results for CSV writing
|
||||
return {
|
||||
"num_seqs": num_seqs,
|
||||
"trt_mean": trt_mean,
|
||||
"trt_std": trt_std.item(),
|
||||
"baseline_mean": baseline_mean,
|
||||
"baseline_std": baseline_std.item(),
|
||||
"speedup_percent": speedup_percent,
|
||||
"q_dtype": str(dtype),
|
||||
"kv_cache_dtype": kv_cache_dtype,
|
||||
"page_size": page_size,
|
||||
"num_kv_heads": num_kv_heads,
|
||||
"head_dim": head_dim,
|
||||
"max_seq_len": max_seq_len,
|
||||
}
|
||||
|
||||
|
||||
def write_results_to_csv(results, filename=None):
|
||||
"""Write benchmark results to CSV file."""
|
||||
if filename is None:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv"
|
||||
|
||||
fieldnames = [
|
||||
"num_seqs",
|
||||
"trt_mean",
|
||||
"trt_std",
|
||||
"baseline_mean",
|
||||
"baseline_std",
|
||||
"speedup_percent",
|
||||
"q_dtype",
|
||||
"kv_cache_dtype",
|
||||
"page_size",
|
||||
"num_kv_heads",
|
||||
"head_dim",
|
||||
"max_seq_len",
|
||||
]
|
||||
|
||||
file_exists = os.path.exists(filename)
|
||||
|
||||
with open(filename, "a", newline="") as csvfile:
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||
|
||||
if not file_exists:
|
||||
writer.writeheader()
|
||||
|
||||
for result in results:
|
||||
writer.writerow(result)
|
||||
|
||||
print(f"Results written to {filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
num_seqs = [1, 4, 8, 16, 32, 64, 128, 256]
|
||||
max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
|
||||
all_results = []
|
||||
|
||||
print("Running benchmark for kv_cache_dtype: bfloat16")
|
||||
print(
|
||||
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\tbaseline_std\tspeedup_percent"
|
||||
)
|
||||
for max_seq_len in max_seq_lens:
|
||||
for bs in num_seqs:
|
||||
result = benchmark_decode(
|
||||
bs, max_seq_len, dtype=torch.bfloat16, kv_cache_dtype="auto"
|
||||
)
|
||||
all_results.append(result)
|
||||
|
||||
print("Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8")
|
||||
print(
|
||||
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\tbaseline_std\tspeedup_percent"
|
||||
)
|
||||
for max_seq_len in max_seq_lens:
|
||||
for bs in num_seqs:
|
||||
result = benchmark_decode(
|
||||
bs, max_seq_len, dtype=torch.bfloat16, kv_cache_dtype="fp8"
|
||||
)
|
||||
all_results.append(result)
|
||||
|
||||
# Write all results to CSV
|
||||
write_results_to_csv(all_results)
|
||||
@ -165,17 +165,32 @@ else()
|
||||
endif()
|
||||
|
||||
#
|
||||
# Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 platforms)
|
||||
#
|
||||
if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||
# Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 /ARM platforms)
|
||||
# Flag to enable ACL kernels for AARCH64 platforms
|
||||
if ( VLLM_BUILD_ACL STREQUAL "ON")
|
||||
set(USE_ACL ON)
|
||||
else()
|
||||
set(USE_ACL OFF)
|
||||
endif()
|
||||
|
||||
if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND)
|
||||
FetchContent_Declare(
|
||||
oneDNN
|
||||
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
|
||||
GIT_TAG v3.7.1
|
||||
GIT_TAG v3.8.1
|
||||
GIT_PROGRESS TRUE
|
||||
GIT_SHALLOW TRUE
|
||||
)
|
||||
|
||||
if(USE_ACL)
|
||||
find_library(ARM_COMPUTE_LIBRARY NAMES arm_compute PATHS $ENV{ACL_ROOT_DIR}/build/)
|
||||
if(NOT ARM_COMPUTE_LIBRARY)
|
||||
message(FATAL_ERROR "Could not find ARM Compute Library: please set ACL_ROOT_DIR")
|
||||
endif()
|
||||
set(ONEDNN_AARCH64_USE_ACL "ON")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ENV{ACL_ROOT_DIR}/build/")
|
||||
endif()
|
||||
|
||||
set(ONEDNN_LIBRARY_TYPE "STATIC")
|
||||
set(ONEDNN_BUILD_DOC "OFF")
|
||||
set(ONEDNN_BUILD_EXAMPLES "OFF")
|
||||
@ -264,6 +279,11 @@ elseif(POWER10_FOUND)
|
||||
"csrc/cpu/quant.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
endif()
|
||||
if (ASIMD_FOUND)
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cpu/quant.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
endif()
|
||||
|
||||
message(STATUS "CPU extension source files: ${VLLM_EXT_SRC}")
|
||||
|
||||
|
||||
@ -33,6 +33,8 @@ namespace vec_op {
|
||||
#endif
|
||||
|
||||
#define FORCE_INLINE __attribute__((always_inline)) inline
|
||||
// Number of elements in single ASIMD vector of given Datatype
|
||||
#define NUM_ELEMENTS_REG(vec) (sizeof(vec) / sizeof(vec[0]))
|
||||
|
||||
namespace {
|
||||
template <typename T, T... indexes, typename F>
|
||||
@ -86,8 +88,8 @@ struct FP16Vec16 : public Vec<FP16Vec16> {
|
||||
}
|
||||
|
||||
void save(void* ptr, const int elem_num) const {
|
||||
int full_blocks = elem_num / 8;
|
||||
int remainder = elem_num % 8;
|
||||
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
|
||||
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
|
||||
|
||||
if (full_blocks > 0) {
|
||||
vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]);
|
||||
@ -197,6 +199,25 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||
vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[2]), v.val[3])}) {};
|
||||
|
||||
void save(void* ptr) const { *reinterpret_cast<bfloat16x8x2_t*>(ptr) = reg; };
|
||||
void save(void* ptr, const int elem_num) const {
|
||||
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
|
||||
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
|
||||
for (int i = 0; i < full_blocks; i++)
|
||||
vst1q_bf16(
|
||||
reinterpret_cast<__bf16*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i,
|
||||
reg.val[i]);
|
||||
if (remainder > 0) {
|
||||
bfloat16x8_t temp = reg.val[full_blocks];
|
||||
bfloat16_t* base = reinterpret_cast<bfloat16_t*>(ptr) + full_blocks * 8;
|
||||
if (remainder > 0) base[0] = vgetq_lane_bf16(temp, 0);
|
||||
if (remainder > 1) base[1] = vgetq_lane_bf16(temp, 1);
|
||||
if (remainder > 2) base[2] = vgetq_lane_bf16(temp, 2);
|
||||
if (remainder > 3) base[3] = vgetq_lane_bf16(temp, 3);
|
||||
if (remainder > 4) base[4] = vgetq_lane_bf16(temp, 4);
|
||||
if (remainder > 5) base[5] = vgetq_lane_bf16(temp, 5);
|
||||
if (remainder > 6) base[6] = vgetq_lane_bf16(temp, 6);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
@ -213,6 +234,25 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
: reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {};
|
||||
|
||||
void save(void* ptr) const { *reinterpret_cast<bfloat16x8x4_t*>(ptr) = reg; };
|
||||
void save(void* ptr, const int elem_num) const {
|
||||
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
|
||||
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
|
||||
for (int i = 0; i < full_blocks; i++)
|
||||
vst1q_bf16(
|
||||
reinterpret_cast<__bf16*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i,
|
||||
reg.val[i]);
|
||||
if (remainder > 0) {
|
||||
bfloat16x8_t temp = reg.val[full_blocks];
|
||||
bfloat16_t* base = reinterpret_cast<bfloat16_t*>(ptr) + full_blocks * 8;
|
||||
base[0] = vgetq_lane_bf16(temp, 0);
|
||||
if (remainder > 1) base[1] = vgetq_lane_bf16(temp, 1);
|
||||
if (remainder > 2) base[2] = vgetq_lane_bf16(temp, 2);
|
||||
if (remainder > 3) base[3] = vgetq_lane_bf16(temp, 3);
|
||||
if (remainder > 4) base[4] = vgetq_lane_bf16(temp, 4);
|
||||
if (remainder > 5) base[5] = vgetq_lane_bf16(temp, 5);
|
||||
if (remainder > 6) base[6] = vgetq_lane_bf16(temp, 6);
|
||||
}
|
||||
};
|
||||
};
|
||||
#endif
|
||||
|
||||
@ -372,6 +412,48 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
}
|
||||
};
|
||||
|
||||
struct INT32Vec16 : public Vec<INT32Vec16> {
|
||||
constexpr static int VEC_ELEM_NUM = 16;
|
||||
union AliasReg {
|
||||
int32x4x4_t reg;
|
||||
int32_t values[VEC_ELEM_NUM];
|
||||
};
|
||||
int32x4x4_t reg;
|
||||
|
||||
explicit INT32Vec16(const void* ptr) {
|
||||
reg.val[0] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr));
|
||||
reg.val[1] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr) + 4);
|
||||
reg.val[2] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr) + 8);
|
||||
reg.val[3] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr) + 12);
|
||||
}
|
||||
|
||||
void save(int32_t* ptr) const {
|
||||
vst1q_s32(ptr, reg.val[0]);
|
||||
vst1q_s32(ptr + 4, reg.val[1]);
|
||||
vst1q_s32(ptr + 8, reg.val[2]);
|
||||
vst1q_s32(ptr + 12, reg.val[3]);
|
||||
};
|
||||
|
||||
void save(int32_t* ptr, const int elem_num) const {
|
||||
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
|
||||
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
|
||||
|
||||
for (int i = 0; i < full_blocks; i++)
|
||||
vst1q_s32(
|
||||
reinterpret_cast<__int32_t*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i,
|
||||
reg.val[i]);
|
||||
|
||||
if (remainder > 0) {
|
||||
int32x4_t temp = reg.val[full_blocks];
|
||||
int32_t* base = reinterpret_cast<int32_t*>(ptr) + full_blocks * 4;
|
||||
if (remainder > 0) base[0] = vgetq_lane_s32(temp, 0);
|
||||
if (remainder > 1) base[1] = vgetq_lane_s32(temp, 1);
|
||||
if (remainder > 2) base[2] = vgetq_lane_s32(temp, 2);
|
||||
if (remainder > 3) base[3] = vgetq_lane_s32(temp, 3);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
constexpr static int VEC_ELEM_NUM = 16;
|
||||
union AliasReg {
|
||||
@ -434,7 +516,12 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
reg.val[2] = vcvt_f32_f16(vget_low_f16(v.reg.val[1]));
|
||||
reg.val[3] = vcvt_f32_f16(vget_high_f16(v.reg.val[1]));
|
||||
};
|
||||
|
||||
explicit FP32Vec16(const INT32Vec16& v) {
|
||||
reg.val[0] = vcvtq_f32_s32(v.reg.val[0]);
|
||||
reg.val[1] = vcvtq_f32_s32(v.reg.val[1]);
|
||||
reg.val[2] = vcvtq_f32_s32(v.reg.val[2]);
|
||||
reg.val[3] = vcvtq_f32_s32(v.reg.val[3]);
|
||||
};
|
||||
FP32Vec16 operator+(const FP32Vec16& b) const {
|
||||
return FP32Vec16(float32x4x4_t({vaddq_f32(reg.val[0], b.reg.val[0]),
|
||||
vaddq_f32(reg.val[1], b.reg.val[1]),
|
||||
@ -463,6 +550,85 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
vdivq_f32(reg.val[3], b.reg.val[3])}));
|
||||
};
|
||||
|
||||
FP32Vec16 clamp(const FP32Vec16& min, const FP32Vec16& max) const {
|
||||
return FP32Vec16(float32x4x4_t(
|
||||
{vminq_f32(max.reg.val[0], vmaxq_f32(min.reg.val[0], reg.val[0])),
|
||||
vminq_f32(max.reg.val[1], vmaxq_f32(min.reg.val[1], reg.val[1])),
|
||||
vminq_f32(max.reg.val[2], vmaxq_f32(min.reg.val[2], reg.val[2])),
|
||||
vminq_f32(max.reg.val[3], vmaxq_f32(min.reg.val[3], reg.val[3]))}));
|
||||
};
|
||||
|
||||
FP32Vec16 max(const FP32Vec16& b) const {
|
||||
return FP32Vec16(float32x4x4_t({vmaxq_f32(b.reg.val[0], reg.val[0]),
|
||||
vmaxq_f32(b.reg.val[1], reg.val[1]),
|
||||
vmaxq_f32(b.reg.val[2], reg.val[2]),
|
||||
vmaxq_f32(b.reg.val[3], reg.val[3])}));
|
||||
};
|
||||
|
||||
FP32Vec16 max(const FP32Vec16& b, const int elem_num) const {
|
||||
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
|
||||
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
|
||||
float32x4x4_t temp;
|
||||
|
||||
for (int i = 0; i < full_blocks; i++)
|
||||
temp.val[i] = vmaxq_f32(b.reg.val[i], reg.val[i]);
|
||||
|
||||
if (remainder > 0) {
|
||||
float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 0),
|
||||
vgetq_lane_f32(b.reg.val[full_blocks], 0));
|
||||
temp.val[full_blocks] = vsetq_lane_f32(max_v, temp.val[full_blocks], 0);
|
||||
}
|
||||
if (remainder > 1) {
|
||||
float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 1),
|
||||
vgetq_lane_f32(b.reg.val[full_blocks], 1));
|
||||
temp.val[full_blocks] = vsetq_lane_f32(max_v, temp.val[full_blocks], 1);
|
||||
}
|
||||
if (remainder > 2) {
|
||||
float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 2),
|
||||
vgetq_lane_f32(b.reg.val[full_blocks], 2));
|
||||
temp.val[full_blocks] = vsetq_lane_f32(max_v, temp.val[full_blocks], 2);
|
||||
}
|
||||
return FP32Vec16(temp);
|
||||
};
|
||||
|
||||
FP32Vec16 min(const FP32Vec16& b) const {
|
||||
return FP32Vec16(float32x4x4_t({
|
||||
vminq_f32(b.reg.val[0], reg.val[0]),
|
||||
vminq_f32(b.reg.val[1], reg.val[1]),
|
||||
vminq_f32(b.reg.val[2], reg.val[2]),
|
||||
vminq_f32(b.reg.val[3], reg.val[3]),
|
||||
}));
|
||||
};
|
||||
FP32Vec16 min(const FP32Vec16& b, const int elem_num) const {
|
||||
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
|
||||
const int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
|
||||
float32x4x4_t temp;
|
||||
for (int i = 0; i < full_blocks; i++)
|
||||
temp.val[i] = vminq_f32(b.reg.val[i], reg.val[i]);
|
||||
|
||||
if (remainder > 0) {
|
||||
float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 0),
|
||||
vgetq_lane_f32(b.reg.val[full_blocks], 0));
|
||||
temp.val[full_blocks] = vsetq_lane_f32(min_v, temp.val[full_blocks], 0);
|
||||
}
|
||||
if (remainder > 1) {
|
||||
float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 1),
|
||||
vgetq_lane_f32(b.reg.val[full_blocks], 1));
|
||||
temp.val[full_blocks] = vsetq_lane_f32(min_v, temp.val[full_blocks], 1);
|
||||
}
|
||||
if (remainder > 2) {
|
||||
float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 2),
|
||||
vgetq_lane_f32(b.reg.val[full_blocks], 2));
|
||||
temp.val[full_blocks] = vsetq_lane_f32(min_v, temp.val[full_blocks], 2);
|
||||
}
|
||||
|
||||
return FP32Vec16(temp);
|
||||
};
|
||||
FP32Vec16 abs() const {
|
||||
return FP32Vec16(
|
||||
float32x4x4_t({vabsq_f32(reg.val[0]), vabsq_f32(reg.val[1]),
|
||||
vabsq_f32(reg.val[2]), vabsq_f32(reg.val[3])}));
|
||||
}
|
||||
float reduce_sum() const {
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
@ -473,6 +639,24 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
return answer;
|
||||
};
|
||||
|
||||
float reduce_max() const {
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
float max_v = std::numeric_limits<float>::lowest();
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&max_v, &ar](int i) { max_v = std::max(max_v, ar.values[i]); });
|
||||
return max_v;
|
||||
}
|
||||
|
||||
float reduce_min() const {
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
float min_v = std::numeric_limits<float>::max();
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&min_v, &ar](int i) { min_v = std::min(min_v, ar.values[i]); });
|
||||
return min_v;
|
||||
}
|
||||
|
||||
template <int group_size>
|
||||
float reduce_sub_sum(int idx) {
|
||||
static_assert(VEC_ELEM_NUM % group_size == 0);
|
||||
@ -493,6 +677,83 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
vst1q_f32(ptr + 8, reg.val[2]);
|
||||
vst1q_f32(ptr + 12, reg.val[3]);
|
||||
};
|
||||
|
||||
void save(float* ptr, const int elem_num) const {
|
||||
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
|
||||
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
|
||||
|
||||
for (int i = 0; i < full_blocks; i++)
|
||||
vst1q_f32(
|
||||
reinterpret_cast<float32_t*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i,
|
||||
reg.val[i]);
|
||||
|
||||
if (remainder > 0) {
|
||||
float32x4_t temp = reg.val[full_blocks];
|
||||
float* base = reinterpret_cast<float32_t*>(ptr) +
|
||||
full_blocks * NUM_ELEMENTS_REG(reg.val[0]);
|
||||
if (remainder > 0) base[0] = vgetq_lane_f32(temp, 0);
|
||||
if (remainder > 1) base[1] = vgetq_lane_f32(temp, 1);
|
||||
if (remainder > 2) base[2] = vgetq_lane_f32(temp, 2);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct INT8Vec16 : public Vec<INT8Vec16> {
|
||||
constexpr static int VEC_ELEM_NUM = 16;
|
||||
union AliasReg {
|
||||
int8x16_t reg;
|
||||
int8_t values[VEC_ELEM_NUM];
|
||||
};
|
||||
int8x16_t reg;
|
||||
|
||||
explicit INT8Vec16(const FP32Vec16& vec) {
|
||||
// Convert each 128-bit float32 vector to int32
|
||||
int32x4_t part0 =
|
||||
vcvtq_s32_f32(vec.reg.val[0]); // Convert first 128-bit block
|
||||
int32x4_t part1 =
|
||||
vcvtq_s32_f32(vec.reg.val[1]); // Convert second 128-bit block
|
||||
int32x4_t part2 =
|
||||
vcvtq_s32_f32(vec.reg.val[2]); // Convert third 128-bit block
|
||||
int32x4_t part3 =
|
||||
vcvtq_s32_f32(vec.reg.val[3]); // Convert fourth 128-bit block
|
||||
|
||||
// Narrow each 32-bit vector to 8 bits and combine
|
||||
int8x8_t lower =
|
||||
vqmovn_s16(vcombine_s16(vqmovn_s32(part0), vqmovn_s32(part1)));
|
||||
int8x8_t upper =
|
||||
vqmovn_s16(vcombine_s16(vqmovn_s32(part2), vqmovn_s32(part3)));
|
||||
reg = vcombine_s8(lower, upper); // Combine to form a single 128-bit vector
|
||||
}
|
||||
|
||||
void save(int8_t* ptr) const { vst1q_s8(ptr, reg); };
|
||||
|
||||
void save(int8_t* ptr, const int elem_num) const {
|
||||
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg);
|
||||
int remainder = elem_num % NUM_ELEMENTS_REG(reg);
|
||||
|
||||
for (int i = 0; i < full_blocks; i++)
|
||||
vst1q_s8(reinterpret_cast<int8_t*>(ptr) + NUM_ELEMENTS_REG(reg) * i, reg);
|
||||
if (remainder > 0) {
|
||||
int8x16_t temp = reg;
|
||||
int8_t* base =
|
||||
reinterpret_cast<int8_t*>(ptr) + full_blocks * NUM_ELEMENTS_REG(reg);
|
||||
if (remainder > 0) base[0] = vgetq_lane_s8(temp, 0);
|
||||
if (remainder > 1) base[1] = vgetq_lane_s8(temp, 1);
|
||||
if (remainder > 2) base[2] = vgetq_lane_s8(temp, 2);
|
||||
if (remainder > 3) base[3] = vgetq_lane_s8(temp, 3);
|
||||
if (remainder > 4) base[4] = vgetq_lane_s8(temp, 4);
|
||||
if (remainder > 5) base[5] = vgetq_lane_s8(temp, 5);
|
||||
if (remainder > 6) base[6] = vgetq_lane_s8(temp, 6);
|
||||
if (remainder > 7) base[7] = vgetq_lane_s8(temp, 7);
|
||||
if (remainder > 8) base[8] = vgetq_lane_s8(temp, 8);
|
||||
if (remainder > 9) base[9] = vgetq_lane_s8(temp, 9);
|
||||
if (remainder > 10) base[10] = vgetq_lane_s8(temp, 10);
|
||||
if (remainder > 11) base[11] = vgetq_lane_s8(temp, 11);
|
||||
if (remainder > 12) base[12] = vgetq_lane_s8(temp, 12);
|
||||
if (remainder > 13) base[13] = vgetq_lane_s8(temp, 13);
|
||||
if (remainder > 14) base[14] = vgetq_lane_s8(temp, 14);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
||||
@ -57,6 +57,7 @@ class DNNLPrimitiveHelper {
|
||||
// Note: Due to the limitation of oneDNN
|
||||
// (https://github.com/oneapi-src/oneDNN/issues/1636), the quantized bias is
|
||||
// not supported.
|
||||
|
||||
template <typename OutputT, typename BiasT>
|
||||
static void gemm_s8s8_jit(const int8_t* a, const int8_t* b, OutputT* c,
|
||||
const BiasT* bias, dnnl_dim_t M, dnnl_dim_t N,
|
||||
@ -90,6 +91,27 @@ class DNNLPrimitiveHelper {
|
||||
}
|
||||
|
||||
dnnl::matmul::primitive_desc matmul_pd;
|
||||
// Create memory descriptors with format_tag::any for the primitive. This
|
||||
// enables the matmul primitive to choose memory layouts for an
|
||||
// optimized primitive implementation, and these layouts may differ from the
|
||||
// ones provided by the user.
|
||||
#ifdef __aarch64__
|
||||
auto mat_src_md = dnnl::memory::desc({M, K}, dnnl::memory::data_type::s8,
|
||||
dnnl::memory::format_tag::any);
|
||||
auto mat_weights_md = dnnl::memory::desc(
|
||||
{K, N}, dnnl::memory::data_type::s8, dnnl::memory::format_tag::any);
|
||||
auto mat_dst_md =
|
||||
dnnl::memory::desc({M, N}, OutputType, dnnl::memory::format_tag::any);
|
||||
if (bias) {
|
||||
dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1});
|
||||
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), mat_src_md,
|
||||
mat_weights_md, bias_md,
|
||||
mat_dst_md, attr);
|
||||
} else {
|
||||
matmul_pd = dnnl::matmul::primitive_desc(
|
||||
default_engine(), mat_src_md, mat_weights_md, mat_dst_md, attr);
|
||||
}
|
||||
#else
|
||||
if (bias) {
|
||||
dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1});
|
||||
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
|
||||
@ -98,6 +120,7 @@ class DNNLPrimitiveHelper {
|
||||
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
|
||||
c_md, attr);
|
||||
}
|
||||
#endif
|
||||
dnnl::matmul matmul(matmul_pd);
|
||||
|
||||
auto& engine = default_engine();
|
||||
@ -111,24 +134,34 @@ class DNNLPrimitiveHelper {
|
||||
(void*)b_scales);
|
||||
|
||||
auto& stream = default_stream();
|
||||
|
||||
auto mat_src_mem = a_m;
|
||||
auto mat_weights_mem = b_m;
|
||||
auto mat_dst_mem = c_m;
|
||||
#ifdef __aarch64__
|
||||
if (matmul_pd.weights_desc() != b_m.get_desc()) {
|
||||
mat_weights_mem = dnnl::memory(matmul_pd.weights_desc(), engine);
|
||||
dnnl::reorder(b_m, mat_weights_mem).execute(stream, b_m, mat_weights_mem);
|
||||
}
|
||||
#endif
|
||||
if constexpr (InputNoScale) {
|
||||
if (bias) {
|
||||
dnnl::memory::desc bias_md({N}, BiasType, {1});
|
||||
dnnl::memory bias_m(bias_md, engine, (void*)bias);
|
||||
matmul.execute(
|
||||
stream, {
|
||||
{DNNL_ARG_SRC, a_m},
|
||||
{DNNL_ARG_WEIGHTS, b_m},
|
||||
{DNNL_ARG_SRC, mat_src_mem},
|
||||
{DNNL_ARG_WEIGHTS, mat_weights_mem},
|
||||
{DNNL_ARG_BIAS, bias_m},
|
||||
{DNNL_ARG_DST, c_m},
|
||||
{DNNL_ARG_DST, mat_dst_mem},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||
});
|
||||
} else {
|
||||
matmul.execute(
|
||||
stream, {
|
||||
{DNNL_ARG_SRC, a_m},
|
||||
{DNNL_ARG_WEIGHTS, b_m},
|
||||
{DNNL_ARG_DST, c_m},
|
||||
{DNNL_ARG_SRC, mat_src_mem},
|
||||
{DNNL_ARG_WEIGHTS, mat_weights_mem},
|
||||
{DNNL_ARG_DST, mat_dst_mem},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||
});
|
||||
}
|
||||
@ -138,19 +171,19 @@ class DNNLPrimitiveHelper {
|
||||
dnnl::memory bias_m(bias_md, engine, (void*)bias);
|
||||
matmul.execute(
|
||||
stream, {
|
||||
{DNNL_ARG_SRC, a_m},
|
||||
{DNNL_ARG_WEIGHTS, b_m},
|
||||
{DNNL_ARG_SRC, mat_src_mem},
|
||||
{DNNL_ARG_WEIGHTS, mat_weights_mem},
|
||||
{DNNL_ARG_BIAS, bias_m},
|
||||
{DNNL_ARG_DST, c_m},
|
||||
{DNNL_ARG_DST, mat_dst_mem},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||
});
|
||||
} else {
|
||||
matmul.execute(
|
||||
stream, {
|
||||
{DNNL_ARG_SRC, a_m},
|
||||
{DNNL_ARG_WEIGHTS, b_m},
|
||||
{DNNL_ARG_DST, c_m},
|
||||
{DNNL_ARG_SRC, mat_src_mem},
|
||||
{DNNL_ARG_WEIGHTS, mat_weights_mem},
|
||||
{DNNL_ARG_DST, mat_dst_mem},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||
});
|
||||
@ -170,5 +203,4 @@ class DNNLPrimitiveHelper {
|
||||
return stream;
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
@ -36,7 +36,7 @@ struct KernelVecType<c10::Half> {
|
||||
using cvt_vec_type = vec_op::FP32Vec16;
|
||||
};
|
||||
|
||||
#ifdef __AVX512F__
|
||||
#if defined(__AVX512F__) || defined(__aarch64__)
|
||||
template <bool AZP, typename scalar_t>
|
||||
void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
const float* scale, const int32_t* azp,
|
||||
@ -598,8 +598,9 @@ void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
const float* scale, const int32_t* azp,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
TORCH_CHECK(
|
||||
false, "static_scaled_int8_quant_impl requires AVX512/powerpc64 support.")
|
||||
TORCH_CHECK(false,
|
||||
"static_scaled_int8_quant_impl requires AVX512/powerpc64/AArch64 "
|
||||
"support.")
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
@ -607,9 +608,9 @@ void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
float* scale, int32_t* azp,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"dynamic_scaled_int8_quant_impl requires AVX512/powerpc64 support.")
|
||||
TORCH_CHECK(false,
|
||||
"dynamic_scaled_int8_quant_impl requires "
|
||||
"AVX512/powerpc64/AArch64 support.")
|
||||
}
|
||||
|
||||
template <bool PerChannel, typename scalar_t>
|
||||
@ -617,7 +618,8 @@ void static_quant_epilogue(const float* input, scalar_t* output,
|
||||
const float a_scale, const float* b_scale,
|
||||
const int32_t* azp_with_adj, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
TORCH_CHECK(false, "static_quant_epilogue requires AVX512/powerpc64 support.")
|
||||
TORCH_CHECK(
|
||||
false, "static_quant_epilogue requires AVX512/powerpc64/AArch64 support.")
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
@ -626,8 +628,9 @@ void dynamic_quant_epilogue(const float* input, scalar_t* output,
|
||||
const int32_t* azp, const int32_t* azp_with_adj,
|
||||
const scalar_t* bias, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
TORCH_CHECK(false,
|
||||
"dynamic_quant_epilogue requires AVX512/powerpc64 support.")
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"dynamic_quant_epilogue requires AVX512/powerpc64/AArch64 support.")
|
||||
}
|
||||
#endif
|
||||
} // namespace
|
||||
|
||||
@ -151,8 +151,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding);
|
||||
|
||||
// Quantization
|
||||
#ifdef __AVX512F__
|
||||
#if defined(__AVX512F__) || defined(__aarch64__)
|
||||
at::Tag stride_tag = at::Tag::needs_fixed_stride_order;
|
||||
|
||||
// Compute int8 quantized tensor for given scaling factor.
|
||||
ops.def(
|
||||
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
|
||||
|
||||
@ -1,656 +0,0 @@
|
||||
// clang-format off
|
||||
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_fwd.cu
|
||||
// and https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_update.cu
|
||||
#include <torch/all.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "causal_conv1d.h"
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
||||
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_store.cuh>
|
||||
|
||||
#ifdef USE_ROCM
|
||||
namespace cub = hipcub;
|
||||
#endif
|
||||
|
||||
#include "static_switch.h"
|
||||
|
||||
|
||||
|
||||
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
||||
|
||||
#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
|
||||
if (ITYPE == at::ScalarType::Half) { \
|
||||
using input_t = at::Half; \
|
||||
using weight_t = at::Half; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (ITYPE == at::ScalarType::BFloat16) { \
|
||||
using input_t = at::BFloat16; \
|
||||
using weight_t = at::BFloat16; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (ITYPE == at::ScalarType::Float) { \
|
||||
using input_t = float; \
|
||||
using weight_t = float; \
|
||||
__VA_ARGS__(); \
|
||||
} else { \
|
||||
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
|
||||
}
|
||||
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
|
||||
void set_conv_params_fwd(ConvParamsBase ¶ms,
|
||||
// sizes
|
||||
const size_t batch,
|
||||
const size_t dim,
|
||||
const size_t seqlen,
|
||||
const size_t width,
|
||||
// device pointers
|
||||
const at::Tensor x,
|
||||
const at::Tensor weight,
|
||||
const at::Tensor out,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
bool silu_activation,
|
||||
int64_t pad_slot_id,
|
||||
const std::optional<at::Tensor>& query_start_loc = std::nullopt,
|
||||
const std::optional<at::Tensor>& cache_indices = std::nullopt,
|
||||
const std::optional<at::Tensor>& has_initial_state = std::nullopt) {
|
||||
|
||||
// Reset the parameters
|
||||
memset(¶ms, 0, sizeof(params));
|
||||
|
||||
params.batch = batch;
|
||||
params.dim = dim;
|
||||
params.seqlen = seqlen;
|
||||
params.width = width;
|
||||
params.pad_slot_id = pad_slot_id;
|
||||
|
||||
params.silu_activation = silu_activation;
|
||||
|
||||
// Set the pointers and strides.
|
||||
params.x_ptr = x.data_ptr();
|
||||
params.weight_ptr = weight.data_ptr();
|
||||
params.bias_ptr = bias.has_value() ? bias.value().data_ptr() : nullptr;
|
||||
params.out_ptr = out.data_ptr();
|
||||
// All stride are in elements, not bytes.
|
||||
params.query_start_loc_ptr = query_start_loc.has_value() ? query_start_loc.value().data_ptr() : nullptr;
|
||||
params.cache_indices_ptr = cache_indices.has_value() ? cache_indices.value().data_ptr() : nullptr;
|
||||
params.has_initial_state_ptr = has_initial_state.has_value() ? has_initial_state.value().data_ptr() : nullptr;
|
||||
const bool varlen = params.query_start_loc_ptr != nullptr;
|
||||
params.x_batch_stride = x.stride(varlen ? 1 : 0);
|
||||
params.x_c_stride = x.stride(varlen ? 0 : 1);
|
||||
params.x_l_stride = x.stride(varlen ? 1 : -1);
|
||||
params.weight_c_stride = weight.stride(0);
|
||||
params.weight_width_stride = weight.stride(1);
|
||||
params.out_batch_stride = out.stride(varlen ? 1 : 0);
|
||||
params.out_c_stride = out.stride(varlen ? 0 : 1);
|
||||
params.out_l_stride = out.stride(varlen ? 1 : -1);
|
||||
}
|
||||
|
||||
|
||||
void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
|
||||
const std::optional<at::Tensor> &bias_,
|
||||
const std::optional<at::Tensor> &conv_states,
|
||||
const std::optional<at::Tensor> &query_start_loc,
|
||||
const std::optional<at::Tensor> &cache_indices,
|
||||
const std::optional<at::Tensor> &has_initial_state,
|
||||
bool silu_activation,
|
||||
// used to identify padding entries if cache_indices provided
|
||||
// in case of padding, the kernel will return early
|
||||
int64_t pad_slot_id) {
|
||||
auto input_type = x.scalar_type();
|
||||
auto weight_type = weight.scalar_type();
|
||||
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
||||
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
|
||||
|
||||
TORCH_CHECK(x.is_cuda());
|
||||
TORCH_CHECK(weight.is_cuda());
|
||||
|
||||
const bool varlen = query_start_loc.has_value() ? true : false;
|
||||
const auto sizes = x.sizes();
|
||||
const int batch_size = varlen ? query_start_loc.value().sizes()[0] - 1 : sizes[0];
|
||||
const int dim = varlen ? sizes[0] : sizes[1];
|
||||
const int seqlen = varlen ? sizes[1] : sizes[2];
|
||||
const int width = weight.size(-1);
|
||||
if (varlen){
|
||||
CHECK_SHAPE(x, dim, seqlen);
|
||||
}
|
||||
else {
|
||||
CHECK_SHAPE(x, batch_size, dim, seqlen);
|
||||
}
|
||||
CHECK_SHAPE(weight, dim, width);
|
||||
|
||||
|
||||
|
||||
if (bias_.has_value()) {
|
||||
auto bias = bias_.value();
|
||||
TORCH_CHECK(bias.scalar_type() == weight_type);
|
||||
TORCH_CHECK(bias.is_cuda());
|
||||
TORCH_CHECK(bias.stride(-1) == 1);
|
||||
CHECK_SHAPE(bias, dim);
|
||||
}
|
||||
|
||||
|
||||
if (has_initial_state.has_value()) {
|
||||
auto has_initial_state_ = has_initial_state.value();
|
||||
TORCH_CHECK(has_initial_state_.scalar_type() == at::ScalarType::Bool);
|
||||
TORCH_CHECK(has_initial_state_.is_cuda());
|
||||
CHECK_SHAPE(has_initial_state_, batch_size);
|
||||
}
|
||||
|
||||
|
||||
if (query_start_loc.has_value()) {
|
||||
auto query_start_loc_ = query_start_loc.value();
|
||||
TORCH_CHECK(query_start_loc_.scalar_type() == at::ScalarType::Int);
|
||||
TORCH_CHECK(query_start_loc_.is_cuda());
|
||||
}
|
||||
|
||||
|
||||
if (cache_indices.has_value()) {
|
||||
auto cache_indices_ = cache_indices.value();
|
||||
TORCH_CHECK(cache_indices_.scalar_type() == at::ScalarType::Int);
|
||||
TORCH_CHECK(cache_indices_.is_cuda());
|
||||
CHECK_SHAPE(cache_indices_, batch_size);
|
||||
}
|
||||
|
||||
at::Tensor out = x;
|
||||
|
||||
ConvParamsBase params;
|
||||
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
|
||||
bias_,
|
||||
silu_activation,
|
||||
pad_slot_id,
|
||||
query_start_loc,
|
||||
cache_indices,
|
||||
has_initial_state
|
||||
);
|
||||
|
||||
if (conv_states.has_value()) {
|
||||
auto conv_states_ = conv_states.value();
|
||||
TORCH_CHECK(conv_states_.scalar_type() == input_type);
|
||||
TORCH_CHECK(conv_states_.is_cuda());
|
||||
params.conv_states_ptr = conv_states_.data_ptr();
|
||||
params.conv_states_batch_stride = conv_states_.stride(0);
|
||||
params.conv_states_c_stride = conv_states_.stride(1);
|
||||
params.conv_states_l_stride = conv_states_.stride(2);
|
||||
} else {
|
||||
params.conv_states_ptr = nullptr;
|
||||
}
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] {
|
||||
causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
void causal_conv1d_update(const at::Tensor &x,
|
||||
const at::Tensor &conv_state,
|
||||
const at::Tensor &weight,
|
||||
const std::optional<at::Tensor> &bias_,
|
||||
bool silu_activation,
|
||||
const std::optional<at::Tensor> &cache_seqlens_,
|
||||
const std::optional<at::Tensor> &conv_state_indices_,
|
||||
// used to identify padding entries if cache_indices provided
|
||||
// in case of padding, the kernel will return early
|
||||
int64_t pad_slot_id) {
|
||||
auto input_type = x.scalar_type();
|
||||
auto weight_type = weight.scalar_type();
|
||||
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
||||
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
|
||||
TORCH_CHECK(weight_type == input_type, "weight type must equal to input type, other variations are disabled due to binary size limitations");
|
||||
TORCH_CHECK(conv_state.scalar_type() == input_type);
|
||||
|
||||
TORCH_CHECK(x.is_cuda());
|
||||
TORCH_CHECK(conv_state.is_cuda());
|
||||
TORCH_CHECK(weight.is_cuda());
|
||||
|
||||
const auto sizes = x.sizes();
|
||||
const int batch_size = sizes[0];
|
||||
const int dim = sizes[1];
|
||||
const int seqlen = sizes[2];
|
||||
const int width = weight.size(-1);
|
||||
const int conv_state_len = conv_state.size(2);
|
||||
TORCH_CHECK(conv_state_len >= width - 1);
|
||||
|
||||
CHECK_SHAPE(x, batch_size, dim, seqlen);
|
||||
CHECK_SHAPE(weight, dim, width);
|
||||
|
||||
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
|
||||
|
||||
if (bias_.has_value()) {
|
||||
auto bias = bias_.value();
|
||||
TORCH_CHECK(bias.scalar_type() == weight_type);
|
||||
TORCH_CHECK(bias.is_cuda());
|
||||
TORCH_CHECK(bias.stride(-1) == 1);
|
||||
CHECK_SHAPE(bias, dim);
|
||||
}
|
||||
|
||||
at::Tensor out = x;
|
||||
|
||||
ConvParamsBase params;
|
||||
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
|
||||
bias_,
|
||||
silu_activation,
|
||||
pad_slot_id);
|
||||
params.conv_state_ptr = conv_state.data_ptr();
|
||||
params.conv_state_len = conv_state_len;
|
||||
// All stride are in elements, not bytes.
|
||||
params.conv_state_batch_stride = conv_state.stride(0);
|
||||
params.conv_state_c_stride = conv_state.stride(1);
|
||||
params.conv_state_l_stride = conv_state.stride(2);
|
||||
|
||||
if (cache_seqlens_.has_value()) {
|
||||
auto cache_seqlens = cache_seqlens_.value();
|
||||
TORCH_CHECK(cache_seqlens.scalar_type() == torch::kInt32);
|
||||
TORCH_CHECK(cache_seqlens.is_cuda());
|
||||
TORCH_CHECK(cache_seqlens.stride(-1) == 1);
|
||||
CHECK_SHAPE(cache_seqlens, batch_size);
|
||||
params.cache_seqlens = cache_seqlens.data_ptr<int32_t>();
|
||||
} else {
|
||||
params.cache_seqlens = nullptr;
|
||||
}
|
||||
|
||||
if (conv_state_indices_.has_value()) {
|
||||
auto conv_state_indices = conv_state_indices_.value();
|
||||
TORCH_CHECK(conv_state_indices.scalar_type() == torch::kInt32)
|
||||
TORCH_CHECK(conv_state_indices.is_cuda());
|
||||
TORCH_CHECK(conv_state_indices.stride(0) == 1)
|
||||
CHECK_SHAPE(conv_state_indices, batch_size);
|
||||
|
||||
int conv_state_entries = conv_state.size(0);
|
||||
CHECK_SHAPE(conv_state, conv_state_entries, dim, conv_state_len);
|
||||
|
||||
params.conv_state_indices_ptr = conv_state_indices.data_ptr<int32_t>();
|
||||
} else {
|
||||
CHECK_SHAPE(conv_state, batch_size, dim, conv_state_len);
|
||||
params.conv_state_indices_ptr = nullptr;
|
||||
}
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] {
|
||||
causal_conv1d_update_cuda<input_t, weight_t>(params, stream);
|
||||
});
|
||||
}
|
||||
|
||||
template<int kNThreads_, int kWidth_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
|
||||
struct Causal_conv1d_fwd_kernel_traits {
|
||||
using input_t = input_t_;
|
||||
using weight_t = weight_t_;
|
||||
static constexpr int kNThreads = kNThreads_;
|
||||
static constexpr int kWidth = kWidth_;
|
||||
static constexpr int kNBytes = sizeof(input_t);
|
||||
static_assert(kNBytes == 2 || kNBytes == 4);
|
||||
static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
|
||||
static_assert(kWidth <= kNElts);
|
||||
static constexpr bool kIsVecLoad = kIsVecLoad_;
|
||||
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
||||
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
||||
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>;
|
||||
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
||||
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, 1, cub::BLOCK_STORE_DIRECT>;
|
||||
static constexpr int kSmemIOSize = kIsVecLoad
|
||||
? 0
|
||||
: custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)});
|
||||
static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts;
|
||||
static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize;
|
||||
};
|
||||
|
||||
template<typename Ktraits>
|
||||
__global__ __launch_bounds__(Ktraits::kNThreads)
|
||||
void causal_conv1d_fwd_kernel(ConvParamsBase params) {
|
||||
constexpr int kWidth = Ktraits::kWidth;
|
||||
constexpr int kNThreads = Ktraits::kNThreads;
|
||||
constexpr int kNElts = Ktraits::kNElts;
|
||||
constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
|
||||
using input_t = typename Ktraits::input_t;
|
||||
using vec_t = typename Ktraits::vec_t;
|
||||
using weight_t = typename Ktraits::weight_t;
|
||||
|
||||
// Shared memory.
|
||||
extern __shared__ char smem_[];
|
||||
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
|
||||
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_);
|
||||
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
|
||||
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
|
||||
vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize);
|
||||
|
||||
const bool kVarlen = params.query_start_loc_ptr != nullptr;
|
||||
const int tidx = threadIdx.x;
|
||||
const int batch_id = blockIdx.x;
|
||||
const int channel_id = blockIdx.y;
|
||||
const int *query_start_loc = kVarlen ? reinterpret_cast<int *>(params.query_start_loc_ptr) : nullptr;
|
||||
const int sequence_start_index = kVarlen ? query_start_loc[batch_id] : batch_id;
|
||||
const int seqlen = kVarlen ? query_start_loc[batch_id + 1] - sequence_start_index : params.seqlen;
|
||||
|
||||
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + sequence_start_index * params.x_batch_stride
|
||||
+ channel_id * params.x_c_stride;
|
||||
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
|
||||
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + sequence_start_index * params.out_batch_stride
|
||||
+ channel_id * params.out_c_stride;
|
||||
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
|
||||
|
||||
bool has_initial_state = params.has_initial_state_ptr == nullptr ? false
|
||||
: reinterpret_cast<bool *>(params.has_initial_state_ptr)[batch_id];
|
||||
|
||||
int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
|
||||
: reinterpret_cast<int *>(params.cache_indices_ptr);
|
||||
int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
|
||||
// cache_index == params.pad_slot_id is defined as padding, so we exit early
|
||||
if (cache_index == params.pad_slot_id){
|
||||
return;
|
||||
}
|
||||
input_t *conv_states = params.conv_states_ptr == nullptr ? nullptr
|
||||
: reinterpret_cast<input_t *>(params.conv_states_ptr) + cache_index * params.conv_states_batch_stride + channel_id * params.conv_states_c_stride;
|
||||
|
||||
// Thread 0 will load the last elements of the previous chunk, so we initialize those to 0.
|
||||
if (tidx == 0) {
|
||||
input_t initial_state[kNElts] = {0};
|
||||
if (has_initial_state) {
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth - 1; ++w){ initial_state[kNElts - 1 - (kWidth - 2) + w ] = conv_states[w]; }
|
||||
}
|
||||
smem_exchange[kNThreads - 1] = reinterpret_cast<vec_t *>(initial_state)[0];
|
||||
}
|
||||
|
||||
float weight_vals[kWidth];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
|
||||
|
||||
constexpr int kChunkSize = kNThreads * kNElts;
|
||||
const int n_chunks = (seqlen + kChunkSize - 1) / kChunkSize;
|
||||
for (int chunk = 0; chunk < n_chunks; ++chunk) {
|
||||
input_t x_vals_load[2 * kNElts] = {0};
|
||||
if constexpr(kIsVecLoad) {
|
||||
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (seqlen - chunk * kChunkSize) / kNElts);
|
||||
} else {
|
||||
__syncthreads();
|
||||
typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), seqlen - chunk * kChunkSize);
|
||||
}
|
||||
x += kChunkSize;
|
||||
__syncthreads();
|
||||
// Thread kNThreads - 1 don't write yet, so that thread 0 can read
|
||||
// the last elements of the previous chunk.
|
||||
if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
|
||||
__syncthreads();
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1];
|
||||
__syncthreads();
|
||||
// Now thread kNThreads - 1 can write the last elements of the current chunk.
|
||||
if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
|
||||
|
||||
float x_vals[2 * kNElts];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); }
|
||||
|
||||
float out_vals[kNElts];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNElts; ++i) {
|
||||
out_vals[i] = bias_val;
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth; ++w) {
|
||||
out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
|
||||
}
|
||||
}
|
||||
|
||||
if (params.silu_activation) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNElts; ++i) {
|
||||
out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i]));
|
||||
}
|
||||
}
|
||||
|
||||
input_t out_vals_store[kNElts];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; }
|
||||
if constexpr(kIsVecLoad) {
|
||||
typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(out), reinterpret_cast<vec_t (&)[1]>(out_vals_store), (seqlen - chunk * kChunkSize) / kNElts);
|
||||
} else {
|
||||
typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, seqlen - chunk * kChunkSize);
|
||||
}
|
||||
out += kChunkSize;
|
||||
|
||||
int final_state_position = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize);
|
||||
// in case the final state is separated between the last "smem_exchange" and
|
||||
// and the one before it (chunk = n_chunks - 1 and chunk = n_chunks - 2),
|
||||
// (which occurs when `final_state_position` is a non-positive index)
|
||||
// we load the correct data from smem_exchange from both chunks, the last chunk iteration and the one before it
|
||||
if (conv_states != nullptr && final_state_position < 0 && seqlen > kWidth){
|
||||
input_t vals_load[kNElts] = {0};
|
||||
if ((chunk == n_chunks - 2) && (tidx == kNThreads - 1)){
|
||||
// chunk = n_chunks - 2, a segment of the final state sits in the last index
|
||||
reinterpret_cast<vec_t *>(vals_load)[0] = smem_exchange[kNThreads - 1];
|
||||
#pragma unroll
|
||||
for (int w = 0; w < -final_state_position; ++w){
|
||||
conv_states[w] = vals_load[kNElts + final_state_position + w];
|
||||
}
|
||||
}
|
||||
if ((chunk == n_chunks - 1) && tidx == 0){
|
||||
// chunk = n_chunks - 1, the second segment of the final state first positions
|
||||
reinterpret_cast<vec_t *>(vals_load)[0] = smem_exchange[0];
|
||||
for (int w = -final_state_position; w < kWidth - 1; ++w){
|
||||
conv_states[w] = vals_load[w + final_state_position];
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Final state is stored in the smem_exchange last token slot,
|
||||
// in case seqlen < kWidth, we would need to take the final state from the
|
||||
// initial state which is stored in conv_states
|
||||
// in case seqlen > kWidth, we would need to load the last kWidth - 1 data
|
||||
// and load it into conv_state accordingly
|
||||
int last_thread = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize) / kNElts;
|
||||
if (conv_states != nullptr && tidx == last_thread) {
|
||||
input_t x_vals_load[kNElts * 2] = {0};
|
||||
// in case we are on the first kWidth tokens
|
||||
if (last_thread == 0 && seqlen < kWidth){
|
||||
// Need to take the initial state
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[0];
|
||||
const int offset = seqlen - (kWidth - 1);
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth - 1; ++w){
|
||||
// pad the existing state
|
||||
if ((w - seqlen) >= 0 && has_initial_state) { conv_states[w - seqlen] = conv_states[w]; }
|
||||
else if ((w - seqlen) >= 0 && !has_initial_state) { conv_states[w - seqlen] = input_t(0.0f); }
|
||||
}
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth - 1; ++w){
|
||||
if (offset + w >= 0)
|
||||
conv_states[w] = x_vals_load[offset + w ];
|
||||
}
|
||||
}
|
||||
else {
|
||||
// in case the final state is in between the threads data
|
||||
const int offset = ((seqlen - (kWidth - 1)) % (kNElts));
|
||||
if ((offset + kWidth - 2) >= kNElts && (last_thread + 1 < kNThreads)){
|
||||
// In case last_thread == kNThreads - 1, accessing last_thread + 1 will result in a
|
||||
// illegal access error on H100.
|
||||
// Therefore, we access last_thread + 1, only if the final state data sits there
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[1] = smem_exchange[last_thread + 1];
|
||||
}
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[last_thread];
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth - 1; ++w){
|
||||
conv_states[w] = x_vals_load[offset + w ];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
||||
void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
|
||||
const bool kVarlen = params.query_start_loc_ptr != nullptr;
|
||||
BOOL_SWITCH(params.seqlen % kNElts == 0 && !kVarlen, kIsVecLoad, [&] {
|
||||
using Ktraits = Causal_conv1d_fwd_kernel_traits<kNThreads, kWidth, kIsVecLoad, input_t, weight_t>;
|
||||
constexpr int kSmemSize = Ktraits::kSmemSize;
|
||||
dim3 grid(params.batch, params.dim);
|
||||
|
||||
auto kernel = &causal_conv1d_fwd_kernel<Ktraits>;
|
||||
|
||||
if (kSmemSize >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
|
||||
}
|
||||
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
||||
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
}
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
if (params.width == 2) {
|
||||
causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 3) {
|
||||
causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 4) {
|
||||
causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template void causal_conv1d_fwd_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_fwd_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
|
||||
|
||||
|
||||
|
||||
template<int kNThreads_, int kWidth_, typename input_t_, typename weight_t_>
|
||||
struct Causal_conv1d_update_kernel_traits {
|
||||
using input_t = input_t_;
|
||||
using weight_t = weight_t_;
|
||||
static constexpr int kNThreads = kNThreads_;
|
||||
static constexpr int kWidth = kWidth_;
|
||||
static constexpr int kNBytes = sizeof(input_t);
|
||||
static_assert(kNBytes == 2 || kNBytes == 4);
|
||||
};
|
||||
|
||||
template<typename Ktraits, bool kIsCircularBuffer>
|
||||
__global__ __launch_bounds__(Ktraits::kNThreads)
|
||||
void causal_conv1d_update_kernel(ConvParamsBase params) {
|
||||
constexpr int kWidth = Ktraits::kWidth;
|
||||
constexpr int kNThreads = Ktraits::kNThreads;
|
||||
using input_t = typename Ktraits::input_t;
|
||||
using weight_t = typename Ktraits::weight_t;
|
||||
|
||||
const int tidx = threadIdx.x;
|
||||
const int batch_id = blockIdx.x;
|
||||
const int channel_id = blockIdx.y * kNThreads + tidx;
|
||||
if (channel_id >= params.dim) return;
|
||||
|
||||
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
||||
+ channel_id * params.x_c_stride;
|
||||
|
||||
// If params.conv_state_batch_indices is set, then the conv state is gathered from the conv state tensor
|
||||
// along the batch axis. Otherwise, the conv state coordinate is the same as the batch id.
|
||||
const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr
|
||||
? batch_id
|
||||
: params.conv_state_indices_ptr[batch_id];
|
||||
// conv_state_batch_coord == params.pad_slot_id is defined as padding so we exit early
|
||||
if (conv_state_batch_coord == params.pad_slot_id){
|
||||
return;
|
||||
}
|
||||
input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr)
|
||||
+ conv_state_batch_coord * params.conv_state_batch_stride
|
||||
+ channel_id * params.conv_state_c_stride;
|
||||
|
||||
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
|
||||
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
||||
+ channel_id * params.out_c_stride;
|
||||
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
|
||||
|
||||
int state_len = params.conv_state_len;
|
||||
int advance_len = params.seqlen;
|
||||
int cache_seqlen = kIsCircularBuffer ? params.cache_seqlens[batch_id] % state_len : 0;
|
||||
int update_idx = cache_seqlen - (kWidth - 1);
|
||||
update_idx = update_idx < 0 ? update_idx + state_len : update_idx;
|
||||
|
||||
float weight_vals[kWidth] = {0};
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
|
||||
|
||||
float x_vals[kWidth] = {0};
|
||||
if constexpr (!kIsCircularBuffer) {
|
||||
#pragma unroll 2
|
||||
for (int i = 0; i < state_len - advance_len - (kWidth - 1); ++i) {
|
||||
conv_state[i * params.conv_state_l_stride] = conv_state[(i + advance_len) * params.conv_state_l_stride];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth - 1; ++i) {
|
||||
input_t state_val = conv_state[(state_len - (kWidth - 1) + i) * params.conv_state_l_stride];
|
||||
if (i < advance_len + (kWidth - 1) && state_len - advance_len - (kWidth - 1) + i >= 0) {
|
||||
conv_state[(state_len - advance_len - (kWidth - 1) + i) * params.conv_state_l_stride] = state_val;
|
||||
}
|
||||
x_vals[i] = float(state_val);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth - 1; ++i, update_idx = update_idx + 1 >= state_len ? update_idx + 1 - state_len : update_idx + 1) {
|
||||
input_t state_val = conv_state[update_idx * params.conv_state_l_stride];
|
||||
x_vals[i] = float(state_val);
|
||||
}
|
||||
}
|
||||
#pragma unroll 2
|
||||
for (int i = 0; i < params.seqlen; ++i) {
|
||||
input_t x_val = x[i * params.x_l_stride];
|
||||
if constexpr (!kIsCircularBuffer) {
|
||||
if (i < advance_len && state_len - advance_len + i >= 0) {
|
||||
conv_state[(state_len - advance_len + i) * params.conv_state_l_stride] = x_val;
|
||||
}
|
||||
} else {
|
||||
conv_state[update_idx * params.conv_state_l_stride] = x_val;
|
||||
++update_idx;
|
||||
update_idx = update_idx >= state_len ? update_idx - state_len : update_idx;
|
||||
}
|
||||
x_vals[kWidth - 1] = float(x_val);
|
||||
float out_val = bias_val;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < kWidth; ++j) { out_val += weight_vals[j] * x_vals[j]; }
|
||||
if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); }
|
||||
out[i * params.out_l_stride] = input_t(out_val);
|
||||
// Shift the input buffer by 1
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = x_vals[i + 1]; }
|
||||
}
|
||||
}
|
||||
|
||||
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
||||
void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
using Ktraits = Causal_conv1d_update_kernel_traits<kNThreads, kWidth, input_t, weight_t>;
|
||||
dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads);
|
||||
auto kernel = params.cache_seqlens == nullptr
|
||||
? &causal_conv1d_update_kernel<Ktraits, false>
|
||||
: &causal_conv1d_update_kernel<Ktraits, true>;
|
||||
kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
if (params.width == 2) {
|
||||
causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 3) {
|
||||
causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 4) {
|
||||
causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream);
|
||||
}
|
||||
}
|
||||
|
||||
template void causal_conv1d_update_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_update_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_update_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
@ -1,159 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
// clang-format off
|
||||
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d.h
|
||||
#pragma once
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct ConvParamsBase {
|
||||
using index_t = uint32_t;
|
||||
|
||||
int batch, dim, seqlen, width;
|
||||
int64_t pad_slot_id;
|
||||
bool silu_activation;
|
||||
|
||||
index_t x_batch_stride;
|
||||
index_t x_c_stride;
|
||||
index_t x_l_stride;
|
||||
index_t weight_c_stride;
|
||||
index_t weight_width_stride;
|
||||
index_t out_batch_stride;
|
||||
index_t out_c_stride;
|
||||
index_t out_l_stride;
|
||||
|
||||
int conv_state_len;
|
||||
index_t conv_state_batch_stride;
|
||||
index_t conv_state_c_stride;
|
||||
index_t conv_state_l_stride;
|
||||
|
||||
// Common data pointers.
|
||||
void *__restrict__ x_ptr;
|
||||
void *__restrict__ weight_ptr;
|
||||
void *__restrict__ bias_ptr;
|
||||
void *__restrict__ out_ptr;
|
||||
|
||||
void *__restrict__ conv_state_ptr;
|
||||
void *__restrict__ query_start_loc_ptr;
|
||||
void *__restrict__ has_initial_state_ptr;
|
||||
void *__restrict__ cache_indices_ptr;
|
||||
int32_t *__restrict__ cache_seqlens;
|
||||
|
||||
// For the continuous batching case. Makes it so that the mamba state for
|
||||
// the current batch doesn't need to be a contiguous tensor.
|
||||
int32_t *__restrict__ conv_state_indices_ptr;
|
||||
|
||||
void *__restrict__ seq_idx_ptr;
|
||||
|
||||
// No __restrict__ since initial_states could be the same as final_states.
|
||||
void * initial_states_ptr;
|
||||
index_t initial_states_batch_stride;
|
||||
index_t initial_states_l_stride;
|
||||
index_t initial_states_c_stride;
|
||||
|
||||
void * final_states_ptr;
|
||||
index_t final_states_batch_stride;
|
||||
index_t final_states_l_stride;
|
||||
index_t final_states_c_stride;
|
||||
|
||||
void * conv_states_ptr;
|
||||
index_t conv_states_batch_stride;
|
||||
index_t conv_states_l_stride;
|
||||
index_t conv_states_c_stride;
|
||||
};
|
||||
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
template<typename T>
|
||||
__device__ inline T shuffle_xor(T val, int offset) {
|
||||
return __shfl_xor_sync(uint32_t(-1), val, offset);
|
||||
}
|
||||
|
||||
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
||||
{
|
||||
return std::max(ilist);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr T constexpr_min(T a, T b) {
|
||||
return std::min(a, b);
|
||||
}
|
||||
|
||||
#else
|
||||
#include <hip/hip_bf16.h>
|
||||
|
||||
template<typename T>
|
||||
__device__ inline T shuffle_xor(T val, int offset) {
|
||||
return __shfl_xor(val, offset);
|
||||
}
|
||||
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
||||
{
|
||||
return *std::max_element(ilist.begin(), ilist.end());
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr T constexpr_min(T a, T b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<int BYTES> struct BytesToType {};
|
||||
|
||||
template<> struct BytesToType<16> {
|
||||
using Type = uint4;
|
||||
static_assert(sizeof(Type) == 16);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<8> {
|
||||
using Type = uint64_t;
|
||||
static_assert(sizeof(Type) == 8);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<4> {
|
||||
using Type = uint32_t;
|
||||
static_assert(sizeof(Type) == 4);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<2> {
|
||||
using Type = uint16_t;
|
||||
static_assert(sizeof(Type) == 2);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<1> {
|
||||
using Type = uint8_t;
|
||||
static_assert(sizeof(Type) == 1);
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
struct SumOp {
|
||||
__device__ inline T operator()(T const & x, T const & y) { return x + y; }
|
||||
};
|
||||
|
||||
template<int THREADS>
|
||||
struct Allreduce {
|
||||
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
|
||||
template<typename T, typename Operator>
|
||||
static __device__ inline T run(T x, Operator &op) {
|
||||
constexpr int OFFSET = THREADS / 2;
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
|
||||
return Allreduce<OFFSET>::run(x, op);
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct Allreduce<2> {
|
||||
template<typename T, typename Operator>
|
||||
static __device__ inline T run(T x, Operator &op) {
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
|
||||
return x;
|
||||
}
|
||||
};
|
||||
@ -1,28 +0,0 @@
|
||||
// Inspired by
|
||||
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
|
||||
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
|
||||
// clang-format off
|
||||
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/static_switch.h
|
||||
|
||||
#pragma once
|
||||
|
||||
/// @param COND - a boolean expression to switch by
|
||||
/// @param CONST_NAME - a name given for the constexpr bool variable.
|
||||
/// @param ... - code to execute for true and false
|
||||
///
|
||||
/// Usage:
|
||||
/// ```
|
||||
/// BOOL_SWITCH(flag, BoolConst, [&] {
|
||||
/// some_function<BoolConst>(...);
|
||||
/// });
|
||||
/// ```
|
||||
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
|
||||
[&] { \
|
||||
if (COND) { \
|
||||
static constexpr bool CONST_NAME = true; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
static constexpr bool CONST_NAME = false; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
||||
@ -312,19 +312,20 @@ void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) {
|
||||
// kIsVariableB, kIsVariableC and kHasZ are all set to True to reduce binary size
|
||||
constexpr bool kIsVariableB = true;
|
||||
constexpr bool kIsVariableC = true;
|
||||
constexpr bool kHasZ = true;
|
||||
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
|
||||
BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] {
|
||||
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kVarlen, input_t, weight_t>;
|
||||
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
|
||||
dim3 grid(params.batch, params.dim / kNRows);
|
||||
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
|
||||
if (kSmemSize >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
}
|
||||
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
|
||||
BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] {
|
||||
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kVarlen, input_t, weight_t>;
|
||||
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
|
||||
dim3 grid(params.batch, params.dim / kNRows);
|
||||
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
|
||||
if (kSmemSize >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
}
|
||||
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
@ -612,19 +613,20 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||
|
||||
at::Tensor z, out_z;
|
||||
const bool has_z = z_.has_value();
|
||||
TORCH_CHECK(has_z, "has_z = False is disabled in favor of reduced binary size")
|
||||
z = z_.value();
|
||||
TORCH_CHECK(z.scalar_type() == input_type);
|
||||
TORCH_CHECK(z.is_cuda());
|
||||
TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
|
||||
if (varlen){
|
||||
CHECK_SHAPE(z, dim, seqlen);
|
||||
} else {
|
||||
CHECK_SHAPE(z, batch_size, dim, seqlen);
|
||||
if (has_z) {
|
||||
z = z_.value();
|
||||
TORCH_CHECK(z.scalar_type() == input_type);
|
||||
TORCH_CHECK(z.is_cuda());
|
||||
TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
|
||||
if (varlen){
|
||||
CHECK_SHAPE(z, dim, seqlen);
|
||||
} else {
|
||||
CHECK_SHAPE(z, batch_size, dim, seqlen);
|
||||
}
|
||||
|
||||
out_z = z;
|
||||
}
|
||||
|
||||
out_z = z;
|
||||
|
||||
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
|
||||
at::Tensor out = delta;
|
||||
TORCH_CHECK(ssm_states.scalar_type() == input_type);
|
||||
@ -653,4 +655,3 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||
selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
16
csrc/ops.h
16
csrc/ops.h
@ -326,22 +326,6 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
|
||||
const std::optional<torch::Tensor>& has_initial_state,
|
||||
const torch::Tensor& ssm_states, int64_t pad_slot_id);
|
||||
|
||||
void causal_conv1d_update(const at::Tensor& x, const at::Tensor& conv_state,
|
||||
const at::Tensor& weight,
|
||||
const std::optional<at::Tensor>& bias_,
|
||||
bool silu_activation,
|
||||
const std::optional<at::Tensor>& cache_seqlens_,
|
||||
const std::optional<at::Tensor>& conv_state_indices_,
|
||||
int64_t pad_slot_id);
|
||||
|
||||
void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
|
||||
const std::optional<at::Tensor>& bias_,
|
||||
const std::optional<at::Tensor>& conv_states,
|
||||
const std::optional<at::Tensor>& query_start_loc,
|
||||
const std::optional<at::Tensor>& cache_indices,
|
||||
const std::optional<at::Tensor>& has_initial_state,
|
||||
bool silu_activation, int64_t pad_slot_id);
|
||||
|
||||
using fptr_t = int64_t;
|
||||
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
|
||||
torch::Tensor& rank_data, int64_t rank,
|
||||
|
||||
@ -201,11 +201,10 @@ void run_blockwise_scaled_group_mm(
|
||||
reinterpret_cast<typename ScheduleConfig::LayoutSFB*>(
|
||||
layout_sfb.data_ptr())};
|
||||
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = a_ptrs.get_device();
|
||||
hw_info.sm_count =
|
||||
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
|
||||
hw_info.device_id);
|
||||
int device_id = a_ptrs.device().index();
|
||||
static const cutlass::KernelHardwareInfo hw_info{
|
||||
device_id, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
|
||||
device_id)};
|
||||
|
||||
// Epilogue Arguments
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
|
||||
@ -30,35 +30,40 @@
|
||||
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
|
||||
#include "core/math.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
// Kernel Perf config
|
||||
template <typename T>
|
||||
struct KernelTraits;
|
||||
|
||||
template <>
|
||||
struct KernelTraits<float> {
|
||||
using MmaTileShape = Shape<_128, _128, _256>;
|
||||
// Configuration for M in (256, inf)
|
||||
struct sm100_fp4_config_default {
|
||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
||||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
using TileShape = Shape<_256, _256, _256>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using PerSmTileShape_MNK = Shape<_128, _256, _256>;
|
||||
};
|
||||
|
||||
// Configuration for M in (16, 256]
|
||||
struct sm100_fp4_config_M256 {
|
||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
||||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
using TileShape = Shape<_256, _128, _256>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using PerSmTileShape_MNK = Shape<_128, _128, _256>;
|
||||
};
|
||||
|
||||
// Configuration for M in [1, 16]
|
||||
struct sm100_fp4_config_M16 {
|
||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
||||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
using TileShape = Shape<_128, _128, _256>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
using PerSmTileShape_MNK = Shape<_128, _128, _256>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KernelTraits<cutlass::half_t> {
|
||||
using MmaTileShape = Shape<_256, _256, _256>;
|
||||
using ClusterShape = Shape<_4, _4, _1>;
|
||||
using PerSmTileShape_MNK = Shape<_128, _256, _256>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KernelTraits<cutlass::bfloat16_t> {
|
||||
using MmaTileShape = Shape<_256, _256, _256>;
|
||||
using ClusterShape = Shape<_4, _4, _1>;
|
||||
using PerSmTileShape_MNK = Shape<_128, _256, _256>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
template <typename Config, typename OutType>
|
||||
struct Fp4GemmSm100 {
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
|
||||
@ -71,21 +76,22 @@ struct Fp4GemmSm100 {
|
||||
static constexpr int AlignmentB = 32;
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementD = T;
|
||||
using ElementC = T;
|
||||
using ElementD = OutType;
|
||||
using ElementC = OutType;
|
||||
using LayoutCTag = cutlass::layout::RowMajor;
|
||||
using LayoutDTag = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
// Kernel functional config
|
||||
using ElementAccumulator = float;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp;
|
||||
|
||||
// Kernel Perf config
|
||||
using MmaTileShape = typename KernelTraits<T>::MmaTileShape;
|
||||
using ClusterShape = typename KernelTraits<T>::ClusterShape;
|
||||
using PerSmTileShape_MNK = typename KernelTraits<T>::PerSmTileShape_MNK;
|
||||
// Use config's tile shapes
|
||||
using MmaTileShape = typename Config::TileShape;
|
||||
using ClusterShape = typename Config::ClusterShape;
|
||||
using PerSmTileShape_MNK = typename Config::PerSmTileShape_MNK;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
@ -119,22 +125,22 @@ struct Fp4GemmSm100 {
|
||||
using LayoutD = decltype(cute::make_layout(make_shape(0, 0, 0), StrideD{}));
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
typename T::Gemm::Arguments args_from_options(
|
||||
template <typename Config>
|
||||
typename Config::Gemm::Arguments args_from_options(
|
||||
at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
||||
at::Tensor const& A_sf, at::Tensor const& B_sf, at::Tensor const& alpha,
|
||||
int64_t M, int64_t N, int64_t K) {
|
||||
using ElementA = typename T::Gemm::ElementA;
|
||||
using ElementB = typename T::Gemm::ElementB;
|
||||
using ElementA = typename Config::Gemm::ElementA;
|
||||
using ElementB = typename Config::Gemm::ElementB;
|
||||
using ElementSFA = cutlass::float_ue4m3_t;
|
||||
using ElementSFB = cutlass::float_ue4m3_t;
|
||||
using ElementD = typename T::Gemm::ElementD;
|
||||
using ElementD = typename Config::Gemm::ElementD;
|
||||
using ElementCompute = float;
|
||||
using StrideA = typename T::StrideA;
|
||||
using StrideB = typename T::StrideB;
|
||||
using StrideD = typename T::StrideD;
|
||||
using Sm100BlkScaledConfig =
|
||||
typename T::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
using StrideA = typename Config::StrideA;
|
||||
using StrideB = typename Config::StrideB;
|
||||
using StrideD = typename Config::StrideD;
|
||||
using Sm100BlkScaledConfig = typename Config::Gemm::GemmKernel::
|
||||
CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
|
||||
int m = static_cast<int>(M);
|
||||
int n = static_cast<int>(N);
|
||||
@ -148,7 +154,7 @@ typename T::Gemm::Arguments args_from_options(
|
||||
auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(
|
||||
cute::make_shape(m, n, k, 1));
|
||||
|
||||
typename T::Gemm::Arguments arguments{
|
||||
typename Config::Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{m, n, k, 1},
|
||||
{// Mainloop arguments
|
||||
@ -167,17 +173,17 @@ typename T::Gemm::Arguments args_from_options(
|
||||
return arguments;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename Config>
|
||||
void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
||||
at::Tensor const& A_sf, at::Tensor const& B_sf,
|
||||
at::Tensor const& alpha, int64_t m, int64_t n, int64_t k,
|
||||
cudaStream_t stream) {
|
||||
typename Fp4GemmSm100<T>::Gemm gemm;
|
||||
typename Config::Gemm gemm;
|
||||
|
||||
auto arguments =
|
||||
args_from_options<Fp4GemmSm100<T>>(D, A, B, A_sf, B_sf, alpha, m, n, k);
|
||||
args_from_options<Config>(D, A, B, A_sf, B_sf, alpha, m, n, k);
|
||||
|
||||
size_t workspace_size = Fp4GemmSm100<T>::Gemm::get_workspace_size(arguments);
|
||||
size_t workspace_size = Config::Gemm::get_workspace_size(arguments);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
@ -188,12 +194,40 @@ void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
||||
|
||||
CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream));
|
||||
}
|
||||
|
||||
// Dispatch function to select appropriate config based on M
|
||||
template <typename OutType>
|
||||
void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
|
||||
torch::Tensor const& B,
|
||||
torch::Tensor const& A_sf,
|
||||
torch::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha, int64_t m, int64_t n,
|
||||
int64_t k, cudaStream_t stream) {
|
||||
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
|
||||
|
||||
if (mp2 <= 16) {
|
||||
// m in [1, 16]
|
||||
runGemm<Fp4GemmSm100<sm100_fp4_config_M16, OutType>>(
|
||||
D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
} else if (mp2 <= 256) {
|
||||
// m in (16, 256]
|
||||
runGemm<Fp4GemmSm100<sm100_fp4_config_M256, OutType>>(
|
||||
D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
} else {
|
||||
// m in (256, inf)
|
||||
runGemm<Fp4GemmSm100<sm100_fp4_config_default, OutType>>(
|
||||
D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
template <typename T>
|
||||
void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
||||
at::Tensor const& A_sf, at::Tensor const& B_sf,
|
||||
at::Tensor const& alpha, int64_t m, int64_t n, int64_t k,
|
||||
cudaStream_t stream) {
|
||||
template <typename OutType>
|
||||
void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
|
||||
torch::Tensor const& B,
|
||||
torch::Tensor const& A_sf,
|
||||
torch::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha, int64_t m, int64_t n,
|
||||
int64_t k, cudaStream_t stream) {
|
||||
TORCH_CHECK(false,
|
||||
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
|
||||
"a CUTLASS 3.8 source directory to enable support.");
|
||||
@ -271,12 +305,13 @@ void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
|
||||
|
||||
if (out_dtype == at::ScalarType::Half) {
|
||||
runGemm<cutlass::half_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
cutlass_fp4_gemm_dispatch<cutlass::half_t>(D, A, B, A_sf, B_sf, alpha, m, n,
|
||||
k, stream);
|
||||
} else if (out_dtype == at::ScalarType::BFloat16) {
|
||||
runGemm<cutlass::bfloat16_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
} else if (out_dtype == at::ScalarType::Float) {
|
||||
runGemm<float>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
cutlass_fp4_gemm_dispatch<cutlass::bfloat16_t>(D, A, B, A_sf, B_sf, alpha,
|
||||
m, n, k, stream);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm");
|
||||
TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm (", out_dtype,
|
||||
")");
|
||||
}
|
||||
}
|
||||
|
||||
@ -594,28 +594,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"int pad_slot_id) -> ()");
|
||||
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
|
||||
|
||||
ops.def(
|
||||
"causal_conv1d_update(Tensor! x,"
|
||||
"Tensor! conv_state,"
|
||||
"Tensor! weight,"
|
||||
"Tensor? bias_,"
|
||||
"bool silu_activation,"
|
||||
"Tensor? cache_seqlens_,"
|
||||
"Tensor? conv_state_indices,"
|
||||
"int pad_slot_id) -> ()");
|
||||
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
|
||||
|
||||
ops.def(
|
||||
"causal_conv1d_fwd(Tensor! x, Tensor! weight,"
|
||||
"Tensor? bias_,"
|
||||
"Tensor!? conv_states,"
|
||||
"Tensor? query_start_loc,"
|
||||
"Tensor? cache_indices,"
|
||||
"Tensor? has_initial_state,"
|
||||
"bool silu_activation,"
|
||||
"int pad_slot_id) -> ()");
|
||||
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
|
||||
ops.def(
|
||||
|
||||
@ -379,38 +379,38 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
|
||||
ARG FLASHINFER_CUDA128_INDEX_URL="https://download.pytorch.org/whl/cu128/flashinfer"
|
||||
ARG FLASHINFER_CUDA128_WHEEL="flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl"
|
||||
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
|
||||
ARG FLASHINFER_GIT_REF="v0.2.6.post1"
|
||||
ARG FLASHINFER_GIT_REF="v0.2.8rc1"
|
||||
# Flag to control whether to use pre-built FlashInfer wheels (set to false to force build from source)
|
||||
# TODO: Currently disabled because the pre-built wheels are not available for FLASHINFER_GIT_REF
|
||||
ARG USE_FLASHINFER_PREBUILT_WHEEL=false
|
||||
RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
|
||||
. /etc/environment
|
||||
if [ "$TARGETPLATFORM" != "linux/arm64" ]; then
|
||||
# FlashInfer already has a wheel for PyTorch 2.7.0 and CUDA 12.8. This is enough for CI use
|
||||
if [[ "$CUDA_VERSION" == 12.8* ]]; then
|
||||
if [[ "$CUDA_VERSION" == 12.8* ]] && [[ "$USE_FLASHINFER_PREBUILT_WHEEL" == "true" ]]; then
|
||||
uv pip install --system ${FLASHINFER_CUDA128_INDEX_URL}/${FLASHINFER_CUDA128_WHEEL}
|
||||
else
|
||||
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0a 10.0a 12.0'
|
||||
git clone ${FLASHINFER_GIT_REPO} --single-branch --branch ${FLASHINFER_GIT_REF} --recursive
|
||||
# Needed to build AOT kernels
|
||||
(cd flashinfer && \
|
||||
python3 -m flashinfer.aot && \
|
||||
uv pip install --system --no-build-isolation . \
|
||||
)
|
||||
rm -rf flashinfer
|
||||
|
||||
# Default arches (skipping 10.0a and 12.0 since these need 12.8)
|
||||
# Exclude CUDA arches for older versions (11.x and 12.0-12.7)
|
||||
# TODO: Update this to allow setting TORCH_CUDA_ARCH_LIST as a build arg.
|
||||
TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a"
|
||||
if [[ "${CUDA_VERSION}" == 11.* ]]; then
|
||||
TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9"
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9"
|
||||
elif [[ "${CUDA_VERSION}" == 12.[0-7]* ]]; then
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a"
|
||||
else
|
||||
# CUDA 12.8+ supports 10.0a and 12.0
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0"
|
||||
fi
|
||||
echo "🏗️ Building FlashInfer for arches: ${TORCH_CUDA_ARCH_LIST}"
|
||||
echo "🏗️ Building FlashInfer for arches: ${FI_TORCH_CUDA_ARCH_LIST}"
|
||||
|
||||
git clone --depth 1 --recursive --shallow-submodules \
|
||||
--branch v0.2.6.post1 \
|
||||
https://github.com/flashinfer-ai/flashinfer.git flashinfer
|
||||
--branch ${FLASHINFER_GIT_REF} \
|
||||
${FLASHINFER_GIT_REPO} flashinfer
|
||||
|
||||
# Needed to build AOT kernels
|
||||
pushd flashinfer
|
||||
python3 -m flashinfer.aot
|
||||
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST}" \
|
||||
TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \
|
||||
python3 -m flashinfer.aot
|
||||
TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \
|
||||
uv pip install --system --no-build-isolation .
|
||||
popd
|
||||
|
||||
|
||||
@ -36,7 +36,7 @@ vLLM is flexible and easy to use with:
|
||||
|
||||
- Seamless integration with popular HuggingFace models
|
||||
- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more
|
||||
- Tensor parallelism and pipeline parallelism support for distributed inference
|
||||
- Tensor, pipeline, data and expert parallelism support for distributed inference
|
||||
- Streaming outputs
|
||||
- OpenAI-compatible API server
|
||||
- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs, Gaudi® accelerators and GPUs, IBM Power CPUs, TPU, and AWS Trainium and Inferentia Accelerators.
|
||||
|
||||
BIN
docs/assets/deployment/dp_external_lb.png
Normal file
BIN
docs/assets/deployment/dp_external_lb.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 84 KiB |
BIN
docs/assets/deployment/dp_internal_lb.png
Normal file
BIN
docs/assets/deployment/dp_internal_lb.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 68 KiB |
@ -1,3 +1,7 @@
|
||||
---
|
||||
toc_depth: 3
|
||||
---
|
||||
|
||||
# Engine Arguments
|
||||
|
||||
Engine arguments control the behavior of the vLLM engine.
|
||||
@ -5,11 +9,12 @@ Engine arguments control the behavior of the vLLM engine.
|
||||
- For [offline inference](../serving/offline_inference.md), they are part of the arguments to [LLM][vllm.LLM] class.
|
||||
- For [online serving](../serving/openai_compatible_server.md), they are part of the arguments to `vllm serve`.
|
||||
|
||||
You can look at [EngineArgs][vllm.engine.arg_utils.EngineArgs] and [AsyncEngineArgs][vllm.engine.arg_utils.AsyncEngineArgs] to see the available engine arguments.
|
||||
The engine argument classes, [EngineArgs][vllm.engine.arg_utils.EngineArgs] and [AsyncEngineArgs][vllm.engine.arg_utils.AsyncEngineArgs], are a combination of the configuration classes defined in [vllm.config][]. Therefore, if you are interested in developer documentation, we recommend looking at these configuration classes as they are the source of truth for types, defaults and docstrings.
|
||||
|
||||
However, these classes are a combination of the configuration classes defined in [vllm.config][]. Therefore, we would recommend you read about them there where they are best documented.
|
||||
## `EngineArgs`
|
||||
|
||||
For offline inference you will have access to these configuration classes and for online serving you can cross-reference the configs with `vllm serve --help`, which has its arguments grouped by config.
|
||||
--8<-- "docs/argparse/engine_args.md"
|
||||
|
||||
!!! note
|
||||
Additional arguments are available to the [AsyncLLMEngine][vllm.engine.async_llm_engine.AsyncLLMEngine] which is used for online serving. These can be found by running `vllm serve --help`
|
||||
## `AsyncEngineArgs`
|
||||
|
||||
--8<-- "docs/argparse/async_engine_args.md"
|
||||
|
||||
@ -99,16 +99,16 @@ Once your `CMakeUserPresets.json` is configured:
|
||||
1. **Initialize the CMake build environment:**
|
||||
This step configures the build system according to your chosen preset (e.g., `release`) and creates the build directory at `binaryDir`
|
||||
|
||||
```console
|
||||
cmake --preset release
|
||||
```
|
||||
```console
|
||||
cmake --preset release
|
||||
```
|
||||
|
||||
2. **Build and install the vLLM components:**
|
||||
This command compiles the code and installs the resulting binaries into your vLLM source directory, making them available to your editable Python installation.
|
||||
|
||||
```console
|
||||
cmake --build --preset release --target install
|
||||
```
|
||||
```console
|
||||
cmake --build --preset release --target install
|
||||
```
|
||||
|
||||
3. **Make changes and repeat!**
|
||||
Now you start using your editable install of vLLM, testing and making changes as needed. If you need to build again to update based on changes, simply run the CMake command again to build only the affected files.
|
||||
|
||||
@ -73,6 +73,8 @@ def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
...
|
||||
```
|
||||
|
||||
@ -272,3 +272,80 @@ The new format of `--lora-modules` is mainly to support the display of parent mo
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Default LoRA Models For Multimodal Models
|
||||
|
||||
Some models, e.g., [Granite Speech](https://huggingface.co/ibm-granite/granite-speech-3.3-8b) and [Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct) multimodal, contain LoRA adapter(s) that are expected to always be applied when a given modality is present. This can be a bit tedious to manage with the above approaches, as it requires the user to send the `LoRARequest` (offline) or to filter requests between the base model and LoRA model (server) depending on the content of the request's multimodal data.
|
||||
|
||||
To this end, we allow registration of default multimodal LoRAs to handle this automatically, where users can map each modality to a LoRA adapter to automatically apply it when the corresponding inputs are present. Note that currently, we only allow one LoRA per prompt; if several modalities are provided, each of which are registered to a given modality, none of them will be applied.
|
||||
|
||||
??? code "Example usage for offline inference"
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.assets.audio import AudioAsset
|
||||
|
||||
model_id = "ibm-granite/granite-speech-3.3-2b"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
def get_prompt(question: str, has_audio: bool):
|
||||
"""Build the input prompt to send to vLLM."""
|
||||
if has_audio:
|
||||
question = f"<|audio|>{question}"
|
||||
chat = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": question
|
||||
}
|
||||
]
|
||||
return tokenizer.apply_chat_template(chat, tokenize=False)
|
||||
|
||||
|
||||
model = LLM(
|
||||
model=model_id,
|
||||
enable_lora=True,
|
||||
max_lora_rank=64,
|
||||
max_model_len=2048,
|
||||
limit_mm_per_prompt={"audio": 1},
|
||||
# Will always pass a `LoRARequest` with the `model_id`
|
||||
# whenever audio is contained in the request data.
|
||||
default_mm_loras = {"audio": model_id},
|
||||
enforce_eager=True,
|
||||
)
|
||||
|
||||
question = "can you transcribe the speech into a written format?"
|
||||
prompt_with_audio = get_prompt(
|
||||
question=question,
|
||||
has_audio=True,
|
||||
)
|
||||
audio = AudioAsset("mary_had_lamb").audio_and_sample_rate
|
||||
|
||||
inputs = {
|
||||
"prompt": prompt_with_audio,
|
||||
"multi_modal_data": {
|
||||
"audio": audio,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
outputs = model.generate(
|
||||
inputs,
|
||||
sampling_params=SamplingParams(
|
||||
temperature=0.2,
|
||||
max_tokens=64,
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
You can also pass a json dictionary of `--default-mm-loras` mapping modalities to LoRA model IDs. For example, when starting the server:
|
||||
|
||||
```bash
|
||||
vllm serve ibm-granite/granite-speech-3.3-2b \
|
||||
--max-model-len 2048 \
|
||||
--enable-lora \
|
||||
--default-mm-loras '{"audio":"ibm-granite/granite-speech-3.3-2b"}' \
|
||||
--max-lora-rank 64
|
||||
```
|
||||
|
||||
Note: Default multimodal LoRAs are currently only available for `.generate` and chat completions.
|
||||
|
||||
@ -229,3 +229,28 @@ python3 quantize_quark.py --model_dir meta-llama/Llama-2-70b-chat-hf \
|
||||
--model_export hf_format \
|
||||
--tasks gsm8k
|
||||
```
|
||||
|
||||
## Using MXFP4 models
|
||||
|
||||
vLLM supports loading MXFP4 models quantized offline through AMD Quark, compliant with [Open Compute Project (OCP) specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf).
|
||||
|
||||
The scheme currently only supports dynamic quantization for activations.
|
||||
|
||||
Example usage, after installing the latest AMD Quark release:
|
||||
|
||||
```bash
|
||||
vllm serve fxmarty/qwen_1.5-moe-a2.7b-mxfp4 --tensor-parallel-size 1
|
||||
```
|
||||
|
||||
A simulation of the matrix multiplication execution in MXFP4 can be run on devices that do not support MXFP4 operations natively (e.g. AMD Instinct MI325, MI300 and MI250), dequantizing weights from MXFP4 to half precision on the fly, using a fused kernel. This is useful e.g. to evaluate MXFP4 models using vLLM, or alternatively to benefit from the ~4x memory savings (compared to float16 and bfloat16).
|
||||
|
||||
To generate offline models quantized using MXFP4 data type, the easiest approach is to use AMD Quark's [quantization script](https://quark.docs.amd.com/latest/pytorch/example_quark_torch_llm_ptq.html), as an example:
|
||||
|
||||
```bash
|
||||
python quantize_quark.py --model_dir Qwen/Qwen1.5-MoE-A2.7B-Chat \
|
||||
--quant_scheme w_mxfp4_a_mxfp4_sym \
|
||||
--output_dir qwen_1.5-moe-a2.7b-mxfp4 \
|
||||
--skip_evaluation \
|
||||
--model_export hf_format \
|
||||
--group_size 32
|
||||
```
|
||||
|
||||
@ -256,12 +256,12 @@ speculative decoding, breaking down the guarantees into three key areas:
|
||||
2. **Algorithmic Losslessness**
|
||||
\- vLLM’s implementation of speculative decoding is algorithmically validated to be lossless. Key validation tests include:
|
||||
|
||||
> - **Rejection Sampler Convergence**: Ensures that samples from vLLM’s rejection sampler align with the target
|
||||
> distribution. [View Test Code](https://github.com/vllm-project/vllm/blob/47b65a550866c7ffbd076ecb74106714838ce7da/tests/samplers/test_rejection_sampler.py#L252)
|
||||
> - **Greedy Sampling Equality**: Confirms that greedy sampling with speculative decoding matches greedy sampling
|
||||
> without it. This verifies that vLLM's speculative decoding framework, when integrated with the vLLM forward pass and the vLLM rejection sampler,
|
||||
> provides a lossless guarantee. Almost all of the tests in <gh-dir:tests/spec_decode/e2e>.
|
||||
> verify this property using [this assertion implementation](https://github.com/vllm-project/vllm/blob/b67ae00cdbbe1a58ffc8ff170f0c8d79044a684a/tests/spec_decode/e2e/conftest.py#L291)
|
||||
> - **Rejection Sampler Convergence**: Ensures that samples from vLLM’s rejection sampler align with the target
|
||||
> distribution. [View Test Code](https://github.com/vllm-project/vllm/blob/47b65a550866c7ffbd076ecb74106714838ce7da/tests/samplers/test_rejection_sampler.py#L252)
|
||||
> - **Greedy Sampling Equality**: Confirms that greedy sampling with speculative decoding matches greedy sampling
|
||||
> without it. This verifies that vLLM's speculative decoding framework, when integrated with the vLLM forward pass and the vLLM rejection sampler,
|
||||
> provides a lossless guarantee. Almost all of the tests in <gh-dir:tests/spec_decode/e2e>.
|
||||
> verify this property using [this assertion implementation](https://github.com/vllm-project/vllm/blob/b67ae00cdbbe1a58ffc8ff170f0c8d79044a684a/tests/spec_decode/e2e/conftest.py#L291)
|
||||
|
||||
3. **vLLM Logprob Stability**
|
||||
\- vLLM does not currently guarantee stable token log probabilities (logprobs). This can result in different outputs for the
|
||||
|
||||
@ -282,6 +282,14 @@ Supported models:
|
||||
|
||||
Flags: `--tool-call-parser deepseek_v3 --chat-template {see_above}`
|
||||
|
||||
### Kimi-K2 Models (`kimi_k2`)
|
||||
|
||||
Supported models:
|
||||
|
||||
* `moonshotai/Kimi-K2-Instruct`
|
||||
|
||||
Flags: `--tool-call-parser kimi_k2`
|
||||
|
||||
### Models with Pythonic Tool Calls (`pythonic`)
|
||||
|
||||
A growing number of models output a python list to represent tool calls instead of using JSON. This has the advantage of inherently supporting parallel tool calls and removing ambiguity around the JSON schema required for tool calls. The `pythonic` tool parser can support such models.
|
||||
|
||||
@ -76,78 +76,56 @@ Currently, there are no pre-built CPU wheels.
|
||||
|
||||
### Build image from source
|
||||
|
||||
??? console "Commands"
|
||||
=== "Intel/AMD x86"
|
||||
|
||||
```bash
|
||||
docker build -f docker/Dockerfile.cpu \
|
||||
--tag vllm-cpu-env \
|
||||
--target vllm-openai .
|
||||
--8<-- "docs/getting_started/installation/cpu/x86.inc.md:build-image-from-source"
|
||||
|
||||
# Launching OpenAI server
|
||||
docker run --rm \
|
||||
--privileged=true \
|
||||
--shm-size=4g \
|
||||
-p 8000:8000 \
|
||||
-e VLLM_CPU_KVCACHE_SPACE=<KV cache space> \
|
||||
-e VLLM_CPU_OMP_THREADS_BIND=<CPU cores for inference> \
|
||||
vllm-cpu-env \
|
||||
--model=meta-llama/Llama-3.2-1B-Instruct \
|
||||
--dtype=bfloat16 \
|
||||
other vLLM OpenAI server arguments
|
||||
```
|
||||
=== "ARM AArch64"
|
||||
|
||||
!!! tip
|
||||
For ARM or Apple silicon, use `docker/Dockerfile.arm`
|
||||
--8<-- "docs/getting_started/installation/cpu/arm.inc.md:build-image-from-source"
|
||||
|
||||
!!! tip
|
||||
For IBM Z (s390x), use `docker/Dockerfile.s390x` and in `docker run` use flag `--dtype float`
|
||||
=== "Apple silicon"
|
||||
|
||||
## Supported features
|
||||
--8<-- "docs/getting_started/installation/cpu/arm.inc.md:build-image-from-source"
|
||||
|
||||
vLLM CPU backend supports the following vLLM features:
|
||||
|
||||
- Tensor Parallel
|
||||
- Model Quantization (`INT8 W8A8, AWQ, GPTQ`)
|
||||
- Chunked-prefill
|
||||
- Prefix-caching
|
||||
- FP8-E5M2 KV cache
|
||||
=== "IBM Z (S390X)"
|
||||
--8<-- "docs/getting_started/installation/cpu/s390x.inc.md:build-image-from-source"
|
||||
|
||||
## Related runtime environment variables
|
||||
|
||||
- `VLLM_CPU_KVCACHE_SPACE`: specify the KV Cache size (e.g, `VLLM_CPU_KVCACHE_SPACE=40` means 40 GiB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users. Default value is `0`.
|
||||
- `VLLM_CPU_OMP_THREADS_BIND`: specify the CPU cores dedicated to the OpenMP threads. For example, `VLLM_CPU_OMP_THREADS_BIND=0-31` means there will be 32 OpenMP threads bound on 0-31 CPU cores. `VLLM_CPU_OMP_THREADS_BIND=0-31|32-63` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores. By setting to `auto`, the OpenMP threads of each rank are bound to the CPU cores in each NUMA node. By setting to `all`, the OpenMP threads of each rank uses all CPU cores available on the system. Default value is `auto`.
|
||||
- `VLLM_CPU_NUM_OF_RESERVED_CPU`: specify the number of CPU cores which are not dedicated to the OpenMP threads for each rank. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. Default value is `0`.
|
||||
- `VLLM_CPU_MOE_PREPACK`: whether to use prepack for MoE layer. This will be passed to `ipex.llm.modules.GatedMLPMOE`. Default is `1` (True). On unsupported CPUs, you might need to set this to `0` (False).
|
||||
- `VLLM_CPU_SGL_KERNEL` (Experimental): whether to use small-batch optimized kernels for linear layer and MoE layer, especially for low-latency requirements like online serving. The kernels require AMX instruction set, BFloat16 weight type and weight shapes divisible by 32. Default is `0` (False).
|
||||
- `VLLM_CPU_MOE_PREPACK` (x86 only): whether to use prepack for MoE layer. This will be passed to `ipex.llm.modules.GatedMLPMOE`. Default is `1` (True). On unsupported CPUs, you might need to set this to `0` (False).
|
||||
- `VLLM_CPU_SGL_KERNEL` (x86 only, Experimental): whether to use small-batch optimized kernels for linear layer and MoE layer, especially for low-latency requirements like online serving. The kernels require AMX instruction set, BFloat16 weight type and weight shapes divisible by 32. Default is `0` (False).
|
||||
|
||||
## Performance tips
|
||||
## FAQ
|
||||
|
||||
- We highly recommend to use TCMalloc for high performance memory allocation and better cache locality. For example, on Ubuntu 22.4, you can run:
|
||||
### Which `dtype` should be used?
|
||||
|
||||
```bash
|
||||
sudo apt-get install libtcmalloc-minimal4 # install TCMalloc library
|
||||
find / -name *libtcmalloc* # find the dynamic link library path
|
||||
export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:$LD_PRELOAD # prepend the library to LD_PRELOAD
|
||||
python examples/offline_inference/basic/basic.py # run vLLM
|
||||
```
|
||||
- Currently vLLM CPU uses model default settings as `dtype`. However, due to unstable float16 support in torch CPU, it is recommended to explicitly set `dtype=bfloat16` if there are any performance or accuracy problem.
|
||||
|
||||
- When using the online serving, it is recommended to reserve 1-2 CPU cores for the serving framework to avoid CPU oversubscription. For example, on a platform with 32 physical CPU cores, reserving CPU 30 and 31 for the framework and using CPU 0-29 for OpenMP:
|
||||
### How to launch a vLLM service on CPU?
|
||||
|
||||
- When using the online serving, it is recommended to reserve 1-2 CPU cores for the serving framework to avoid CPU oversubscription. For example, on a platform with 32 physical CPU cores, reserving CPU 31 for the framework and using CPU 0-30 for inference threads:
|
||||
|
||||
```bash
|
||||
export VLLM_CPU_KVCACHE_SPACE=40
|
||||
export VLLM_CPU_OMP_THREADS_BIND=0-29
|
||||
vllm serve facebook/opt-125m
|
||||
export VLLM_CPU_OMP_THREADS_BIND=0-30
|
||||
vllm serve facebook/opt-125m --dtype=bfloat16
|
||||
```
|
||||
|
||||
or using default auto thread binding:
|
||||
|
||||
```bash
|
||||
export VLLM_CPU_KVCACHE_SPACE=40
|
||||
export VLLM_CPU_NUM_OF_RESERVED_CPU=2
|
||||
vllm serve facebook/opt-125m
|
||||
export VLLM_CPU_NUM_OF_RESERVED_CPU=1
|
||||
vllm serve facebook/opt-125m --dtype=bfloat16
|
||||
```
|
||||
|
||||
- If using vLLM CPU backend on a machine with hyper-threading, it is recommended to bind only one OpenMP thread on each physical CPU core using `VLLM_CPU_OMP_THREADS_BIND` or using auto thread binding feature by default. On a hyper-threading enabled platform with 16 logical CPU cores / 8 physical CPU cores:
|
||||
### How to decide `VLLM_CPU_OMP_THREADS_BIND`?
|
||||
|
||||
- Bind each OpenMP thread to a dedicated physical CPU core respectively, or use auto thread binding feature by default. On a hyper-threading enabled platform with 16 logical CPU cores / 8 physical CPU cores:
|
||||
|
||||
??? console "Commands"
|
||||
|
||||
@ -178,34 +156,21 @@ vllm serve facebook/opt-125m
|
||||
$ python examples/offline_inference/basic/basic.py
|
||||
```
|
||||
|
||||
- If using vLLM CPU backend on a multi-socket machine with NUMA, be aware to set CPU cores using `VLLM_CPU_OMP_THREADS_BIND` to avoid cross NUMA node memory access.
|
||||
- When deploy vLLM CPU backend on a multi-socket machine with NUMA and enable tensor parallel or pipeline parallel, each NUMA node is treated as a TP/PP rank. So be aware to set CPU cores of a single rank on a same NUMA node to avoid cross NUMA node memory access.
|
||||
|
||||
## Other considerations
|
||||
### How to decide `VLLM_CPU_KVCACHE_SPACE`?
|
||||
|
||||
- The CPU backend significantly differs from the GPU backend since the vLLM architecture was originally optimized for GPU use. A number of optimizations are needed to enhance its performance.
|
||||
- This value is 4GB by default. Larger space can support more concurrent requests, longer context length. However, users should take care of memory capacity of each NUMA node. The memory usage of each TP rank is the sum of `weight shard size` and `VLLM_CPU_KVCACHE_SPACE`, if it exceeds the capacity of a single NUMA node, the TP worker will be killed with `exitcode 9` due to out-of-memory.
|
||||
|
||||
- Decouple the HTTP serving components from the inference components. In a GPU backend configuration, the HTTP serving and tokenization tasks operate on the CPU, while inference runs on the GPU, which typically does not pose a problem. However, in a CPU-based setup, the HTTP serving and tokenization can cause significant context switching and reduced cache efficiency. Therefore, it is strongly recommended to segregate these two components for improved performance.
|
||||
### Which quantization configs does vLLM CPU support?
|
||||
|
||||
- On CPU based setup with NUMA enabled, the memory access performance may be largely impacted by the [topology](https://github.com/intel/intel-extension-for-pytorch/blob/main/docs/tutorials/performance_tuning/tuning_guide.md#non-uniform-memory-access-numa). For NUMA architecture, Tensor Parallel is a option for better performance.
|
||||
- vLLM CPU supports quantizations:
|
||||
- AWQ (x86 only)
|
||||
- GPTQ (x86 only)
|
||||
- compressed-tensor INT8 W8A8 (x86, s390x)
|
||||
|
||||
- Tensor Parallel is supported for serving and offline inferencing. In general each NUMA node is treated as one GPU card. Below is the example script to enable Tensor Parallel = 2 for serving:
|
||||
### (x86 only) What is the purpose of `VLLM_CPU_MOE_PREPACK` and `VLLM_CPU_SGL_KERNEL`?
|
||||
|
||||
```bash
|
||||
VLLM_CPU_KVCACHE_SPACE=40 VLLM_CPU_OMP_THREADS_BIND="0-31|32-63" \
|
||||
vllm serve meta-llama/Llama-2-7b-chat-hf \
|
||||
-tp=2 \
|
||||
--distributed-executor-backend mp
|
||||
```
|
||||
|
||||
or using default auto thread binding:
|
||||
|
||||
```bash
|
||||
VLLM_CPU_KVCACHE_SPACE=40 \
|
||||
vllm serve meta-llama/Llama-2-7b-chat-hf \
|
||||
-tp=2 \
|
||||
--distributed-executor-backend mp
|
||||
```
|
||||
|
||||
- For each thread id list in `VLLM_CPU_OMP_THREADS_BIND`, users should guarantee threads in the list belong to a same NUMA node.
|
||||
|
||||
- Meanwhile, users should also take care of memory capacity of each NUMA node. The memory usage of each TP rank is the sum of `weight shard size` and `VLLM_CPU_KVCACHE_SPACE`, if it exceeds the capacity of a single NUMA node, TP worker will be killed due to out-of-memory.
|
||||
- Both of them requires `amx` CPU flag.
|
||||
- `VLLM_CPU_MOE_PREPACK` can provides better performance for MoE models
|
||||
- `VLLM_CPU_SGL_KERNEL` can provides better performance for MoE models and small-batch scenarios.
|
||||
|
||||
@ -35,23 +35,22 @@ pip install -e .
|
||||
!!! note
|
||||
On macOS the `VLLM_TARGET_DEVICE` is automatically set to `cpu`, which currently is the only supported device.
|
||||
|
||||
#### Troubleshooting
|
||||
!!! example "Troubleshooting"
|
||||
If the build has error like the following snippet where standard C++ headers cannot be found, try to remove and reinstall your
|
||||
[Command Line Tools for Xcode](https://developer.apple.com/download/all/).
|
||||
|
||||
If the build has error like the following snippet where standard C++ headers cannot be found, try to remove and reinstall your
|
||||
[Command Line Tools for Xcode](https://developer.apple.com/download/all/).
|
||||
```text
|
||||
[...] fatal error: 'map' file not found
|
||||
1 | #include <map>
|
||||
| ^~~~~
|
||||
1 error generated.
|
||||
[2/8] Building CXX object CMakeFiles/_C.dir/csrc/cpu/pos_encoding.cpp.o
|
||||
|
||||
```text
|
||||
[...] fatal error: 'map' file not found
|
||||
1 | #include <map>
|
||||
| ^~~~~
|
||||
1 error generated.
|
||||
[2/8] Building CXX object CMakeFiles/_C.dir/csrc/cpu/pos_encoding.cpp.o
|
||||
|
||||
[...] fatal error: 'cstddef' file not found
|
||||
10 | #include <cstddef>
|
||||
| ^~~~~~~~~
|
||||
1 error generated.
|
||||
```
|
||||
[...] fatal error: 'cstddef' file not found
|
||||
10 | #include <cstddef>
|
||||
| ^~~~~~~~~
|
||||
1 error generated.
|
||||
```
|
||||
|
||||
# --8<-- [end:build-wheel-from-source]
|
||||
# --8<-- [start:pre-built-images]
|
||||
|
||||
@ -32,7 +32,22 @@ Testing has been conducted on AWS Graviton3 instances for compatibility.
|
||||
|
||||
# --8<-- [end:pre-built-images]
|
||||
# --8<-- [start:build-image-from-source]
|
||||
```bash
|
||||
docker build -f docker/Dockerfile.arm \
|
||||
--tag vllm-cpu-env .
|
||||
|
||||
# Launching OpenAI server
|
||||
docker run --rm \
|
||||
--privileged=true \
|
||||
--shm-size=4g \
|
||||
-p 8000:8000 \
|
||||
-e VLLM_CPU_KVCACHE_SPACE=<KV cache space> \
|
||||
-e VLLM_CPU_OMP_THREADS_BIND=<CPU cores for inference> \
|
||||
vllm-cpu-env \
|
||||
--model=meta-llama/Llama-3.2-1B-Instruct \
|
||||
--dtype=bfloat16 \
|
||||
other vLLM OpenAI server arguments
|
||||
```
|
||||
# --8<-- [end:build-image-from-source]
|
||||
# --8<-- [start:extra-information]
|
||||
# --8<-- [end:extra-information]
|
||||
|
||||
@ -2,7 +2,7 @@ First, install recommended compiler. We recommend to use `gcc/g++ >= 12.3.0` as
|
||||
|
||||
```bash
|
||||
sudo apt-get update -y
|
||||
sudo apt-get install -y gcc-12 g++-12 libnuma-dev python3-dev
|
||||
sudo apt-get install -y --no-install-recommends ccache git curl wget ca-certificates gcc-12 g++-12 libtcmalloc-minimal4 libnuma-dev ffmpeg libsm6 libxext6 libgl1 jq lsof
|
||||
sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
|
||||
```
|
||||
|
||||
@ -17,7 +17,7 @@ Third, install Python packages for vLLM CPU backend building:
|
||||
|
||||
```bash
|
||||
pip install --upgrade pip
|
||||
pip install "cmake>=3.26.1" wheel packaging ninja "setuptools-scm>=8" numpy
|
||||
pip install -v -r requirements/cpu-build.txt --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
pip install -v -r requirements/cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
```
|
||||
|
||||
@ -33,4 +33,7 @@ If you want to develop vllm, install it in editable mode instead.
|
||||
VLLM_TARGET_DEVICE=cpu python setup.py develop
|
||||
```
|
||||
|
||||
!!! note
|
||||
If you are building vLLM from source and not using the pre-built images, remember to set `LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:$LD_PRELOAD"` on x86 machines before running vLLM.
|
||||
|
||||
# --8<-- [end:extra-information]
|
||||
|
||||
@ -61,6 +61,23 @@ Execute the following commands to build and install vLLM from the source.
|
||||
# --8<-- [end:pre-built-images]
|
||||
# --8<-- [start:build-image-from-source]
|
||||
|
||||
```bash
|
||||
docker build -f docker/Dockerfile.s390x \
|
||||
--tag vllm-cpu-env .
|
||||
|
||||
# Launching OpenAI server
|
||||
docker run --rm \
|
||||
--privileged=true \
|
||||
--shm-size=4g \
|
||||
-p 8000:8000 \
|
||||
-e VLLM_CPU_KVCACHE_SPACE=<KV cache space> \
|
||||
-e VLLM_CPU_OMP_THREADS_BIND=<CPU cores for inference> \
|
||||
vllm-cpu-env \
|
||||
--model=meta-llama/Llama-3.2-1B-Instruct \
|
||||
--dtype=float \
|
||||
other vLLM OpenAI server arguments
|
||||
```
|
||||
|
||||
# --8<-- [end:build-image-from-source]
|
||||
# --8<-- [start:extra-information]
|
||||
# --8<-- [end:extra-information]
|
||||
|
||||
@ -1,19 +1,15 @@
|
||||
# --8<-- [start:installation]
|
||||
|
||||
vLLM initially supports basic model inferencing and serving on x86 CPU platform, with data types FP32, FP16 and BF16.
|
||||
|
||||
!!! warning
|
||||
There are no pre-built wheels or images for this device, so you must build vLLM from source.
|
||||
vLLM supports basic model inferencing and serving on x86 CPU platform, with data types FP32, FP16 and BF16.
|
||||
|
||||
# --8<-- [end:installation]
|
||||
# --8<-- [start:requirements]
|
||||
|
||||
- OS: Linux
|
||||
- Compiler: `gcc/g++ >= 12.3.0` (optional, recommended)
|
||||
- Instruction Set Architecture (ISA): AVX512 (optional, recommended)
|
||||
- CPU flags: `avx512f`, `avx512_bf16` (Optional), `avx512_vnni` (Optional)
|
||||
|
||||
!!! tip
|
||||
[Intel Extension for PyTorch (IPEX)](https://github.com/intel/intel-extension-for-pytorch) extends PyTorch with up-to-date features optimizations for an extra performance boost on Intel hardware.
|
||||
Use `lscpu` to check the CPU flags.
|
||||
|
||||
# --8<-- [end:requirements]
|
||||
# --8<-- [start:set-up-using-python]
|
||||
@ -26,18 +22,37 @@ vLLM initially supports basic model inferencing and serving on x86 CPU platform,
|
||||
|
||||
--8<-- "docs/getting_started/installation/cpu/build.inc.md"
|
||||
|
||||
!!! note
|
||||
- AVX512_BF16 is an extension ISA provides native BF16 data type conversion and vector product instructions, which brings some performance improvement compared with pure AVX512. The CPU backend build script will check the host CPU flags to determine whether to enable AVX512_BF16.
|
||||
- If you want to force enable AVX512_BF16 for the cross-compilation, please set environment variable `VLLM_CPU_AVX512BF16=1` before the building.
|
||||
|
||||
# --8<-- [end:build-wheel-from-source]
|
||||
# --8<-- [start:pre-built-images]
|
||||
|
||||
See [https://gallery.ecr.aws/q9t5s3a7/vllm-cpu-release-repo](https://gallery.ecr.aws/q9t5s3a7/vllm-cpu-release-repo)
|
||||
[https://gallery.ecr.aws/q9t5s3a7/vllm-cpu-release-repo](https://gallery.ecr.aws/q9t5s3a7/vllm-cpu-release-repo)
|
||||
|
||||
!!! warning
|
||||
If deploying the pre-built images on machines only contain `avx512f`, `Illegal instruction` error may be raised. It is recommended to build images for these machines with `--build-arg VLLM_CPU_AVX512BF16=false` and `--build-arg VLLM_CPU_AVX512VNNI=false`.
|
||||
|
||||
# --8<-- [end:pre-built-images]
|
||||
# --8<-- [start:build-image-from-source]
|
||||
|
||||
```bash
|
||||
docker build -f docker/Dockerfile.cpu \
|
||||
--build-arg VLLM_CPU_AVX512BF16=false (default)|true \
|
||||
--build-arg VLLM_CPU_AVX512VNNI=false (default)|true \
|
||||
--tag vllm-cpu-env \
|
||||
--target vllm-openai .
|
||||
|
||||
# Launching OpenAI server
|
||||
docker run --rm \
|
||||
--privileged=true \
|
||||
--shm-size=4g \
|
||||
-p 8000:8000 \
|
||||
-e VLLM_CPU_KVCACHE_SPACE=<KV cache space> \
|
||||
-e VLLM_CPU_OMP_THREADS_BIND=<CPU cores for inference> \
|
||||
vllm-cpu-env \
|
||||
--model=meta-llama/Llama-3.2-1B-Instruct \
|
||||
--dtype=bfloat16 \
|
||||
other vLLM OpenAI server arguments
|
||||
```
|
||||
|
||||
# --8<-- [end:build-image-from-source]
|
||||
# --8<-- [start:extra-information]
|
||||
# --8<-- [end:extra-information]
|
||||
|
||||
@ -133,36 +133,20 @@ docker run \
|
||||
The following configurations have been validated to function with
|
||||
Gaudi2 devices. Configurations that are not listed may or may not work.
|
||||
|
||||
- [meta-llama/Llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b)
|
||||
on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16
|
||||
datatype with random or greedy sampling
|
||||
- [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
|
||||
on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16
|
||||
datatype with random or greedy sampling
|
||||
- [meta-llama/Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B)
|
||||
on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16
|
||||
datatype with random or greedy sampling
|
||||
- [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
|
||||
on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16
|
||||
datatype with random or greedy sampling
|
||||
- [meta-llama/Meta-Llama-3.1-8B](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B)
|
||||
on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16
|
||||
datatype with random or greedy sampling
|
||||
- [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct)
|
||||
on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16
|
||||
datatype with random or greedy sampling
|
||||
- [meta-llama/Llama-2-70b](https://huggingface.co/meta-llama/Llama-2-70b)
|
||||
with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling
|
||||
- [meta-llama/Llama-2-70b-chat-hf](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf)
|
||||
with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling
|
||||
- [meta-llama/Meta-Llama-3-70B](https://huggingface.co/meta-llama/Meta-Llama-3-70B)
|
||||
with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling
|
||||
- [meta-llama/Meta-Llama-3-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct)
|
||||
with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling
|
||||
- [meta-llama/Meta-Llama-3.1-70B](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B)
|
||||
with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling
|
||||
- [meta-llama/Meta-Llama-3.1-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct)
|
||||
with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling
|
||||
| Model | TP Size| dtype | Sampling |
|
||||
|-------|--------|--------|----------|
|
||||
| [meta-llama/Llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b) | 1, 2, 8 | BF16 | Random / Greedy |
|
||||
| [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) | 1, 2, 8 | BF16 | Random / Greedy |
|
||||
| [meta-llama/Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B) | 1, 2, 8 | BF16 | Random / Greedy |
|
||||
| [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) | 1, 2, 8 | BF16 | Random / Greedy |
|
||||
| [meta-llama/Meta-Llama-3.1-8B](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B) | 1, 2, 8 | BF16 | Random / Greedy |
|
||||
| [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) | 1, 2, 8 | BF16 | Random / Greedy |
|
||||
| [meta-llama/Llama-2-70b](https://huggingface.co/meta-llama/Llama-2-70b) | 8 | BF16 | Random / Greedy |
|
||||
| [meta-llama/Llama-2-70b-chat-hf](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf) | 8 | BF16 | Random / Greedy |
|
||||
| [meta-llama/Meta-Llama-3-70B](https://huggingface.co/meta-llama/Meta-Llama-3-70B) | 8 | BF16 | Random / Greedy |
|
||||
| [meta-llama/Meta-Llama-3-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct) | 8 | BF16 | Random / Greedy |
|
||||
| [meta-llama/Meta-Llama-3.1-70B](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B) | 8 | BF16 | Random / Greedy |
|
||||
| [meta-llama/Meta-Llama-3.1-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct) | 8 | BF16 | Random / Greedy |
|
||||
|
||||
## Performance tuning
|
||||
|
||||
|
||||
105
docs/mkdocs/hooks/generate_argparse.py
Normal file
105
docs/mkdocs/hooks/generate_argparse.py
Normal file
@ -0,0 +1,105 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import logging
|
||||
import sys
|
||||
from argparse import SUPPRESS, HelpFormatter
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
ROOT_DIR = Path(__file__).parent.parent.parent.parent
|
||||
ARGPARSE_DOC_DIR = ROOT_DIR / "docs/argparse"
|
||||
|
||||
sys.path.insert(0, str(ROOT_DIR))
|
||||
sys.modules["aiohttp"] = MagicMock()
|
||||
sys.modules["blake3"] = MagicMock()
|
||||
sys.modules["vllm._C"] = MagicMock()
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs # noqa: E402
|
||||
from vllm.utils import FlexibleArgumentParser # noqa: E402
|
||||
|
||||
logger = logging.getLogger("mkdocs")
|
||||
|
||||
|
||||
class MarkdownFormatter(HelpFormatter):
|
||||
"""Custom formatter that generates markdown for argument groups."""
|
||||
|
||||
def __init__(self, prog):
|
||||
super().__init__(prog,
|
||||
max_help_position=float('inf'),
|
||||
width=float('inf'))
|
||||
self._markdown_output = []
|
||||
|
||||
def start_section(self, heading):
|
||||
if heading not in {"positional arguments", "options"}:
|
||||
self._markdown_output.append(f"\n### {heading}\n\n")
|
||||
|
||||
def end_section(self):
|
||||
pass
|
||||
|
||||
def add_text(self, text):
|
||||
if text:
|
||||
self._markdown_output.append(f"{text.strip()}\n\n")
|
||||
|
||||
def add_usage(self, usage, actions, groups, prefix=None):
|
||||
pass
|
||||
|
||||
def add_arguments(self, actions):
|
||||
for action in actions:
|
||||
|
||||
option_strings = f'`{"`, `".join(action.option_strings)}`'
|
||||
self._markdown_output.append(f"#### {option_strings}\n\n")
|
||||
|
||||
if choices := action.choices:
|
||||
choices = f'`{"`, `".join(str(c) for c in choices)}`'
|
||||
self._markdown_output.append(
|
||||
f"Possible choices: {choices}\n\n")
|
||||
|
||||
self._markdown_output.append(f"{action.help}\n\n")
|
||||
|
||||
if (default := action.default) != SUPPRESS:
|
||||
self._markdown_output.append(f"Default: `{default}`\n\n")
|
||||
|
||||
def format_help(self):
|
||||
"""Return the formatted help as markdown."""
|
||||
return "".join(self._markdown_output)
|
||||
|
||||
|
||||
def create_parser(cls, **kwargs) -> FlexibleArgumentParser:
|
||||
"""Create a parser for the given class with markdown formatting.
|
||||
|
||||
Args:
|
||||
cls: The class to create a parser for
|
||||
**kwargs: Additional keyword arguments to pass to `cls.add_cli_args`.
|
||||
|
||||
Returns:
|
||||
FlexibleArgumentParser: A parser with markdown formatting for the class.
|
||||
"""
|
||||
parser = FlexibleArgumentParser()
|
||||
parser.formatter_class = MarkdownFormatter
|
||||
with patch("vllm.config.DeviceConfig.__post_init__"):
|
||||
return cls.add_cli_args(parser, **kwargs)
|
||||
|
||||
|
||||
def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool):
|
||||
logger.info("Generating argparse documentation")
|
||||
logger.debug("Root directory: %s", ROOT_DIR.resolve())
|
||||
logger.debug("Output directory: %s", ARGPARSE_DOC_DIR.resolve())
|
||||
|
||||
# Create the ARGPARSE_DOC_DIR if it doesn't exist
|
||||
if not ARGPARSE_DOC_DIR.exists():
|
||||
ARGPARSE_DOC_DIR.mkdir(parents=True)
|
||||
|
||||
# Create parsers to document
|
||||
parsers = {
|
||||
"engine_args": create_parser(EngineArgs),
|
||||
"async_engine_args": create_parser(AsyncEngineArgs,
|
||||
async_args_only=True),
|
||||
}
|
||||
|
||||
# Generate documentation for each parser
|
||||
for stem, parser in parsers.items():
|
||||
doc_path = ARGPARSE_DOC_DIR / f"{stem}.md"
|
||||
with open(doc_path, "w") as f:
|
||||
f.write(parser.format_help())
|
||||
logger.info("Argparse generated: %s", doc_path.relative_to(ROOT_DIR))
|
||||
@ -161,8 +161,8 @@ def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool):
|
||||
for example in sorted(examples, key=lambda e: e.path.stem):
|
||||
example_name = f"{example.path.stem}.md"
|
||||
doc_path = EXAMPLE_DOC_DIR / example.category / example_name
|
||||
logger.debug("Example generated: %s", doc_path.relative_to(ROOT_DIR))
|
||||
if not doc_path.parent.exists():
|
||||
doc_path.parent.mkdir(parents=True)
|
||||
with open(doc_path, "w+") as f:
|
||||
f.write(example.generate())
|
||||
logger.debug("Example generated: %s", doc_path.relative_to(ROOT_DIR))
|
||||
|
||||
21
docs/mkdocs/overrides/partials/toc-item.html
Normal file
21
docs/mkdocs/overrides/partials/toc-item.html
Normal file
@ -0,0 +1,21 @@
|
||||
<!-- Enables the use of toc_depth in document frontmatter https://github.com/squidfunk/mkdocs-material/issues/4827#issuecomment-1869812019 -->
|
||||
<li class="md-nav__item">
|
||||
<a href="{{ toc_item.url }}" class="md-nav__link">
|
||||
<span class="md-ellipsis">
|
||||
{{ toc_item.title }}
|
||||
</span>
|
||||
</a>
|
||||
|
||||
<!-- Table of contents list -->
|
||||
{% if toc_item.children %}
|
||||
<nav class="md-nav" aria-label="{{ toc_item.title | striptags }}">
|
||||
<ul class="md-nav__list">
|
||||
{% for toc_item in toc_item.children %}
|
||||
{% if not page.meta.toc_depth or toc_item.level <= page.meta.toc_depth %}
|
||||
{% include "partials/toc-item.html" %}
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
</ul>
|
||||
</nav>
|
||||
{% endif %}
|
||||
</li>
|
||||
@ -374,6 +374,7 @@ Specified using `--task generate`.
|
||||
| `Phi3ForCausalLM` | Phi-4, Phi-3 | `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Phi3SmallForCausalLM` | Phi-3-Small | `microsoft/Phi-3-small-8k-instruct`, `microsoft/Phi-3-small-128k-instruct`, etc. | | ✅︎ | ✅︎ |
|
||||
| `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Phi4FlashForCausalLM` | Phi-4-mini-flash-reasoning | `microsoft/microsoft/Phi-4-mini-instruct`, etc. | | | |
|
||||
| `PersimmonForCausalLM` | Persimmon | `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | | |
|
||||
| `QWenLMHeadModel` | Qwen | `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
@ -712,6 +713,14 @@ The following table lists those that are tested in vLLM.
|
||||
|
||||
---
|
||||
|
||||
#### Scoring
|
||||
|
||||
Specified using `--task score`.
|
||||
|
||||
| Architecture | Models | Inputs | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | [V1](gh-issue:8779) |
|
||||
|-------------------------------------|--------------------|----------|--------------------------|------------------------|-----------------------------|-----------------------|
|
||||
| `JinaVLForSequenceClassification` | JinaVL-based | T + I<sup>E+</sup> | `jinaai/jina-reranker-m0`, etc. | | | ✅︎ |
|
||||
|
||||
## Model Support Policy
|
||||
|
||||
At vLLM, we are committed to facilitating the integration and support of third-party models within our ecosystem. Our approach is designed to balance the need for robustness and the practical limitations of supporting a wide range of models. Here’s how we manage third-party model support:
|
||||
|
||||
112
docs/serving/data_parallel_deployment.md
Normal file
112
docs/serving/data_parallel_deployment.md
Normal file
@ -0,0 +1,112 @@
|
||||
# Data Parallel Deployment
|
||||
|
||||
vLLM supports Data Parallel deployment, where model weights are replicated across separate instances/GPUs to process independent batches of requests.
|
||||
|
||||
This will work with both dense and MoE models.
|
||||
|
||||
For MoE models, particularly those like DeepSeek that employ MLA (Multi-head Latent Attention), it can be advantageous to use data parallel for the attention layers and expert or tensor parallel (EP or TP) for the expert layers.
|
||||
|
||||
In these cases, the data parallel ranks are not completely independent. Forward passes must be aligned, and expert layers across all ranks are required to synchronize during every forward pass, even when there are fewer requests to be processed than DP ranks.
|
||||
|
||||
The expert layers will by default form a (DP x TP) sized tensor parallel group. To enable expert parallelism, include the `--enable-expert-parallel` CLI arg (on all nodes in the multi-node case).
|
||||
|
||||
In vLLM, each DP rank is deployed as a separate "core engine" process that communicates with front-end process(es) via ZMQ sockets. Data Parallel attention can be combined with Tensor Parallel attention, in which case each DP engine owns a number of per-GPU worker processes equal to the configured TP size.
|
||||
|
||||
For MoE models, when any requests are in progress in any rank, we must ensure that empty "dummy" forward passes are performed in all ranks that don't currently have any requests scheduled. This is handled via a separate DP Coordinator process that communicates with all ranks, and a collective operation performed every N steps to determine when all ranks become idle and can be paused. When TP is used in conjunction with DP, expert layers form an EP or TP group of size (DP x TP).
|
||||
|
||||
In all cases, it is beneficial to load-balance requests between DP ranks. For online deployments, this balancing can be optimized by taking into account the state of each DP engine - in particular its currently scheduled and waiting (queued) requests, and KV cache state. Each DP engine has an independent KV cache, and the benefit of prefix caching can be maximized by directing prompts intelligently.
|
||||
|
||||
This document focuses on online deployments (with the API server). DP + EP is also supported for offline usage (via the LLM class), for an example see <gh-file:examples/offline_inference/data_parallel.py>.
|
||||
|
||||
There are two distinct modes supported for online deployments - self-contained with internal load balancing, or externally per-rank process deployment and load balancing.
|
||||
|
||||
## Internal Load Balancing
|
||||
|
||||
vLLM supports "self-contained" data parallel deployments that expose a single API endpoint.
|
||||
|
||||
It can be configured by simply including e.g. `--data-parallel-size=4` in the vllm serve command line arguments. This will require 4 GPUs. It can be combined with tensor parallel, for example `--data-parallel-size=4 --tensor-parallel-size=2`, which would require 8 GPUs.
|
||||
|
||||
Running a single data parallel deployment across multiple nodes requires a different `vllm serve` to be run on each node, specifying which DP ranks should run on that node. In this case, there will still be a single HTTP entrypoint - the API server(s) will run only on one node, but it doesn't necessarily need to be co-located with the DP ranks.
|
||||
|
||||
This will run DP=4, TP=2 on a single 8-GPU node:
|
||||
|
||||
```bash
|
||||
vllm serve $MODEL --data-parallel-size 4 --tensor-parallel-size 2
|
||||
```
|
||||
|
||||
This will run DP=4 with DP ranks 0 and 1 on the head node and ranks 2 and 3 on the second node:
|
||||
|
||||
```bash
|
||||
# Node 0 (with ip address 10.99.48.128)
|
||||
vllm serve $MODEL --data-parallel-size 4 --data-parallel-size-local 2 \
|
||||
--data-parallel-address 10.99.48.128 --data-parallel-rpc-port 13345
|
||||
# Node 1
|
||||
vllm serve $MODEL --headless --data-parallel-size 4 --data-parallel-size-local 2 \
|
||||
--data-parallel-start-rank 2 \
|
||||
--data-parallel-address 10.99.48.128 --data-parallel-rpc-port 13345
|
||||
```
|
||||
|
||||
This will run DP=4 with only the API server on the first node and all engines on the second node:
|
||||
|
||||
```bash
|
||||
# Node 0 (with ip address 10.99.48.128)
|
||||
vllm serve $MODEL --data-parallel-size 4 --data-parallel-size-local 0 \
|
||||
--data-parallel-address 10.99.48.128 --data-parallel-rpc-port 13345
|
||||
# Node 1
|
||||
vllm serve $MODEL --headless --data-parallel-size 4 --data-parallel-size-local 4 \
|
||||
--data-parallel-address 10.99.48.128 --data-parallel-rpc-port 13345
|
||||
```
|
||||
|
||||
This DP mode can also be used with Ray, in which case only a single launch command is needed irrespective of the number of nodes:
|
||||
|
||||
```bash
|
||||
vllm serve $MODEL --data-parallel-size 16 --tensor-parallel-size 2 --data-parallel-backend=ray
|
||||
```
|
||||
|
||||
Currently, the internal DP load balancing is done within the API server process(es) and is based on the running and waiting queues in each of the engines. This could be made more sophisticated in future by incorporating KV cache aware logic.
|
||||
|
||||
When deploying large DP sizes using this method, the API server process can become a bottleneck. In this case, the orthogonal `--api-server-count` command line option can be used to scale this out (for example `--api-server-count=4`). This is transparent to users - a single HTTP endpoint / port is still exposed. Note that this API server scale-out is "internal" and still confined to the "head" node.
|
||||
|
||||
<figure markdown="1">
|
||||

|
||||
</figure>
|
||||
|
||||
## External Load Balancing
|
||||
|
||||
For larger scale deployments especially, it can make sense to handle the orchestration and load balancing of data parallel ranks externally.
|
||||
|
||||
In this case, it's more convenient to treat each DP rank like a separate vLLM deployment, with its own endpoint, and have an external router balance HTTP requests between them, making use of appropriate real-time telemetry from each server for routing decisions.
|
||||
|
||||
This can already be done trivially for non-MoE models, since each deployed server is fully independent. No data parallel CLI options need to be used for this.
|
||||
|
||||
We support an equivalent topology for MoE DP+EP which can be configured via the following CLI arguments.
|
||||
|
||||
If DP ranks are co-located (same node / ip address), a default RPC port is used, but a different HTTP server port must be specified for each rank:
|
||||
|
||||
```bash
|
||||
# Rank 0
|
||||
CUDA_VISIBLE_DEVICES=0 vllm serve $MODEL --data-parallel-size 2 --data-parallel-rank 0 \
|
||||
--port 8000
|
||||
# Rank 1
|
||||
CUDA_VISIBLE_DEVICES=1 vllm serve $MODEL --data-parallel-size 2 --data-parallel-rank 1 \
|
||||
--port 8001
|
||||
```
|
||||
|
||||
For multi-node cases, the address/port of rank 0 must also be specified:
|
||||
|
||||
```bash
|
||||
# Rank 0 (with ip address 10.99.48.128)
|
||||
vllm serve $MODEL --data-parallel-size 2 --data-parallel-rank 0 \
|
||||
--data-parallel-address 10.99.48.128 --data-parallel-rpc-port 13345
|
||||
# Rank 1
|
||||
vllm serve $MODEL --data-parallel-size 2 --data-parallel-rank 1 \
|
||||
--data-parallel-address 10.99.48.128 --data-parallel-rpc-port 13345
|
||||
```
|
||||
|
||||
The coordinator process also runs in this scenario, co-located with the DP rank 0 engine.
|
||||
|
||||
<figure markdown="1">
|
||||

|
||||
</figure>
|
||||
|
||||
In the above diagram, each of the dotted boxes corresponds to a separate launch of `vllm serve` - these could be separate Kubernetes pods, for example.
|
||||
@ -15,6 +15,10 @@ After adding enough GPUs and nodes to hold the model, you can run vLLM first, wh
|
||||
!!! note
|
||||
There is one edge case: if the model fits in a single node with multiple GPUs, but the number of GPUs cannot divide the model size evenly, you can use pipeline parallelism, which splits the model along layers and supports uneven splits. In this case, the tensor parallel size should be 1 and the pipeline parallel size should be the number of GPUs.
|
||||
|
||||
### Distributed serving of MoE (Mixture of Experts) models
|
||||
|
||||
It is often advantageous to exploit the inherent parallelism of experts by using a separate parallelism strategy for the expert layers. vLLM supports large-scale deployment combining Data Parallel attention with Expert or Tensor Parallel MoE layers. See the page on [Data Parallel Deployment](data_parallel_deployment.md) for more information.
|
||||
|
||||
## Running vLLM on a single node
|
||||
|
||||
vLLM supports distributed tensor-parallel and pipeline-parallel inference and serving. Currently, we support [Megatron-LM's tensor parallel algorithm](https://arxiv.org/pdf/1909.08053.pdf). We manage the distributed runtime with either [Ray](https://github.com/ray-project/ray) or python native multiprocessing. Multiprocessing can be used when deploying on a single node, multi-node inference currently requires Ray.
|
||||
|
||||
@ -537,7 +537,7 @@ The following extra parameters are supported:
|
||||
|
||||
### Score API
|
||||
|
||||
Our Score API can apply a cross-encoder model or an embedding model to predict scores for sentence pairs. When using an embedding model the score corresponds to the cosine similarity between each embedding pair.
|
||||
Our Score API can apply a cross-encoder model or an embedding model to predict scores for sentence or multimodal pairs. When using an embedding model the score corresponds to the cosine similarity between each embedding pair.
|
||||
Usually, the score for a sentence pair refers to the similarity between two sentences, on a scale of 0 to 1.
|
||||
|
||||
You can find the documentation for cross encoder models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html).
|
||||
@ -676,6 +676,55 @@ The total number of pairs is `len(text_2)`.
|
||||
}
|
||||
```
|
||||
|
||||
#### Multi-modal inputs
|
||||
|
||||
You can pass multi-modal inputs to scoring models by passing `content` including a list of multi-modal input (image, etc.) in the request. Refer to the examples below for illustration.
|
||||
|
||||
=== "JinaVL-Reranker"
|
||||
|
||||
To serve the model:
|
||||
|
||||
```bash
|
||||
vllm serve jinaai/jina-reranker-m0
|
||||
```
|
||||
|
||||
Since the request schema is not defined by OpenAI client, we post a request to the server using the lower-level `requests` library:
|
||||
|
||||
??? Code
|
||||
|
||||
```python
|
||||
import requests
|
||||
|
||||
response = requests.post(
|
||||
"http://localhost:8000/v1/score",
|
||||
json={
|
||||
"model": "jinaai/jina-reranker-m0",
|
||||
"text_1": "slm markdown",
|
||||
"text_2": {
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png"
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png"
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
print("Scoring output:", response_json["data"][0]["score"])
|
||||
print("Scoring output:", response_json["data"][1]["score"])
|
||||
```
|
||||
Full example: <gh-file:examples/online_serving/openai_cross_encoder_score_for_multimodal.py>
|
||||
|
||||
#### Extra parameters
|
||||
|
||||
The following [pooling parameters][pooling-params] are supported.
|
||||
@ -695,8 +744,7 @@ The following extra parameters are supported:
|
||||
### Re-rank API
|
||||
|
||||
Our Re-rank API can apply an embedding model or a cross-encoder model to predict relevant scores between a single query, and
|
||||
each of a list of documents. Usually, the score for a sentence pair refers to the similarity between two sentences, on
|
||||
a scale of 0 to 1.
|
||||
each of a list of documents. Usually, the score for a sentence pair refers to the similarity between two sentences or multi-modal inputs (image, etc.), on a scale of 0 to 1.
|
||||
|
||||
You can find the documentation for cross encoder models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html).
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This example shows how to use vLLM for running offline inference with
|
||||
the correct prompt format on vision language models for multimodal embedding.
|
||||
the correct prompt format on vision language models for multimodal pooling.
|
||||
|
||||
For most models, the prompt format should follow corresponding examples
|
||||
on HuggingFace model repository.
|
||||
@ -15,6 +15,7 @@ from typing import Literal, NamedTuple, Optional, TypedDict, Union, get_args
|
||||
from PIL.Image import Image
|
||||
|
||||
from vllm import LLM, EngineArgs
|
||||
from vllm.entrypoints.score_utils import ScoreMultiModalParam
|
||||
from vllm.multimodal.utils import fetch_image
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
@ -35,14 +36,22 @@ class TextImageQuery(TypedDict):
|
||||
image: Image
|
||||
|
||||
|
||||
QueryModality = Literal["text", "image", "text+image"]
|
||||
Query = Union[TextQuery, ImageQuery, TextImageQuery]
|
||||
class TextImagesQuery(TypedDict):
|
||||
modality: Literal["text+images"]
|
||||
text: str
|
||||
image: ScoreMultiModalParam
|
||||
|
||||
|
||||
QueryModality = Literal["text", "image", "text+image", "text+images"]
|
||||
Query = Union[TextQuery, ImageQuery, TextImageQuery, TextImagesQuery]
|
||||
|
||||
|
||||
class ModelRequestData(NamedTuple):
|
||||
engine_args: EngineArgs
|
||||
prompt: str
|
||||
image: Optional[Image]
|
||||
prompt: Optional[str] = None
|
||||
image: Optional[Image] = None
|
||||
query: Optional[str] = None
|
||||
documents: Optional[ScoreMultiModalParam] = None
|
||||
|
||||
|
||||
def run_e5_v(query: Query) -> ModelRequestData:
|
||||
@ -107,6 +116,29 @@ def run_vlm2vec(query: Query) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
def run_jinavl_reranker(query: Query) -> ModelRequestData:
|
||||
if query["modality"] != "text+images":
|
||||
raise ValueError(f"Unsupported query modality: '{query['modality']}'")
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model="jinaai/jina-reranker-m0",
|
||||
task="score",
|
||||
max_model_len=32768,
|
||||
trust_remote_code=True,
|
||||
mm_processor_kwargs={
|
||||
"min_pixels": 3136,
|
||||
"max_pixels": 602112,
|
||||
},
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
query=query["text"],
|
||||
documents=query["image"],
|
||||
)
|
||||
|
||||
|
||||
def get_query(modality: QueryModality):
|
||||
if modality == "text":
|
||||
return TextQuery(modality="text", text="A dog sitting in the grass")
|
||||
@ -128,6 +160,28 @@ def get_query(modality: QueryModality):
|
||||
),
|
||||
)
|
||||
|
||||
if modality == "text+images":
|
||||
return TextImagesQuery(
|
||||
modality="text+images",
|
||||
text="slm markdown",
|
||||
image={
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png"
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png"
|
||||
},
|
||||
},
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
msg = f"Modality {modality} is not supported."
|
||||
raise ValueError(msg)
|
||||
|
||||
@ -162,16 +216,31 @@ def run_encode(model: str, modality: QueryModality, seed: Optional[int]):
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
def run_score(model: str, modality: QueryModality, seed: Optional[int]):
|
||||
query = get_query(modality)
|
||||
req_data = model_example_map[model](query)
|
||||
|
||||
engine_args = asdict(req_data.engine_args) | {"seed": seed}
|
||||
llm = LLM(**engine_args)
|
||||
|
||||
outputs = llm.score(req_data.query, req_data.documents)
|
||||
|
||||
print("-" * 30)
|
||||
print([output.outputs.score for output in outputs])
|
||||
print("-" * 30)
|
||||
|
||||
|
||||
model_example_map = {
|
||||
"e5_v": run_e5_v,
|
||||
"vlm2vec": run_vlm2vec,
|
||||
"jinavl_reranker": run_jinavl_reranker,
|
||||
}
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Demo on using vLLM for offline inference with "
|
||||
"vision language models for multimodal embedding"
|
||||
"vision language models for multimodal pooling tasks."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
@ -181,6 +250,14 @@ def parse_args():
|
||||
choices=model_example_map.keys(),
|
||||
help="The name of the embedding model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
"-t",
|
||||
type=str,
|
||||
default="embedding",
|
||||
choices=["embedding", "scoring"],
|
||||
help="The task type.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--modality",
|
||||
type=str,
|
||||
@ -198,7 +275,12 @@ def parse_args():
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
run_encode(args.model_name, args.modality, args.seed)
|
||||
if args.task == "embedding":
|
||||
run_encode(args.model_name, args.modality, args.seed)
|
||||
elif args.task == "scoring":
|
||||
run_score(args.model_name, args.modality, args.seed)
|
||||
else:
|
||||
raise ValueError(f"Unsupported task: {args.task}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -0,0 +1,60 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Example online usage of Score API.
|
||||
|
||||
Run `vllm serve <model> --task score` to start up the server in vLLM.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import pprint
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def post_http_request(prompt: dict, api_url: str) -> requests.Response:
|
||||
headers = {"User-Agent": "Test Client"}
|
||||
response = requests.post(api_url, headers=headers, json=prompt)
|
||||
return response
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--model", type=str, default="jinaai/jina-reranker-m0")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args):
|
||||
api_url = f"http://{args.host}:{args.port}/score"
|
||||
model_name = args.model
|
||||
|
||||
text_1 = "slm markdown"
|
||||
text_2 = {
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png"
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png"
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
prompt = {"model": model_name, "text_1": text_1, "text_2": text_2}
|
||||
score_response = post_http_request(prompt=prompt, api_url=api_url)
|
||||
print("\nPrompt when text_1 is string and text_2 is a image list:")
|
||||
pprint.pprint(prompt)
|
||||
print("\nScore Response:")
|
||||
pprint.pprint(score_response.json())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
@ -11,7 +11,7 @@
|
||||
{% set ns.system_prompt = ns.system_prompt + '\n\n' + message['content'] %}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- endfor -%}
|
||||
|
||||
{#- Adapted from https://github.com/sgl-project/sglang/blob/main/examples/chat_template/tool_chat_template_deepseekr1.jinja #}
|
||||
{% if tools is defined and tools is not none %}
|
||||
@ -27,8 +27,8 @@
|
||||
{% set ns.system_prompt = ns.system_prompt + '\n\n' + tool_ns.text %}
|
||||
{% endif %}
|
||||
|
||||
{{ bos_token }}
|
||||
{{ ns.system_prompt }}
|
||||
{{- bos_token }}
|
||||
{{- ns.system_prompt }}
|
||||
{%- for message in messages %}
|
||||
{% set content = message['content'] %}
|
||||
{%- if message['role'] == 'user' %}
|
||||
@ -45,7 +45,7 @@
|
||||
{%- if message['role'] == 'assistant' and message['tool_calls'] is defined and message['tool_calls'] is not none %}
|
||||
{%- set ns.is_last_user = false -%}
|
||||
{%- if ns.is_tool %}
|
||||
{{'<|tool▁outputs▁end|>'}}
|
||||
{{- '<|tool▁outputs▁end|>'}}
|
||||
{%- endif %}
|
||||
{%- set ns.is_first = false %}
|
||||
{%- set ns.is_tool = false -%}
|
||||
@ -53,40 +53,40 @@
|
||||
{%- for tool in message['tool_calls'] %}
|
||||
{%- if not ns.is_first %}
|
||||
{%- if content is none %}
|
||||
{{'<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments']|tojson + '\n' + '```' + '<|tool▁call▁end|>'}}
|
||||
{{- '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments']|tojson + '\n' + '```' + '<|tool▁call▁end|>'}}
|
||||
{%- else %}
|
||||
{{content + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments']|tojson + '\n' + '```' + '<|tool▁call▁end|>'}}
|
||||
{{- content + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments']|tojson + '\n' + '```' + '<|tool▁call▁end|>'}}
|
||||
{%- endif %}
|
||||
{%- set ns.is_first = true -%}
|
||||
{%- else %}
|
||||
{{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments']|tojson + '\n' + '```' + '<|tool▁call▁end|>'}}
|
||||
{{- '\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments']|tojson + '\n' + '```' + '<|tool▁call▁end|>'}}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}
|
||||
{{- '<|tool▁calls▁end|><|end▁of▁sentence|>'}}
|
||||
{%- endif %}
|
||||
{%- if message['role'] == 'assistant' and (message['tool_calls'] is not defined or message['tool_calls'] is none)%}
|
||||
{%- set ns.is_last_user = false -%}
|
||||
{%- if ns.is_tool %}
|
||||
{{'<|tool▁outputs▁end|>' + content + '<|end▁of▁sentence|>'}}
|
||||
{{- '<|tool▁outputs▁end|>' + content + '<|end▁of▁sentence|>'}}
|
||||
{%- set ns.is_tool = false -%}
|
||||
{%- else %}
|
||||
{{content + '<|end▁of▁sentence|>'}}
|
||||
{{- content + '<|end▁of▁sentence|>'}}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- if message['role'] == 'tool' %}
|
||||
{%- set ns.is_last_user = false -%}
|
||||
{%- set ns.is_tool = true -%}
|
||||
{%- if ns.is_output_first %}
|
||||
{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + content + '<|tool▁output▁end|>'}}
|
||||
{{- '<|tool▁outputs▁begin|><|tool▁output▁begin|>' + content + '<|tool▁output▁end|>'}}
|
||||
{%- set ns.is_output_first = false %}
|
||||
{%- else %}
|
||||
{{'\n<|tool▁output▁begin|>' + content + '<|tool▁output▁end|>'}}
|
||||
{{- '\n<|tool▁output▁begin|>' + content + '<|tool▁output▁end|>'}}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- endfor -%}
|
||||
{% if ns.is_tool %}
|
||||
{{'<|tool▁outputs▁end|>'}}
|
||||
{% endif %}
|
||||
{{- '<|tool▁outputs▁end|>'}}
|
||||
{%- endif %}
|
||||
{% if add_generation_prompt and not ns.is_last_user and not ns.is_tool %}
|
||||
{{'<|Assistant|>'}}
|
||||
{% endif %}
|
||||
{{- '<|Assistant|>'}}
|
||||
{%- endif %}
|
||||
@ -3,6 +3,7 @@ site_url: https://docs.vllm.ai
|
||||
repo_url: https://github.com/vllm-project/vllm
|
||||
edit_uri: edit/main/docs/
|
||||
exclude_docs: |
|
||||
argparse
|
||||
*.inc.md
|
||||
*.template.md
|
||||
theme:
|
||||
@ -47,6 +48,7 @@ theme:
|
||||
hooks:
|
||||
- docs/mkdocs/hooks/remove_announcement.py
|
||||
- docs/mkdocs/hooks/generate_examples.py
|
||||
- docs/mkdocs/hooks/generate_argparse.py
|
||||
- docs/mkdocs/hooks/url_schemes.py
|
||||
|
||||
# Required to stop api-autonav from raising an error
|
||||
|
||||
@ -21,7 +21,9 @@ prometheus-fastapi-instrumentator >= 7.0.0
|
||||
tiktoken >= 0.6.0 # Required for DBRX tokenizer
|
||||
lm-format-enforcer >= 0.10.11, < 0.11
|
||||
llguidance >= 0.7.11, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64"
|
||||
outlines == 0.1.11
|
||||
outlines_core == 0.2.10
|
||||
# required for outlines backend disk cache
|
||||
diskcache == 5.6.3
|
||||
lark == 1.2.2
|
||||
xgrammar == 0.1.19; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64"
|
||||
typing_extensions >= 4.10
|
||||
@ -38,7 +40,7 @@ six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that need
|
||||
setuptools>=77.0.3,<80; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12
|
||||
einops # Required for Qwen2-VL.
|
||||
compressed-tensors == 0.10.2 # required for compressed-tensors
|
||||
depyf==0.18.0 # required for profiling and debugging with compilation config
|
||||
depyf==0.19.0 # required for profiling and debugging with compilation config
|
||||
cloudpickle # allows pickling lambda functions in model_executor/models/registry.py
|
||||
watchfiles # required for http server to monitor the updates of TLS files
|
||||
python-json-logger # Used by logging as per examples/others/logging_configuration.md
|
||||
|
||||
@ -7,3 +7,18 @@ mkdocs-awesome-nav
|
||||
python-markdown-math
|
||||
regex
|
||||
ruff
|
||||
|
||||
# Required for argparse hook only
|
||||
-f https://download.pytorch.org/whl/cpu
|
||||
cachetools
|
||||
cloudpickle
|
||||
fastapi
|
||||
msgspec
|
||||
openai
|
||||
pillow
|
||||
psutil
|
||||
pybase64
|
||||
pydantic
|
||||
torch
|
||||
transformers
|
||||
zmq
|
||||
|
||||
@ -17,3 +17,4 @@ setuptools>=77.0.3,<80.0.0
|
||||
setuptools-scm>=8
|
||||
runai-model-streamer==0.11.0
|
||||
runai-model-streamer-s3==0.11.0
|
||||
conch-triton-kernels==1.2.1
|
||||
|
||||
@ -34,7 +34,7 @@ opencv-python-headless >= 4.11.0 # required for video test
|
||||
datamodel_code_generator # required for minicpm3 test
|
||||
lm-eval[api]==0.4.8 # required for model evaluation test
|
||||
mteb[bm25s]>=1.38.11, <2 # required for mteb test
|
||||
transformers==4.52.4
|
||||
transformers==4.53.2
|
||||
tokenizers==0.21.1
|
||||
huggingface-hub[hf_xet]>=0.33.0 # Required for Xet downloads.
|
||||
schemathesis>=3.39.15 # Required for openai schema test.
|
||||
|
||||
@ -800,7 +800,7 @@ tqdm==4.66.6
|
||||
# transformers
|
||||
tqdm-multiprocess==0.0.11
|
||||
# via lm-eval
|
||||
transformers==4.52.4
|
||||
transformers==4.53.2
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# genai-perf
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import pytest
|
||||
@ -111,6 +112,11 @@ def test_full_graph(
|
||||
pass_config=PassConfig(enable_fusion=True,
|
||||
enable_noop=True)), model)
|
||||
for model in models_list(keywords=["FP8-dynamic", "quantized.w8a8"])
|
||||
] + [
|
||||
# Test depyf integration works
|
||||
(CompilationConfig(level=CompilationLevel.PIECEWISE,
|
||||
debug_dump_path=tempfile.gettempdir()),
|
||||
("facebook/opt-125m", {})),
|
||||
])
|
||||
# only test some of the models
|
||||
@create_new_process_for_each_test()
|
||||
|
||||
@ -44,7 +44,9 @@ class TestModel(torch.nn.Module):
|
||||
]
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
cutlass_fp8_supported=cutlass_fp8_enabled,
|
||||
use_per_token_if_dynamic=True)
|
||||
act_quant_static=static,
|
||||
act_quant_group_shape=group_shape,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
resid = torch.sqrt(x)
|
||||
@ -91,9 +93,10 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
|
||||
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
|
||||
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"]))
|
||||
vllm_config.compilation_config.pass_config = \
|
||||
PassConfig(enable_fusion=True, enable_noop=True)
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
custom_ops=["+rms_norm", "+quant_fp8"],
|
||||
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
|
||||
))
|
||||
with vllm.config.set_current_vllm_config(vllm_config):
|
||||
# Reshape pass is needed for the fusion pass to work
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
|
||||
152
tests/compile/test_fusion_all_reduce.py
Normal file
152
tests/compile/test_fusion_all_reduce.py
Normal file
@ -0,0 +1,152 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from importlib.util import find_spec
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.collective_fusion import AllReduceFusionPass
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig,
|
||||
ModelConfig, PassConfig, VllmConfig)
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import update_environment_variables
|
||||
|
||||
from ..utils import multi_gpu_test
|
||||
from .backend import TestBackend
|
||||
|
||||
|
||||
class TestAllReduceRMSNormModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size=16, eps=1e-6):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
self.norm = RMSNorm(hidden_size, eps)
|
||||
|
||||
def forward(self, hidden_states, residual):
|
||||
view = hidden_states.reshape(-1, self.hidden_size)
|
||||
all_reduce = tensor_model_parallel_all_reduce(view)
|
||||
norm = self.norm(all_reduce)
|
||||
return norm
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [torch.ops.vllm.all_reduce.default]
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
|
||||
|
||||
|
||||
class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size=16, eps=1e-6):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
self.norm = RMSNorm(hidden_size, eps)
|
||||
|
||||
def forward(self, hidden_states, residual):
|
||||
view = hidden_states.reshape(-1, self.hidden_size)
|
||||
all_reduce = tensor_model_parallel_all_reduce(view)
|
||||
norm, _ = self.norm(all_reduce, residual)
|
||||
return norm
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [torch.ops.vllm.all_reduce.default]
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize(
|
||||
"test_model",
|
||||
[TestAllReduceRMSNormModel, TestAllReduceFusedAddRMSNormModel])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seq_len", [8])
|
||||
@pytest.mark.parametrize("hidden_size", [4096])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
|
||||
reason="Only test on CUDA")
|
||||
@pytest.mark.skipif(not find_spec("flashinfer"),
|
||||
reason="flashinfer is not installed")
|
||||
@pytest.mark.skipif(not current_platform.is_device_capability(100),
|
||||
reason="Only test on SM100")
|
||||
def test_all_reduce_fusion_pass_replace(test_model: torch.nn.Module,
|
||||
batch_size: int, seq_len: int,
|
||||
hidden_size: int, dtype: torch.dtype):
|
||||
num_processes = 2
|
||||
|
||||
def run_torch_spawn(fn, nprocs):
|
||||
torch.multiprocessing.spawn(fn,
|
||||
args=(num_processes, test_model,
|
||||
batch_size, seq_len, hidden_size,
|
||||
dtype),
|
||||
nprocs=nprocs)
|
||||
|
||||
run_torch_spawn(all_reduce_fusion_pass_on_test_model, num_processes)
|
||||
|
||||
|
||||
def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
|
||||
test_model_cls: torch.nn.Module,
|
||||
batch_size: int, seq_len: int,
|
||||
hidden_size: int, dtype: torch.dtype):
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
torch.cuda.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
torch.set_default_dtype(dtype)
|
||||
|
||||
update_environment_variables({
|
||||
'RANK': str(local_rank),
|
||||
'LOCAL_RANK': str(local_rank),
|
||||
'WORLD_SIZE': str(world_size),
|
||||
'MASTER_ADDR': 'localhost',
|
||||
'MASTER_PORT': '12345',
|
||||
})
|
||||
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(level=CompilationLevel.PIECEWISE,
|
||||
custom_ops=["+rms_norm"],
|
||||
compile_sizes=[2, 4, 8]))
|
||||
vllm_config.compilation_config.pass_config = PassConfig(
|
||||
enable_fi_allreduce_fusion=True)
|
||||
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
||||
|
||||
# this is a fake model name to construct the model config
|
||||
# in the vllm_config, it's not really used.
|
||||
model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
|
||||
vllm_config.model_config = ModelConfig(model=model_name,
|
||||
task="auto",
|
||||
tokenizer=model_name,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=True,
|
||||
dtype=dtype,
|
||||
seed=42)
|
||||
|
||||
all_reduce_fusion_pass = AllReduceFusionPass(
|
||||
vllm_config, vllm_config.compilation_config.pass_config.
|
||||
fi_allreduce_fusion_max_token_num)
|
||||
backend = TestBackend(all_reduce_fusion_pass)
|
||||
|
||||
model = test_model_cls(hidden_size)
|
||||
|
||||
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
|
||||
requires_grad=False)
|
||||
residual = torch.randn((batch_size * seq_len, hidden_size),
|
||||
requires_grad=False)
|
||||
|
||||
compiled_model = torch.compile(model, backend=backend)
|
||||
compiled_model(hidden_states, residual)
|
||||
|
||||
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
|
||||
backend.check_after_ops(model.ops_in_model_after())
|
||||
del all_reduce_fusion_pass
|
||||
@ -50,6 +50,7 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str,
|
||||
# DYNAMO_ONCE does not properly propagate shapes.
|
||||
level=CompilationLevel.DYNAMO_AS_IS,
|
||||
backend="tests.compile.test_fusion_attn.backend_unfused",
|
||||
custom_ops=["+quant_fp8"],
|
||||
)
|
||||
vllm_config = VllmConfig(compilation_config=compile_config)
|
||||
backend_unfused = TestBackend(NoOpEliminationPass(vllm_config))
|
||||
@ -73,6 +74,7 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str,
|
||||
# DYNAMO_ONCE does not properly propagate shapes.
|
||||
level=CompilationLevel.DYNAMO_AS_IS,
|
||||
backend="tests.compile.test_fusion_attn.backend",
|
||||
custom_ops=["+quant_fp8"],
|
||||
)
|
||||
vllm_config = VllmConfig(compilation_config=compile_config)
|
||||
|
||||
|
||||
@ -4,33 +4,56 @@ import pytest
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm._custom_ops import scaled_fp8_quant
|
||||
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
|
||||
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.config import CompilationConfig, PassConfig, VllmConfig
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
CUTLASS_FP8_SUPPORTED, Fp8LinearOp)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .backend import TestBackend
|
||||
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, hidden_size: int, cutlass_fp8_enabled: bool, *args,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.silu_and_mul = SiluAndMul()
|
||||
self.wscale = torch.rand(1, dtype=torch.float32)
|
||||
self.scale = torch.rand(1, dtype=torch.float32)
|
||||
|
||||
self.w = (torch.rand(
|
||||
hidden_size,
|
||||
hidden_size).to(dtype=current_platform.fp8_dtype()).t())
|
||||
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
cutlass_fp8_supported=cutlass_fp8_enabled,
|
||||
act_quant_static=True,
|
||||
act_quant_group_shape=GroupShape.PER_TENSOR,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.silu_and_mul(x)
|
||||
x2 = scaled_fp8_quant(y, self.scale)
|
||||
x2 = self.fp8_linear.apply(y,
|
||||
self.w,
|
||||
self.wscale,
|
||||
input_scale=self.wscale)
|
||||
return x2
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [256])
|
||||
@pytest.mark.parametrize("hidden_size", [64])
|
||||
@pytest.mark.parametrize("cutlass_fp8_enabled",
|
||||
[True, False] if CUTLASS_FP8_SUPPORTED else [False])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
|
||||
reason="Only test on CUDA and ROCm")
|
||||
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size):
|
||||
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size,
|
||||
cutlass_fp8_enabled):
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(torch.float16)
|
||||
|
||||
@ -40,11 +63,11 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size):
|
||||
pass_config=PassConfig(enable_fusion=True, enable_noop=True))
|
||||
fusion_pass = ActivationQuantFusionPass(config)
|
||||
|
||||
backend = TestBackend(fusion_pass)
|
||||
model = TestModel()
|
||||
backend = TestBackend(NoOpEliminationPass(config), fusion_pass)
|
||||
model = TestModel(hidden_size, cutlass_fp8_enabled)
|
||||
|
||||
# First dimension dynamic
|
||||
x = torch.rand(num_tokens, hidden_size)
|
||||
x = torch.rand(num_tokens, hidden_size * 2)
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
|
||||
result = model(x)
|
||||
|
||||
@ -759,7 +759,8 @@ class VllmRunner:
|
||||
- `trust_remote_code`: Set to `True` instead of `False` for convenience.
|
||||
- `seed`: Set to `0` instead of `None` for test reproducibility.
|
||||
- `max_model_len`: Set to `1024` instead of `None` to reduce memory usage.
|
||||
- `block_size`: Set to `16` instead of `None` to reduce memory usage.
|
||||
- `block_size`: To reduce memory usage, set default to `64` if on XPU
|
||||
devices, otherwise default to `16`.
|
||||
- `enable_chunked_prefill`: Set to `False` instead of `None` for
|
||||
test reproducibility.
|
||||
- `enforce_eager`: Set to `False` to test CUDA graph.
|
||||
@ -777,7 +778,7 @@ class VllmRunner:
|
||||
dtype: str = "auto",
|
||||
disable_log_stats: bool = True,
|
||||
tensor_parallel_size: int = 1,
|
||||
block_size: int = 16,
|
||||
block_size: int = 16 if not torch.xpu.is_available() else 64,
|
||||
enable_chunked_prefill: Optional[bool] = False,
|
||||
swap_space: int = 4,
|
||||
enforce_eager: Optional[bool] = False,
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
import multiprocessing
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed
|
||||
@ -177,6 +178,38 @@ def test_pynccl_all_gather():
|
||||
distributed_run(all_gather_worker_fn, 2)
|
||||
|
||||
|
||||
@worker_fn_wrapper
|
||||
def all_gatherv_worker_fn():
|
||||
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
||||
device=get_world_group().device)
|
||||
|
||||
rank = pynccl_comm.rank
|
||||
world_size = pynccl_comm.world_size
|
||||
device = f'cuda:{pynccl_comm.rank}'
|
||||
|
||||
assert world_size <= 8
|
||||
sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size]
|
||||
num_elems = sizes[rank]
|
||||
tensor = torch.arange(num_elems, dtype=torch.float32,
|
||||
device=device) + rank * 100
|
||||
result = torch.zeros(sum(sizes), dtype=torch.float32, device=device)
|
||||
|
||||
expected = torch.cat([
|
||||
torch.arange(sizes[r], dtype=torch.float32) + r * 100
|
||||
for r in range(world_size)
|
||||
]).to(device)
|
||||
|
||||
pynccl_comm.all_gatherv(result, tensor, sizes=sizes)
|
||||
torch.cuda.synchronize()
|
||||
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
def test_pynccl_all_gatherv():
|
||||
distributed_run(all_gatherv_worker_fn, 2)
|
||||
|
||||
|
||||
@worker_fn_wrapper
|
||||
def reduce_scatter_worker_fn():
|
||||
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
||||
@ -214,6 +247,43 @@ def test_pynccl_reduce_scatter():
|
||||
distributed_run(reduce_scatter_worker_fn, 2)
|
||||
|
||||
|
||||
@worker_fn_wrapper
|
||||
def reduce_scatterv_worker_fn():
|
||||
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
||||
device=get_world_group().device)
|
||||
|
||||
rank = pynccl_comm.rank
|
||||
world_size = pynccl_comm.world_size
|
||||
device = f'cuda:{pynccl_comm.rank}'
|
||||
|
||||
assert world_size <= 8
|
||||
sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size]
|
||||
num_elems = sum(sizes)
|
||||
tensor = torch.arange(num_elems, dtype=torch.float32,
|
||||
device=device) + rank * 100
|
||||
result = torch.zeros(sizes[rank], dtype=torch.float32, device=device)
|
||||
|
||||
# Calculate expected result for this rank's chunk
|
||||
all_tensors = [
|
||||
torch.arange(num_elems, dtype=torch.float32) + r * 100
|
||||
for r in range(world_size)
|
||||
]
|
||||
sizes_cumsum = np.cumsum(sizes)
|
||||
start = 0 if rank == 0 else sizes_cumsum[rank - 1]
|
||||
end = sizes_cumsum[rank]
|
||||
expected = sum(tensor[start:end] for tensor in all_tensors).to(device)
|
||||
|
||||
pynccl_comm.reduce_scatterv(result, tensor, sizes=sizes)
|
||||
torch.cuda.synchronize()
|
||||
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
def test_pynccl_reduce_scatterv():
|
||||
distributed_run(reduce_scatterv_worker_fn, 2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
def test_pynccl_with_cudagraph():
|
||||
|
||||
@ -69,6 +69,12 @@ def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch):
|
||||
more_args = None
|
||||
if current_platform.is_tpu():
|
||||
# Limit compilation time for TPU V1
|
||||
|
||||
if model == "google/gemma-3-1b-it":
|
||||
pytest.skip(
|
||||
"Temporarily disabled due to test failures"
|
||||
"(timeout or accuracy mismatch). Re-enable once fixed.")
|
||||
|
||||
more_args = "max_model_len=2048,max_num_seqs=64"
|
||||
|
||||
# Add TP test (if provided)
|
||||
|
||||
@ -16,14 +16,18 @@ from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
GUIDED_DECODING_BACKENDS = [
|
||||
|
||||
# Separate backends which support grammars vs ones
|
||||
# which only support regex based constraints in tests.
|
||||
GRAMMAR_DECODING_BACKENDS = [
|
||||
# (backend, disable_any_whitespace),
|
||||
("outlines", False),
|
||||
("lm-format-enforcer", False),
|
||||
("xgrammar", True),
|
||||
("guidance", True),
|
||||
]
|
||||
|
||||
ALL_DECODING_BACKENDS = ([("outlines", False)] + GRAMMAR_DECODING_BACKENDS)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def llm():
|
||||
@ -39,7 +43,7 @@ def llm():
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GUIDED_DECODING_BACKENDS)
|
||||
ALL_DECODING_BACKENDS)
|
||||
def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
sampling_params = SamplingParams(
|
||||
@ -49,6 +53,7 @@ def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,
|
||||
regex=sample_regex,
|
||||
backend=guided_decoding_backend,
|
||||
disable_any_whitespace=disable_any_whitespace))
|
||||
|
||||
outputs = llm.generate(prompts=[
|
||||
f"Give an example IPv4 address with this regex: {sample_regex}"
|
||||
] * 2,
|
||||
@ -69,7 +74,7 @@ def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GUIDED_DECODING_BACKENDS)
|
||||
ALL_DECODING_BACKENDS)
|
||||
def test_guided_json_completion(sample_json_schema, llm,
|
||||
guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
@ -103,7 +108,7 @@ def test_guided_json_completion(sample_json_schema, llm,
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GUIDED_DECODING_BACKENDS)
|
||||
ALL_DECODING_BACKENDS)
|
||||
def test_guided_complex_json_completion(sample_complex_json_schema, llm,
|
||||
guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
@ -138,7 +143,7 @@ def test_guided_complex_json_completion(sample_complex_json_schema, llm,
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GUIDED_DECODING_BACKENDS)
|
||||
ALL_DECODING_BACKENDS)
|
||||
def test_guided_definition_json_completion(sample_definition_json_schema, llm,
|
||||
guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
@ -173,7 +178,7 @@ def test_guided_definition_json_completion(sample_definition_json_schema, llm,
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GUIDED_DECODING_BACKENDS)
|
||||
ALL_DECODING_BACKENDS)
|
||||
def test_guided_enum_json_completion(sample_enum_json_schema, llm,
|
||||
guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
@ -218,7 +223,7 @@ def test_guided_enum_json_completion(sample_enum_json_schema, llm,
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GUIDED_DECODING_BACKENDS)
|
||||
ALL_DECODING_BACKENDS)
|
||||
def test_guided_choice_completion(sample_guided_choice, llm,
|
||||
guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
@ -248,7 +253,7 @@ def test_guided_choice_completion(sample_guided_choice, llm,
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GUIDED_DECODING_BACKENDS)
|
||||
GRAMMAR_DECODING_BACKENDS)
|
||||
def test_guided_grammar(sample_sql_statements, llm,
|
||||
guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
@ -344,7 +349,7 @@ def test_disable_guided_decoding_fallback(sample_regex, llm):
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GUIDED_DECODING_BACKENDS)
|
||||
GRAMMAR_DECODING_BACKENDS)
|
||||
def test_guided_json_object(llm, guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
sampling_params = SamplingParams(
|
||||
@ -377,7 +382,9 @@ def test_guided_json_object(llm, guided_decoding_backend: str,
|
||||
|
||||
# Parse to verify it is valid JSON
|
||||
parsed_json = json.loads(generated_text)
|
||||
assert isinstance(parsed_json, dict)
|
||||
# A list is not what was intended, but is still valid
|
||||
# json.
|
||||
assert isinstance(parsed_json, (dict, list))
|
||||
|
||||
|
||||
class CarType(str, Enum):
|
||||
@ -395,7 +402,7 @@ class CarDescription(BaseModel):
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GUIDED_DECODING_BACKENDS)
|
||||
ALL_DECODING_BACKENDS)
|
||||
def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
json_schema = CarDescription.model_json_schema()
|
||||
@ -427,7 +434,7 @@ def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str,
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GUIDED_DECODING_BACKENDS)
|
||||
ALL_DECODING_BACKENDS)
|
||||
def test_guided_number_range_json_completion(llm, guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
sample_output_schema = {
|
||||
|
||||
@ -1113,10 +1113,7 @@ async def test_http_chat_no_model_name_with_curl(server: RemoteOpenAIServer):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME, ""])
|
||||
async def test_http_chat_no_model_name_with_openai(server: RemoteOpenAIServer,
|
||||
model_name: str):
|
||||
|
||||
async def test_http_chat_no_model_name_with_openai(server: RemoteOpenAIServer):
|
||||
openai_api_key = "EMPTY"
|
||||
openai_api_base = f"http://localhost:{server.port}/v1"
|
||||
|
||||
@ -1135,3 +1132,35 @@ async def test_http_chat_no_model_name_with_openai(server: RemoteOpenAIServer,
|
||||
messages=messages,
|
||||
)
|
||||
assert response.model == MODEL_NAME
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invocations(server: RemoteOpenAIServer,
|
||||
client: openai.AsyncOpenAI):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "what is 1+1?"
|
||||
}]
|
||||
|
||||
request_args = {
|
||||
"model": MODEL_NAME,
|
||||
"messages": messages,
|
||||
"max_completion_tokens": 5,
|
||||
"temperature": 0.0,
|
||||
"logprobs": False,
|
||||
}
|
||||
|
||||
chat_completion = await client.chat.completions.create(**request_args)
|
||||
|
||||
invocation_response = requests.post(server.url_for("invocations"),
|
||||
json=request_args)
|
||||
invocation_response.raise_for_status()
|
||||
|
||||
chat_output = chat_completion.model_dump()
|
||||
invocation_output = invocation_response.json()
|
||||
|
||||
assert chat_output.keys() == invocation_output.keys()
|
||||
assert chat_output["choices"] == invocation_output["choices"]
|
||||
|
||||
@ -155,3 +155,25 @@ def test_batch_classification_empty_list(server: RemoteOpenAIServer,
|
||||
assert output.object == "list"
|
||||
assert isinstance(output.data, list)
|
||||
assert len(output.data) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invocations(server: RemoteOpenAIServer):
|
||||
request_args = {
|
||||
"model": MODEL_NAME,
|
||||
"input": "This product was excellent and exceeded my expectations"
|
||||
}
|
||||
|
||||
classification_response = requests.post(server.url_for("classify"),
|
||||
json=request_args)
|
||||
classification_response.raise_for_status()
|
||||
|
||||
invocation_response = requests.post(server.url_for("invocations"),
|
||||
json=request_args)
|
||||
invocation_response.raise_for_status()
|
||||
|
||||
classification_output = classification_response.json()
|
||||
invocation_output = invocation_response.json()
|
||||
|
||||
assert classification_output.keys() == invocation_output.keys()
|
||||
assert classification_output["data"] == invocation_output["data"]
|
||||
|
||||
@ -11,6 +11,7 @@ import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import regex as re
|
||||
import requests
|
||||
# downloading lora to test lora requests
|
||||
from huggingface_hub import snapshot_download
|
||||
from openai import BadRequestError
|
||||
@ -833,3 +834,27 @@ async def test_echo_stream_completion(client: openai.AsyncOpenAI,
|
||||
assert content is not None and saying in content
|
||||
else:
|
||||
assert content is not None and saying not in content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invocations(server: RemoteOpenAIServer,
|
||||
client: openai.AsyncOpenAI):
|
||||
request_args = {
|
||||
"model": MODEL_NAME,
|
||||
"prompt": "Hello, my name is",
|
||||
"max_tokens": 5,
|
||||
"temperature": 0.0,
|
||||
"logprobs": None,
|
||||
}
|
||||
|
||||
completion = await client.completions.create(**request_args)
|
||||
|
||||
invocation_response = requests.post(server.url_for("invocations"),
|
||||
json=request_args)
|
||||
invocation_response.raise_for_status()
|
||||
|
||||
completion_output = completion.model_dump()
|
||||
invocation_output = invocation_response.json()
|
||||
|
||||
assert completion_output.keys() == invocation_output.keys()
|
||||
assert completion_output["choices"] == invocation_output["choices"]
|
||||
|
||||
@ -72,8 +72,43 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str,
|
||||
"The unit to fetch the temperature in",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
"options": {
|
||||
"$ref": "#/$defs/WeatherOptions",
|
||||
"description":
|
||||
"Optional parameters for weather query",
|
||||
},
|
||||
},
|
||||
"required": ["country", "unit"],
|
||||
"$defs": {
|
||||
"WeatherOptions": {
|
||||
"title": "WeatherOptions",
|
||||
"type": "object",
|
||||
"additionalProperties": False,
|
||||
"properties": {
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"default": "celsius",
|
||||
"description": "Temperature unit",
|
||||
"title": "Temperature Unit",
|
||||
},
|
||||
"include_forecast": {
|
||||
"type": "boolean",
|
||||
"default": False,
|
||||
"description":
|
||||
"Whether to include a 24-hour forecast",
|
||||
"title": "Include Forecast",
|
||||
},
|
||||
"language": {
|
||||
"type": "string",
|
||||
"default": "zh-CN",
|
||||
"description": "Language of the response",
|
||||
"title": "Language",
|
||||
"enum": ["zh-CN", "en-US", "ja-JP"],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
107
tests/entrypoints/openai/test_default_mm_loras.py
Normal file
107
tests/entrypoints/openai/test_default_mm_loras.py
Normal file
@ -0,0 +1,107 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from ...conftest import AudioTestAssets
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
# NOTE - the tests in this module are currently analogous to test_chat, but are
|
||||
# separated to avoid OOM killing due to module-scoped servers, since we
|
||||
# need a multimodal model for these tests.
|
||||
|
||||
# Contains a modality specific lora alongside the base model
|
||||
MULTIMODAL_MODEL_NAME = snapshot_download(
|
||||
"microsoft/Phi-4-multimodal-instruct")
|
||||
AUDIO_LORA_PATH = os.path.join(MULTIMODAL_MODEL_NAME, "speech-lora")
|
||||
|
||||
ACTIVE_MM_LORA_RESPONSE = "Spoken text: The first words I spoke in the original chronograph, a little piece of practical poetry. Mary had a little lamb, it slept with quite a snow, and everywhere that Mary went, the lamb was sure to go." # noqa: E501
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def monkeypatch_module():
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
mpatch = MonkeyPatch()
|
||||
yield mpatch
|
||||
mpatch.undo()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=[False, True])
|
||||
def multimodal_server(request, monkeypatch_module): # noqa: F811
|
||||
|
||||
use_v1 = request.param
|
||||
monkeypatch_module.setenv('VLLM_USE_V1', '1' if use_v1 else '0')
|
||||
|
||||
args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"half",
|
||||
"--max-model-len",
|
||||
"12800",
|
||||
"--enforce-eager",
|
||||
# lora config below
|
||||
"--enable-lora",
|
||||
"--lora-modules",
|
||||
f"speech={AUDIO_LORA_PATH}",
|
||||
"--max-lora-rank",
|
||||
"320",
|
||||
"--max-num-seqs",
|
||||
"2",
|
||||
"--trust-remote-code",
|
||||
"--gpu-memory-utilization",
|
||||
"0.8",
|
||||
"--default-mm-loras",
|
||||
f"{{\"audio\": \"{AUDIO_LORA_PATH}\"}}",
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MULTIMODAL_MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def multi_modal_client(multimodal_server):
|
||||
async with multimodal_server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
# base model with default lora should give the same response as lora model
|
||||
"model_name",
|
||||
[MULTIMODAL_MODEL_NAME, "speech"],
|
||||
)
|
||||
async def test_default_mm_lora_chat_completions(
|
||||
model_name: str,
|
||||
multi_modal_client: openai.AsyncOpenAI,
|
||||
audio_assets: AudioTestAssets,
|
||||
):
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": "Can you transcribe this audio?",
|
||||
}, {
|
||||
"type": "audio_url",
|
||||
"audio_url": {
|
||||
"url": audio_assets[0].url
|
||||
},
|
||||
}]
|
||||
}]
|
||||
|
||||
chat_completion = await multi_modal_client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_completion_tokens=128,
|
||||
temperature=0.0)
|
||||
|
||||
assert len(chat_completion.choices) > 0
|
||||
|
||||
message = chat_completion.choices[0].message
|
||||
assert message.content is not None and len(message.content) >= 0
|
||||
assert message.content == ACTIVE_MM_LORA_RESPONSE
|
||||
@ -296,3 +296,63 @@ async def test_single_embedding_truncation_invalid(client: openai.AsyncOpenAI,
|
||||
assert "error" in response.object
|
||||
assert "truncate_prompt_tokens value is greater than max_model_len. "\
|
||||
"Please, select a smaller truncation size." in response.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invocations(server: RemoteOpenAIServer,
|
||||
client: openai.AsyncOpenAI):
|
||||
input_texts = [
|
||||
"The chef prepared a delicious meal.",
|
||||
]
|
||||
|
||||
request_args = {
|
||||
"model": MODEL_NAME,
|
||||
"input": input_texts,
|
||||
"encoding_format": "float",
|
||||
}
|
||||
|
||||
completion_response = await client.embeddings.create(**request_args)
|
||||
|
||||
invocation_response = requests.post(server.url_for("invocations"),
|
||||
json=request_args)
|
||||
invocation_response.raise_for_status()
|
||||
|
||||
completion_output = completion_response.model_dump()
|
||||
invocation_output = invocation_response.json()
|
||||
|
||||
assert completion_output.keys() == invocation_output.keys()
|
||||
assert completion_output["data"] == invocation_output["data"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invocations_conversation(server: RemoteOpenAIServer):
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": "The cat sat on the mat.",
|
||||
}, {
|
||||
"role": "assistant",
|
||||
"content": "A feline was resting on a rug.",
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Stars twinkle brightly in the night sky.",
|
||||
}]
|
||||
|
||||
request_args = {
|
||||
"model": MODEL_NAME,
|
||||
"messages": messages,
|
||||
"encoding_format": "float",
|
||||
}
|
||||
|
||||
chat_response = requests.post(server.url_for("v1/embeddings"),
|
||||
json=request_args)
|
||||
chat_response.raise_for_status()
|
||||
|
||||
invocation_response = requests.post(server.url_for("invocations"),
|
||||
json=request_args)
|
||||
invocation_response.raise_for_status()
|
||||
|
||||
chat_output = chat_response.json()
|
||||
invocation_output = invocation_response.json()
|
||||
|
||||
assert chat_output.keys() == invocation_output.keys()
|
||||
assert chat_output["data"] == invocation_output["data"]
|
||||
|
||||
@ -13,7 +13,7 @@ from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach"
|
||||
MODEL_NAME = "internlm/internlm2-1_8b-reward"
|
||||
DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501
|
||||
|
||||
|
||||
@ -21,15 +21,16 @@ DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' +
|
||||
def server():
|
||||
args = [
|
||||
"--task",
|
||||
"classify",
|
||||
"reward",
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--enforce-eager",
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"512",
|
||||
"--chat-template",
|
||||
DUMMY_CHAT_TEMPLATE,
|
||||
"--trust-remote-code",
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
@ -57,10 +58,10 @@ async def test_single_pooling(server: RemoteOpenAIServer, model_name: str):
|
||||
|
||||
assert poolings.id is not None
|
||||
assert len(poolings.data) == 1
|
||||
assert len(poolings.data[0].data) == 2
|
||||
assert len(poolings.data[0].data) == 8
|
||||
assert poolings.usage.completion_tokens == 0
|
||||
assert poolings.usage.prompt_tokens == 7
|
||||
assert poolings.usage.total_tokens == 7
|
||||
assert poolings.usage.prompt_tokens == 8
|
||||
assert poolings.usage.total_tokens == 8
|
||||
|
||||
# test using token IDs
|
||||
input_tokens = [1, 1, 1, 1, 1]
|
||||
@ -77,7 +78,7 @@ async def test_single_pooling(server: RemoteOpenAIServer, model_name: str):
|
||||
|
||||
assert poolings.id is not None
|
||||
assert len(poolings.data) == 1
|
||||
assert len(poolings.data[0].data) == 2
|
||||
assert len(poolings.data[0].data) == 5
|
||||
assert poolings.usage.completion_tokens == 0
|
||||
assert poolings.usage.prompt_tokens == 5
|
||||
assert poolings.usage.total_tokens == 5
|
||||
@ -104,10 +105,10 @@ async def test_batch_pooling(server: RemoteOpenAIServer, model_name: str):
|
||||
|
||||
assert poolings.id is not None
|
||||
assert len(poolings.data) == 3
|
||||
assert len(poolings.data[0].data) == 2
|
||||
assert len(poolings.data[0].data) == 8
|
||||
assert poolings.usage.completion_tokens == 0
|
||||
assert poolings.usage.prompt_tokens == 25
|
||||
assert poolings.usage.total_tokens == 25
|
||||
assert poolings.usage.prompt_tokens == 29
|
||||
assert poolings.usage.total_tokens == 29
|
||||
|
||||
# test list[list[int]]
|
||||
input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24],
|
||||
@ -125,7 +126,7 @@ async def test_batch_pooling(server: RemoteOpenAIServer, model_name: str):
|
||||
|
||||
assert poolings.id is not None
|
||||
assert len(poolings.data) == 4
|
||||
assert len(poolings.data[0].data) == 2
|
||||
assert len(poolings.data[0].data) == 5
|
||||
assert poolings.usage.completion_tokens == 0
|
||||
assert poolings.usage.prompt_tokens == 17
|
||||
assert poolings.usage.total_tokens == 17
|
||||
@ -157,7 +158,11 @@ async def test_conversation_pooling(server: RemoteOpenAIServer,
|
||||
chat_response.raise_for_status()
|
||||
chat_poolings = PoolingResponse.model_validate(chat_response.json())
|
||||
|
||||
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
|
||||
tokenizer = get_tokenizer(
|
||||
tokenizer_name=model_name,
|
||||
tokenizer_mode="fast",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
chat_template=DUMMY_CHAT_TEMPLATE,
|
||||
@ -206,6 +211,9 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer,
|
||||
)
|
||||
float_response.raise_for_status()
|
||||
responses_float = PoolingResponse.model_validate(float_response.json())
|
||||
float_data = [
|
||||
np.array(d.data).squeeze(-1).tolist() for d in responses_float.data
|
||||
]
|
||||
|
||||
base64_response = requests.post(
|
||||
server.url_for("pooling"),
|
||||
@ -224,11 +232,10 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer,
|
||||
np.frombuffer(base64.b64decode(data.data),
|
||||
dtype="float32").tolist())
|
||||
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=[d.data for d in responses_float.data],
|
||||
embeddings_1_lst=decoded_responses_base64_data,
|
||||
name_0="float32",
|
||||
name_1="base64")
|
||||
check_embeddings_close(embeddings_0_lst=float_data,
|
||||
embeddings_1_lst=decoded_responses_base64_data,
|
||||
name_0="float32",
|
||||
name_1="base64")
|
||||
|
||||
# Default response is float32 decoded from base64 by OpenAI Client
|
||||
default_response = requests.post(
|
||||
@ -240,9 +247,71 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer,
|
||||
)
|
||||
default_response.raise_for_status()
|
||||
responses_default = PoolingResponse.model_validate(default_response.json())
|
||||
default_data = [
|
||||
np.array(d.data).squeeze(-1).tolist() for d in responses_default.data
|
||||
]
|
||||
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=[d.data for d in responses_default.data],
|
||||
embeddings_1_lst=[d.data for d in responses_default.data],
|
||||
name_0="float32",
|
||||
name_1="base64")
|
||||
check_embeddings_close(embeddings_0_lst=float_data,
|
||||
embeddings_1_lst=default_data,
|
||||
name_0="float32",
|
||||
name_1="default")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invocations(server: RemoteOpenAIServer):
|
||||
input_texts = [
|
||||
"The chef prepared a delicious meal.",
|
||||
]
|
||||
|
||||
request_args = {
|
||||
"model": MODEL_NAME,
|
||||
"input": input_texts,
|
||||
"encoding_format": "float",
|
||||
}
|
||||
|
||||
completion_response = requests.post(server.url_for("pooling"),
|
||||
json=request_args)
|
||||
completion_response.raise_for_status()
|
||||
|
||||
invocation_response = requests.post(server.url_for("invocations"),
|
||||
json=request_args)
|
||||
invocation_response.raise_for_status()
|
||||
|
||||
completion_output = completion_response.json()
|
||||
invocation_output = invocation_response.json()
|
||||
|
||||
assert completion_output.keys() == invocation_output.keys()
|
||||
assert completion_output["data"] == invocation_output["data"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invocations_conversation(server: RemoteOpenAIServer):
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": "The cat sat on the mat.",
|
||||
}, {
|
||||
"role": "assistant",
|
||||
"content": "A feline was resting on a rug.",
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Stars twinkle brightly in the night sky.",
|
||||
}]
|
||||
|
||||
request_args = {
|
||||
"model": MODEL_NAME,
|
||||
"messages": messages,
|
||||
"encoding_format": "float",
|
||||
}
|
||||
|
||||
chat_response = requests.post(server.url_for("pooling"), json=request_args)
|
||||
chat_response.raise_for_status()
|
||||
|
||||
invocation_response = requests.post(server.url_for("invocations"),
|
||||
json=request_args)
|
||||
invocation_response.raise_for_status()
|
||||
|
||||
chat_output = chat_response.json()
|
||||
invocation_output = invocation_response.json()
|
||||
|
||||
assert chat_output.keys() == invocation_output.keys()
|
||||
assert chat_output["data"] == invocation_output["data"]
|
||||
|
||||
@ -94,3 +94,30 @@ def test_rerank_max_model_len(server: RemoteOpenAIServer, model_name: str):
|
||||
# Assert just a small fragments of the response
|
||||
assert "Please reduce the length of the input." in \
|
||||
rerank_response.text
|
||||
|
||||
|
||||
def test_invocations(server: RemoteOpenAIServer):
|
||||
query = "What is the capital of France?"
|
||||
documents = [
|
||||
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
|
||||
]
|
||||
|
||||
request_args = {
|
||||
"model": MODEL_NAME,
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
}
|
||||
|
||||
rerank_response = requests.post(server.url_for("rerank"),
|
||||
json=request_args)
|
||||
rerank_response.raise_for_status()
|
||||
|
||||
invocation_response = requests.post(server.url_for("invocations"),
|
||||
json=request_args)
|
||||
invocation_response.raise_for_status()
|
||||
|
||||
rerank_output = rerank_response.json()
|
||||
invocation_output = invocation_response.json()
|
||||
|
||||
assert rerank_output.keys() == invocation_output.keys()
|
||||
assert rerank_output["results"] == invocation_output["results"]
|
||||
|
||||
@ -191,3 +191,28 @@ class TestModel:
|
||||
assert score_response.status_code == 400
|
||||
assert "Please, select a smaller truncation size." in \
|
||||
score_response.text
|
||||
|
||||
def test_invocations(self, server: RemoteOpenAIServer, model: dict[str,
|
||||
Any]):
|
||||
text_1 = "What is the capital of France?"
|
||||
text_2 = "The capital of France is Paris."
|
||||
|
||||
request_args = {
|
||||
"model": model["name"],
|
||||
"text_1": text_1,
|
||||
"text_2": text_2,
|
||||
}
|
||||
|
||||
score_response = requests.post(server.url_for("score"),
|
||||
json=request_args)
|
||||
score_response.raise_for_status()
|
||||
|
||||
invocation_response = requests.post(server.url_for("invocations"),
|
||||
json=request_args)
|
||||
invocation_response.raise_for_status()
|
||||
|
||||
score_output = score_response.json()
|
||||
invocation_output = invocation_response.json()
|
||||
|
||||
assert score_output.keys() == invocation_output.keys()
|
||||
assert score_output["data"] == invocation_output["data"]
|
||||
|
||||
@ -0,0 +1,140 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import flashinfer
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.is_device_capability(100):
|
||||
pytest.skip("This TRTLLM kernel requires NVIDIA Blackwell.",
|
||||
allow_module_level=True)
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
|
||||
# KV Cache Layout for TRT-LLM
|
||||
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
|
||||
|
||||
NUM_HEADS = [(64, 8), (16, 16), (40, 8), (32, 8)]
|
||||
HEAD_SIZES = [128]
|
||||
BLOCK_SIZES = [16, 32]
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
|
||||
SOFT_CAPS = [None, 30.0, 50.0]
|
||||
|
||||
|
||||
def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||
finfo = torch.finfo(dtype)
|
||||
min_val, max_val = x.aminmax()
|
||||
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
||||
scale = finfo.max / amax * 0.1
|
||||
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
|
||||
return x_scl_sat.to(dtype), scale.float().reciprocal()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("kv_layout", ["HND"])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
|
||||
@torch.inference_mode
|
||||
def test_flashinfer_trtllm_decode_with_baseline(
|
||||
kv_lens: list[int],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: Optional[float],
|
||||
kv_layout: str,
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
current_platform.seed_everything(0)
|
||||
num_seqs = len(kv_lens)
|
||||
num_query_heads = num_heads[0]
|
||||
num_kv_heads = num_heads[1]
|
||||
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
max_kv_len = max(kv_lens)
|
||||
scale = head_size**-0.5
|
||||
|
||||
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
|
||||
kv_cache_shape = None
|
||||
if kv_layout == "NHD":
|
||||
kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
|
||||
elif kv_layout == "HND":
|
||||
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
|
||||
else:
|
||||
raise ValueError(f"Invalid kv_layout: {kv_layout}")
|
||||
key_value_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
||||
|
||||
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(0,
|
||||
NUM_BLOCKS,
|
||||
(num_seqs, max_num_blocks_per_seq),
|
||||
dtype=torch.int32)
|
||||
k_scale = v_scale = 1.0
|
||||
kv_indptr = [0]
|
||||
kv_indices = []
|
||||
kv_last_page_lens = []
|
||||
for i in range(num_seqs):
|
||||
seq_len = kv_lens[i]
|
||||
assert seq_len > 0
|
||||
num_blocks = (seq_len + block_size - 1) // block_size
|
||||
kv_indices.extend(block_tables[i, :num_blocks])
|
||||
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
||||
kv_last_page_len = seq_len % block_size
|
||||
if kv_last_page_len == 0:
|
||||
kv_last_page_len = block_size
|
||||
kv_last_page_lens.append(kv_last_page_len)
|
||||
|
||||
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
||||
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||
|
||||
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
|
||||
wrapper = flashinfer.\
|
||||
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, kv_layout,
|
||||
use_tensor_cores=(
|
||||
(num_query_heads//num_kv_heads) > 4)
|
||||
)
|
||||
wrapper.plan(kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
num_query_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
"NONE",
|
||||
q_data_type=dtype,
|
||||
kv_data_type=dtype,
|
||||
logits_soft_cap=soft_cap)
|
||||
|
||||
output = wrapper.run(query, key_value_cache, scale)
|
||||
|
||||
# TRTLLM Decode
|
||||
max_kv_len = max(kv_lens)
|
||||
kv_lens_tensor = torch.tensor(kv_lens,
|
||||
dtype=torch.int,
|
||||
device=query.device)
|
||||
output_trtllm = flashinfer.decode.trtllm_batch_decode_with_kv_cache(
|
||||
query.contiguous(),
|
||||
key_value_cache,
|
||||
workspace_buffer,
|
||||
num_query_heads,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
kv_lens_tensor,
|
||||
block_size,
|
||||
max_kv_len,
|
||||
"auto",
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \
|
||||
f"{torch.max(torch.abs(output - output_trtllm))}"
|
||||
@ -6,9 +6,8 @@ from typing import Optional
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops # noqa: F401
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
@ -144,79 +143,6 @@ def causal_conv1d_opcheck_fn(x: torch.Tensor,
|
||||
x = x.contiguous()
|
||||
bias = bias.contiguous() if bias is not None else None
|
||||
|
||||
opcheck(torch.ops._C.causal_conv1d_fwd,
|
||||
(x, weight, bias, conv_states, cu_seq_len, cache_indices,
|
||||
has_initial_state, activation in ["silu", "swish"], pad_slot_id))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
|
||||
@pytest.mark.parametrize("silu_activation", [True])
|
||||
@pytest.mark.parametrize("has_bias", [True])
|
||||
@pytest.mark.parametrize("has_initial_state", [True, False])
|
||||
@pytest.mark.parametrize("width", [4])
|
||||
@pytest.mark.parametrize(
|
||||
'seqlen', [1, 8, 16, 32, 64, 128, 256, 512, 784, 1024, 1025, 2048, 4096])
|
||||
@pytest.mark.parametrize('dim', [64])
|
||||
@pytest.mark.parametrize('batch', [1])
|
||||
def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
|
||||
has_initial_state, itype):
|
||||
device = "cuda"
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
# set seed
|
||||
current_platform.seed_everything(0)
|
||||
x = torch.randn(batch, dim, seqlen, device=device,
|
||||
dtype=itype).contiguous()
|
||||
|
||||
weight = torch.randn(dim, width, device=device, dtype=itype)
|
||||
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
||||
if has_initial_state:
|
||||
initial_states = torch.randn(batch,
|
||||
dim,
|
||||
width - 1,
|
||||
device=device,
|
||||
dtype=itype)
|
||||
has_initial_state_tensor = torch.ones(batch,
|
||||
dtype=torch.bool,
|
||||
device=x.device)
|
||||
else:
|
||||
initial_states = None
|
||||
has_initial_state_tensor = None
|
||||
x_ref = x.clone()
|
||||
weight_ref = weight.clone()
|
||||
bias_ref = bias.clone() if bias is not None else None
|
||||
initial_states_ref = initial_states.clone(
|
||||
) if initial_states is not None else None
|
||||
activation = None if not silu_activation else "silu"
|
||||
out = causal_conv1d_fn(x,
|
||||
weight,
|
||||
bias,
|
||||
activation=activation,
|
||||
conv_states=initial_states,
|
||||
has_initial_state=has_initial_state_tensor)
|
||||
out_ref, final_states_ref = causal_conv1d_ref(
|
||||
x_ref,
|
||||
weight_ref,
|
||||
bias_ref,
|
||||
initial_states=initial_states_ref,
|
||||
return_final_states=True,
|
||||
activation=activation)
|
||||
if has_initial_state:
|
||||
assert initial_states is not None and final_states_ref is not None
|
||||
assert torch.allclose(initial_states,
|
||||
final_states_ref,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
causal_conv1d_opcheck_fn(x,
|
||||
weight,
|
||||
bias,
|
||||
activation=activation,
|
||||
conv_states=initial_states,
|
||||
has_initial_state=has_initial_state_tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("silu_activation", [False, True])
|
||||
@ -255,22 +181,19 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
|
||||
assert torch.equal(conv_state, conv_state_ref)
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
opcheck(torch.ops._C.causal_conv1d_update,
|
||||
(x, conv_state, weight, bias, activation
|
||||
in ["silu", "swish"], None, None, PAD_SLOT_ID))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype",
|
||||
[torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("silu_activation", [False, True])
|
||||
@pytest.mark.parametrize("has_bias", [False, True])
|
||||
@pytest.mark.parametrize("seqlen", [1, 4, 5])
|
||||
@pytest.mark.parametrize("width", [2, 3, 4])
|
||||
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
||||
@pytest.mark.parametrize("seqlen", [1, 3])
|
||||
@pytest.mark.parametrize("width", [3, 4])
|
||||
@pytest.mark.parametrize("dim", [2048 + 16, 4096])
|
||||
# tests correctness in case subset of the sequences are padded
|
||||
@pytest.mark.parametrize("with_padding", [True, False])
|
||||
def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
|
||||
seqlen, has_bias,
|
||||
@pytest.mark.parametrize("batch_size", [3])
|
||||
def test_causal_conv1d_update_with_batch_gather(batch_size, with_padding, dim,
|
||||
width, seqlen, has_bias,
|
||||
silu_activation, itype):
|
||||
device = "cuda"
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||
@ -280,12 +203,15 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
|
||||
# set seed
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
batch_size = 3
|
||||
padding = 5 if with_padding else 0
|
||||
padded_batch_size = batch_size + padding
|
||||
# total_entries = number of cache line
|
||||
total_entries = 10 * batch_size
|
||||
|
||||
x = torch.randn(padded_batch_size, dim, 1, device=device, dtype=itype)
|
||||
# x will be (batch, dim, seqlen) with contiguous along dim-axis
|
||||
x = torch.randn(padded_batch_size, seqlen, dim, device=device,
|
||||
dtype=itype).transpose(1, 2)
|
||||
|
||||
x_ref = x.clone()
|
||||
|
||||
conv_state_indices = torch.randperm(total_entries)[:batch_size].to(
|
||||
@ -300,17 +226,22 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
|
||||
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device)
|
||||
],
|
||||
dim=0)
|
||||
|
||||
# conv_state will be (cache_lines, dim, state_len)
|
||||
# with contiguous along dim-axis
|
||||
conv_state = torch.randn(total_entries,
|
||||
dim,
|
||||
width - 1,
|
||||
dim,
|
||||
device=device,
|
||||
dtype=itype)
|
||||
dtype=itype).transpose(1, 2)
|
||||
|
||||
conv_state_for_padding_test = conv_state.clone()
|
||||
|
||||
weight = torch.randn(dim, width, device=device, dtype=itype)
|
||||
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
||||
conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
|
||||
activation = None if not silu_activation else "silu"
|
||||
|
||||
out = causal_conv1d_update(x,
|
||||
conv_state,
|
||||
weight,
|
||||
@ -325,26 +256,21 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
|
||||
activation=activation)
|
||||
|
||||
assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
|
||||
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
|
||||
assert torch.equal(conv_state[unused_states_bool],
|
||||
conv_state_for_padding_test[unused_states_bool])
|
||||
|
||||
opcheck(torch.ops._C.causal_conv1d_update,
|
||||
(x, conv_state, weight, bias, activation
|
||||
in ["silu", "swish"], None, padded_state_indices, PAD_SLOT_ID))
|
||||
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("silu_activation", [True])
|
||||
@pytest.mark.parametrize("has_bias", [True])
|
||||
@pytest.mark.parametrize("width", [4])
|
||||
@pytest.mark.parametrize(
|
||||
'seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 2049, 4096])
|
||||
@pytest.mark.parametrize('seqlen', [8, 30, 249, 2049, 4096])
|
||||
@pytest.mark.parametrize('dim', [64, 4096])
|
||||
# tests correctness in case subset of the sequences are padded
|
||||
@pytest.mark.parametrize('with_padding', [True, False])
|
||||
def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
|
||||
silu_activation, itype):
|
||||
@pytest.mark.parametrize('batch', [4, 10])
|
||||
def test_causal_conv1d_varlen(batch, with_padding, dim, seqlen, width,
|
||||
has_bias, silu_activation, itype):
|
||||
device = "cuda"
|
||||
torch.cuda.empty_cache()
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||
@ -353,14 +279,13 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
|
||||
# set seed
|
||||
current_platform.seed_everything(0)
|
||||
seqlens = []
|
||||
batch_size = 4
|
||||
if seqlen < 10:
|
||||
batch_size = 1
|
||||
batch_size = batch
|
||||
padding = 3 if with_padding else 0
|
||||
padded_batch_size = batch_size + padding
|
||||
nsplits = padded_batch_size - 1
|
||||
|
||||
eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
|
||||
|
||||
seqlens.append(
|
||||
torch.diff(
|
||||
torch.cat(
|
||||
@ -373,19 +298,22 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
|
||||
cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
|
||||
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum],
|
||||
dim=0)
|
||||
x = torch.randn(1, 4096 + dim + 64, seqlen, device=device,
|
||||
dtype=itype)[:, 4096:4096 + dim, :]
|
||||
x = rearrange(
|
||||
torch.randn(1, seqlen, 4096 + dim + 64, device=device, dtype=itype),
|
||||
"b s d -> b d s")[:, 4096:4096 + dim, :]
|
||||
|
||||
weight = torch.randn(dim, width, device=device, dtype=itype)
|
||||
|
||||
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
||||
x_ref = x.clone()
|
||||
weight_ref = weight.clone()
|
||||
bias_ref = bias.clone() if bias is not None else None
|
||||
activation = None if not silu_activation else "silu"
|
||||
final_states = torch.randn(total_entries,
|
||||
dim,
|
||||
width - 1,
|
||||
dim,
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
dtype=x.dtype).transpose(1, 2)
|
||||
final_states_ref = final_states.clone()
|
||||
has_initial_states = torch.randint(0,
|
||||
2, (cumsum.shape[0] - 1, ),
|
||||
@ -400,10 +328,16 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
|
||||
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
|
||||
],
|
||||
dim=-1)
|
||||
out = causal_conv1d_fn(x.squeeze(0),
|
||||
weight,
|
||||
bias=bias,
|
||||
conv_states=final_states,
|
||||
query_start_loc=cumsum.cuda(),
|
||||
cache_indices=padded_state_indices,
|
||||
has_initial_state=has_initial_states,
|
||||
activation=activation,
|
||||
pad_slot_id=PAD_SLOT_ID)
|
||||
|
||||
out = causal_conv1d_fn(x.squeeze(0), weight, bias, cumsum.cuda(),
|
||||
padded_state_indices, has_initial_states,
|
||||
final_states, activation, PAD_SLOT_ID)
|
||||
out_ref = []
|
||||
out_ref_b = []
|
||||
|
||||
@ -426,13 +360,9 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
|
||||
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2))
|
||||
out_ref_tensor = torch.cat(out_ref, dim=0)
|
||||
|
||||
unpadded_out = out[:, :out_ref_tensor.shape[-1]]
|
||||
assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)
|
||||
assert torch.allclose(final_states[state_indices],
|
||||
final_states_ref[state_indices],
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
causal_conv1d_opcheck_fn(x.squeeze(0), weight, bias, cumsum.cuda(),
|
||||
padded_state_indices, has_initial_states,
|
||||
final_states, activation)
|
||||
unpadded_out = out[:, :out_ref_tensor.shape[-1]]
|
||||
assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)
|
||||
|
||||
@ -6,11 +6,11 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||
_query_start_loc_to_chunk_indices_offsets)
|
||||
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
|
||||
mamba_chunk_scan_combined)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.mamba_attn import (
|
||||
_query_start_loc_to_chunk_indices_offsets)
|
||||
|
||||
# Added by the IBM Team, 2024
|
||||
|
||||
|
||||
0
tests/kernels/moe/modular_kernel_tools/__init__.py
Normal file
0
tests/kernels/moe/modular_kernel_tools/__init__.py
Normal file
160
tests/kernels/moe/modular_kernel_tools/cli_args.py
Normal file
160
tests/kernels/moe/modular_kernel_tools/cli_args.py
Normal file
@ -0,0 +1,160 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
|
||||
from .common import Config
|
||||
from .mk_objects import (MK_ALL_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES,
|
||||
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES)
|
||||
|
||||
|
||||
def make_config_arg_parser(description: str):
|
||||
|
||||
def to_pf_class_type(s: str) -> mk.FusedMoEPrepareAndFinalize:
|
||||
for pf in MK_ALL_PREPARE_FINALIZE_TYPES:
|
||||
if pf.__name__ == s:
|
||||
return pf
|
||||
raise ValueError(
|
||||
f"Cannot find a PrepareFinalize type that matches {s}")
|
||||
|
||||
def to_experts_class_type(s: str) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
for fe in MK_FUSED_EXPERT_TYPES:
|
||||
if fe.__name__ == s:
|
||||
return fe
|
||||
raise ValueError(f"Cannot find a FusedExperts type that matches {s}")
|
||||
|
||||
def to_quant_torch_dtype(s: str) -> torch.dtype:
|
||||
if s == "torch.float8_e4m3fn":
|
||||
return torch.float8_e4m3fn
|
||||
raise ValueError(f"Unsupported quant type {s}")
|
||||
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
|
||||
parser.add_argument(
|
||||
"--world-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Number of ranks that participate in all2all",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pf-type",
|
||||
type=to_pf_class_type,
|
||||
required=True,
|
||||
help=("Choose a PrepareFinalize Type : "
|
||||
f"{[x.__name__ for x in MK_ALL_PREPARE_FINALIZE_TYPES]}"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--experts-type",
|
||||
type=to_experts_class_type,
|
||||
required=True,
|
||||
help=(f"Choose a FusedExpert type : "
|
||||
f"{[x.__name__ for x in MK_FUSED_EXPERT_TYPES]}"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[64],
|
||||
help="num tokens per rank",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-k",
|
||||
type=int,
|
||||
default=7168,
|
||||
help="hidden-size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-n",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="N dimension of the first fused-moe matmul",
|
||||
)
|
||||
parser.add_argument("--num-experts",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Global num experts")
|
||||
parser.add_argument("--topk",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[4, 1],
|
||||
help="num topk")
|
||||
parser.add_argument(
|
||||
"--fused-moe-chunk-size",
|
||||
nargs="+",
|
||||
type=int,
|
||||
help="Fused moe chunk size used for the non-batched fused experts impl."
|
||||
)
|
||||
|
||||
# Quant args
|
||||
parser.add_argument("--quant-dtype",
|
||||
type=to_quant_torch_dtype,
|
||||
help="Quant datatype")
|
||||
parser.add_argument("--per-token-quantized-activations",
|
||||
action='store_true',
|
||||
help=("The input activations must be per-token "
|
||||
"quantized"))
|
||||
parser.add_argument("--per-channel-quantized-weights",
|
||||
action="store_true",
|
||||
help="The weights must be per-channel quantized.")
|
||||
parser.add_argument("--block-shape",
|
||||
nargs="+",
|
||||
type=int,
|
||||
help="Quantization block shape")
|
||||
|
||||
# Torch trace profile generation args
|
||||
parser.add_argument("--torch-trace-dir-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Get torch trace for single execution")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def _validate_args(args: argparse.Namespace):
|
||||
|
||||
if args.quant_dtype is not None:
|
||||
assert args.quant_dtype == torch.float8_e4m3fn
|
||||
if args.block_shape is not None:
|
||||
assert len(args.block_shape) == 2, (
|
||||
f"block shape must have 2 elements. got {args.block_shape}")
|
||||
|
||||
if args.experts_type in MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES:
|
||||
assert args.world_size == 1, (
|
||||
"Single GPU objects need world size set to 1")
|
||||
|
||||
if args.torch_trace_dir_path is not None:
|
||||
from pathlib import Path
|
||||
assert Path(args.torch_trace_dir_path).is_dir(), (
|
||||
f"Please create {args.torch_trace_dir_path}")
|
||||
|
||||
|
||||
def make_config(args: argparse.Namespace) -> Config:
|
||||
|
||||
_validate_args(args)
|
||||
|
||||
quant_config = None
|
||||
if args.quant_dtype is not None:
|
||||
quant_config = FusedMoEQuantConfig(
|
||||
quant_dtype=args.quant_dtype,
|
||||
per_act_token_quant=args.per_token_quantized_activations,
|
||||
per_out_ch_quant=args.per_channel_quantized_weights,
|
||||
block_shape=args.block_shape)
|
||||
|
||||
return Config(
|
||||
Ms=args.m,
|
||||
K=args.k,
|
||||
N=args.n,
|
||||
E=args.num_experts,
|
||||
topks=args.topk,
|
||||
dtype=torch.bfloat16, # hard-code
|
||||
quant_config=quant_config,
|
||||
prepare_finalize_type=args.pf_type,
|
||||
fused_experts_type=args.experts_type,
|
||||
fused_moe_chunk_size=args.fused_moe_chunk_size,
|
||||
world_size=args.world_size,
|
||||
torch_trace_dir_path=args.torch_trace_dir_path)
|
||||
641
tests/kernels/moe/modular_kernel_tools/common.py
Normal file
641
tests/kernels/moe/modular_kernel_tools/common.py
Normal file
@ -0,0 +1,641 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from tests.kernels.utils import torch_experts
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size
|
||||
# Fused experts and PrepareFinalize imports
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
|
||||
BatchedTritonOrDeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig)
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedTritonExperts, NaiveBatchedExperts)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.layer import (FusedMoEMethodBase,
|
||||
TritonExperts)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP)
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts)
|
||||
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
|
||||
|
||||
from .parallel_utils import ProcessGroupInfo
|
||||
from .utils import (make_block_quant_fp8_weights, make_non_quant_weights,
|
||||
make_quant_fp8_weights, per_token_cast_to_fp8)
|
||||
|
||||
if has_pplx():
|
||||
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
||||
PplxPrepareAndFinalize)
|
||||
if has_deep_ep():
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
||||
DeepEPHTPrepareAndFinalize)
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
||||
DeepEPLLPrepareAndFinalize)
|
||||
|
||||
|
||||
def _describe_tensor(t: Optional[torch.Tensor], name: str) -> str:
|
||||
if t is None:
|
||||
return f"{name} : None"
|
||||
else:
|
||||
return f"{name} : {t.shape} {t.dtype} {t.device}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
Ms: Union[list[int], int]
|
||||
K: int
|
||||
N: int
|
||||
E: int
|
||||
topks: Union[list[int], int]
|
||||
dtype: torch.dtype
|
||||
quant_config: Optional[FusedMoEQuantConfig]
|
||||
|
||||
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize
|
||||
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute
|
||||
|
||||
fused_moe_chunk_size: Optional[int]
|
||||
world_size: int
|
||||
|
||||
torch_trace_dir_path: Optional[str] = None
|
||||
|
||||
def describe(self) -> str:
|
||||
s = ""
|
||||
s += "== Config: \n"
|
||||
s += f" world_size={self.world_size} \n"
|
||||
s += f" PF={self.prepare_finalize_type.__name__} \n"
|
||||
s += f" FE={self.fused_experts_type.__name__} \n"
|
||||
s += f" topk={self.topks} \n"
|
||||
s += f" dtype={self.dtype} \n"
|
||||
s += f" fused_moe_chunk_size={self.fused_moe_chunk_size} \n"
|
||||
s += " Quant: \n"
|
||||
s += f" fused_moe_chunk_size={self.fused_moe_chunk_size} \n "
|
||||
if self.quant_config is not None:
|
||||
s += f" q_dtype={self.quant_dtype} \n"
|
||||
s += f" q_block_shape={self.quant_block_shape} \n"
|
||||
s += f" q_per_out_ch_quant={self.is_per_out_ch_quant} \n"
|
||||
s += f" q_per_act_token={self.is_per_act_token_quant} \n"
|
||||
else:
|
||||
s += " quant=None \n"
|
||||
return s
|
||||
|
||||
@property
|
||||
def M(self) -> int:
|
||||
assert isinstance(self.Ms, int)
|
||||
return self.Ms
|
||||
|
||||
@property
|
||||
def quant_dtype(self) -> Optional[torch.dtype]:
|
||||
if self.quant_config is None:
|
||||
return None
|
||||
return self.quant_config.quant_dtype
|
||||
|
||||
@property
|
||||
def is_per_act_token_quant(self) -> bool:
|
||||
if self.quant_config is None:
|
||||
return False
|
||||
return self.quant_config.per_act_token_quant
|
||||
|
||||
@property
|
||||
def is_per_tensor_act_quant(self) -> bool:
|
||||
if self.quant_config is None:
|
||||
return False
|
||||
return (not self.is_per_act_token_quant
|
||||
and self.quant_block_shape is None)
|
||||
|
||||
@property
|
||||
def is_per_out_ch_quant(self) -> bool:
|
||||
if self.quant_config is None:
|
||||
return False
|
||||
return self.quant_config.per_out_ch_quant
|
||||
|
||||
@property
|
||||
def quant_block_shape(self) -> Optional[list[int]]:
|
||||
if self.quant_config is None:
|
||||
return None
|
||||
return self.quant_config.block_shape
|
||||
|
||||
@property
|
||||
def topk(self) -> int:
|
||||
assert isinstance(self.topks, int)
|
||||
return self.topks
|
||||
|
||||
@property
|
||||
def topk_ids_dtype(self) -> Optional[torch.dtype]:
|
||||
topk_ids_dtype = None
|
||||
if self.prepare_finalize_type == PplxPrepareAndFinalize:
|
||||
topk_ids_dtype = torch.uint32
|
||||
elif self.prepare_finalize_type in [
|
||||
DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize
|
||||
]:
|
||||
topk_ids_dtype = torch.int64
|
||||
return topk_ids_dtype
|
||||
|
||||
@property
|
||||
def num_local_experts(self) -> int:
|
||||
return self.E // self.world_size
|
||||
|
||||
def make_env_data(self) -> tuple[VllmConfig, dict[Any, Any]]:
|
||||
"""
|
||||
make env data for vllm launch.
|
||||
"""
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.parallel_config.data_parallel_size = self.world_size
|
||||
vllm_config.parallel_config.enable_expert_parallel = True
|
||||
|
||||
env_dict = {
|
||||
"VLLM_ALL2ALL_BACKEND": self.all2all_backend(),
|
||||
"VLLM_USE_DEEP_GEMM": str(int(self.needs_deep_gemm())),
|
||||
}
|
||||
if self.fused_moe_chunk_size is not None:
|
||||
env_dict.update(
|
||||
{"VLLM_FUSED_MOE_CHUNK_SIZE": str(self.fused_moe_chunk_size)})
|
||||
return vllm_config, env_dict
|
||||
|
||||
def is_fp8_block_quantized(self):
|
||||
return (self.quant_dtype == torch.float8_e4m3fn
|
||||
and self.quant_block_shape is not None)
|
||||
|
||||
def is_batched_prepare_finalize(self):
|
||||
return self.prepare_finalize_type in [
|
||||
PplxPrepareAndFinalize, DeepEPLLPrepareAndFinalize
|
||||
]
|
||||
|
||||
def is_batched_fused_experts(self):
|
||||
return self.fused_experts_type in [
|
||||
CutlassExpertsFp8, BatchedDeepGemmExperts, BatchedTritonExperts,
|
||||
NaiveBatchedExperts, BatchedTritonOrDeepGemmExperts
|
||||
]
|
||||
|
||||
def is_standard_fused_experts(self):
|
||||
return self.fused_experts_type in [
|
||||
CutlassExpertsFp8, DeepGemmExperts, TritonOrDeepGemmExperts,
|
||||
TritonExperts
|
||||
]
|
||||
|
||||
def is_fe_16bit_supported(self):
|
||||
return self.fused_experts_type in [
|
||||
BatchedTritonExperts, BatchedTritonOrDeepGemmExperts,
|
||||
NaiveBatchedExperts, TritonExperts
|
||||
]
|
||||
|
||||
def is_fe_fp8_supported(self):
|
||||
return self.fused_experts_type in [
|
||||
BatchedDeepGemmExperts,
|
||||
BatchedTritonExperts,
|
||||
BatchedTritonOrDeepGemmExperts,
|
||||
CutlassExpertsFp8,
|
||||
DeepGemmExperts,
|
||||
TritonExperts,
|
||||
TritonOrDeepGemmExperts,
|
||||
NaiveBatchedExperts,
|
||||
]
|
||||
|
||||
def is_fe_block_fp8_supported(self):
|
||||
return self.fused_experts_type in [
|
||||
BatchedDeepGemmExperts,
|
||||
BatchedTritonOrDeepGemmExperts,
|
||||
DeepGemmExperts,
|
||||
TritonExperts,
|
||||
TritonOrDeepGemmExperts,
|
||||
BatchedTritonExperts,
|
||||
NaiveBatchedExperts,
|
||||
]
|
||||
|
||||
def is_fe_supports_chunking(self):
|
||||
return self.fused_experts_type in [
|
||||
CutlassExpertsFp8, DeepGemmExperts, TritonOrDeepGemmExperts,
|
||||
TritonExperts
|
||||
]
|
||||
|
||||
def needs_deep_gemm(self):
|
||||
return self.fused_experts_type in [
|
||||
BatchedDeepGemmExperts,
|
||||
DeepGemmExperts,
|
||||
]
|
||||
|
||||
def needs_pplx(self):
|
||||
return self.prepare_finalize_type in [PplxPrepareAndFinalize]
|
||||
|
||||
def needs_deep_ep(self):
|
||||
return self.prepare_finalize_type in [
|
||||
DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize
|
||||
]
|
||||
|
||||
def all2all_backend(self):
|
||||
if self.needs_pplx():
|
||||
return "pplx"
|
||||
if self.prepare_finalize_type == DeepEPHTPrepareAndFinalize:
|
||||
return "deepep_high_throughput"
|
||||
if self.prepare_finalize_type == DeepEPLLPrepareAndFinalize:
|
||||
return "deepep_low_latency"
|
||||
return "naive"
|
||||
|
||||
def needs_all2all(self):
|
||||
return self.prepare_finalize_type in [
|
||||
PplxPrepareAndFinalize, DeepEPHTPrepareAndFinalize,
|
||||
DeepEPLLPrepareAndFinalize
|
||||
]
|
||||
|
||||
def is_valid(self):
|
||||
# Check prepare-finalize and fused-experts compatibility
|
||||
if self.is_batched_prepare_finalize():
|
||||
if not self.is_batched_fused_experts():
|
||||
return False
|
||||
else:
|
||||
if not self.is_standard_fused_experts():
|
||||
return False
|
||||
|
||||
use_chunking = self.fused_moe_chunk_size is not None
|
||||
if use_chunking and not self.is_fe_supports_chunking():
|
||||
return False
|
||||
|
||||
# Check quantization sanity
|
||||
if (int(self.is_per_act_token_quant) +
|
||||
int(self.is_per_tensor_act_quant) +
|
||||
int(self.quant_block_shape is not None)) > 1:
|
||||
# invalid quant config
|
||||
return False
|
||||
|
||||
# check bf16 / fp16 support
|
||||
is_16bit = (self.dtype.itemsize == 2 and self.quant_dtype is None)
|
||||
if is_16bit and not self.is_fe_16bit_supported():
|
||||
return False
|
||||
|
||||
# Check fp8 support
|
||||
is_fp8 = self.quant_dtype == torch.float8_e4m3fn
|
||||
if is_fp8 and not self.is_fe_fp8_supported():
|
||||
return False
|
||||
|
||||
# Check fp8 block quanization support
|
||||
is_block_quatized = self.quant_block_shape is not None
|
||||
if is_block_quatized and not is_fp8:
|
||||
return False
|
||||
if is_block_quatized and not self.is_fe_block_fp8_supported():
|
||||
return False
|
||||
|
||||
# deep_gemm only works with block-quantized
|
||||
if self.needs_deep_gemm() and not is_block_quatized:
|
||||
return False
|
||||
|
||||
# Check dependencies
|
||||
if self.needs_deep_ep() and not has_deep_ep():
|
||||
return False
|
||||
if self.needs_deep_gemm() and not has_deep_gemm():
|
||||
return False
|
||||
if self.needs_pplx() and not has_pplx(): # noqa: SIM103
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@dataclass
|
||||
class WeightTensors:
|
||||
w1: torch.Tensor
|
||||
w2: torch.Tensor
|
||||
w1_scale: Optional[torch.Tensor]
|
||||
w2_scale: Optional[torch.Tensor]
|
||||
|
||||
def describe(self):
|
||||
s = ""
|
||||
s += "== Weight Tensors: \n"
|
||||
s += f' - {_describe_tensor(self.w1, "w1")} \n'
|
||||
s += f' - {_describe_tensor(self.w2, "w2")} \n'
|
||||
s += f' - {_describe_tensor(self.w1_scale, "w1_scale")} \n'
|
||||
s += f' - {_describe_tensor(self.w2_scale, "w2_scale")} \n'
|
||||
return s
|
||||
|
||||
def to_current_device(self):
|
||||
self.w1 = self.w1.to(device=torch.cuda.current_device())
|
||||
self.w2 = self.w2.to(device=torch.cuda.current_device())
|
||||
is_quantized = self.w1.dtype == torch.float8_e4m3fn
|
||||
if is_quantized:
|
||||
assert self.w1_scale is not None
|
||||
assert self.w2_scale is not None
|
||||
self.w1_scale = self.w1_scale.to(
|
||||
device=torch.cuda.current_device())
|
||||
self.w2_scale = self.w2_scale.to(
|
||||
device=torch.cuda.current_device())
|
||||
|
||||
def slice_weights(self, rank: int,
|
||||
num_local_experts: int) -> "WeightTensors":
|
||||
s = rank * num_local_experts
|
||||
e = s + num_local_experts
|
||||
w1 = self.w1[s:e, :, :]
|
||||
w2 = self.w2[s:e, :, :]
|
||||
is_quantized = self.w1.dtype == torch.float8_e4m3fn
|
||||
w1_scale, w2_scale = (None, None)
|
||||
if is_quantized:
|
||||
assert self.w1_scale is not None
|
||||
assert self.w2_scale is not None
|
||||
w1_scale = self.w1_scale[s:e, :, :]
|
||||
w2_scale = self.w2_scale[s:e, :, :]
|
||||
return WeightTensors(w1, w2, w1_scale, w2_scale)
|
||||
|
||||
@staticmethod
|
||||
def make(config: Config) -> "WeightTensors":
|
||||
|
||||
if config.quant_dtype is None:
|
||||
# just make normal dtype weights
|
||||
w1, w2 = make_non_quant_weights(e=config.E,
|
||||
n=config.N,
|
||||
k=config.K,
|
||||
dtype=config.dtype)
|
||||
return WeightTensors(w1=w1, w2=w2, w1_scale=None, w2_scale=None)
|
||||
|
||||
assert config.quant_dtype == torch.float8_e4m3fn
|
||||
if not config.is_fp8_block_quantized():
|
||||
w1, w2, w1_scale, w2_scale = make_quant_fp8_weights(
|
||||
e=config.E,
|
||||
n=config.N,
|
||||
k=config.K,
|
||||
per_out_channel_quant=config.is_per_out_ch_quant,
|
||||
)
|
||||
return WeightTensors(w1=w1,
|
||||
w2=w2,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale)
|
||||
|
||||
assert config.quant_block_shape is not None
|
||||
w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
|
||||
e=config.E,
|
||||
n=config.N,
|
||||
k=config.K,
|
||||
block_size=config.quant_block_shape,
|
||||
)
|
||||
return WeightTensors(w1=w1,
|
||||
w2=w2,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RankTensors:
|
||||
hidden_states: torch.Tensor
|
||||
hidden_states_scale: Optional[torch.Tensor]
|
||||
|
||||
topk_weights: torch.Tensor
|
||||
topk_ids: torch.Tensor
|
||||
expert_map: Optional[torch.Tensor]
|
||||
|
||||
quant_config: Optional[FusedMoEQuantConfig]
|
||||
|
||||
def describe(self):
|
||||
s = ""
|
||||
s += "== Rank Tensors: \n"
|
||||
s += f' - {_describe_tensor(self.hidden_states, "HS")} \n'
|
||||
s += f' - {_describe_tensor(self.hidden_states_scale, "HS_scale")} \n'
|
||||
s += f' - {_describe_tensor(self.topk_weights, "topk_weights")} \n'
|
||||
s += f' - {_describe_tensor(self.topk_ids, "topk_ids")} \n'
|
||||
s += f' - {_describe_tensor(self.expert_map, "expert_map")} \n'
|
||||
return s
|
||||
|
||||
@staticmethod
|
||||
def make_hidden_states(
|
||||
config: Config) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Return hidden_states
|
||||
"""
|
||||
m, k, dtype = (config.M, config.K, config.dtype)
|
||||
a = (torch.randn(
|
||||
(m, k), device=torch.cuda.current_device(), dtype=dtype) / 15.0)
|
||||
|
||||
if config.quant_dtype is None:
|
||||
return a, None
|
||||
|
||||
# We dequant and use that as hidden_states so the tests are stable.
|
||||
# quantizing and dequantizing yield slightly different results
|
||||
# depending on the hardware. Here we, quantize and dequantize
|
||||
# first - so further quantize and dequantize will yeild the same
|
||||
# values.
|
||||
if config.is_per_tensor_act_quant:
|
||||
a_q, a_scales = ops.scaled_fp8_quant(
|
||||
a, use_per_token_if_dynamic=False)
|
||||
return a_q.float().mul(a_scales).to(dtype), a_scales
|
||||
|
||||
if config.is_per_act_token_quant:
|
||||
a_q, a_scales = ops.scaled_fp8_quant(a,
|
||||
use_per_token_if_dynamic=True)
|
||||
return a_q.float().mul(a_scales).to(dtype), None
|
||||
|
||||
assert config.quant_block_shape is not None
|
||||
block_k = config.quant_block_shape[1]
|
||||
a_q, a_scales = per_token_cast_to_fp8(a, block_size=block_k)
|
||||
return a_q.float().view(
|
||||
(-1, block_k)).mul(a_scales.view(-1, 1)).view(m, k).to(dtype), None
|
||||
|
||||
@staticmethod
|
||||
def make(config: Config, pgi: ProcessGroupInfo):
|
||||
|
||||
dtype = config.dtype
|
||||
topk, m, _ = (config.topk, config.M, config.K)
|
||||
hidden_states, hidden_states_scale = RankTensors.make_hidden_states(
|
||||
config)
|
||||
|
||||
num_local_experts, global_num_experts = (config.num_local_experts,
|
||||
config.E)
|
||||
score = torch.randn((m, global_num_experts),
|
||||
device="cuda",
|
||||
dtype=dtype)
|
||||
topk_weights, topk_ids, _ = fused_topk(hidden_states, score, topk,
|
||||
False)
|
||||
topk_ids = topk_ids.to(config.topk_ids_dtype)
|
||||
|
||||
# distribute topk_ids evenly
|
||||
for mi in range(m):
|
||||
topk_ids[mi] = torch.randperm(config.E)[:topk]
|
||||
topk_ids = topk_ids.to(device=torch.cuda.current_device())
|
||||
|
||||
expert_map = None
|
||||
if config.world_size > 1:
|
||||
expert_map = torch.full((global_num_experts, ),
|
||||
fill_value=-1,
|
||||
dtype=torch.int32)
|
||||
s = pgi.rank * num_local_experts
|
||||
e = s + num_local_experts
|
||||
expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
|
||||
expert_map = expert_map.to(device=torch.cuda.current_device(),
|
||||
dtype=torch.int32)
|
||||
|
||||
return RankTensors(
|
||||
hidden_states=hidden_states,
|
||||
hidden_states_scale=hidden_states_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map,
|
||||
quant_config=config.quant_config,
|
||||
)
|
||||
|
||||
|
||||
def reference_moe_impl(config: Config, weights: WeightTensors,
|
||||
rank_tensors: RankTensors) -> torch.Tensor:
|
||||
|
||||
return torch_experts(a=rank_tensors.hidden_states,
|
||||
w1=weights.w1,
|
||||
w2=weights.w2,
|
||||
topk_weight=rank_tensors.topk_weights,
|
||||
topk_ids=rank_tensors.topk_ids,
|
||||
global_num_experts=config.E,
|
||||
expert_map=None,
|
||||
w1_scale=weights.w1_scale,
|
||||
w2_scale=weights.w2_scale,
|
||||
a1_scale=rank_tensors.hidden_states_scale,
|
||||
quant_dtype=config.quant_dtype,
|
||||
per_act_token_quant=config.is_per_act_token_quant,
|
||||
block_shape=config.quant_block_shape,
|
||||
apply_router_weights_on_input=config.topk == 1)
|
||||
|
||||
|
||||
def make_fused_experts(
|
||||
config: Config, moe: FusedMoEConfig,
|
||||
num_dispatchers: int) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
|
||||
use_fp8 = config.quant_dtype == torch.float8_e4m3fn
|
||||
batch_kwargs = {
|
||||
"max_num_tokens": moe.max_num_tokens,
|
||||
"num_dispatchers": num_dispatchers,
|
||||
}
|
||||
quant_kwargs = {
|
||||
"use_fp8_w8a8": use_fp8,
|
||||
"use_int8_w8a8": False,
|
||||
"use_int8_w8a16": False,
|
||||
"use_int4_w4a16": False,
|
||||
"block_shape": config.quant_block_shape,
|
||||
"per_act_token_quant": config.is_per_act_token_quant,
|
||||
}
|
||||
deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()}
|
||||
|
||||
if config.fused_experts_type == BatchedDeepGemmExperts:
|
||||
kwargs = batch_kwargs | {
|
||||
"block_shape": config.quant_block_shape,
|
||||
"per_act_token_quant": config.is_per_act_token_quant,
|
||||
}
|
||||
print(f"Making BatchedDeepGemmExperts {kwargs} ...")
|
||||
experts = BatchedDeepGemmExperts(**kwargs)
|
||||
elif config.fused_experts_type == BatchedTritonExperts:
|
||||
kwargs = batch_kwargs | quant_kwargs
|
||||
print(f"Making BatchedTritonExperts {kwargs} ...")
|
||||
experts = BatchedTritonExperts(**kwargs)
|
||||
elif config.fused_experts_type == BatchedTritonOrDeepGemmExperts:
|
||||
kwargs = batch_kwargs | quant_kwargs | deepgemm_kwargs
|
||||
print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...")
|
||||
experts = BatchedTritonOrDeepGemmExperts(**kwargs)
|
||||
elif config.fused_experts_type == DeepGemmExperts:
|
||||
print("Making DeepGemmExperts () ...")
|
||||
experts = DeepGemmExperts()
|
||||
elif config.fused_experts_type == TritonExperts:
|
||||
kwargs = quant_kwargs
|
||||
print(f"Making TritonExperts {kwargs} ...")
|
||||
experts = TritonExperts(**kwargs)
|
||||
elif config.fused_experts_type == TritonOrDeepGemmExperts:
|
||||
kwargs = quant_kwargs | deepgemm_kwargs
|
||||
print(f"Making TritonOrDeepGemmExperts {kwargs} ...")
|
||||
experts = TritonOrDeepGemmExperts(**kwargs)
|
||||
elif config.fused_experts_type == NaiveBatchedExperts:
|
||||
kwargs = batch_kwargs | quant_kwargs
|
||||
print(f"Making NaiveBatchedExperts {kwargs} ...")
|
||||
experts = NaiveBatchedExperts(**kwargs)
|
||||
elif config.fused_experts_type == CutlassExpertsFp8:
|
||||
use_batched_format = config.is_batched_prepare_finalize()
|
||||
num_experts = (moe.num_local_experts
|
||||
if use_batched_format else moe.num_experts)
|
||||
kwargs = {
|
||||
"max_experts_per_worker": num_experts,
|
||||
"out_dtype": moe.in_dtype,
|
||||
"per_act_token_quant": config.is_per_act_token_quant,
|
||||
"per_out_ch_quant": config.is_per_out_ch_quant,
|
||||
"block_shape": config.quant_block_shape,
|
||||
"num_dispatchers": num_dispatchers,
|
||||
"use_batched_format": use_batched_format
|
||||
}
|
||||
print(f"Making CutlassExpertsFp8 {kwargs} ...")
|
||||
experts = CutlassExpertsFp8(**kwargs)
|
||||
|
||||
return experts
|
||||
|
||||
|
||||
def make_modular_kernel(config: Config,
|
||||
vllm_config: VllmConfig) -> mk.FusedMoEModularKernel:
|
||||
|
||||
def next_power_of_2(x):
|
||||
import math
|
||||
if x == 0:
|
||||
return 1
|
||||
return 2**math.ceil(math.log2(x))
|
||||
|
||||
# make moe config
|
||||
moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make(
|
||||
tp_size_=get_tensor_model_parallel_world_size(),
|
||||
dp_size_=get_dp_group().world_size,
|
||||
vllm_parallel_config=vllm_config.parallel_config,
|
||||
)
|
||||
moe = FusedMoEConfig(
|
||||
num_experts=config.E,
|
||||
experts_per_token=config.topk,
|
||||
hidden_dim=config.K,
|
||||
num_local_experts=config.num_local_experts,
|
||||
moe_parallel_config=moe_parallel_config,
|
||||
in_dtype=config.dtype,
|
||||
quant_config=config.quant_config,
|
||||
max_num_tokens=next_power_of_2(config.M),
|
||||
)
|
||||
|
||||
# make modular kernel
|
||||
prepare_finalize = None
|
||||
if config.needs_all2all():
|
||||
prepare_finalize = FusedMoEMethodBase.maybe_make_prepare_finalize(moe)
|
||||
assert prepare_finalize is not None
|
||||
else:
|
||||
prepare_finalize = MoEPrepareAndFinalizeNoEP()
|
||||
|
||||
fused_experts = make_fused_experts(config, moe,
|
||||
prepare_finalize.num_dispatchers())
|
||||
|
||||
modular_kernel = mk.FusedMoEModularKernel(
|
||||
prepare_finalize=prepare_finalize, fused_experts=fused_experts)
|
||||
|
||||
return modular_kernel
|
||||
|
||||
|
||||
def run_modular_kernel(
|
||||
pgi: ProcessGroupInfo,
|
||||
vllm_config: VllmConfig,
|
||||
config: Config,
|
||||
weights: WeightTensors,
|
||||
rank_tensors: RankTensors,
|
||||
) -> torch.Tensor:
|
||||
assert isinstance(config.Ms, int)
|
||||
assert isinstance(config.topks, int)
|
||||
|
||||
# weights for rank
|
||||
rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts)
|
||||
|
||||
mk = make_modular_kernel(config, vllm_config)
|
||||
|
||||
mk_kwargs = {
|
||||
"hidden_states": rank_tensors.hidden_states.clone(
|
||||
), # impls might update the tensor in place
|
||||
"w1": rank_weights.w1,
|
||||
"w2": rank_weights.w2,
|
||||
"topk_weights": rank_tensors.topk_weights,
|
||||
"topk_ids": rank_tensors.topk_ids,
|
||||
"expert_map": rank_tensors.expert_map,
|
||||
"w1_scale": rank_weights.w1_scale,
|
||||
"w2_scale": rank_weights.w2_scale,
|
||||
"a1_scale": rank_tensors.hidden_states_scale,
|
||||
"global_num_experts": config.E,
|
||||
"apply_router_weight_on_input": config.topk == 1,
|
||||
}
|
||||
out = mk.forward(**mk_kwargs)
|
||||
|
||||
return out
|
||||
173
tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py
Normal file
173
tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py
Normal file
@ -0,0 +1,173 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import copy
|
||||
from enum import Enum
|
||||
from itertools import product
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .common import (Config, RankTensors, WeightTensors, reference_moe_impl,
|
||||
run_modular_kernel)
|
||||
from .mk_objects import (MK_FUSED_EXPERT_TYPES,
|
||||
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_QUANT_CONFIGS)
|
||||
from .parallel_utils import ProcessGroupInfo, parallel_launch_with_config
|
||||
|
||||
|
||||
class Result(Enum):
|
||||
PASS = 1
|
||||
FAIL = 2
|
||||
SKIP = 3
|
||||
|
||||
|
||||
def rank_worker(
|
||||
pgi: ProcessGroupInfo,
|
||||
vllm_config: VllmConfig,
|
||||
cpu_group,
|
||||
config: Config,
|
||||
weights: WeightTensors,
|
||||
):
|
||||
current_platform.seed_everything(pgi.rank)
|
||||
|
||||
# sanity check
|
||||
from vllm import envs
|
||||
if config.fused_moe_chunk_size is not None:
|
||||
assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE)
|
||||
|
||||
# get weights to this device
|
||||
weights.to_current_device()
|
||||
|
||||
Ms = config.Ms
|
||||
assert isinstance(Ms, list)
|
||||
TOPKs = config.topks
|
||||
assert isinstance(TOPKs, list)
|
||||
|
||||
for m, topk in product(Ms, TOPKs):
|
||||
print(f"Running m={m}, topk={topk} ...")
|
||||
# override m and topk
|
||||
cfgx = copy.deepcopy(config)
|
||||
cfgx.Ms = m
|
||||
cfgx.topks = topk
|
||||
|
||||
# inputs for rank
|
||||
rank_tensors = RankTensors.make(cfgx, pgi)
|
||||
|
||||
# modular kernel out
|
||||
mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights,
|
||||
rank_tensors)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
ref_out = reference_moe_impl(cfgx, weights, rank_tensors)
|
||||
|
||||
torch.testing.assert_close(ref_out, mk_out, atol=3e-2, rtol=3e-2)
|
||||
|
||||
|
||||
def make_feature_matrix(csv_file_path: str):
|
||||
|
||||
from dataclasses import asdict
|
||||
|
||||
import pandas as pd
|
||||
|
||||
def add_to_results(config: Config,
|
||||
success: Result,
|
||||
results_df: Optional[pd.DataFrame] = None):
|
||||
config_dict = asdict(config)
|
||||
config_dict['prepare_finalize_type'] = config_dict[
|
||||
'prepare_finalize_type'].__name__
|
||||
config_dict['fused_experts_type'] = config_dict[
|
||||
'fused_experts_type'].__name__
|
||||
config_dict['per_tensor_act_quant'] = config.is_per_tensor_act_quant
|
||||
quant_config_dict = config_dict['quant_config']
|
||||
del config_dict['quant_config']
|
||||
if quant_config_dict is None:
|
||||
quant_config = FusedMoEQuantConfig(None)
|
||||
quant_config_dict = asdict(quant_config)
|
||||
|
||||
config_dict |= quant_config_dict
|
||||
result_dict = config_dict | {'success': success.name}
|
||||
|
||||
result_df = pd.DataFrame([result_dict])
|
||||
if results_df is None:
|
||||
results_df = result_df
|
||||
else:
|
||||
results_df = pd.concat([results_df, result_df], ignore_index=True)
|
||||
|
||||
return results_df
|
||||
|
||||
Ms = [64]
|
||||
Ks = [7168] # hidden sizes
|
||||
Ns = [2048]
|
||||
TOPKs = [[4, 1]]
|
||||
Es = [32]
|
||||
DTYPEs = [torch.bfloat16]
|
||||
PF_TYPES = MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
|
||||
FE_TYPES = MK_FUSED_EXPERT_TYPES
|
||||
Q_TYPES = MK_QUANT_CONFIGS
|
||||
|
||||
combinations = list(
|
||||
product(Ms, Ks, Ns, Es, TOPKs, DTYPEs, PF_TYPES, FE_TYPES, Q_TYPES))
|
||||
|
||||
results_df: Optional[pd.DataFrame] = None
|
||||
for m, k, n, e, topks, dtype, pf_type, experts_type, quant_config in tqdm(
|
||||
combinations): #noqa: E501
|
||||
config = Config(Ms=[m],
|
||||
K=k,
|
||||
N=n,
|
||||
E=e,
|
||||
topks=topks,
|
||||
dtype=dtype,
|
||||
prepare_finalize_type=pf_type,
|
||||
fused_experts_type=experts_type,
|
||||
quant_config=quant_config,
|
||||
world_size=2,
|
||||
fused_moe_chunk_size=None)
|
||||
|
||||
success = None
|
||||
if config.is_valid():
|
||||
print(f"Running config : {config.describe()} ...")
|
||||
try:
|
||||
weights: WeightTensors = WeightTensors.make(config)
|
||||
vllm_config, env_dict = config.make_env_data()
|
||||
parallel_launch_with_config(config.world_size, rank_worker,
|
||||
vllm_config, env_dict, config,
|
||||
weights)
|
||||
success = Result.PASS
|
||||
except Exception as _:
|
||||
success = Result.FAIL
|
||||
else:
|
||||
success = Result.SKIP
|
||||
|
||||
results_df = add_to_results(config, success, results_df)
|
||||
|
||||
if results_df is not None:
|
||||
results_df.to_csv(f"{csv_file_path}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
parser = argparse.ArgumentParser(description=(
|
||||
"Make ModularKernel feature matrix \n"
|
||||
"Example : python3 -m tests.kernels.moe.modular_kernel_tools.make_feature_matrix " #noqa: E501
|
||||
"-f ./feature_matrices/feature_matrix.csv"))
|
||||
|
||||
parser.add_argument("-f",
|
||||
"--feature-matrix-csv-file-path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="File name to Generate a .csv file")
|
||||
args = parser.parse_args()
|
||||
|
||||
csv_path = args.feature_matrix_csv_file_path
|
||||
assert csv_path.endswith(
|
||||
'csv'), f"Need a file path ending with .csv, got {csv_path}"
|
||||
assert Path(csv_path).parent.is_dir(
|
||||
), f"Cannot find parent directory for {Path(csv_path).parent}"
|
||||
|
||||
make_feature_matrix(args.feature_matrix_csv_file_path)
|
||||
87
tests/kernels/moe/modular_kernel_tools/mk_objects.py
Normal file
87
tests/kernels/moe/modular_kernel_tools/mk_objects.py
Normal file
@ -0,0 +1,87 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
# Fused experts and PrepareFinalize imports
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
|
||||
BatchedTritonOrDeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedTritonExperts, NaiveBatchedExperts)
|
||||
from vllm.model_executor.layers.fused_moe.layer import TritonExperts
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP)
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts)
|
||||
from vllm.utils import has_deep_ep, has_pplx
|
||||
|
||||
if has_deep_ep():
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
||||
DeepEPHTPrepareAndFinalize)
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
||||
DeepEPLLPrepareAndFinalize)
|
||||
|
||||
if has_pplx():
|
||||
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
||||
PplxPrepareAndFinalize)
|
||||
|
||||
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES = []
|
||||
if has_pplx():
|
||||
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES += [PplxPrepareAndFinalize]
|
||||
if has_deep_ep():
|
||||
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES += [
|
||||
DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize
|
||||
]
|
||||
|
||||
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES = [MoEPrepareAndFinalizeNoEP]
|
||||
|
||||
MK_ALL_PREPARE_FINALIZE_TYPES = (MK_MULTI_GPU_PREPARE_FINALIZE_TYPES +
|
||||
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES)
|
||||
|
||||
MK_FUSED_EXPERT_TYPES = [
|
||||
BatchedDeepGemmExperts,
|
||||
BatchedTritonExperts,
|
||||
NaiveBatchedExperts,
|
||||
BatchedTritonOrDeepGemmExperts,
|
||||
CutlassExpertsFp8,
|
||||
DeepGemmExperts,
|
||||
TritonOrDeepGemmExperts,
|
||||
TritonExperts,
|
||||
]
|
||||
|
||||
MK_QUANT_CONFIGS = [
|
||||
None,
|
||||
# per-channel / per-column weights and per-tensor activations
|
||||
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=True,
|
||||
per_act_token_quant=False,
|
||||
block_shape=None),
|
||||
# per-channel / per-column weights and per-token activations
|
||||
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=True,
|
||||
per_act_token_quant=True,
|
||||
block_shape=None),
|
||||
# per-tensor weights and per-tensor activations
|
||||
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=None),
|
||||
# per-tensor weights and per-token activations
|
||||
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=False,
|
||||
per_act_token_quant=True,
|
||||
block_shape=None),
|
||||
# block-quantized weights and 128 block per-token activations
|
||||
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=[128, 128]),
|
||||
# TODO (varun) : Should we test the following combinations ?
|
||||
# block-quantized weights and per-token activations
|
||||
# block-quantized weights and per-tensor activations
|
||||
]
|
||||
138
tests/kernels/moe/modular_kernel_tools/parallel_utils.py
Normal file
138
tests/kernels/moe/modular_kernel_tools/parallel_utils.py
Normal file
@ -0,0 +1,138 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import dataclasses
|
||||
import os
|
||||
import traceback
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch.multiprocessing import (
|
||||
spawn) # pyright: ignore[reportPrivateImportUsage]
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed import (init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
from vllm.utils import get_open_port
|
||||
|
||||
## Parallel Processes Utils
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ProcessGroupInfo:
|
||||
world_size: int
|
||||
world_local_size: int
|
||||
rank: int
|
||||
node_rank: int
|
||||
local_rank: int
|
||||
device: torch.device
|
||||
|
||||
|
||||
def _set_vllm_config(vllm_config: VllmConfig, world_size: int, rank: int,
|
||||
local_rank: int):
|
||||
|
||||
import tempfile
|
||||
temp_file = tempfile.mkstemp()[1]
|
||||
|
||||
set_current_vllm_config(vllm_config)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
init_distributed_environment(
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
distributed_init_method=f"file://{temp_file}",
|
||||
local_rank=local_rank,
|
||||
backend="nccl",
|
||||
)
|
||||
|
||||
initialize_model_parallel(
|
||||
tensor_model_parallel_size=vllm_config.parallel_config.
|
||||
tensor_parallel_size,
|
||||
pipeline_model_parallel_size=vllm_config.parallel_config.
|
||||
pipeline_parallel_size,
|
||||
)
|
||||
cpu_group = torch.distributed.new_group(list(range(world_size)),
|
||||
backend="gloo")
|
||||
return cpu_group
|
||||
|
||||
|
||||
def _worker_parallel_launch(
|
||||
local_rank: int,
|
||||
world_size: int,
|
||||
world_local_size: int,
|
||||
node_rank: int,
|
||||
init_method: str,
|
||||
worker: Callable[Concatenate[ProcessGroupInfo, Optional[VllmConfig], Any,
|
||||
P], None],
|
||||
vllm_config: Optional[VllmConfig],
|
||||
env_dict: Optional[dict],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
rank = node_rank * world_local_size + local_rank
|
||||
torch.cuda.set_device(local_rank)
|
||||
device = torch.device("cuda", local_rank)
|
||||
torch.distributed.init_process_group(
|
||||
backend="cpu:gloo,cuda:nccl",
|
||||
init_method=init_method,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
device_id=device,
|
||||
)
|
||||
barrier = torch.tensor([rank], device=device)
|
||||
torch.distributed.all_reduce(barrier)
|
||||
|
||||
if env_dict is not None:
|
||||
os.environ.update(env_dict)
|
||||
|
||||
cpu_group = None
|
||||
if vllm_config is not None:
|
||||
cpu_group = _set_vllm_config(vllm_config, world_size, rank, local_rank)
|
||||
|
||||
try:
|
||||
worker(
|
||||
ProcessGroupInfo(
|
||||
world_size=world_size,
|
||||
world_local_size=world_local_size,
|
||||
rank=rank,
|
||||
node_rank=node_rank,
|
||||
local_rank=local_rank,
|
||||
device=device,
|
||||
),
|
||||
vllm_config,
|
||||
cpu_group,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
traceback.print_exc()
|
||||
raise
|
||||
finally:
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
def parallel_launch_with_config(
|
||||
world_size: int,
|
||||
worker: Callable[Concatenate[ProcessGroupInfo, VllmConfig, Any, P], None],
|
||||
vllm_config: VllmConfig,
|
||||
env_dict: dict[Any, Any],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
assert not kwargs
|
||||
spawn(
|
||||
_worker_parallel_launch,
|
||||
args=(
|
||||
world_size,
|
||||
world_size,
|
||||
0,
|
||||
f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}",
|
||||
worker,
|
||||
vllm_config,
|
||||
env_dict,
|
||||
) + args,
|
||||
nprocs=world_size,
|
||||
join=True,
|
||||
)
|
||||
127
tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py
Normal file
127
tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py
Normal file
@ -0,0 +1,127 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import copy
|
||||
from itertools import product
|
||||
from typing import Any, Callable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .common import Config, RankTensors, WeightTensors, make_modular_kernel
|
||||
from .parallel_utils import ProcessGroupInfo, parallel_launch_with_config
|
||||
|
||||
|
||||
def do_profile(fn: Callable,
|
||||
fn_kwargs: dict[Any, Any],
|
||||
pgi: ProcessGroupInfo,
|
||||
config: Config,
|
||||
num_warmups: int = 5):
|
||||
for _ in range(num_warmups):
|
||||
fn(**fn_kwargs)
|
||||
|
||||
with torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
with_stack=True,
|
||||
record_shapes=True,
|
||||
) as tprof:
|
||||
fn(**fn_kwargs)
|
||||
torch.cuda.synchronize(torch.cuda.current_device())
|
||||
|
||||
# TODO (varun): Add a descriptive trace file name
|
||||
tprof.export_chrome_trace(
|
||||
f"{config.torch_trace_dir_path}/m{config.M}_{pgi.rank}_trace.json")
|
||||
|
||||
|
||||
def profile_modular_kernel(
|
||||
pgi: ProcessGroupInfo,
|
||||
vllm_config: VllmConfig,
|
||||
config: Config,
|
||||
weights: WeightTensors,
|
||||
rank_tensors: RankTensors,
|
||||
) -> None:
|
||||
assert isinstance(config.Ms, int)
|
||||
assert isinstance(config.topks, int)
|
||||
|
||||
# weights for rank
|
||||
rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts)
|
||||
|
||||
# make modular kernel
|
||||
mk = make_modular_kernel(config, vllm_config)
|
||||
|
||||
mk_kwargs = {
|
||||
"hidden_states": rank_tensors.hidden_states,
|
||||
"w1": rank_weights.w1,
|
||||
"w2": rank_weights.w2,
|
||||
"topk_weights": rank_tensors.topk_weights,
|
||||
"topk_ids": rank_tensors.topk_ids,
|
||||
"expert_map": rank_tensors.expert_map,
|
||||
"w1_scale": rank_weights.w1_scale,
|
||||
"w2_scale": rank_weights.w2_scale,
|
||||
"a1_scale": rank_tensors.hidden_states_scale,
|
||||
"global_num_experts": config.E,
|
||||
"apply_router_weight_on_input": config.topk == 1,
|
||||
}
|
||||
|
||||
do_profile(mk.forward, mk_kwargs, pgi, config)
|
||||
|
||||
|
||||
def rank_worker(
|
||||
pgi: ProcessGroupInfo,
|
||||
vllm_config: VllmConfig,
|
||||
cpu_group,
|
||||
config: Config,
|
||||
weights: WeightTensors,
|
||||
):
|
||||
current_platform.seed_everything(pgi.rank)
|
||||
|
||||
# sanity check
|
||||
from vllm import envs
|
||||
if config.fused_moe_chunk_size is not None:
|
||||
assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE)
|
||||
|
||||
# get weights to this device
|
||||
weights.to_current_device()
|
||||
|
||||
Ms = config.Ms
|
||||
assert isinstance(Ms, list)
|
||||
TOPKs = config.topks
|
||||
assert isinstance(TOPKs, list)
|
||||
|
||||
for m, topk in product(Ms, TOPKs):
|
||||
print(f"Running m={m}, topk={topk} ...")
|
||||
# override m and topk
|
||||
cfgx = copy.deepcopy(config)
|
||||
cfgx.Ms = m
|
||||
cfgx.topks = topk
|
||||
|
||||
# inputs for rank
|
||||
rank_tensors = RankTensors.make(cfgx, pgi)
|
||||
profile_modular_kernel(pgi, vllm_config, cfgx, weights, rank_tensors)
|
||||
|
||||
|
||||
def run(config: Config):
|
||||
weights: WeightTensors = WeightTensors.make(config)
|
||||
vllm_config, env_dict = config.make_env_data()
|
||||
parallel_launch_with_config(config.world_size, rank_worker, vllm_config,
|
||||
env_dict, config, weights)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from .cli_args import make_config, make_config_arg_parser
|
||||
parser = make_config_arg_parser(description=(
|
||||
"Run single prepare-finalize & fused-experts combination test"
|
||||
"Example : python3 -m tests.kernels.moe.modular_kernel_tools.profile_modular_kernel " #noqa: E501
|
||||
"--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts"
|
||||
))
|
||||
args = parser.parse_args()
|
||||
assert args.torch_trace_dir_path is not None, (
|
||||
"Please pass in a directory to store torch traces")
|
||||
config = make_config(args)
|
||||
|
||||
run(config)
|
||||
142
tests/kernels/moe/modular_kernel_tools/utils.py
Normal file
142
tests/kernels/moe/modular_kernel_tools/utils.py
Normal file
@ -0,0 +1,142 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
|
||||
|
||||
def per_token_cast_to_fp8(
|
||||
x: torch.Tensor, block_size: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
pad_size = (block_size - (n % block_size)) % block_size
|
||||
x = torch.nn.functional.pad(x,
|
||||
(0, pad_size), value=0) if pad_size > 0 else x
|
||||
x_view = x.view(m, -1, block_size)
|
||||
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
||||
fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
|
||||
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
|
||||
|
||||
|
||||
def per_block_cast_to_fp8(
|
||||
x: torch.Tensor, block_size_k: int,
|
||||
block_size_n: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
x_padded = torch.zeros(
|
||||
(
|
||||
int(math.ceil(m / block_size_k)) * block_size_k,
|
||||
int(math.ceil(n / block_size_n)) * block_size_n,
|
||||
),
|
||||
dtype=x.dtype,
|
||||
device=x.device,
|
||||
)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, block_size_k,
|
||||
x_padded.size(1) // block_size_k, block_size_n)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
|
||||
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
|
||||
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
|
||||
return x_scaled_sub, scales
|
||||
|
||||
|
||||
def make_non_quant_weights(
|
||||
e: int,
|
||||
n: int,
|
||||
k: int,
|
||||
dtype: torch.dtype,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Return weights w1, w2
|
||||
"""
|
||||
device = torch.cuda.current_device()
|
||||
w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 15
|
||||
w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 15
|
||||
return w1, w2
|
||||
|
||||
|
||||
def make_block_quant_fp8_weights(
|
||||
e: int,
|
||||
n: int,
|
||||
k: int,
|
||||
block_size: list[int],
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Return weights w1, w2, w1_scale, w2_scale
|
||||
"""
|
||||
dtype = torch.bfloat16
|
||||
device = torch.cuda.current_device()
|
||||
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
w1_bf16, w2_bf16 = make_non_quant_weights(e, n, k, dtype)
|
||||
w1_bf16 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
|
||||
w2_bf16 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
|
||||
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
n_tiles_w1 = ((2 * n) + block_n - 1) // block_n
|
||||
k_tiles_w1 = (k + block_k - 1) // block_k
|
||||
n_tiles_w2 = (k + block_n - 1) // block_n
|
||||
k_tiles_w2 = (n + block_k - 1) // block_k
|
||||
|
||||
w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn, device=device)
|
||||
w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn, device=device)
|
||||
|
||||
w1_s = torch.empty((e, n_tiles_w1, k_tiles_w1),
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
w2_s = torch.empty((e, n_tiles_w2, k_tiles_w2),
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
|
||||
assert w1_s.shape == (e, (2 * n + (block_n - 1)) // block_n,
|
||||
(k + (block_k - 1)) // block_k)
|
||||
assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2]
|
||||
|
||||
for i in range(e):
|
||||
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i],
|
||||
block_size_k=block_k,
|
||||
block_size_n=block_n)
|
||||
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i],
|
||||
block_size_k=block_k,
|
||||
block_size_n=block_n)
|
||||
|
||||
return w1, w2, w1_s, w2_s
|
||||
|
||||
|
||||
def make_quant_fp8_weights(
|
||||
e: int,
|
||||
n: int,
|
||||
k: int,
|
||||
per_out_channel_quant: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Return w1, w2, w1_scale, w2_scale
|
||||
"""
|
||||
q_dtype = torch.float8_e4m3fn
|
||||
|
||||
w1, w2 = make_non_quant_weights(e, n, k, dtype=torch.bfloat16)
|
||||
|
||||
# w1 -> w1_q, w2 -> w2_q
|
||||
w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype)
|
||||
w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype)
|
||||
|
||||
n_b_scales = 2 * n if per_out_channel_quant else 1
|
||||
k_b_scales = k if per_out_channel_quant else 1
|
||||
w1_scale = torch.empty((e, n_b_scales, 1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
w2_scale = torch.empty((e, k_b_scales, 1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
|
||||
for expert in range(e):
|
||||
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
|
||||
w1[expert], use_per_token_if_dynamic=per_out_channel_quant)
|
||||
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
|
||||
w2[expert], use_per_token_if_dynamic=per_out_channel_quant)
|
||||
return w1_q, w2_q, w1_scale, w2_scale
|
||||
@ -4,7 +4,6 @@
|
||||
DeepEP test utilities
|
||||
"""
|
||||
import dataclasses
|
||||
import importlib
|
||||
import os
|
||||
import traceback
|
||||
from typing import Callable, Optional
|
||||
@ -15,10 +14,9 @@ from torch.multiprocessing import (
|
||||
spawn) # pyright: ignore[reportPrivateImportUsage]
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
|
||||
from vllm.utils import get_open_port
|
||||
from vllm.utils import get_open_port, has_deep_ep
|
||||
|
||||
has_deep_ep = importlib.util.find_spec("deep_ep") is not None
|
||||
if has_deep_ep:
|
||||
if has_deep_ep():
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
||||
DeepEPHTPrepareAndFinalize)
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
||||
|
||||
@ -15,13 +15,13 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_topk, modular_triton_fused_moe)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import has_deep_gemm
|
||||
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
|
||||
|
||||
dg_available = False
|
||||
try:
|
||||
import deep_gemm
|
||||
dg_available = True
|
||||
except ImportError:
|
||||
pass
|
||||
dg_available = has_deep_gemm()
|
||||
|
||||
if dg_available:
|
||||
from deep_gemm import get_m_alignment_for_contiguous_layout
|
||||
|
||||
if current_platform.get_device_capability() < (9, 0):
|
||||
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
|
||||
@ -224,6 +224,7 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
|
||||
@pytest.mark.skipif(is_blackwell_deep_gemm_used(), reason="Not E8M0 scale MOE")
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
|
||||
monkeypatch):
|
||||
@ -238,8 +239,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
|
||||
torch.manual_seed(seed)
|
||||
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
|
||||
|
||||
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
|
||||
block_m = get_m_alignment_for_contiguous_layout()
|
||||
block_size = [block_m, block_m]
|
||||
dtype = torch.bfloat16
|
||||
|
||||
|
||||
140
tests/kernels/moe/test_count_expert_num_tokens.py
Normal file
140
tests/kernels/moe/test_count_expert_num_tokens.py
Normal file
@ -0,0 +1,140 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Tests compute_expert_num_tokens kernels
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.utils import count_expert_num_tokens
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TestTensors:
|
||||
|
||||
topk_ids: torch.Tensor
|
||||
expert_map: Optional[torch.Tensor] = None
|
||||
|
||||
def to_device(self, device: str):
|
||||
self.topk_ids = self.topk_ids.to(device=device)
|
||||
if self.expert_map is not None:
|
||||
self.expert_map = self.expert_map.to(device=device)
|
||||
|
||||
@staticmethod
|
||||
def make(num_tokens: int, num_topk: int, num_experts: int, device: str,
|
||||
topk_ids_dtype: torch.dtype) -> "TestTensors":
|
||||
|
||||
# make topk ids
|
||||
topk_ids = torch.empty((num_tokens, num_topk),
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
for x in range(num_tokens):
|
||||
topk_ids[x] = torch.randperm(num_experts)[:num_topk]
|
||||
topk_ids = topk_ids.to(dtype=torch.int64)
|
||||
return TestTensors(topk_ids=topk_ids)
|
||||
|
||||
def with_ep_rank(self, ep_rank: int, num_global_experts: int,
|
||||
num_local_experts: int, device: str):
|
||||
# make an expert map
|
||||
expert_map = torch.empty((num_global_experts),
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
expert_map.fill_(-1)
|
||||
s = ep_rank * num_local_experts
|
||||
e = s + num_local_experts
|
||||
expert_map[s:e] = torch.tensor(list(range(num_local_experts)),
|
||||
device=device)
|
||||
|
||||
return TestTensors(topk_ids=self.topk_ids.clone(),
|
||||
expert_map=expert_map)
|
||||
|
||||
|
||||
def ref_impl(tt: TestTensors, expert_num_tokens: torch.Tensor):
|
||||
# do the reference in cpu
|
||||
tt.to_device("cpu")
|
||||
expert_ids, counts = tt.topk_ids.unique(return_counts=True)
|
||||
|
||||
for eid, count in zip(expert_ids, counts):
|
||||
if eid != -1 and tt.expert_map is not None:
|
||||
eid = tt.expert_map[eid]
|
||||
|
||||
if eid == -1:
|
||||
continue
|
||||
|
||||
expert_num_tokens[eid] += count
|
||||
|
||||
|
||||
def do_test_compute_expert_num_tokens(num_tokens: int, num_topk: int,
|
||||
num_experts: int, ep_size: int,
|
||||
topk_ids_dtype: torch.dtype):
|
||||
|
||||
assert num_topk <= num_experts
|
||||
|
||||
tt = TestTensors.make(num_tokens,
|
||||
num_topk,
|
||||
num_experts,
|
||||
topk_ids_dtype=topk_ids_dtype,
|
||||
device="cpu")
|
||||
|
||||
num_global_experts = num_experts
|
||||
assert num_global_experts % ep_size == 0
|
||||
num_local_experts = num_global_experts // ep_size
|
||||
for ep_rank in range(ep_size):
|
||||
tt_rank = tt.with_ep_rank(ep_rank, num_global_experts,
|
||||
num_local_experts, "cpu")
|
||||
|
||||
ref_expert_num_tokens = torch.zeros((num_local_experts),
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
ref_impl(tt_rank, ref_expert_num_tokens)
|
||||
ref_expert_num_tokens = ref_expert_num_tokens.to("cuda")
|
||||
|
||||
tt_rank.to_device("cuda")
|
||||
# Test with expert_map
|
||||
triton_expert_num_tokens_w_emap = count_expert_num_tokens(
|
||||
tt_rank.topk_ids, num_local_experts, tt_rank.expert_map)
|
||||
|
||||
# Test without expert map
|
||||
topk_ids = tt_rank.expert_map[tt_rank.topk_ids].to(topk_ids_dtype)
|
||||
triton_expert_num_tokens_wo_emap = count_expert_num_tokens(
|
||||
topk_ids, num_local_experts, expert_map=None)
|
||||
|
||||
torch.testing.assert_close(ref_expert_num_tokens,
|
||||
triton_expert_num_tokens_w_emap,
|
||||
atol=0,
|
||||
rtol=0)
|
||||
torch.testing.assert_close(ref_expert_num_tokens,
|
||||
triton_expert_num_tokens_wo_emap,
|
||||
atol=0,
|
||||
rtol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_tokens", [1, 4, 8, 11, 19, 128, 127, 405, 1024, 3333, 6666, 7317])
|
||||
@pytest.mark.parametrize("num_topk", [2, 6, 8])
|
||||
@pytest.mark.parametrize("num_experts", [64])
|
||||
@pytest.mark.parametrize("ep_size", [1, 2, 4])
|
||||
@pytest.mark.parametrize("topk_ids_dtype", [torch.int64])
|
||||
def test_compute_expert_num_tokens(num_tokens: int, num_topk: int,
|
||||
num_experts: int, ep_size: int,
|
||||
topk_ids_dtype: torch.dtype):
|
||||
do_test_compute_expert_num_tokens(num_tokens, num_topk, num_experts,
|
||||
ep_size, topk_ids_dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("numel", list(range(1, 8192, 11)))
|
||||
@pytest.mark.parametrize("num_experts", [32])
|
||||
@pytest.mark.parametrize("ep_size", [2])
|
||||
@pytest.mark.parametrize("topk_ids_dtype", [torch.int64])
|
||||
def test_compute_expert_num_tokens_from_numel(numel: int, num_experts: int,
|
||||
ep_size: int,
|
||||
topk_ids_dtype: torch.dtype):
|
||||
do_test_compute_expert_num_tokens(num_tokens=numel,
|
||||
num_topk=1,
|
||||
num_experts=num_experts,
|
||||
ep_size=ep_size,
|
||||
topk_ids_dtype=topk_ids_dtype)
|
||||
@ -20,6 +20,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import has_deep_ep, has_deep_gemm
|
||||
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
|
||||
|
||||
from .parallel_utils import ProcessGroupInfo, parallel_launch
|
||||
from .utils import make_test_weights
|
||||
@ -368,6 +369,8 @@ NUM_EXPERTS = [32]
|
||||
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
|
||||
@requires_deep_ep
|
||||
@requires_deep_gemm
|
||||
@pytest.mark.skipif(is_blackwell_deep_gemm_used(),
|
||||
reason="Skipping test for Blackwell DeepGEMM")
|
||||
def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
|
||||
topk: int, world_dp_size: tuple[int, int]):
|
||||
"""
|
||||
@ -423,6 +426,8 @@ USE_FP8_DISPATCH = [False]
|
||||
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
|
||||
@requires_deep_ep
|
||||
@requires_deep_gemm
|
||||
@pytest.mark.skipif(is_blackwell_deep_gemm_used(),
|
||||
reason="Skipping test for Blackwell DeepGEMM")
|
||||
def test_ll_deepep_deepgemm_moe(
|
||||
mnk: tuple[int, int, int],
|
||||
num_experts: int,
|
||||
|
||||
@ -13,48 +13,18 @@ import torch
|
||||
|
||||
# vLLM fused-expert reference (Triton fallback + DeepGEMM option)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
from vllm.utils import cdiv
|
||||
from vllm.utils import has_deep_gemm
|
||||
from vllm.utils.deep_gemm import (calc_diff, per_block_cast_to_fp8,
|
||||
per_token_group_cast_to_fp8)
|
||||
|
||||
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
|
||||
|
||||
if has_deep_gemm:
|
||||
import deep_gemm
|
||||
BLOCK_M = deep_gemm.get_m_alignment_for_contiguous_layout()
|
||||
BLOCK_SIZE = [BLOCK_M, BLOCK_M]
|
||||
BLOCK_SIZE = [128, 128]
|
||||
|
||||
requires_deep_gemm = pytest.mark.skipif(
|
||||
not has_deep_gemm,
|
||||
not has_deep_gemm(),
|
||||
reason="Requires deep_gemm kernels",
|
||||
)
|
||||
|
||||
|
||||
def calc_diff(x: torch.Tensor, y: torch.Tensor):
|
||||
x, y = x.double(), y.double()
|
||||
denominator = (x * x + y * y).sum()
|
||||
sim = 2 * (x * y).sum() / denominator
|
||||
return 1 - sim
|
||||
|
||||
|
||||
def per_block_cast_to_fp8(
|
||||
x: torch.Tensor,
|
||||
block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
x_padded = torch.zeros(
|
||||
(cdiv(m, 128) * 128, cdiv(n, block_size_n) * block_size_n),
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
|
||||
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
|
||||
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
|
||||
return x_scaled_sub, scales
|
||||
|
||||
|
||||
def make_block_quant_fp8_weights(
|
||||
e: int,
|
||||
n: int,
|
||||
@ -111,7 +81,7 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
|
||||
"""
|
||||
tokens_bf16 = torch.randn(
|
||||
m, k, device="cuda", dtype=torch.bfloat16).clamp_min_(-1).clamp_max_(1)
|
||||
_, a1_scale = per_token_group_quant_fp8(tokens_bf16, block_size[1])
|
||||
_, a1_scale = per_token_group_cast_to_fp8(tokens_bf16, block_size[1])
|
||||
|
||||
# expert weight tensors
|
||||
w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(num_experts, n, k,
|
||||
@ -155,17 +125,8 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
|
||||
block_shape=block_size,
|
||||
allow_deep_gemm=True,
|
||||
)
|
||||
|
||||
base = out_triton.abs().mean()
|
||||
atol = 0.1 * base.clamp(min=1e-2) # 10% of mean, but not lower than 1e-3
|
||||
rtol = 0.05
|
||||
# ----- Compare -----
|
||||
torch.testing.assert_close(
|
||||
out_deepgemm.to(torch.float32),
|
||||
out_triton.to(torch.float32),
|
||||
rtol=rtol,
|
||||
atol=float(atol),
|
||||
)
|
||||
diff = calc_diff(out_deepgemm, out_triton)
|
||||
assert diff < 0.001, f"Diff exceeded 1%: {diff}"
|
||||
|
||||
|
||||
# Note: W1 has shape (E, 2N, K), so N = 512
|
||||
|
||||
214
tests/kernels/moe/test_modular_kernel_combinations.py
Normal file
214
tests/kernels/moe/test_modular_kernel_combinations.py
Normal file
@ -0,0 +1,214 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import copy
|
||||
from itertools import product
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.config import VllmConfig, current_platform, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
|
||||
BatchedTritonOrDeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedTritonExperts)
|
||||
from vllm.model_executor.layers.fused_moe.layer import TritonExperts
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts)
|
||||
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
|
||||
|
||||
from .modular_kernel_tools.common import (Config, RankTensors, WeightTensors,
|
||||
reference_moe_impl,
|
||||
run_modular_kernel)
|
||||
from .modular_kernel_tools.mk_objects import (
|
||||
MK_FUSED_EXPERT_TYPES, MK_MULTI_GPU_PREPARE_FINALIZE_TYPES,
|
||||
MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES)
|
||||
from .modular_kernel_tools.parallel_utils import (ProcessGroupInfo,
|
||||
parallel_launch_with_config)
|
||||
|
||||
# TODO (varun): These requirements are very strict and could be relaxed.
|
||||
has_all_packages = (has_deep_ep() and has_deep_gemm() and has_pplx())
|
||||
|
||||
meets_package_requirements = pytest.mark.skipif(
|
||||
not has_all_packages,
|
||||
reason="Requires deep_ep & deep_gemm & pplx packages",
|
||||
)
|
||||
|
||||
|
||||
def rank_worker(
|
||||
pgi: ProcessGroupInfo,
|
||||
vllm_config: VllmConfig,
|
||||
cpu_group,
|
||||
config: Config,
|
||||
weights: WeightTensors,
|
||||
):
|
||||
current_platform.seed_everything(pgi.rank)
|
||||
|
||||
# sanity check
|
||||
from vllm import envs
|
||||
if config.fused_moe_chunk_size is not None:
|
||||
assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE)
|
||||
|
||||
# get weights to this device
|
||||
weights.to_current_device()
|
||||
|
||||
Ms = config.Ms
|
||||
assert isinstance(Ms, list)
|
||||
TOPKs = config.topks
|
||||
assert isinstance(TOPKs, list)
|
||||
|
||||
for m, topk in product(Ms, TOPKs):
|
||||
print(f"Running m={m}, topk={topk} ...")
|
||||
# override m and topk
|
||||
cfgx = copy.deepcopy(config)
|
||||
cfgx.Ms = m
|
||||
cfgx.topks = topk
|
||||
|
||||
# inputs for rank
|
||||
rank_tensors = RankTensors.make(cfgx, pgi)
|
||||
|
||||
# modular kernel out
|
||||
mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights,
|
||||
rank_tensors)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
ref_out = reference_moe_impl(cfgx, weights, rank_tensors)
|
||||
|
||||
torch.testing.assert_close(ref_out, mk_out, atol=3e-2, rtol=3e-2)
|
||||
|
||||
|
||||
def run(config: Config):
|
||||
assert config.is_valid()
|
||||
print(f"Testing config \n{config.describe()} ...")
|
||||
|
||||
weights: WeightTensors = WeightTensors.make(config)
|
||||
|
||||
vllm_config, env_dict = config.make_env_data()
|
||||
parallel_launch_with_config(config.world_size, rank_worker, vllm_config,
|
||||
env_dict, config, weights)
|
||||
|
||||
|
||||
Ms = [32, 64]
|
||||
Ks = [7168] # hidden sizes
|
||||
Ns = [2048]
|
||||
TOPKs = [4, 1]
|
||||
Es = [32]
|
||||
DTYPEs = [torch.bfloat16]
|
||||
FUSED_MOE_CHUNK_SIZEs = [None, 16]
|
||||
|
||||
|
||||
def is_nyi_config(config: Config) -> bool:
|
||||
# We know these configs to be legitimate. but still fail.
|
||||
|
||||
if (config.fused_experts_type in [
|
||||
BatchedTritonExperts, BatchedTritonOrDeepGemmExperts,
|
||||
TritonExperts, TritonOrDeepGemmExperts
|
||||
]):
|
||||
# The triton kernels expect both per-act-token-quant and
|
||||
# per-out-ch-quant or neither.
|
||||
unsupported_quant_config = ((config.is_per_act_token_quant +
|
||||
config.is_per_out_ch_quant) == 1)
|
||||
return unsupported_quant_config
|
||||
|
||||
# cutlass kernels dont support expert_maps yet.
|
||||
return config.fused_experts_type == CutlassExpertsFp8
|
||||
|
||||
|
||||
@pytest.mark.parametrize("k", Ks)
|
||||
@pytest.mark.parametrize("n", Ns)
|
||||
@pytest.mark.parametrize("e", Es)
|
||||
@pytest.mark.parametrize("dtype", DTYPEs)
|
||||
@pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS)
|
||||
@pytest.mark.parametrize(
|
||||
"combination",
|
||||
product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES))
|
||||
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
|
||||
@pytest.mark.parametrize("world_size", [2])
|
||||
@meets_package_requirements
|
||||
def test_modular_kernel_combinations_multigpu(
|
||||
k: int, n: int, e: int, dtype: torch.dtype,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
combination: tuple[mk.FusedMoEPrepareAndFinalize,
|
||||
mk.FusedMoEPermuteExpertsUnpermute],
|
||||
fused_moe_chunk_size: Optional[int], world_size: int):
|
||||
|
||||
config = Config(
|
||||
Ms=Ms,
|
||||
K=k,
|
||||
N=n,
|
||||
E=e,
|
||||
topks=TOPKs,
|
||||
dtype=dtype,
|
||||
quant_config=quant_config,
|
||||
prepare_finalize_type=combination[0],
|
||||
fused_experts_type=combination[1],
|
||||
fused_moe_chunk_size=fused_moe_chunk_size,
|
||||
world_size=world_size,
|
||||
)
|
||||
if not config.is_valid():
|
||||
pytest.skip(f"Tests config {config} is not valid. Skipping ...")
|
||||
|
||||
if is_nyi_config(config):
|
||||
pytest.skip(f"Tests config {config} is nyi. Skipping ...")
|
||||
|
||||
print(f"{config.describe()}")
|
||||
run(config)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("k", Ks)
|
||||
@pytest.mark.parametrize("n", Ns)
|
||||
@pytest.mark.parametrize("e", Es)
|
||||
@pytest.mark.parametrize("dtype", DTYPEs)
|
||||
@pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS)
|
||||
@pytest.mark.parametrize(
|
||||
"combination",
|
||||
product(MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES))
|
||||
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
|
||||
@pytest.mark.parametrize("world_size", [1])
|
||||
@meets_package_requirements
|
||||
def test_modular_kernel_combinations_singlegpu(
|
||||
k: int, n: int, e: int, dtype: torch.dtype,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
combination: tuple[mk.FusedMoEPrepareAndFinalize,
|
||||
mk.FusedMoEPermuteExpertsUnpermute],
|
||||
fused_moe_chunk_size: Optional[int], world_size: int):
|
||||
config = Config(
|
||||
Ms=Ms,
|
||||
K=k,
|
||||
N=n,
|
||||
E=e,
|
||||
topks=TOPKs,
|
||||
dtype=dtype,
|
||||
quant_config=quant_config,
|
||||
prepare_finalize_type=combination[0],
|
||||
fused_experts_type=combination[1],
|
||||
fused_moe_chunk_size=fused_moe_chunk_size,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
if not config.is_valid():
|
||||
pytest.skip(f"Tests config {config} is not valid. Skipping ...")
|
||||
|
||||
if is_nyi_config(config):
|
||||
pytest.skip(f"Tests config {config} is nyi. Skipping ...")
|
||||
|
||||
run(config)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Ability to test individual PrepareAndFinalize and FusedExperts combination
|
||||
from .modular_kernel_tools.cli_args import (make_config,
|
||||
make_config_arg_parser)
|
||||
parser = make_config_arg_parser(description=(
|
||||
"Run single prepare-finalize & fused-experts combination test"
|
||||
"Example : python3 -m tests.kernels.moe.test_modular_kernel_combinations " #noqa: E501
|
||||
"--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts"
|
||||
))
|
||||
args = parser.parse_args()
|
||||
config = make_config(args)
|
||||
|
||||
run(config)
|
||||
@ -174,6 +174,7 @@ def test_fused_moe(
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
use_mxfp4_w4a4=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=None)
|
||||
|
||||
|
||||
57
tests/kernels/moe/test_mxfp4_moe.py
Normal file
57
tests/kernels/moe/test_mxfp4_moe.py
Normal file
@ -0,0 +1,57 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import importlib
|
||||
import importlib.metadata
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
QUARK_MXFP4_AVAILABLE = importlib.util.find_spec(
|
||||
"quark") is not None and version.parse(
|
||||
importlib.metadata.version("amd-quark")) >= version.parse('0.8.99')
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelCase:
|
||||
model_id: str
|
||||
tp: int
|
||||
|
||||
|
||||
@pytest.mark.parametrize('model_case', [
|
||||
ModelCase("fxmarty/qwen_1.5-moe-a2.7b-mxfp4", tp=1),
|
||||
ModelCase("fxmarty/deepseek_r1_3_layers_mxfp4", tp=8),
|
||||
ModelCase("fxmarty/Llama-4-Scout-17B-16E-Instruct-2-layers-mxfp4", tp=1)
|
||||
])
|
||||
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE,
|
||||
reason="amd-quark>=0.9 is not available")
|
||||
def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
|
||||
if torch.cuda.device_count() < model_case.tp:
|
||||
pytest.skip(f"This test requires >={model_case.tp} gpus, got only "
|
||||
f"{torch.cuda.device_count()}")
|
||||
|
||||
with vllm_runner(model_case.model_id,
|
||||
tensor_parallel_size=model_case.tp,
|
||||
load_format="dummy") as llm:
|
||||
|
||||
# TODO: llm.apply_model(check_model) currently relies on V0 internals.
|
||||
# Re-enable this later.
|
||||
# def check_model(model):
|
||||
# layer = model.model.layers[0]
|
||||
|
||||
# qkv_proj = layer.self_attn.qkv_proj
|
||||
|
||||
# assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
|
||||
# assert isinstance(qkv_proj.scheme, QuarkW4A4MXFP4)
|
||||
|
||||
# assert isinstance(layer.mlp.experts.quant_method,
|
||||
# QuarkW4A4MXFp4MoEMethod)
|
||||
|
||||
# if model_case.model_id == "fxmarty/qwen_1.5-moe-a2.7b-mxfp4":
|
||||
# llm.apply_model(check_model)
|
||||
|
||||
output = llm.generate_greedy("Today I am in the French Alps and",
|
||||
max_tokens=20)
|
||||
assert output
|
||||
@ -32,6 +32,8 @@ from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import get_default_config
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import round_up
|
||||
|
||||
@ -371,6 +373,7 @@ def pplx_prepare_finalize(
|
||||
chunk_topk_weight,
|
||||
chunk_topk_ids,
|
||||
False,
|
||||
weight_and_reduce_impl=TopKWeightAndReduceDelegate(),
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user