Compare commits
99 Commits
v0.7.0
...
full_cudag
| Author | SHA1 | Date | |
|---|---|---|---|
| 1244c25908 | |||
| 34fb0cbbd0 | |||
| f33ec6d2d2 | |||
| c33aeecf24 | |||
| 230730c34d | |||
| d151b63b8b | |||
| a1a2aaadb9 | |||
| 1298a400e8 | |||
| ad4a9dc817 | |||
| b9986454fe | |||
| c5932e5dac | |||
| 20579c0fae | |||
| 95460fc513 | |||
| 326fcc8b9f | |||
| e64330910b | |||
| e489ad7a21 | |||
| f256ebe4df | |||
| f8ece6e17f | |||
| abfcdcdf27 | |||
| e497f33491 | |||
| baaa2b24da | |||
| b4e5c03306 | |||
| 3194039c0e | |||
| 4f4d427ac2 | |||
| 1e3698393f | |||
| baeded2569 | |||
| 3e1c76cf3a | |||
| cfa134d247 | |||
| 35b7a05507 | |||
| 1867c258bd | |||
| cb3e73e4c8 | |||
| b1340f9d55 | |||
| 44bbca78d7 | |||
| 60808bd4c7 | |||
| fc542144c4 | |||
| eb5741ad42 | |||
| 145c2ff648 | |||
| 415f19474d | |||
| 89003c4082 | |||
| 60bcef000e | |||
| 847f883232 | |||
| 325f679f32 | |||
| e3f7ff65e7 | |||
| 7a8987dac5 | |||
| cabaf4eff3 | |||
| a1fc18c030 | |||
| 9798b2fb00 | |||
| 4078052f09 | |||
| bd2107e30a | |||
| 9b0c4bab36 | |||
| 41bf5612f5 | |||
| a2769032ca | |||
| f17f1d4608 | |||
| 1c1bb0bbf2 | |||
| e0cc5f259a | |||
| 73aa6cfdf7 | |||
| 27b78c73ca | |||
| b02fd288b2 | |||
| ff7424f491 | |||
| d93bf4da85 | |||
| 036ca94c25 | |||
| ef001d98ef | |||
| 5f671cb4c3 | |||
| bd02164cf9 | |||
| 46fb056749 | |||
| dd6a3a02cb | |||
| a7e3eba66f | |||
| fbb5bd4cef | |||
| 80fcc3ed1c | |||
| c386c43ca3 | |||
| f26d790718 | |||
| 0f657bdc52 | |||
| 3fd1fb63ef | |||
| 925d2f1908 | |||
| 8f58a51358 | |||
| 2079e43bee | |||
| e29d4358ef | |||
| 8cbc424975 | |||
| dd66fd2b01 | |||
| 0f465ab533 | |||
| 23a7cbc88b | |||
| 426a5c3625 | |||
| ddee88d0ff | |||
| 823ab79633 | |||
| 6116ca8cd7 | |||
| 2bc3fbba0c | |||
| 3f1fc7425a | |||
| 01ba927040 | |||
| 103bd17ac5 | |||
| ce69f7f754 | |||
| 624a1e4711 | |||
| 372bf0890b | |||
| bed9efafeb | |||
| d4c9448b26 | |||
| 76732ff701 | |||
| 22bd7296e4 | |||
| 4024253797 | |||
| 7eba374599 | |||
| 0c7e6c1e36 |
@ -1,12 +1,14 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
import sys
|
||||
import zipfile
|
||||
|
||||
# Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 300 MiB
|
||||
# Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 400 MiB
|
||||
# Note that we have 400 MiB quota, please use it wisely.
|
||||
# See https://github.com/pypi/support/issues/3792 .
|
||||
# Please also sync the value with the one in Dockerfile.
|
||||
VLLM_MAX_SIZE_MB = int(os.environ.get('VLLM_MAX_SIZE_MB', 300))
|
||||
VLLM_MAX_SIZE_MB = int(os.environ.get('VLLM_MAX_SIZE_MB', 400))
|
||||
|
||||
|
||||
def print_top_10_largest_files(zip_file):
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
LM eval harness on model to compare vs HF baseline computed offline.
|
||||
Configs are found in configs/$MODEL.yaml
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import argparse
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from lmdeploy.serve.openai.api_client import APIClient
|
||||
|
||||
api_client = APIClient("http://localhost:8000")
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
|
||||
@ -56,6 +56,11 @@ steps:
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
- input: "Provide Release version here"
|
||||
fields:
|
||||
- text: "What is the release version?"
|
||||
key: "release-version"
|
||||
|
||||
- block: "Build CPU release image"
|
||||
key: block-cpu-release-image-build
|
||||
depends_on: ~
|
||||
@ -66,7 +71,7 @@ steps:
|
||||
queue: cpu_queue_postmerge
|
||||
commands:
|
||||
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$RELEASE_VERSION --progress plain -f Dockerfile.cpu ."
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$RELEASE_VERSION"
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --progress plain -f Dockerfile.cpu ."
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version)"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
@ -23,6 +23,6 @@ trap remove_docker_container EXIT
|
||||
remove_docker_container
|
||||
|
||||
# Run the image and test offline inference
|
||||
docker run --name gh200-test --gpus=all --entrypoint="" gh200-test bash -c '
|
||||
python3 examples/offline_inference/basic.py
|
||||
docker run -e HF_TOKEN -v /root/.cache/huggingface:/root/.cache/huggingface --name gh200-test --gpus=all --entrypoint="" gh200-test bash -c '
|
||||
python3 examples/offline_inference/cli.py --model meta-llama/Llama-3.2-1B
|
||||
'
|
||||
|
||||
@ -54,4 +54,4 @@ docker run --rm -it --device=/dev/neuron0 --device=/dev/neuron1 --network host \
|
||||
-e "NEURON_COMPILE_CACHE_URL=${NEURON_COMPILE_CACHE_MOUNT}" \
|
||||
--name "${container_name}" \
|
||||
${image_name} \
|
||||
/bin/bash -c "python3 /workspace/vllm/examples/offline_inference/neuron.py"
|
||||
/bin/bash -c "python3 /workspace/vllm/examples/offline_inference/neuron.py && python3 -m pytest /workspace/vllm/tests/neuron/ -v --capture=tee-sys"
|
||||
|
||||
0
.buildkite/run-tpu-test.sh
Normal file → Executable file
0
.buildkite/run-tpu-test.sh
Normal file → Executable file
@ -50,9 +50,9 @@ steps:
|
||||
- tests/multimodal
|
||||
- tests/test_utils
|
||||
- tests/worker
|
||||
- tests/standalone_tests/lazy_torch_compile.py
|
||||
- tests/standalone_tests/lazy_imports.py
|
||||
commands:
|
||||
- python3 standalone_tests/lazy_torch_compile.py
|
||||
- python3 standalone_tests/lazy_imports.py
|
||||
- pytest -v -s mq_llm_engine # MQLLMEngine
|
||||
- pytest -v -s async_engine # AsyncLLMEngine
|
||||
- NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py
|
||||
@ -183,7 +183,16 @@ steps:
|
||||
- vllm/
|
||||
- tests/v1
|
||||
commands:
|
||||
- VLLM_USE_V1=1 pytest -v -s v1
|
||||
# split the test to avoid interference
|
||||
- VLLM_USE_V1=1 pytest -v -s v1/core
|
||||
- VLLM_USE_V1=1 pytest -v -s v1/engine
|
||||
- VLLM_USE_V1=1 pytest -v -s v1/sample
|
||||
- VLLM_USE_V1=1 pytest -v -s v1/worker
|
||||
- VLLM_USE_V1=1 pytest -v -s v1/test_stats.py
|
||||
- VLLM_USE_V1=1 pytest -v -s v1/test_utils.py
|
||||
# TODO: accuracy does not match, whether setting
|
||||
# VLLM_USE_FLASHINFER_SAMPLER or not on H100.
|
||||
- VLLM_USE_V1=1 pytest -v -s v1/e2e
|
||||
|
||||
- label: Examples Test # 25min
|
||||
working_dir: "/vllm-workspace/examples"
|
||||
@ -340,6 +349,7 @@ steps:
|
||||
- vllm/
|
||||
- tests/models
|
||||
commands:
|
||||
- pytest -v -s models/test_transformers.py
|
||||
- pytest -v -s models/test_registry.py
|
||||
- pytest -v -s models/test_initialization.py
|
||||
|
||||
@ -476,6 +486,7 @@ steps:
|
||||
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
|
||||
- TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)'
|
||||
# Avoid importing model tests that cause CUDA reinitialization error
|
||||
- pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest models/encoder_decoder/language/test_bart.py -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest models/decoder_only/vision_language/test_models.py -v -s -m 'distributed(num_gpus=2)'
|
||||
|
||||
37
.github/mergify.yml
vendored
37
.github/mergify.yml
vendored
@ -35,6 +35,43 @@ pull_request_rules:
|
||||
add:
|
||||
- frontend
|
||||
|
||||
- name: label-structured-output
|
||||
description: Automatically apply structured-output label
|
||||
conditions:
|
||||
- or:
|
||||
- files~=^vllm/model_executor/guided_decoding/
|
||||
- files=tests/model_executor/test_guided_processors.py
|
||||
- files=tests/entrypoints/llm/test_guided_generate.py
|
||||
- files=benchmarks/benchmark_serving_guided.py
|
||||
- files=benchmarks/benchmark_guided.py
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
- structured-output
|
||||
|
||||
- name: label-speculative-decoding
|
||||
description: Automatically apply speculative-decoding label
|
||||
conditions:
|
||||
- or:
|
||||
- files~=^vllm/spec_decode/
|
||||
- files=vllm/model_executor/layers/spec_decode_base_sampler.py
|
||||
- files~=^tests/spec_decode/
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
- speculative-decoding
|
||||
|
||||
- name: label-v1
|
||||
description: Automatically apply v1 label
|
||||
conditions:
|
||||
- or:
|
||||
- files~=^vllm/v1/
|
||||
- files~=^tests/v1/
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
- v1
|
||||
|
||||
- name: ping author on conflicts and add 'needs-rebase' label
|
||||
conditions:
|
||||
- conflict
|
||||
|
||||
@ -3,18 +3,18 @@ default_stages:
|
||||
- manual # Run in CI
|
||||
repos:
|
||||
- repo: https://github.com/google/yapf
|
||||
rev: v0.32.0
|
||||
rev: v0.43.0
|
||||
hooks:
|
||||
- id: yapf
|
||||
args: [--in-place, --verbose]
|
||||
additional_dependencies: [toml] # TODO: Remove when yapf is upgraded
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.6.5
|
||||
rev: v0.9.3
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--output-format, github]
|
||||
- repo: https://github.com/codespell-project/codespell
|
||||
rev: v2.3.0
|
||||
rev: v2.4.0
|
||||
hooks:
|
||||
- id: codespell
|
||||
exclude: 'benchmarks/sonnet.txt|(build|tests/(lora/data|models/fixtures|prompts))/.*'
|
||||
@ -23,7 +23,7 @@ repos:
|
||||
hooks:
|
||||
- id: isort
|
||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||
rev: v18.1.5
|
||||
rev: v19.1.7
|
||||
hooks:
|
||||
- id: clang-format
|
||||
exclude: 'csrc/(moe/topk_softmax_kernels.cu|quantization/gguf/(ggml-common.h|dequantize.cuh|vecdotq.cuh|mmq.cuh|mmvq.cuh))'
|
||||
@ -35,7 +35,7 @@ repos:
|
||||
- id: pymarkdown
|
||||
files: docs/.*
|
||||
- repo: https://github.com/rhysd/actionlint
|
||||
rev: v1.7.6
|
||||
rev: v1.7.7
|
||||
hooks:
|
||||
- id: actionlint
|
||||
- repo: local
|
||||
@ -85,8 +85,26 @@ repos:
|
||||
entry: tools/png-lint.sh
|
||||
language: script
|
||||
types: [png]
|
||||
- id: signoff-commit
|
||||
name: Sign-off Commit
|
||||
entry: bash
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
if ! grep -q "^Signed-off-by: $(git config user.name) <$(git config user.email)>" .git/COMMIT_EDITMSG; then
|
||||
printf "\nSigned-off-by: $(git config user.name) <$(git config user.email)>\n" >> .git/COMMIT_EDITMSG
|
||||
fi
|
||||
language: system
|
||||
verbose: true
|
||||
stages: [commit-msg]
|
||||
- id: check-spdx-header
|
||||
name: Check SPDX headers
|
||||
entry: python tools/check_spdx_header.py
|
||||
language: python
|
||||
types: [python]
|
||||
- id: suggestion
|
||||
name: Suggestion
|
||||
entry: bash -c 'echo "To bypass pre-commit hooks, add --no-verify to git commit."'
|
||||
language: system
|
||||
verbose: true
|
||||
pass_filenames: false
|
||||
|
||||
@ -245,7 +245,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
FetchContent_Declare(
|
||||
cutlass
|
||||
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
|
||||
GIT_TAG v3.6.0
|
||||
GIT_TAG v3.7.0
|
||||
GIT_PROGRESS TRUE
|
||||
|
||||
# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
|
||||
@ -275,7 +275,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# Only build Marlin kernels if we are building for at least some compatible archs.
|
||||
# Keep building Marlin for 9.0 as there are some group sizes and shapes that
|
||||
# are not supported by Machete yet.
|
||||
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0" ${CUDA_ARCHS})
|
||||
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}")
|
||||
if (MARLIN_ARCHS)
|
||||
set(MARLIN_SRCS
|
||||
"csrc/quantization/fp8/fp8_marlin.cu"
|
||||
@ -296,10 +296,15 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
endif()
|
||||
|
||||
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
|
||||
# CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now).
|
||||
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}")
|
||||
# CUDA 12.0 or later (and only work on Hopper, 9.0a for now).
|
||||
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
|
||||
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
|
||||
set(SRCS
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
|
||||
@ -351,7 +356,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# 2:4 Sparse Kernels
|
||||
|
||||
# The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
|
||||
# require CUDA 12.2 or later (and only work on Hopper, 9.0/9.0a for now).
|
||||
# require CUDA 12.2 or later (and only work on Hopper, 9.0a for now).
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS)
|
||||
set(SRCS "csrc/sparse/cutlass/sparse_compressor_c3x.cu"
|
||||
"csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
|
||||
@ -446,9 +451,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
endif()
|
||||
|
||||
message(STATUS "Enabling C extension.")
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_C_LIBS cuda)
|
||||
endif()
|
||||
define_gpu_extension_target(
|
||||
_C
|
||||
DESTINATION vllm
|
||||
@ -457,7 +459,6 @@ define_gpu_extension_target(
|
||||
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
|
||||
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
||||
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
|
||||
LIBRARIES ${VLLM_C_LIBS}
|
||||
USE_SABI 3
|
||||
WITH_SOABI)
|
||||
|
||||
|
||||
@ -61,7 +61,7 @@ representative at an online or offline/IRL event.
|
||||
|
||||
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||
reported to the community leaders responsible for enforcement in the #code-of-conduct
|
||||
channel in the [vLLM Discord](https://discord.com/invite/jz7wjKhh6g).
|
||||
channel in the [vLLM Slack](https://slack.vllm.ai).
|
||||
All complaints will be reviewed and investigated promptly and fairly.
|
||||
|
||||
All community leaders are obligated to respect the privacy and security of the
|
||||
|
||||
25
Dockerfile
25
Dockerfile
@ -127,7 +127,7 @@ RUN --mount=type=cache,target=/root/.cache/ccache \
|
||||
# Check the size of the wheel if RUN_WHEEL_CHECK is true
|
||||
COPY .buildkite/check-wheel-size.py check-wheel-size.py
|
||||
# sync the default value with .buildkite/check-wheel-size.py
|
||||
ARG VLLM_MAX_SIZE_MB=300
|
||||
ARG VLLM_MAX_SIZE_MB=400
|
||||
ENV VLLM_MAX_SIZE_MB=$VLLM_MAX_SIZE_MB
|
||||
ARG RUN_WHEEL_CHECK=true
|
||||
RUN if [ "$RUN_WHEEL_CHECK" = "true" ]; then \
|
||||
@ -149,7 +149,8 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
|
||||
#################### vLLM installation IMAGE ####################
|
||||
# image with vLLM installed
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu22.04 AS vllm-base
|
||||
# TODO: Restore to base image after FlashInfer AOT wheel fixed
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 AS vllm-base
|
||||
ARG CUDA_VERSION=12.4.1
|
||||
ARG PYTHON_VERSION=3.12
|
||||
WORKDIR /vllm-workspace
|
||||
@ -194,12 +195,30 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
|
||||
--mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install dist/*.whl --verbose
|
||||
|
||||
# How to build this FlashInfer wheel:
|
||||
# $ export FLASHINFER_ENABLE_AOT=1
|
||||
# $ # Note we remove 7.0 from the arch list compared to the list below, since FlashInfer only supports sm75+
|
||||
# $ export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.6 8.9 9.0+PTX'
|
||||
# $ git clone https://github.com/flashinfer-ai/flashinfer.git --recursive
|
||||
# $ cd flashinfer
|
||||
# $ git checkout 524304395bd1d8cd7d07db083859523fcaa246a4
|
||||
# $ python3 setup.py bdist_wheel --dist-dir=dist --verbose
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
. /etc/environment && \
|
||||
if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
|
||||
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.6/flashinfer-0.1.6+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl; \
|
||||
python3 -m pip install https://wheels.vllm.ai/flashinfer/524304395bd1d8cd7d07db083859523fcaa246a4/flashinfer_python-0.2.0.post1-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl; \
|
||||
fi
|
||||
COPY examples examples
|
||||
|
||||
# Although we build Flashinfer with AOT mode, there's still
|
||||
# some issues w.r.t. JIT compilation. Therefore we need to
|
||||
# install build dependencies for JIT compilation.
|
||||
# TODO: Remove this once FlashInfer AOT wheel is fixed
|
||||
COPY requirements-build.txt requirements-build.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install -r requirements-build.txt
|
||||
|
||||
#################### vLLM installation IMAGE ####################
|
||||
|
||||
#################### TEST IMAGE ####################
|
||||
|
||||
@ -4,12 +4,12 @@ USER root
|
||||
|
||||
ENV PATH="/usr/local/cargo/bin:$PATH:/opt/conda/bin/"
|
||||
|
||||
RUN apt-get update -y && apt-get install -y git wget curl vim libnuma-dev libsndfile-dev libprotobuf-dev build-essential ffmpeg libsm6 libxext6 libgl1 libssl-dev
|
||||
RUN apt-get update -y && apt-get install -y git wget kmod curl vim libnuma-dev libsndfile-dev libprotobuf-dev build-essential ffmpeg libsm6 libxext6 libgl1 libssl-dev
|
||||
|
||||
# Some packages in requirements-cpu are installed here
|
||||
# IBM provides optimized packages for ppc64le processors in the open-ce project for mamba
|
||||
# Currently these may not be available for venv or pip directly
|
||||
RUN micromamba install -y -n base -c https://ftp.osuosl.org/pub/open-ce/1.11.0-p10/ -c defaults python=3.10 torchvision-cpu=0.16.2 rust && micromamba clean --all --yes
|
||||
RUN micromamba install -y -n base -c https://ftp.osuosl.org/pub/open-ce/1.11.0-p10/ -c defaults python=3.10 rust && micromamba clean --all --yes
|
||||
|
||||
COPY ./ /workspace/vllm
|
||||
|
||||
@ -21,7 +21,6 @@ RUN --mount=type=bind,source=.git,target=.git \
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
RUSTFLAGS='-L /opt/conda/lib' pip install -v --prefer-binary --extra-index-url https://repo.fury.io/mgiessing \
|
||||
'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \
|
||||
torch==2.3.1 \
|
||||
-r requirements-cpu.txt \
|
||||
xformers uvloop==0.20.0
|
||||
|
||||
|
||||
@ -10,12 +10,13 @@ Easy, fast, and cheap LLM serving for everyone
|
||||
</h3>
|
||||
|
||||
<p align="center">
|
||||
| <a href="https://docs.vllm.ai"><b>Documentation</b></a> | <a href="https://vllm.ai"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://discord.gg/jz7wjKhh6g"><b>Discord</b></a> | <a href="https://x.com/vllm_project"><b>Twitter/X</b></a> | <a href="https://slack.vllm.ai"><b>Developer Slack</b></a> |
|
||||
| <a href="https://docs.vllm.ai"><b>Documentation</b></a> | <a href="https://vllm.ai"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://x.com/vllm_project"><b>Twitter/X</b></a> | <a href="https://slack.vllm.ai"><b>Developer Slack</b></a> |
|
||||
</p>
|
||||
|
||||
---
|
||||
|
||||
*Latest News* 🔥
|
||||
- [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html).
|
||||
- [2025/01] We hosted [the eighth vLLM meetup](https://lu.ma/zep56hui) with Google Cloud! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1epVkt4Zu8Jz_S5OhEHPc798emsYh2BwYfRuDDVEF7u4/edit?usp=sharing).
|
||||
- [2024/12] vLLM joins [pytorch ecosystem](https://pytorch.org/blog/vllm-joins-pytorch)! Easy, Fast, and Cheap LLM Serving for Everyone!
|
||||
- [2024/11] We hosted [the seventh vLLM meetup](https://lu.ma/h0qvrajz) with Snowflake! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1e3CxQBV3JsfGp30SwyvS3eM_tW-ghOhJ9PAJGK6KR54/edit?usp=sharing), and Snowflake team [here](https://docs.google.com/presentation/d/1qF3RkDAbOULwz9WK5TOltt2fE9t6uIc_hVNLFAaQX6A/edit?usp=sharing).
|
||||
@ -35,7 +36,7 @@ Easy, fast, and cheap LLM serving for everyone
|
||||
## About
|
||||
vLLM is a fast and easy-to-use library for LLM inference and serving.
|
||||
|
||||
Originally developed in the [Sky Computing Lab](https://sky.cs.berkeley.edu) at UC Berkeley, vLLM has evloved into a community-driven project with contributions from both academia and industry.
|
||||
Originally developed in the [Sky Computing Lab](https://sky.cs.berkeley.edu) at UC Berkeley, vLLM has evolved into a community-driven project with contributions from both academia and industry.
|
||||
|
||||
vLLM is fast with:
|
||||
|
||||
@ -138,8 +139,7 @@ If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs
|
||||
## Contact Us
|
||||
|
||||
* For technical questions and feature requests, please use Github issues or discussions.
|
||||
* For discussing with fellow users, please use Discord.
|
||||
* For coordinating contributions and development, please use Slack.
|
||||
* For discussing with fellow users and coordinating contributions and development, please use Slack.
|
||||
* For security disclosures, please use Github's security advisory feature.
|
||||
* For collaborations and partnerships, please contact us at vllm-questions AT lists.berkeley.edu.
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Benchmark guided decoding throughput."""
|
||||
import argparse
|
||||
import dataclasses
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Benchmark the latency of processing a single batch of requests."""
|
||||
import argparse
|
||||
import dataclasses
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
Offline benchmark to test the long document QA throughput.
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
Benchmark the efficiency of prefix caching.
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Benchmark offline prioritization."""
|
||||
import argparse
|
||||
import dataclasses
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
r"""Benchmark online serving throughput.
|
||||
|
||||
On the server side, run one of the following commands:
|
||||
@ -926,8 +927,8 @@ def main(args: argparse.Namespace):
|
||||
)
|
||||
|
||||
# Traffic
|
||||
result_json["request_rate"] = (
|
||||
args.request_rate if args.request_rate < float("inf") else "inf")
|
||||
result_json["request_rate"] = (args.request_rate if args.request_rate
|
||||
< float("inf") else "inf")
|
||||
result_json["burstiness"] = args.burstiness
|
||||
result_json["max_concurrency"] = args.max_concurrency
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
r"""Benchmark online serving throughput with guided decoding.
|
||||
|
||||
On the server side, run one of the following commands:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Benchmark offline inference throughput."""
|
||||
import argparse
|
||||
import dataclasses
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Cutlass bench utils
|
||||
from typing import Iterable, Tuple
|
||||
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
import pickle as pkl
|
||||
import time
|
||||
from typing import Callable, Iterable, List, Tuple
|
||||
from typing import Callable, Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.utils.benchmark as TBenchmark
|
||||
@ -12,6 +14,8 @@ from utils import make_rand_tensors
|
||||
from weight_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
w8a8_block_fp8_matmul)
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
|
||||
@ -38,8 +42,15 @@ def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
|
||||
|
||||
def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
||||
sub_label: str) -> Iterable[TMeasurement]:
|
||||
def bench_int8(
|
||||
dtype: torch.dtype,
|
||||
m: int,
|
||||
k: int,
|
||||
n: int,
|
||||
label: str,
|
||||
sub_label: str,
|
||||
bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]:
|
||||
"""Benchmark INT8-based kernels."""
|
||||
assert dtype == torch.int8
|
||||
a, b = make_rand_tensors(torch.int8, m, n, k)
|
||||
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
@ -48,155 +59,132 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
||||
azp = torch.zeros((m, ), device="cuda", dtype=torch.int32)
|
||||
azp_adj = torch.zeros((n, ), device="cuda", dtype=torch.int32)
|
||||
|
||||
bench_fns = {
|
||||
"pytorch_bf16_bf16_bf16_matmul-no-scales":
|
||||
lambda: torch.mm(a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
|
||||
),
|
||||
"pytorch_fp16_fp16_fp16_matmul-no-scales":
|
||||
lambda: torch.mm(a.to(dtype=torch.float16), b.to(dtype=torch.float16)),
|
||||
"cutlass_i8_i8_bf16_scaled_mm":
|
||||
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16),
|
||||
"cutlass_i8_i8_bf16_scaled_mm_bias":
|
||||
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16,
|
||||
bias),
|
||||
"cutlass_i8_i8_bf16_scaled_mm_azp":
|
||||
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
|
||||
bfloat16, azp_adj),
|
||||
"cutlass_i8_i8_bf16_scaled_mm_azp_bias":
|
||||
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
|
||||
bfloat16, azp_adj, None, bias),
|
||||
"cutlass_i8_i8_bf16_scaled_mm_azp_pt":
|
||||
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
|
||||
bfloat16, azp_adj, azp),
|
||||
"cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias":
|
||||
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
|
||||
bfloat16, azp_adj, azp, bias),
|
||||
}
|
||||
|
||||
timers = []
|
||||
# pytorch impl - bfloat16
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
|
||||
torch.mm, a.to(dtype=torch.bfloat16),
|
||||
b.to(dtype=torch.bfloat16)))
|
||||
|
||||
# pytorch impl - float16
|
||||
timers.append(
|
||||
bench_fn(label, sub_label,
|
||||
"pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm,
|
||||
a.to(dtype=torch.float16), b.to(dtype=torch.float16)))
|
||||
|
||||
# cutlass impl
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm",
|
||||
ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
|
||||
torch.bfloat16))
|
||||
|
||||
# cutlass with bias
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias",
|
||||
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16,
|
||||
bias))
|
||||
|
||||
# cutlass with azp per-tensor
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp",
|
||||
ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
|
||||
torch.bfloat16, azp_adj))
|
||||
|
||||
# cutlass with azp per-tensor + bias
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_bias",
|
||||
ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
|
||||
torch.bfloat16, azp_adj, None, bias))
|
||||
|
||||
# cutlass with azp per-token
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_pt",
|
||||
ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
|
||||
torch.bfloat16, azp_adj, azp))
|
||||
|
||||
# cutlass with azp per-token + bias
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias",
|
||||
ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
|
||||
torch.bfloat16, azp_adj, azp, bias))
|
||||
for name, fn in bench_fns.items():
|
||||
# If bench_kernels is None, run all. Otherwise, run only exact matches.
|
||||
if bench_kernels is None or name in bench_kernels:
|
||||
print(f"Running {name}")
|
||||
timers.append(bench_fn(label, sub_label, name, fn))
|
||||
|
||||
return timers
|
||||
|
||||
|
||||
def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
||||
sub_label: str) -> Iterable[TMeasurement]:
|
||||
def bench_fp8(
|
||||
dtype: torch.dtype,
|
||||
m: int,
|
||||
k: int,
|
||||
n: int,
|
||||
label: str,
|
||||
sub_label: str,
|
||||
bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]:
|
||||
"""Benchmark FP8-based kernels."""
|
||||
assert dtype == torch.float8_e4m3fn
|
||||
a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k)
|
||||
a_cont = a.contiguous()
|
||||
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
block_scale_a = torch.rand((m, k // 128),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
block_scale_b = torch.rand((k // 128, n // 128),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
block_scale_a_M_major = block_scale_a.t().contiguous().t()
|
||||
block_scale_b_K_major = block_scale_b.t().contiguous().t()
|
||||
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
print(m, k, n)
|
||||
|
||||
bench_fns = {
|
||||
"pytorch_bf16_bf16_bf16_matmul-no-scales":
|
||||
lambda: torch.mm(a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
|
||||
),
|
||||
"pytorch_fp16_fp16_fp16_matmul-no-scales":
|
||||
lambda: torch.mm(a.to(dtype=torch.float16), b.to(dtype=torch.float16)),
|
||||
"pytorch_fp8_fp8_fp16_scaled_mm":
|
||||
lambda: torch._scaled_mm(
|
||||
a, b, scale_a, scale_b, out_dtype=torch.float16),
|
||||
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum":
|
||||
lambda: torch._scaled_mm(a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=torch.float16,
|
||||
use_fast_accum=True),
|
||||
"pytorch_fp8_fp8_bf16_scaled_mm":
|
||||
lambda: torch._scaled_mm(
|
||||
a, b, scale_a, scale_b, out_dtype=torch.bfloat16),
|
||||
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum":
|
||||
lambda: torch._scaled_mm(a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=torch.bfloat16,
|
||||
use_fast_accum=True),
|
||||
"cutlass_fp8_fp8_bf16_scaled_mm":
|
||||
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16),
|
||||
"cutlass_fp8_fp8_fp16_scaled_mm":
|
||||
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.float16),
|
||||
"cutlass_fp8_fp8_bf16_scaled_mm_bias":
|
||||
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16,
|
||||
bias),
|
||||
"cutlass_fp8_fp8_fp16_scaled_mm_bias":
|
||||
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.float16,
|
||||
bias.to(dtype=torch.float16)),
|
||||
"triton_fp8_fp8_fp16_scaled_mm_blockwise":
|
||||
lambda: w8a8_block_fp8_matmul(a_cont, b.t(), block_scale_a,
|
||||
block_scale_b.t(), (128, 128)),
|
||||
"cutlass_fp8_fp8_fp16_scaled_mm_blockwise":
|
||||
lambda: ops.cutlass_scaled_mm(a, b, block_scale_a_M_major,
|
||||
block_scale_b_K_major, torch.float16),
|
||||
}
|
||||
|
||||
timers = []
|
||||
|
||||
# pytorch impl w. bf16
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
|
||||
torch.mm, a.to(dtype=torch.bfloat16, device="cuda"),
|
||||
b.to(dtype=torch.bfloat16, device="cuda")))
|
||||
|
||||
# pytorch impl: bf16 output, without fp8 fast accum
|
||||
timers.append(
|
||||
bench_fn(label,
|
||||
sub_label,
|
||||
"pytorch_fp8_fp8_bf16_scaled_mm",
|
||||
torch._scaled_mm,
|
||||
a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=torch.bfloat16))
|
||||
|
||||
# pytorch impl: bf16 output, with fp8 fast accum
|
||||
timers.append(
|
||||
bench_fn(label,
|
||||
sub_label,
|
||||
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum",
|
||||
torch._scaled_mm,
|
||||
a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=torch.bfloat16,
|
||||
use_fast_accum=True))
|
||||
|
||||
# pytorch impl: fp16 output, without fp8 fast accum
|
||||
timers.append(
|
||||
bench_fn(label,
|
||||
sub_label,
|
||||
"pytorch_fp8_fp8_fp16_scaled_mm",
|
||||
torch._scaled_mm,
|
||||
a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=torch.float16))
|
||||
|
||||
# pytorch impl: fp16 output, with fp8 fast accum
|
||||
timers.append(
|
||||
bench_fn(label,
|
||||
sub_label,
|
||||
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum",
|
||||
torch._scaled_mm,
|
||||
a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=torch.float16,
|
||||
use_fast_accum=True))
|
||||
|
||||
# cutlass impl: bf16 output
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm",
|
||||
ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
|
||||
torch.bfloat16))
|
||||
# cutlass impl: fp16 output
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_mm",
|
||||
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.float16))
|
||||
|
||||
# cutlass impl: bf16 output, with bias
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm_bias",
|
||||
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16,
|
||||
bias))
|
||||
|
||||
# cutlass impl: fp16 output, with bias
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_mm_bias",
|
||||
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.float16,
|
||||
bias.to(dtype=torch.float16)))
|
||||
for name, fn in bench_fns.items():
|
||||
# If bench_kernels is None, run all. Otherwise, run only exact matches.
|
||||
if bench_kernels is None or name in bench_kernels:
|
||||
print(f"Running {name}")
|
||||
timers.append(bench_fn(label, sub_label, name, fn))
|
||||
|
||||
return timers
|
||||
|
||||
|
||||
def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
||||
sub_label: str) -> Iterable[TMeasurement]:
|
||||
def bench(dtype: torch.dtype,
|
||||
m: int,
|
||||
k: int,
|
||||
n: int,
|
||||
label: str,
|
||||
sub_label: str,
|
||||
bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]:
|
||||
if dtype == torch.int8:
|
||||
return bench_int8(dtype, m, k, n, label, sub_label)
|
||||
return bench_int8(dtype, m, k, n, label, sub_label, bench_kernels)
|
||||
if dtype == torch.float8_e4m3fn:
|
||||
return bench_fp8(dtype, m, k, n, label, sub_label)
|
||||
return bench_fp8(dtype, m, k, n, label, sub_label, bench_kernels)
|
||||
raise ValueError("unsupported type")
|
||||
|
||||
|
||||
@ -207,18 +195,22 @@ def print_timers(timers: Iterable[TMeasurement]):
|
||||
|
||||
|
||||
def run(dtype: torch.dtype,
|
||||
MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
|
||||
MKNs: Iterable[Tuple[int, int, int]],
|
||||
bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]:
|
||||
results = []
|
||||
for m, k, n in MKNs:
|
||||
timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm",
|
||||
f"MKN=({m}x{k}x{n})")
|
||||
timers = bench(dtype,
|
||||
m,
|
||||
k,
|
||||
n,
|
||||
f"scaled-{dtype}-gemm",
|
||||
f"MKN=({m}x{k}x{n})",
|
||||
bench_kernels=bench_kernels)
|
||||
print_timers(timers)
|
||||
results.extend(timers)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# output makers
|
||||
def make_output(data: Iterable[TMeasurement],
|
||||
MKNs: Iterable[Tuple[int, int, int]],
|
||||
base_description: str,
|
||||
@ -232,15 +224,11 @@ def make_output(data: Iterable[TMeasurement],
|
||||
pkl.dump(data, f)
|
||||
|
||||
|
||||
# argparse runners
|
||||
|
||||
|
||||
def run_square_bench(args):
|
||||
dim_sizes = list(
|
||||
range(args.dim_start, args.dim_end + 1, args.dim_increment))
|
||||
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
|
||||
data = run(args.dtype, MKNs)
|
||||
|
||||
data = run(args.dtype, MKNs, bench_kernels=args.kernels)
|
||||
make_output(data, MKNs, f"square_bench-{args.dtype}")
|
||||
|
||||
|
||||
@ -251,8 +239,7 @@ def run_range_bench(args):
|
||||
Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes
|
||||
Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes
|
||||
MKNs = list(zip(Ms, Ks, Ns))
|
||||
data = run(args.dtype, MKNs)
|
||||
|
||||
data = run(args.dtype, MKNs, bench_kernels=args.kernels)
|
||||
make_output(data, MKNs, f"range_bench-{args.dtype}")
|
||||
|
||||
|
||||
@ -278,7 +265,7 @@ def run_model_bench(args):
|
||||
for k, n in KNs:
|
||||
MKNs.append((m, k, n))
|
||||
|
||||
data = run(args.dtype, MKNs)
|
||||
data = run(args.dtype, MKNs, bench_kernels=args.kernels)
|
||||
model_bench_data.append(data)
|
||||
|
||||
# Print all results
|
||||
@ -328,6 +315,15 @@ Benchmark Cutlass GEMM.
|
||||
type=to_torch_dtype,
|
||||
required=True,
|
||||
help="Available options are ['int8', 'fp8']")
|
||||
parser.add_argument(
|
||||
"--kernels",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=None,
|
||||
help=
|
||||
"Exact names of the kernels to benchmark. If not set, runs all kernels."
|
||||
)
|
||||
|
||||
subparsers = parser.add_subparsers(dest="cmd")
|
||||
|
||||
square_parser = subparsers.add_parser("square_bench")
|
||||
@ -362,4 +358,4 @@ Benchmark Cutlass GEMM.
|
||||
model_parser.set_defaults(func=run_model_bench)
|
||||
|
||||
args = parser.parse_args()
|
||||
args.func(args)
|
||||
args.func(args)
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Weight Shapes are in the format
|
||||
# ([K, N], TP_SPLIT_DIM)
|
||||
# Example:
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
|
||||
import aiohttp
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import asyncio
|
||||
import itertools
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import json
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pickle as pkl
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import json
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import argparse
|
||||
import time
|
||||
from datetime import datetime
|
||||
@ -343,9 +345,13 @@ class BenchmarkWorker:
|
||||
op_config = get_moe_configs(num_experts, shard_intermediate_size // 2,
|
||||
dtype_str)
|
||||
if op_config is None:
|
||||
config = get_default_config(num_tokens, num_experts,
|
||||
shard_intermediate_size, hidden_size,
|
||||
topk, dtype_str)
|
||||
config = get_default_config(num_tokens,
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype_str,
|
||||
is_marlin=False)
|
||||
else:
|
||||
config = op_config[min(op_config.keys(),
|
||||
key=lambda x: abs(x - num_tokens))]
|
||||
@ -450,7 +456,8 @@ def save_configs(configs: Dict[int, BenchmarkConfig], num_experts: int,
|
||||
def main(args: argparse.Namespace):
|
||||
print(args)
|
||||
|
||||
config = AutoConfig.from_pretrained(args.model)
|
||||
config = AutoConfig.from_pretrained(
|
||||
args.model, trust_remote_code=args.trust_remote_code)
|
||||
if config.architectures[0] == "DbrxForCausalLM":
|
||||
E = config.ffn_config.moe_num_experts
|
||||
topk = config.ffn_config.moe_top_k
|
||||
@ -461,6 +468,11 @@ def main(args: argparse.Namespace):
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
elif config.architectures[0] == "DeepseekV3ForCausalLM":
|
||||
E = config.n_routed_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
else:
|
||||
# Default: Mixtral.
|
||||
E = config.num_local_experts
|
||||
@ -530,7 +542,11 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--model",
|
||||
type=str,
|
||||
default="mistralai/Mixtral-8x7B-Instruct-v0.1")
|
||||
parser.add_argument("--tp-size", "-tp", type=int, default=2)
|
||||
parser.add_argument("--tp-size",
|
||||
"-tp",
|
||||
"--tensor-parallel-size",
|
||||
type=int,
|
||||
default=2)
|
||||
parser.add_argument("--dtype",
|
||||
type=str,
|
||||
choices=["auto", "fp8_w8a8", "int8_w8a16"],
|
||||
@ -538,6 +554,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--batch-size", type=int, required=False)
|
||||
parser.add_argument("--tune", action="store_true")
|
||||
parser.add_argument("--trust-remote-code", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import random
|
||||
import time
|
||||
from typing import List, Optional
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import itertools
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from itertools import accumulate
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
WEIGHT_SHAPES = {
|
||||
"ideal": [[4 * 256 * 32, 256 * 32]],
|
||||
"mistralai/Mistral-7B-v0.1/TP1": [
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import math
|
||||
import pickle
|
||||
import re
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import dataclasses
|
||||
from typing import Any, Callable, Iterable, Optional
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Weight Shapes are in the format
|
||||
# ([K, N], TP_SPLIT_DIM)
|
||||
# Example:
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import cProfile
|
||||
import pstats
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
#!/usr/bin/env python3
|
||||
|
||||
#
|
||||
|
||||
@ -259,7 +259,7 @@ endmacro()
|
||||
# in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`.
|
||||
# We have special handling for 9.0a, if 9.0a is in `SRC_CUDA_ARCHS` and 9.0 is
|
||||
# in `TGT_CUDA_ARCHS` then we should remove 9.0a from `SRC_CUDA_ARCHS` and add
|
||||
# 9.0a to the result.
|
||||
# 9.0a to the result (and remove 9.0 from TGT_CUDA_ARCHS).
|
||||
# The result is stored in `OUT_CUDA_ARCHS`.
|
||||
#
|
||||
# Example:
|
||||
@ -270,34 +270,47 @@ endmacro()
|
||||
#
|
||||
function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
|
||||
list(REMOVE_DUPLICATES SRC_CUDA_ARCHS)
|
||||
set(TGT_CUDA_ARCHS_ ${TGT_CUDA_ARCHS})
|
||||
|
||||
# if 9.0a is in SRC_CUDA_ARCHS and 9.0 is in CUDA_ARCHS then we should
|
||||
# remove 9.0a from SRC_CUDA_ARCHS and add 9.0a to _CUDA_ARCHS
|
||||
set(_CUDA_ARCHS)
|
||||
if ("9.0a" IN_LIST SRC_CUDA_ARCHS)
|
||||
list(REMOVE_ITEM SRC_CUDA_ARCHS "9.0a")
|
||||
if ("9.0" IN_LIST TGT_CUDA_ARCHS)
|
||||
if ("9.0" IN_LIST TGT_CUDA_ARCHS_)
|
||||
list(REMOVE_ITEM TGT_CUDA_ARCHS_ "9.0")
|
||||
set(_CUDA_ARCHS "9.0a")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
list(SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
|
||||
|
||||
# for each ARCH in CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that is
|
||||
# less or eqault to ARCH
|
||||
foreach(_ARCH ${CUDA_ARCHS})
|
||||
set(_TMP_ARCH)
|
||||
foreach(_SRC_ARCH ${SRC_CUDA_ARCHS})
|
||||
if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
|
||||
set(_TMP_ARCH ${_SRC_ARCH})
|
||||
else()
|
||||
break()
|
||||
# for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that
|
||||
# is less or equal to ARCH (but has the same major version since SASS binary
|
||||
# compatibility is only forward compatible within the same major version).
|
||||
foreach(_ARCH ${TGT_CUDA_ARCHS_})
|
||||
set(_TMP_ARCH)
|
||||
# Extract the major version of the target arch
|
||||
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" TGT_ARCH_MAJOR "${_ARCH}")
|
||||
foreach(_SRC_ARCH ${SRC_CUDA_ARCHS})
|
||||
# Extract the major version of the source arch
|
||||
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" SRC_ARCH_MAJOR "${_SRC_ARCH}")
|
||||
# Check major-version match AND version-less-or-equal
|
||||
if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
|
||||
if (SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
|
||||
set(_TMP_ARCH "${_SRC_ARCH}")
|
||||
endif()
|
||||
else()
|
||||
# If we hit a version greater than the target, we can break
|
||||
break()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
# If we found a matching _TMP_ARCH, append it to _CUDA_ARCHS
|
||||
if (_TMP_ARCH)
|
||||
list(APPEND _CUDA_ARCHS "${_TMP_ARCH}")
|
||||
endif()
|
||||
endforeach()
|
||||
if (_TMP_ARCH)
|
||||
list(APPEND _CUDA_ARCHS ${_TMP_ARCH})
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
list(REMOVE_DUPLICATES _CUDA_ARCHS)
|
||||
set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# ruff: noqa
|
||||
# code borrowed from https://github.com/pytorch/pytorch/blob/main/torch/utils/collect_env.py
|
||||
|
||||
|
||||
@ -28,6 +28,11 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
|
||||
const std::string& kv_cache_dtype,
|
||||
torch::Tensor& k_scale, torch::Tensor& v_scale);
|
||||
|
||||
void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
|
||||
torch::Tensor& kv_cache, torch::Tensor& slot_mapping,
|
||||
const std::string& kv_cache_dtype,
|
||||
torch::Tensor& scale);
|
||||
|
||||
// Just for unittest
|
||||
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||
const double scale, const std::string& kv_cache_dtype);
|
||||
|
||||
@ -245,6 +245,51 @@ __global__ void reshape_and_cache_flash_kernel(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
||||
__global__ void concat_and_cache_mla_kernel(
|
||||
const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank]
|
||||
const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim]
|
||||
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank
|
||||
// + pe_dim)]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int block_stride, //
|
||||
const int kv_c_stride, //
|
||||
const int k_pe_stride, //
|
||||
const int kv_lora_rank, //
|
||||
const int pe_dim, //
|
||||
const int block_size, //
|
||||
const float* scale //
|
||||
) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
// NOTE: slot_idx can be -1 if the token is padded
|
||||
if (slot_idx < 0) {
|
||||
return;
|
||||
}
|
||||
const int64_t block_idx = slot_idx / block_size;
|
||||
const int64_t block_offset = slot_idx % block_size;
|
||||
|
||||
auto copy = [&](const scalar_t* __restrict__ src, cache_t* __restrict__ dst,
|
||||
int src_stride, int dst_stride, int size, int offset) {
|
||||
for (int i = threadIdx.x; i < size; i += blockDim.x) {
|
||||
const int64_t src_idx = token_idx * src_stride + i;
|
||||
const int64_t dst_idx = block_idx * block_stride +
|
||||
block_offset * (kv_lora_rank + pe_dim) + i +
|
||||
offset;
|
||||
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
|
||||
dst[dst_idx] = src[src_idx];
|
||||
} else {
|
||||
dst[dst_idx] =
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(src[src_idx], *scale);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0);
|
||||
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
// KV_T is the stored data type of kv-cache.
|
||||
@ -343,6 +388,56 @@ void reshape_and_cache_flash(
|
||||
CALL_RESHAPE_AND_CACHE_FLASH);
|
||||
}
|
||||
|
||||
// KV_T is the stored data type of kv-cache.
|
||||
// CACHE_T is the data type of key and value tensors.
|
||||
// KV_DTYPE is the real data type of kv-cache.
|
||||
#define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
|
||||
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
|
||||
slot_mapping.data_ptr<int64_t>(), block_stride, kv_c_stride, \
|
||||
k_pe_stride, kv_lora_rank, pe_dim, block_size, \
|
||||
reinterpret_cast<const float*>(scale.data_ptr()));
|
||||
|
||||
void concat_and_cache_mla(
|
||||
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
|
||||
torch::Tensor& k_pe, // [num_tokens, pe_dim]
|
||||
torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank +
|
||||
// pe_dim)]
|
||||
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
|
||||
const std::string& kv_cache_dtype, torch::Tensor& scale) {
|
||||
// NOTE(woosuk): In vLLM V1, key.size(0) can be different from
|
||||
// slot_mapping.size(0) because of padding for CUDA graphs.
|
||||
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
|
||||
// both include padding.
|
||||
// In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0)
|
||||
// since key includes padding for CUDA graphs, while slot_mapping does not.
|
||||
// In this case, slot_mapping.size(0) represents the actual number of tokens
|
||||
// before padding.
|
||||
// For compatibility with both cases, we use slot_mapping.size(0) as the
|
||||
// number of tokens.
|
||||
int num_tokens = slot_mapping.size(0);
|
||||
int kv_lora_rank = kv_c.size(1);
|
||||
int pe_dim = k_pe.size(1);
|
||||
int block_size = kv_cache.size(1);
|
||||
|
||||
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
|
||||
|
||||
int kv_c_stride = kv_c.stride(0);
|
||||
int k_pe_stride = k_pe.stride(0);
|
||||
int block_stride = kv_cache.stride(0);
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(kv_lora_rank, 512));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
|
||||
CALL_CONCAT_AND_CACHE_MLA);
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
||||
|
||||
@ -1,7 +1,14 @@
|
||||
#pragma once
|
||||
|
||||
#include <climits>
|
||||
#include <iostream>
|
||||
|
||||
inline uint32_t next_pow_2(uint32_t const num) {
|
||||
inline constexpr uint32_t next_pow_2(uint32_t const num) {
|
||||
if (num <= 1) return num;
|
||||
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline constexpr std::enable_if_t<std::is_integral_v<T>, T> ceil_div(T a, T b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
@ -38,9 +38,13 @@ struct Signal {
|
||||
alignas(128) FlagType peer_counter[2][kMaxBlocks][8];
|
||||
};
|
||||
|
||||
struct __align__(16) RankData { const void* __restrict__ ptrs[8]; };
|
||||
struct __align__(16) RankData {
|
||||
const void* __restrict__ ptrs[8];
|
||||
};
|
||||
|
||||
struct __align__(16) RankSignals { Signal* signals[8]; };
|
||||
struct __align__(16) RankSignals {
|
||||
Signal* signals[8];
|
||||
};
|
||||
|
||||
// like std::array, but aligned
|
||||
template <typename T, int sz>
|
||||
|
||||
@ -32,3 +32,20 @@ inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
|
||||
}
|
||||
|
||||
int32_t get_sm_version_num();
|
||||
|
||||
/**
|
||||
* A wrapper for a kernel that is used to guard against compilation on
|
||||
* architectures that will never use the kernel. The purpose of this is to
|
||||
* reduce the size of the compiled binary.
|
||||
* __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
|
||||
* into code that will be executed on the device where it is defined.
|
||||
*/
|
||||
template <typename Kernel>
|
||||
struct enable_sm90_or_later : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE void operator()(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
|
||||
Kernel::operator()(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
123
csrc/cutlass_extensions/gemm/collective/collective_builder.hpp
Normal file
123
csrc/cutlass_extensions/gemm/collective/collective_builder.hpp
Normal file
@ -0,0 +1,123 @@
|
||||
// Modified from: cutlass/gemm/collective/builders/sm90_gmma_builder.inl
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/collective/builders/sm90_gmma_builder.inl"
|
||||
|
||||
#include "cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp"
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// GMMA_TMA_WS_SS (BlockScaled Builders)
|
||||
template <
|
||||
class ElementA,
|
||||
class GmemLayoutATag,
|
||||
int AlignmentA,
|
||||
class ElementB,
|
||||
class GmemLayoutBTag,
|
||||
int AlignmentB,
|
||||
class ElementAccumulator,
|
||||
class TileShape_MNK,
|
||||
class ClusterShape_MNK,
|
||||
class StageCountType,
|
||||
int ScaleGranularityM
|
||||
>
|
||||
struct CollectiveBuilder<
|
||||
arch::Sm90,
|
||||
arch::OpClassTensorOp,
|
||||
ElementA,
|
||||
GmemLayoutATag,
|
||||
AlignmentA,
|
||||
ElementB,
|
||||
GmemLayoutBTag,
|
||||
AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape_MNK,
|
||||
ClusterShape_MNK,
|
||||
StageCountType,
|
||||
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>,
|
||||
cute::enable_if_t<
|
||||
not detail::is_use_rmem_A<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>()>
|
||||
> {
|
||||
using KernelScheduleType = KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>;
|
||||
|
||||
static_assert(is_static<TileShape_MNK>::value);
|
||||
static_assert(is_static<ClusterShape_MNK>::value);
|
||||
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
|
||||
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
|
||||
#endif
|
||||
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
|
||||
"Should meet TMA alignment requirement\n");
|
||||
|
||||
static constexpr bool IsArrayOfPointersGemm = (cute::is_any_of_v<KernelScheduleType,
|
||||
KernelPtrArrayTmaWarpSpecializedCooperative,
|
||||
KernelPtrArrayTmaWarpSpecializedPingpong>);
|
||||
static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
|
||||
static_assert((!IsFP8Input || !IsArrayOfPointersGemm),
|
||||
"KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum is only compatible with FP8 Blocked Scaled version right now.");
|
||||
|
||||
// For fp32 types, map to tf32 MMA value type
|
||||
using ElementAMma = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
|
||||
using ElementBMma = cute::conditional_t<cute::is_same_v<ElementB, float>, tfloat32_t, ElementB>;
|
||||
|
||||
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementAMma, GmemLayoutATag>();
|
||||
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementBMma, GmemLayoutBTag>();
|
||||
|
||||
static constexpr bool IsCooperative = cute::is_any_of_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecializedCooperative,
|
||||
KernelPtrArrayTmaWarpSpecializedCooperative,
|
||||
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>>;
|
||||
using AtomLayoutMNK = cute::conditional_t<IsCooperative,
|
||||
Layout<Shape<_2,_1,_1>>, Layout<Shape<_1,_1,_1>>>;
|
||||
|
||||
using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector<
|
||||
ElementAMma, ElementBMma, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{}));
|
||||
|
||||
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
|
||||
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
|
||||
|
||||
using SmemLayoutAtomA = decltype(detail::ss_smem_selector<
|
||||
GmmaMajorA, ElementAMma, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutAtomB = decltype(detail::ss_smem_selector<
|
||||
GmmaMajorB, ElementBMma, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
|
||||
static constexpr size_t TensorMapStorage = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0;
|
||||
static constexpr int KernelSmemCarveout = static_cast<int>(TensorMapStorage);
|
||||
|
||||
static constexpr int PipelineStages = detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes - KernelSmemCarveout,
|
||||
ElementAMma, ElementBMma, TileShape_MNK>(StageCountType{});
|
||||
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType, ScaleGranularityM>;
|
||||
|
||||
using SmemCopyAtomA = void;
|
||||
using SmemCopyAtomB = void;
|
||||
|
||||
using CollectiveOp = CollectiveMma<
|
||||
DispatchPolicy,
|
||||
TileShape_MNK,
|
||||
ElementA,
|
||||
TagToStrideA_t<GmemLayoutATag>,
|
||||
ElementB,
|
||||
TagToStrideB_t<GmemLayoutBTag>,
|
||||
TiledMma,
|
||||
GmemTiledCopyA,
|
||||
SmemLayoutAtomA,
|
||||
SmemCopyAtomA,
|
||||
cute::identity,
|
||||
GmemTiledCopyB,
|
||||
SmemLayoutAtomB,
|
||||
SmemCopyAtomB,
|
||||
cute::identity
|
||||
>;
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
183
csrc/cutlass_extensions/gemm/collective/fp8_accumulation.hpp
Normal file
183
csrc/cutlass_extensions/gemm/collective/fp8_accumulation.hpp
Normal file
@ -0,0 +1,183 @@
|
||||
// clang-format off
|
||||
// adapted from: https://github.com/soundOfDestiny/cutlass/blob/a4208aa6958864923505cade9c63eb2a6daf16e5/include/cutlass/gemm/collective/fp8_accumulation.hpp
|
||||
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cute/algorithm/clear.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
///////////////////////////////////FP8 Accumulation///////////////////////////
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
/// This class provides API to promote (add) or scale (multiply_add) the results
|
||||
/// from the tensor core accumulators to the main accumulators when the number
|
||||
/// of MMAs reaches the max number of MMA interval specified by user, after that
|
||||
/// the tensor core accumulators are zeroed.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
|
||||
template <
|
||||
class EngineAccum,
|
||||
class LayoutAccum>
|
||||
struct GmmaFP8AccumulationWithScale {
|
||||
using TensorAccum = cute::Tensor<EngineAccum, LayoutAccum>;
|
||||
using ElementAccumulator = typename EngineAccum::value_type;
|
||||
|
||||
static_assert(is_static<LayoutAccum>::value, "Accumulator Layout should be static");
|
||||
static_assert(is_rmem<TensorAccum>::value , "Accumulator tensor must be rmem resident.");
|
||||
|
||||
private:
|
||||
TensorAccum& accum_;
|
||||
TensorAccum accum_temp_;
|
||||
|
||||
uint32_t accum_promotion_interval_; // defines the max num of executed MMAs after which accum should be promoted.
|
||||
uint32_t mma_count_per_mainloop_iteration_; // num of MMAs per k_tile of mainloop
|
||||
uint32_t mma_count_; // current executed MMAs
|
||||
uint32_t reset_accum_flag_; // accum needs to be zeroed or not.
|
||||
|
||||
// promote or `add` the partial accumulators to main accumulator (FADD).
|
||||
CUTLASS_DEVICE
|
||||
void promote_core() {
|
||||
warpgroup_wait<0>();
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(accum_); ++i) {
|
||||
accum_(i) += accum_temp_(i);
|
||||
}
|
||||
}
|
||||
|
||||
// `multiply` scale the partial accumulators and `add` to main accumulator (FFMA).
|
||||
template <
|
||||
class EngineScale,
|
||||
class LayoutScale>
|
||||
CUTLASS_DEVICE
|
||||
void scale_core(const cute::Tensor<EngineScale, LayoutScale> &scale) {
|
||||
using TensorScale = cute::Tensor<EngineScale, LayoutScale>;
|
||||
|
||||
static_assert(is_static<LayoutScale>::value, "Scale Layout should be static");
|
||||
static_assert(is_rmem<TensorScale>::value , "Scale tensor must be rmem resident.");
|
||||
|
||||
static_assert(LayoutAccum{}.shape() == LayoutScale{}.shape(), "Accumulator and scale must have same shape.");
|
||||
|
||||
warpgroup_wait<0>();
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(accum_); ++i) {
|
||||
accum_(i) += accum_temp_(i) * scale(i);
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
GmmaFP8AccumulationWithScale(
|
||||
TensorAccum &accum,
|
||||
uint32_t accum_promotion_interval,
|
||||
uint32_t mma_count_per_mainloop_iteration)
|
||||
: accum_(accum),
|
||||
accum_promotion_interval_(accum_promotion_interval),
|
||||
mma_count_per_mainloop_iteration_(mma_count_per_mainloop_iteration),
|
||||
mma_count_(0),
|
||||
reset_accum_flag_(0)
|
||||
{
|
||||
accum_temp_ = cute::make_fragment_like(accum);
|
||||
}
|
||||
|
||||
//
|
||||
// Methods (Common)
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
TensorAccum& operator()() {
|
||||
return accum_temp_;
|
||||
}
|
||||
|
||||
/// prepare the MMA accumulators when initialization or zeroing is required.
|
||||
CUTLASS_DEVICE
|
||||
bool prepare_if_needed() {
|
||||
return reset_accum_flag_;
|
||||
}
|
||||
|
||||
//
|
||||
// Methods (for FADD version)
|
||||
//
|
||||
|
||||
/// promote (add) the results from the MMA accumulators to main accumulator if needed.
|
||||
CUTLASS_DEVICE
|
||||
void promote_if_needed() {
|
||||
mma_count_ += mma_count_per_mainloop_iteration_;
|
||||
reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0);
|
||||
if (reset_accum_flag_) {
|
||||
promote_core();
|
||||
mma_count_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/// promote (add) the residue results from the MMA accumulators to main accumulator if needed.
|
||||
CUTLASS_DEVICE
|
||||
void promote_residue_if_needed() {
|
||||
if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) {
|
||||
promote_core();
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Methods (for FFMA version)
|
||||
//
|
||||
|
||||
/// scale (multiply_add) the results from the MMA accumulators to main accumulator if needed.
|
||||
template <
|
||||
class EngineScale,
|
||||
class LayoutScale>
|
||||
CUTLASS_DEVICE
|
||||
void scale_if_needed(const cute::Tensor<EngineScale, LayoutScale> &scale) {
|
||||
mma_count_ += mma_count_per_mainloop_iteration_;
|
||||
reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0);
|
||||
if (reset_accum_flag_) {
|
||||
scale_core(scale);
|
||||
mma_count_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/// scale (multiply_add) the residue results from the MMA accumulators to main accumulator if needed.
|
||||
template <
|
||||
class EngineScale,
|
||||
class LayoutScale>
|
||||
CUTLASS_DEVICE
|
||||
void scale_residue_if_needed(const cute::Tensor<EngineScale, LayoutScale> &scale) {
|
||||
if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) {
|
||||
scale_core(scale);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
@ -0,0 +1,730 @@
|
||||
// clang-format off
|
||||
// Adapted (Heavily) from: https://github.com/soundOfDestiny/cutlass/blob/9d997ce0dea4c5fa1a617db6b7ff29aa9235822c/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp
|
||||
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cute/arch/cluster_sm90.hpp"
|
||||
#include "cute/arch/copy_sm80.hpp"
|
||||
#include "cute/arch/copy_sm90.hpp"
|
||||
#include "cute/algorithm/functional.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cute/algorithm/gemm.hpp"
|
||||
#include "cute/tensor_predicate.hpp"
|
||||
#include "cute/numeric/arithmetic_tuple.hpp"
|
||||
|
||||
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass_extensions/gemm/collective/fp8_accumulation.hpp"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
using namespace cute;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// WarpSpecialized Mainloop
|
||||
template <
|
||||
int Stages,
|
||||
class ClusterShape,
|
||||
class KernelSchedule,
|
||||
int ScaleGranularityM_,
|
||||
class TileShape_,
|
||||
class ElementA_,
|
||||
class StrideA_,
|
||||
class ElementB_,
|
||||
class StrideB_,
|
||||
class TiledMma_,
|
||||
class GmemTiledCopyA_,
|
||||
class SmemLayoutAtomA_,
|
||||
class SmemCopyAtomA_,
|
||||
class TransformA_,
|
||||
class GmemTiledCopyB_,
|
||||
class SmemLayoutAtomB_,
|
||||
class SmemCopyAtomB_,
|
||||
class TransformB_>
|
||||
struct CollectiveMma<
|
||||
MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8<Stages, ClusterShape, KernelSchedule, ScaleGranularityM_>,
|
||||
TileShape_,
|
||||
ElementA_,
|
||||
StrideA_,
|
||||
ElementB_,
|
||||
StrideB_,
|
||||
TiledMma_,
|
||||
GmemTiledCopyA_,
|
||||
SmemLayoutAtomA_,
|
||||
SmemCopyAtomA_,
|
||||
TransformA_,
|
||||
GmemTiledCopyB_,
|
||||
SmemLayoutAtomB_,
|
||||
SmemCopyAtomB_,
|
||||
TransformB_>
|
||||
{
|
||||
//
|
||||
// Type Aliases
|
||||
//
|
||||
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8<Stages, ClusterShape, KernelSchedule, ScaleGranularityM_>;
|
||||
using TileShape = TileShape_;
|
||||
using ElementA = ElementA_;
|
||||
using StrideA = StrideA_;
|
||||
using ElementB = ElementB_;
|
||||
using StrideB = StrideB_;
|
||||
using TiledMma = TiledMma_;
|
||||
using ElementAccumulator = typename TiledMma::ValTypeC;
|
||||
using ElementBlockScale = ElementAccumulator;
|
||||
using GmemTiledCopyA = GmemTiledCopyA_;
|
||||
using GmemTiledCopyB = GmemTiledCopyB_;
|
||||
using SmemLayoutAtomA = SmemLayoutAtomA_;
|
||||
using SmemLayoutAtomB = SmemLayoutAtomB_;
|
||||
using SmemCopyAtomA = SmemCopyAtomA_;
|
||||
using SmemCopyAtomB = SmemCopyAtomB_;
|
||||
using TransformA = TransformA_;
|
||||
using TransformB = TransformB_;
|
||||
using ArchTag = typename DispatchPolicy::ArchTag;
|
||||
|
||||
using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{}));
|
||||
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
|
||||
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
|
||||
// Two threads per CTA are producers (1 for operand tile and 32 for scales)
|
||||
static constexpr int NumProducerThreadEvents = 33;
|
||||
|
||||
static constexpr int ScaleGranularityM = ScaleGranularityM_ == 0 ? size<0>(TileShape{}) : ScaleGranularityM_;
|
||||
static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM;
|
||||
|
||||
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
static_assert((size<0>(TileShape{}) % ScaleGranularityM) == 0, "FP8 scaling granularity must evenly divide tile shape along M.");
|
||||
|
||||
// Tile along modes in a way that maximizes the TMA box size.
|
||||
using SmemLayoutA = decltype(tile_to_shape(
|
||||
SmemLayoutAtomA{},
|
||||
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
|
||||
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
|
||||
using SmemLayoutB = decltype(tile_to_shape(
|
||||
SmemLayoutAtomB{},
|
||||
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
|
||||
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
|
||||
|
||||
// Block scaling gmem-to-smem copy atom
|
||||
using SmemBlockScalingCopyAtomA = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>;
|
||||
using SmemBlockScalingCopyAtomB = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>;
|
||||
|
||||
// Block scaling smem layout
|
||||
using SmemLayoutScaleA = Layout<Shape<Int<ScaleMsPerTile>, Int<DispatchPolicy::Stages>>>;
|
||||
using SmemLayoutScaleB = Layout<Shape<Int<DispatchPolicy::Stages>>, Stride<_1>>; // `ScaleNsPerTile` is always 1.
|
||||
|
||||
static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more.");
|
||||
static_assert(cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value &&
|
||||
cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeB>::value,
|
||||
"MMA atom must source both A and B operand from smem_desc for this mainloop.");
|
||||
static_assert(cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>,
|
||||
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
|
||||
static_assert(cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>,
|
||||
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
|
||||
static_assert(cute::is_same_v<ElementAccumulator, ElementBlockScale>,
|
||||
"ElementAccumulator and ElementBlockScale should be same datatype");
|
||||
|
||||
struct SharedStorage
|
||||
{
|
||||
struct TensorStorage : cute::aligned_struct<128> {
|
||||
cute::array_aligned<typename TiledMma::ValTypeA, cute::cosize_v<SmemLayoutA>> smem_A; // mxk
|
||||
cute::array_aligned<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B; // nxk
|
||||
cute::array_aligned<ElementBlockScale, cute::cosize_v<SmemLayoutScaleA>> smem_scale_A; // ScaleMsPerTile x k
|
||||
cute::array_aligned<ElementBlockScale, cute::cosize_v<SmemLayoutScaleB>> smem_scale_B; // 1xk
|
||||
} tensors;
|
||||
|
||||
using PipelineStorage = typename MainloopPipeline::SharedStorage;
|
||||
PipelineStorage pipeline;
|
||||
};
|
||||
using TensorStorage = typename SharedStorage::TensorStorage;
|
||||
using PipelineStorage = typename SharedStorage::PipelineStorage;
|
||||
|
||||
// Host side kernel arguments
|
||||
struct Arguments {
|
||||
ElementA const* ptr_A;
|
||||
StrideA dA;
|
||||
ElementB const* ptr_B;
|
||||
StrideB dB;
|
||||
ElementBlockScale const* ptr_scale_A;
|
||||
ElementBlockScale const* ptr_scale_B;
|
||||
};
|
||||
|
||||
// Device side kernel params
|
||||
struct Params {
|
||||
// Assumption: StrideA is congruent with Problem_MK
|
||||
using TMA_A = decltype(make_tma_copy_A_sm90(
|
||||
GmemTiledCopyA{},
|
||||
make_tensor(static_cast<ElementA const*>(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}),
|
||||
SmemLayoutA{}(_,_,0),
|
||||
TileShape{},
|
||||
ClusterShape{}));
|
||||
// Assumption: StrideB is congruent with Problem_NK
|
||||
using TMA_B = decltype(make_tma_copy_B_sm90(
|
||||
GmemTiledCopyB{},
|
||||
make_tensor(static_cast<ElementB const*>(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}),
|
||||
SmemLayoutB{}(_,_,0),
|
||||
TileShape{},
|
||||
ClusterShape{}));
|
||||
TMA_A tma_load_a;
|
||||
TMA_B tma_load_b;
|
||||
uint32_t tma_transaction_bytes = TmaTransactionBytes;
|
||||
uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK;
|
||||
uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK;
|
||||
// Block scaling factors for A and B
|
||||
ElementBlockScale const* ptr_scale_A;
|
||||
ElementBlockScale const* ptr_scale_B;
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
(void) workspace;
|
||||
|
||||
// Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK)
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
auto [M,N,K,L] = problem_shape_MNKL;
|
||||
|
||||
auto ptr_A = reinterpret_cast<ElementA const*>(args.ptr_A);
|
||||
auto ptr_B = reinterpret_cast<ElementB const*>(args.ptr_B);
|
||||
|
||||
Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA));
|
||||
Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB));
|
||||
typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90(
|
||||
GmemTiledCopyA{},
|
||||
tensor_a,
|
||||
SmemLayoutA{}(_,_,cute::Int<0>{}),
|
||||
TileShape{},
|
||||
ClusterShape{});
|
||||
typename Params::TMA_B tma_load_b = make_tma_copy_B_sm90(
|
||||
GmemTiledCopyB{},
|
||||
tensor_b,
|
||||
SmemLayoutB{}(_,_,cute::Int<0>{}),
|
||||
TileShape{},
|
||||
ClusterShape{});
|
||||
uint32_t transaction_bytes_mk = TmaTransactionBytesMK;
|
||||
uint32_t transaction_bytes_nk = TmaTransactionBytesNK;
|
||||
uint32_t transaction_bytes = transaction_bytes_mk + transaction_bytes_nk;
|
||||
|
||||
return {
|
||||
tma_load_a,
|
||||
tma_load_b,
|
||||
transaction_bytes,
|
||||
transaction_bytes_mk,
|
||||
transaction_bytes_nk,
|
||||
args.ptr_scale_A,
|
||||
args.ptr_scale_B
|
||||
};
|
||||
}
|
||||
|
||||
template<class ProblemShape>
|
||||
static bool
|
||||
can_implement(
|
||||
ProblemShape const& problem_shape,
|
||||
[[maybe_unused]] Arguments const& args) {
|
||||
constexpr int tma_alignment_bits = 128;
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
auto [M,N,K,L] = problem_shape_MNKL;
|
||||
|
||||
bool implementable = true;
|
||||
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
|
||||
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
|
||||
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
|
||||
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), StrideB{});
|
||||
|
||||
if (!implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
|
||||
}
|
||||
return implementable;
|
||||
}
|
||||
|
||||
static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
|
||||
static constexpr int K_PIPE_MMAS = 1;
|
||||
static constexpr uint32_t TmaTransactionBytesMK =
|
||||
cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast<uint32_t>(sizeof_bits<ElementA>::value));
|
||||
static constexpr uint32_t TmaTransactionBytesNK =
|
||||
cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast<uint32_t>(sizeof_bits<ElementB>::value));
|
||||
static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK;
|
||||
|
||||
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& mainloop_params)
|
||||
{
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor());
|
||||
}
|
||||
|
||||
/// Set up the data needed by this collective for load and mma.
|
||||
/// Returns a tuple of tensors. The collective and the kernel layer have the contract
|
||||
/// Returned tuple must contain at least two elements, with the first two elements being:
|
||||
/// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l)
|
||||
/// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l)
|
||||
template <class ProblemShape_MNKL>
|
||||
CUTLASS_DEVICE auto
|
||||
load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const {
|
||||
using X = Underscore;
|
||||
// Separate out problem shape for convenience
|
||||
auto [M,N,K,L] = problem_shape_MNKL;
|
||||
|
||||
// TMA requires special handling of strides to deal with coord codomain mapping
|
||||
// Represent the full tensors -- get these from TMA
|
||||
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l)
|
||||
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l)
|
||||
|
||||
// Make tiled views, defer the slice
|
||||
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
|
||||
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
|
||||
|
||||
constexpr auto scales_m = Int<ScaleMsPerTile>{};
|
||||
auto tM = get<2>(gA_mkl.shape());
|
||||
auto tN = get<2>(gB_nkl.shape());
|
||||
auto tK = get<3>(gA_mkl.shape());
|
||||
|
||||
// Make the tiled views of scale tensors
|
||||
auto scaleA_shape = make_shape(M / ScaleGranularityM, tK, L); // (scale_m,k,l)
|
||||
auto scaleA_layout = make_ordered_layout(scaleA_shape, Step<_0, _1, _2>{});
|
||||
auto scaleB_shape = make_shape(tN, tK, L); // (n,k,l)
|
||||
auto scaleB_layout = make_ordered_layout(scaleB_shape, Step<_1, _0, _2>{});
|
||||
|
||||
// Note that mScaleA_mkl and mScaleB_nkl are already blocked tiled in the `m` host and
|
||||
// gScaleA_mkl and gScaleB_nkl in `g` global memory are same as mScaleA_mkl and mScaleB_nkl.
|
||||
Tensor mScaleA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_A), scaleA_layout); // (scale_m,k,l)
|
||||
Tensor mScaleB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_B), scaleB_layout); // (n,k,l)
|
||||
|
||||
return cute::make_tuple(gA_mkl, gB_nkl, mScaleA_mkl, mScaleB_nkl);
|
||||
}
|
||||
|
||||
/// Perform a collective-scoped matrix multiply-accumulate
|
||||
/// Producer Perspective
|
||||
template <
|
||||
class TensorA, class TensorB,
|
||||
class TensorScaleA, class TensorScaleB,
|
||||
class KTileIterator, class BlockCoord
|
||||
>
|
||||
CUTLASS_DEVICE void
|
||||
load(
|
||||
Params const& mainloop_params,
|
||||
MainloopPipeline pipeline,
|
||||
PipelineState smem_pipe_write,
|
||||
cute::tuple<TensorA, TensorB, TensorScaleA, TensorScaleB> const& load_inputs,
|
||||
BlockCoord const& blk_coord,
|
||||
KTileIterator k_tile_iter, int k_tile_count,
|
||||
int thread_idx,
|
||||
uint32_t block_rank_in_cluster,
|
||||
TensorStorage& shared_tensors) {
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
// Blockscaling: Tma loads for load_input and CpAsync for load_scale
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
Tensor sScaleA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), SmemLayoutScaleA{}); // (ScaleMsPerTile,k)
|
||||
Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k)
|
||||
|
||||
//
|
||||
// Prepare the TMA loads for A and B
|
||||
//
|
||||
|
||||
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
|
||||
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
|
||||
|
||||
Tensor gA_mkl = get<0>(load_inputs);
|
||||
Tensor gB_nkl = get<1>(load_inputs);
|
||||
|
||||
auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
|
||||
auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
|
||||
|
||||
// Partition the inputs based on the current block coordinates.
|
||||
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
|
||||
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
|
||||
Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
|
||||
|
||||
|
||||
// Block scaling: load_scale has scaling tensors in global memory which are not tiled
|
||||
Tensor mScaleA_mkl = get<2>(load_inputs);
|
||||
Tensor mScaleB_nkl = get<3>(load_inputs);
|
||||
auto scales_m = get<0>(mScaleA_mkl.shape());
|
||||
|
||||
Tensor cScaleA_mkl = make_identity_tensor(mScaleA_mkl.shape());
|
||||
|
||||
Tensor gScaleA = local_tile(
|
||||
mScaleA_mkl, make_tile(Int<ScaleMsPerTile>{}),
|
||||
make_coord(m_coord,_,l_coord)); // (ScaleMsPerTile,k,1)
|
||||
Tensor cScaleA = local_tile(
|
||||
cScaleA_mkl, make_tile(Int<ScaleMsPerTile>{}),
|
||||
make_coord(m_coord,_,l_coord));
|
||||
Tensor gScaleB = mScaleB_nkl(n_coord,_,l_coord); // (1,k,1)
|
||||
|
||||
// TODO: test `scale_copy_a` with `ScaleMsPerTile` < 128
|
||||
TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{},
|
||||
Layout<Shape<_32, _1>>{}, Layout<Shape<_4, _1>>{}); // (1,1,1)
|
||||
TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{},
|
||||
Layout<Shape<_1>>{}, Layout<Shape<_1>>{}); // (1,1,1)
|
||||
ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x);
|
||||
ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x);
|
||||
|
||||
Tensor tAgA_ScaleA = thr_scale_copy_a.partition_S(gScaleA);
|
||||
Tensor tAcA_ScaleA = thr_scale_copy_a.partition_S(cScaleA);
|
||||
Tensor tAsA_ScaleA = thr_scale_copy_a.partition_D(sScaleA);
|
||||
|
||||
Tensor tBgB_ScaleB = thr_scale_copy_b.partition_S(gScaleB);
|
||||
Tensor tBsB_ScaleB = thr_scale_copy_b.partition_D(sScaleB);
|
||||
|
||||
// Applies the mapping from block_tma_a
|
||||
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
|
||||
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
|
||||
|
||||
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
|
||||
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
|
||||
|
||||
uint16_t mcast_mask_a = 0;
|
||||
uint16_t mcast_mask_b = 0;
|
||||
|
||||
// Issue TmaLoads for GEMM operands A/B and CpAsync for scale tensors
|
||||
// Maps the tile -> block, value
|
||||
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
|
||||
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
||||
for (int n = 0; n < size<1>(block_layout); ++n) {
|
||||
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
|
||||
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
||||
for (int m = 0; m < size<0>(block_layout); ++m) {
|
||||
mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
// Allocate predicate tensors for a_scales (since we can't guarantee that
|
||||
// all scales are valid, since we could have a partial tiles along M)
|
||||
Tensor tApA_ScaleA = make_tensor<bool>(shape(tAsA_ScaleA(_,_,0)));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(tApA_ScaleA); ++i) {
|
||||
tApA_ScaleA(i) = get<0>(tAcA_ScaleA(i)) < scales_m;
|
||||
}
|
||||
|
||||
// Mainloop
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for ( ; k_tile_count > 0; --k_tile_count) {
|
||||
// LOCK smem_pipe_write for _writing_
|
||||
pipeline.producer_acquire(smem_pipe_write);
|
||||
|
||||
//
|
||||
// Copy gmem to smem for *k_tile_iter
|
||||
//
|
||||
int write_stage = smem_pipe_write.index();
|
||||
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
|
||||
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
|
||||
|
||||
// Copy operands A and B from global memory to shared memory
|
||||
if (lane_predicate) copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
|
||||
if (lane_predicate) copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
|
||||
|
||||
// Copy scale tensors from global memory to shared memory
|
||||
copy_if(scale_copy_a, tApA_ScaleA, tAgA_ScaleA(_,_,*k_tile_iter), tAsA_ScaleA(_,_,write_stage));
|
||||
copy(scale_copy_b, tBgB_ScaleB(_,*k_tile_iter), tBsB_ScaleB(_,write_stage));
|
||||
pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc);
|
||||
|
||||
++k_tile_iter;
|
||||
|
||||
// Advance smem_pipe_write
|
||||
++smem_pipe_write;
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
|
||||
CUTLASS_DEVICE void
|
||||
load_tail(
|
||||
MainloopPipeline pipeline,
|
||||
PipelineState smem_pipe_write) {
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
// Issue the epilogue waits
|
||||
if (lane_predicate) {
|
||||
/* This helps avoid early exit of blocks in Cluster
|
||||
* Waits for all stages to either be released (all
|
||||
* Consumer UNLOCKs), or if the stage was never used
|
||||
* then would just be acquired since the phase was
|
||||
* still inverted from make_producer_start_state
|
||||
*/
|
||||
pipeline.producer_tail(smem_pipe_write);
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a collective-scoped matrix multiply-accumulate
|
||||
/// Consumer Perspective
|
||||
template <
|
||||
class FrgTensorC
|
||||
>
|
||||
CUTLASS_DEVICE void
|
||||
mma(MainloopPipeline pipeline,
|
||||
PipelineState smem_pipe_read,
|
||||
FrgTensorC& accum,
|
||||
int k_tile_count,
|
||||
int thread_idx,
|
||||
TensorStorage& shared_tensors,
|
||||
Params const& mainloop_params) {
|
||||
|
||||
|
||||
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
|
||||
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::is_void_v<SmemCopyAtomA>,
|
||||
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
|
||||
static_assert(cute::is_void_v<SmemCopyAtomB>,
|
||||
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
|
||||
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
|
||||
// Block scaling
|
||||
Tensor sScaleAViewAsC = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()),
|
||||
Layout<
|
||||
Shape<Shape<Int<ScaleGranularityM>, Int<ScaleMsPerTile>>, cute::tuple_element_t<1, TileShape>, Int<DispatchPolicy::Stages>>,
|
||||
Stride<Stride<_0, _1>, _0, Int<ScaleMsPerTile>>
|
||||
>{}); // ((ScaleGranularityM,ScaleMsPerTile),n,k)
|
||||
Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k)
|
||||
|
||||
//
|
||||
// Define C accumulators and A/B partitioning
|
||||
//
|
||||
|
||||
// Layout of warp group to thread mapping
|
||||
|
||||
static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and
|
||||
stride<0>(typename TiledMma::BLayout{}) == 0 and
|
||||
size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and
|
||||
size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup,
|
||||
"Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup");
|
||||
|
||||
constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup;
|
||||
Layout warp_group_thread_layout = make_layout(Int<MmaWarpGroups>{},
|
||||
Int<NumThreadsPerWarpGroup>{});
|
||||
|
||||
int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);
|
||||
|
||||
TiledMma tiled_mma;
|
||||
auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx));
|
||||
|
||||
Tensor tCsScaleAViewAsC = tiled_mma.get_slice(thread_idx).partition_C(sScaleAViewAsC); // (MMA,MMA_M,MMA_N,PIPE), `thread_mma` above is correct when partitioning A and B, but it is not correct when partitioning C.
|
||||
|
||||
Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
|
||||
// Allocate "fragments/descriptors"
|
||||
Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K
|
||||
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sA)); // PIPE
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB)); // PIPE
|
||||
|
||||
//
|
||||
// PIPELINED MAIN LOOP
|
||||
//
|
||||
static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX),
|
||||
"ERROR : Incorrect number of MMAs in flight");
|
||||
|
||||
// We release buffers to producer warps(dma load) with some mmas in flight
|
||||
PipelineState smem_pipe_release = smem_pipe_read;
|
||||
|
||||
// Per block scale values for operand A and B
|
||||
|
||||
using RegLayoutScaleAViewAsC = decltype(make_layout_like(tCsScaleAViewAsC(_, _, _, 0).layout())); // `make_layout_like` makes a compact layout.
|
||||
using RegLayoutScaleAEssential = decltype(filter_zeros(RegLayoutScaleAViewAsC{}.stride(), RegLayoutScaleAViewAsC{}.shape())); // an interface to traverse the underlying storage for the compact layout mentioned above
|
||||
|
||||
Tensor tCrScaleAViewAsC = make_tensor<ElementBlockScale>(RegLayoutScaleAViewAsC{}); // (MMA,MMA_M,MMA_N)
|
||||
ElementBlockScale scale_b;
|
||||
|
||||
// Prologue GMMAs
|
||||
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
||||
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
|
||||
GmmaFP8AccumulationWithScale accumulation(accum, size<2>(TileShape{}) / size<2>(typename TiledMma::AtomShape_MNK{}), size<2>(tCrA));
|
||||
warpgroup_fence_operand(accumulation());
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue)
|
||||
{
|
||||
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
|
||||
if (accumulation.prepare_if_needed()) {
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
}
|
||||
|
||||
int read_stage = smem_pipe_read.index();
|
||||
|
||||
// Load per block scale values from shared memory to registers.
|
||||
scale_b = sScaleB[read_stage];
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
|
||||
tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{}));
|
||||
}
|
||||
if constexpr (ScaleMsPerTile == 1) {
|
||||
static_assert(size(RegLayoutScaleAEssential{}) == 1);
|
||||
tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`.
|
||||
} else {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
|
||||
tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b;
|
||||
}
|
||||
}
|
||||
|
||||
warpgroup_arrive();
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
// (V,M,K) x (V,N,K) => (V,M,N)
|
||||
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation());
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
|
||||
// Block scale the accumulators with reg tensor `tCrScaleAViewAsC`
|
||||
accumulation.scale_if_needed(tCrScaleAViewAsC);
|
||||
|
||||
++smem_pipe_read;
|
||||
}
|
||||
|
||||
warpgroup_fence_operand(accumulation());
|
||||
// Mainloop GMMAs
|
||||
k_tile_count -= prologue_mma_count;
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for ( ; k_tile_count > 0; --k_tile_count)
|
||||
{
|
||||
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
|
||||
//
|
||||
// Compute on k_tile
|
||||
//
|
||||
|
||||
int read_stage = smem_pipe_read.index();
|
||||
|
||||
// Load per block scale values from shared memory to registers (at most twice per block along M and exactly once per block along N)
|
||||
scale_b = sScaleB[read_stage];
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
|
||||
tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{}));
|
||||
}
|
||||
if constexpr (ScaleMsPerTile == 1) {
|
||||
static_assert(size(RegLayoutScaleAEssential{}) == 1);
|
||||
tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`.
|
||||
} else {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
|
||||
tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b;
|
||||
}
|
||||
}
|
||||
|
||||
if (accumulation.prepare_if_needed()) {
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
}
|
||||
|
||||
warpgroup_fence_operand(accumulation());
|
||||
warpgroup_arrive();
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
// (V,M,K) x (V,N,K) => (V,M,N)
|
||||
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation());
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
|
||||
/// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed
|
||||
warpgroup_wait<K_PIPE_MMAS>();
|
||||
warpgroup_fence_operand(accumulation());
|
||||
|
||||
// Block scale the accumulators with reg tensor `tCrScaleAViewAsC`
|
||||
accumulation.scale_if_needed(tCrScaleAViewAsC);
|
||||
|
||||
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
|
||||
|
||||
// Advance smem_pipe_read and smem_pipe_release
|
||||
++smem_pipe_read;
|
||||
++smem_pipe_release;
|
||||
}
|
||||
|
||||
accumulation.scale_residue_if_needed(tCrScaleAViewAsC);
|
||||
|
||||
warpgroup_fence_operand(accumulation());
|
||||
}
|
||||
|
||||
/// Perform a Consumer Epilogue to release all buffers
|
||||
CUTLASS_DEVICE void
|
||||
mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) {
|
||||
// Prologue GMMAs
|
||||
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
||||
k_tile_count -= prologue_mma_count;
|
||||
|
||||
smem_pipe_release.advance(k_tile_count);
|
||||
|
||||
// Wait on all GMMAs to complete
|
||||
warpgroup_wait<0>();
|
||||
|
||||
for (int count = 0; count < prologue_mma_count; ++count) {
|
||||
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
|
||||
++smem_pipe_release;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
39
csrc/cutlass_extensions/gemm/dispatch_policy.hpp
Normal file
39
csrc/cutlass_extensions/gemm/dispatch_policy.hpp
Normal file
@ -0,0 +1,39 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
|
||||
namespace cutlass::gemm {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// FP8 related policies (including Blocked Scaled Accumulation)
|
||||
// `ScaleGranularityM` specifies scaling granularity along M, while zero-value
|
||||
// `ScaleGranularityM` indicates that scaling granularity is
|
||||
// `size<0>(TileShape_MNK{})` along M.
|
||||
template <int ScaleGranularityM = 0>
|
||||
struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
|
||||
: KernelTmaWarpSpecializedCooperative {};
|
||||
|
||||
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp
|
||||
// specialized dynamic schedule For FP8 kernels with Block Scaling
|
||||
template <int Stages_, class ClusterShape_ = Shape<_1, _1, _1>,
|
||||
class KernelSchedule = KernelTmaWarpSpecialized,
|
||||
int ScaleGranularityM =
|
||||
0 // `ScaleGranularityM` specifies scaling granularity along M,
|
||||
// while zero-value `ScaleGranularityM` indicates that scaling
|
||||
// granularity is `size<0>(TileShape_MNK{})` along M.
|
||||
>
|
||||
struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8
|
||||
: MainloopSm90TmaGmmaWarpSpecialized<Stages_, ClusterShape_,
|
||||
KernelSchedule> {
|
||||
static_assert(
|
||||
cute::is_same_v<
|
||||
KernelSchedule,
|
||||
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<
|
||||
ScaleGranularityM>>,
|
||||
"KernelSchedule must be one of the warp specialized policies");
|
||||
};
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm
|
||||
@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
using namespace cute;
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import enum
|
||||
from typing import Dict, Union
|
||||
|
||||
|
||||
@ -138,8 +138,8 @@ __device__ inline FragB dequant<vllm::kU4B8.id()>(int q) {
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
|
||||
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
|
||||
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
|
||||
// directly into `SUB` and `ADD`.
|
||||
const int SUB = 0x64086408;
|
||||
@ -182,8 +182,8 @@ __device__ inline FragB dequant<vllm::kU4.id()>(int q) {
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
|
||||
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
|
||||
|
||||
const int SUB = 0x64006400;
|
||||
const int MUL = 0x2c002c00;
|
||||
|
||||
@ -197,6 +197,72 @@ __global__ void moe_align_block_size_global_mem_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
// taken from
|
||||
// https://github.com/sgl-project/sglang/commit/ded9fcd09a43d5e7d5bb31a2bc3e9fc21bf65d2a
|
||||
template <typename scalar_t>
|
||||
__global__ void sgl_moe_align_block_size_kernel(
|
||||
scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids,
|
||||
int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts,
|
||||
int32_t block_size, size_t numel, int32_t* cumsum) {
|
||||
__shared__ int32_t shared_counts[32][8];
|
||||
__shared__ int32_t local_offsets[256];
|
||||
|
||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||
const int experts_per_warp = 8;
|
||||
const int my_expert_start = warp_id * experts_per_warp;
|
||||
|
||||
for (int i = 0; i < experts_per_warp; ++i) {
|
||||
if (my_expert_start + i < num_experts) {
|
||||
shared_counts[warp_id][i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
|
||||
const size_t start_idx = threadIdx.x * tokens_per_thread;
|
||||
|
||||
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
||||
int expert_id = topk_ids[i];
|
||||
int warp_idx = expert_id / experts_per_warp;
|
||||
int expert_offset = expert_id % experts_per_warp;
|
||||
atomicAdd(&shared_counts[warp_idx][expert_offset], 1);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
cumsum[0] = 0;
|
||||
for (int i = 1; i <= num_experts; ++i) {
|
||||
int expert_count = 0;
|
||||
int warp_idx = (i - 1) / experts_per_warp;
|
||||
int expert_offset = (i - 1) % experts_per_warp;
|
||||
expert_count = shared_counts[warp_idx][expert_offset];
|
||||
|
||||
cumsum[i] =
|
||||
cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size;
|
||||
}
|
||||
*total_tokens_post_pad = cumsum[num_experts];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < num_experts) {
|
||||
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
|
||||
i += block_size) {
|
||||
expert_ids[i / block_size] = threadIdx.x;
|
||||
}
|
||||
local_offsets[threadIdx.x] = cumsum[threadIdx.x];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
||||
int32_t expert_id = topk_ids[i];
|
||||
int32_t rank_post_pad = atomicAdd(&local_offsets[expert_id], 1);
|
||||
sorted_token_ids[rank_post_pad] = i;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, int TOPK>
|
||||
__global__ void moe_sum_kernel(
|
||||
scalar_t* __restrict__ out, // [..., d]
|
||||
@ -305,6 +371,32 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
}
|
||||
}
|
||||
|
||||
void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
int64_t block_size,
|
||||
torch::Tensor sorted_token_ids,
|
||||
torch::Tensor experts_ids,
|
||||
torch::Tensor num_tokens_post_pad) {
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_INTEGRAL_TYPES(
|
||||
topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] {
|
||||
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
|
||||
// tensors
|
||||
auto options_int =
|
||||
torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device());
|
||||
// torch::Tensor token_cnts_buffer =
|
||||
// torch::empty({(num_experts + 1) * num_experts}, options_int);
|
||||
torch::Tensor cumsum_buffer =
|
||||
torch::empty({num_experts + 1}, options_int);
|
||||
|
||||
auto kernel = vllm::moe::sgl_moe_align_block_size_kernel<scalar_t>;
|
||||
kernel<<<1, 1024, 0, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
|
||||
experts_ids.data_ptr<int32_t>(),
|
||||
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
|
||||
topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>());
|
||||
});
|
||||
}
|
||||
|
||||
void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
|
||||
torch::Tensor& output) // [num_tokens, hidden_size]
|
||||
{
|
||||
|
||||
@ -12,3 +12,9 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
int64_t block_size, torch::Tensor sorted_token_ids,
|
||||
torch::Tensor experts_ids,
|
||||
torch::Tensor num_tokens_post_pad);
|
||||
|
||||
void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
int64_t block_size,
|
||||
torch::Tensor sorted_token_ids,
|
||||
torch::Tensor experts_ids,
|
||||
torch::Tensor num_tokens_post_pad);
|
||||
|
||||
@ -22,6 +22,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
" Tensor! num_tokens_post_pad) -> ()");
|
||||
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
|
||||
|
||||
// temporarily adapted from
|
||||
// https://github.com/sgl-project/sglang/commit/ded9fcd09a43d5e7d5bb31a2bc3e9fc21bf65d2a
|
||||
m.def(
|
||||
"sgl_moe_align_block_size(Tensor topk_ids, int num_experts,"
|
||||
" int block_size, Tensor! sorted_token_ids,"
|
||||
" Tensor! experts_ids,"
|
||||
" Tensor! num_tokens_post_pad) -> ()");
|
||||
m.impl("sgl_moe_align_block_size", torch::kCUDA, &sgl_moe_align_block_size);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
m.def(
|
||||
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
|
||||
|
||||
@ -153,6 +153,7 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
|
||||
|
||||
#ifndef USE_ROCM
|
||||
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
|
||||
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
|
||||
|
||||
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||
|
||||
93
csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh
Normal file
93
csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh
Normal file
@ -0,0 +1,93 @@
|
||||
#pragma once
|
||||
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "core/math.hpp"
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
// clang-format on
|
||||
|
||||
namespace vllm::c3x {
|
||||
|
||||
static inline cute::Shape<int, int, int, int> get_problem_shape(
|
||||
torch::Tensor const& a, torch::Tensor const& b) {
|
||||
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
|
||||
return {m, n, k, 1};
|
||||
}
|
||||
|
||||
template <typename GemmKernel>
|
||||
void cutlass_gemm_caller(torch::Device device,
|
||||
cute::Shape<int, int, int, int> prob_shape,
|
||||
typename GemmKernel::MainloopArguments mainloop_args,
|
||||
typename GemmKernel::EpilogueArguments epilogue_args) {
|
||||
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
prob_shape, mainloop_args, epilogue_args};
|
||||
|
||||
// Launch the CUTLASS GEMM kernel.
|
||||
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
GemmOp gemm_op;
|
||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(device);
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(device.index());
|
||||
|
||||
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
template <typename Gemm, typename... EpilogueArgs>
|
||||
void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_params) {
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
|
||||
int64_t lda = a.stride(0);
|
||||
int64_t ldb = b.stride(1);
|
||||
int64_t ldc = out.stride(0);
|
||||
|
||||
using StrideA = cute::Stride<int64_t, cute::Int<1>, int64_t>;
|
||||
using StrideB = cute::Stride<int64_t, cute::Int<1>, int64_t>;
|
||||
using StrideC = typename Gemm::StrideC;
|
||||
|
||||
StrideA a_stride{lda, cute::Int<1>{}, 0};
|
||||
StrideB b_stride{ldb, cute::Int<1>{}, 0};
|
||||
StrideC c_stride{ldc, cute::Int<1>{}, cute::Int<0>{}};
|
||||
|
||||
typename GemmKernel::ProblemShape prob_shape = get_problem_shape(a, b);
|
||||
|
||||
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
|
||||
typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
|
||||
b_stride};
|
||||
|
||||
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
Gemm::Epilogue::prepare_args(
|
||||
std::forward<EpilogueArgs>(epilogue_params)...),
|
||||
c_ptr, c_stride, c_ptr, c_stride};
|
||||
|
||||
cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
|
||||
epilogue_args);
|
||||
}
|
||||
|
||||
} // namespace vllm::c3x
|
||||
@ -2,9 +2,6 @@
|
||||
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
@ -32,21 +29,6 @@ using namespace cute;
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// A wrapper for the GEMM kernel that is used to guard against compilation on
|
||||
// architectures that will never use the kernel. The purpose of this is to
|
||||
// reduce the size of the compiled binary.
|
||||
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
|
||||
// into code that will be executed on the device where it is defined.
|
||||
template <typename Kernel>
|
||||
struct enable_sm90_or_later : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE void operator()(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
|
||||
Kernel::operator()(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ElementAB_, typename ElementD_,
|
||||
template <typename, typename, typename> typename Epilogue_,
|
||||
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
||||
@ -101,60 +83,4 @@ struct cutlass_3x_gemm {
|
||||
struct GemmKernel : public KernelType {};
|
||||
};
|
||||
|
||||
template <typename Gemm, typename... EpilogueArgs>
|
||||
void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_params) {
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
|
||||
int32_t m = a.size(0);
|
||||
int32_t n = b.size(1);
|
||||
int32_t k = a.size(1);
|
||||
|
||||
int64_t lda = a.stride(0);
|
||||
int64_t ldb = b.stride(1);
|
||||
int64_t ldc = out.stride(0);
|
||||
|
||||
using StrideA = Stride<int64_t, Int<1>, int64_t>;
|
||||
using StrideB = Stride<int64_t, Int<1>, int64_t>;
|
||||
using StrideC = typename Gemm::StrideC;
|
||||
|
||||
StrideA a_stride{lda, Int<1>{}, 0};
|
||||
StrideB b_stride{ldb, Int<1>{}, 0};
|
||||
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
|
||||
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
typename GemmKernel::ProblemShape prob_shape{m, n, k, 1};
|
||||
|
||||
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
|
||||
typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
|
||||
b_stride};
|
||||
|
||||
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
Gemm::Epilogue::prepare_args(
|
||||
std::forward<EpilogueArgs>(epilogue_params)...),
|
||||
c_ptr, c_stride, c_ptr, c_stride};
|
||||
|
||||
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
prob_shape, mainloop_args, epilogue_args};
|
||||
|
||||
// Launch the CUTLASS GEMM kernel.
|
||||
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
GemmOp gemm_op;
|
||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||
|
||||
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@ -0,0 +1,24 @@
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_sm90_int8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_azp_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
if (azp) {
|
||||
return cutlass_scaled_mm_sm90_int8_epilogue<
|
||||
c3x::ScaledEpilogueBiasAzpToken>(out, a, b, a_scales, b_scales, azp_adj,
|
||||
*azp, bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogueBiasAzp>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@ -0,0 +1,24 @@
|
||||
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_blockwise_sm90_fp8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::bfloat16_t>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::half_t>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@ -0,0 +1,168 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass_gemm_caller.cuh"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename OutType, int GroupSizeM_, int GroupSizeN_, int GroupSizeK_,
|
||||
int TileSizeM_ = 128, class ClusterShape = Shape<_1, _2, _1>>
|
||||
struct cutlass_3x_gemm_fp8_blockwise {
|
||||
using GroupSizeM = Int<GroupSizeM_>;
|
||||
using GroupSizeN = Int<GroupSizeN_>;
|
||||
using GroupSizeK = Int<GroupSizeK_>;
|
||||
using TileSizeM = Int<TileSizeM_>;
|
||||
|
||||
static_assert(TileSizeM_ % GroupSizeM_ == 0,
|
||||
"TileSizeM must be a multiple of GroupSizeM");
|
||||
|
||||
using ElementAB = cutlass::float_e4m3_t;
|
||||
|
||||
using ElementA = ElementAB;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
|
||||
|
||||
using ElementB = ElementAB;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
||||
|
||||
using ElementD = OutType;
|
||||
using StrideD = Stride<int64_t, Int<1>, Int<0>>;
|
||||
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
using ElementC = void;
|
||||
using StrideC = StrideD;
|
||||
static constexpr int AlignmentC = AlignmentD;
|
||||
|
||||
using ElementAccumulator = float;
|
||||
using ElementBlockScale = float;
|
||||
using ElementCompute = float;
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using TileShape = Shape<TileSizeM, GroupSizeN, GroupSizeK>;
|
||||
|
||||
using KernelSchedule = cutlass::gemm::
|
||||
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<
|
||||
GroupSizeM_>;
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
|
||||
using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT<
|
||||
cutlass::epilogue::fusion::Sm90AccFetch>;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType,
|
||||
ElementAccumulator, ElementCompute, ElementC, StrideC, AlignmentC,
|
||||
ElementD, StrideD, AlignmentD, EpilogueSchedule,
|
||||
StoreEpilogueCompute>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB,
|
||||
LayoutB, AlignmentB, ElementAccumulator, TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
|
||||
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
|
||||
cutlass::gemm::PersistentScheduler>>;
|
||||
|
||||
struct GemmKernel : public KernelType {};
|
||||
|
||||
using StrideA = typename GemmKernel::StrideA;
|
||||
using StrideB = typename GemmKernel::StrideB;
|
||||
};
|
||||
|
||||
template <typename Gemm>
|
||||
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
|
||||
auto prob_shape = c3x::get_problem_shape(a, b);
|
||||
int32_t m = get<0>(prob_shape), n = get<1>(prob_shape),
|
||||
k = get<2>(prob_shape);
|
||||
|
||||
int64_t lda = a.stride(0);
|
||||
int64_t ldb = b.stride(1);
|
||||
int64_t ldc = out.stride(0);
|
||||
|
||||
using StrideA = Stride<int64_t, Int<1>, int64_t>;
|
||||
using StrideB = Stride<int64_t, Int<1>, int64_t>;
|
||||
using StrideC = typename Gemm::StrideC;
|
||||
|
||||
StrideA a_stride{lda, Int<1>{}, 0};
|
||||
StrideB b_stride{ldb, Int<1>{}, 0};
|
||||
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
|
||||
|
||||
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
|
||||
auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr());
|
||||
auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr());
|
||||
|
||||
// Check is the t is contiguous and is 1D or 2D with one of the dimensions
|
||||
// being 1 (i.e. a row or column vector)
|
||||
auto is_contiguous_vector = [](const torch::Tensor& t) {
|
||||
auto t_sizes = t.sizes();
|
||||
return t.is_contiguous() &&
|
||||
(t.dim() == 1 ||
|
||||
(t.dim() == 2 &&
|
||||
*std::min_element(t_sizes.begin(), t_sizes.end()) == 1));
|
||||
};
|
||||
|
||||
// TODO(lucas): lets clean-up the kernel so that we pass in Strides so
|
||||
// we don't have to deal with enforcing implicit layouts
|
||||
TORCH_CHECK(a_scales.size(0) == m / Gemm::GroupSizeM::value);
|
||||
TORCH_CHECK(a_scales.size(1) == k / Gemm::GroupSizeK::value);
|
||||
TORCH_CHECK(a_scales.stride(0) == 1 || is_contiguous_vector(a_scales),
|
||||
"a_scales must be M major");
|
||||
TORCH_CHECK(b_scales.size(0) == k / Gemm::GroupSizeK::value);
|
||||
TORCH_CHECK(b_scales.size(1) == n / Gemm::GroupSizeN::value);
|
||||
TORCH_CHECK(b_scales.stride(0) == 1 || is_contiguous_vector(b_scales),
|
||||
"b_scales must be K major");
|
||||
typename GemmKernel::MainloopArguments mainloop_args{
|
||||
a_ptr, a_stride, b_ptr, b_stride, a_scales_ptr, b_scales_ptr};
|
||||
|
||||
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
{}, c_ptr, c_stride, c_ptr, c_stride};
|
||||
|
||||
c3x::cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
|
||||
epilogue_args);
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
cutlass_gemm_caller_blockwise<
|
||||
cutlass_3x_gemm_fp8_blockwise<OutType, 1, 128, 128>>(out, a, b, a_scales,
|
||||
b_scales);
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
33
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
Normal file
33
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
Normal file
@ -0,0 +1,33 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_sm90_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_azp_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales);
|
||||
|
||||
} // namespace vllm
|
||||
24
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu
Normal file
24
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu
Normal file
@ -0,0 +1,24 @@
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_sm90_fp8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_sm90_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_mm_sm90_fp8_epilogue<c3x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm90_fp8_epilogue<c3x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "scaled_mm_c3x.cuh"
|
||||
#include "scaled_mm.cuh"
|
||||
#include "cutlass_gemm_caller.cuh"
|
||||
|
||||
/**
|
||||
* This file defines Gemm kernel configurations for SM90 (fp8) based on the Gemm
|
||||
@ -9,6 +10,8 @@
|
||||
|
||||
namespace vllm {
|
||||
|
||||
using c3x::cutlass_gemm_caller;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_default {
|
||||
@ -93,4 +96,25 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
|
||||
}
|
||||
}
|
||||
|
||||
template <template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm90_fp8_epilogue(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
24
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu
Normal file
24
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu
Normal file
@ -0,0 +1,24 @@
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_sm90_int8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "scaled_mm_c3x.cuh"
|
||||
#include "scaled_mm.cuh"
|
||||
#include "cutlass_gemm_caller.cuh"
|
||||
|
||||
/**
|
||||
* This file defines Gemm kernel configurations for SM90 (int8) based on the
|
||||
@ -9,6 +10,8 @@
|
||||
|
||||
namespace vllm {
|
||||
|
||||
using c3x::cutlass_gemm_caller;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_int8_config_default {
|
||||
@ -137,4 +140,24 @@ inline void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out,
|
||||
}
|
||||
}
|
||||
|
||||
template <template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm90_int8_epilogue(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
|
||||
Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@ -1,52 +1,13 @@
|
||||
#include <cudaTypedefs.h>
|
||||
#include "c3x/scaled_mm_kernels.hpp"
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||
|
||||
#include "scaled_mm_c3x_sm90_fp8_dispatch.cuh"
|
||||
#include "scaled_mm_c3x_sm90_int8_dispatch.cuh"
|
||||
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
using namespace vllm;
|
||||
#include "core/math.hpp"
|
||||
|
||||
/*
|
||||
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
||||
NVIDIA GPUs with sm90a (Hopper) or later.
|
||||
*/
|
||||
|
||||
template <template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm90_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
if (a.dtype() == torch::kInt8) {
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
|
||||
Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
@ -54,14 +15,56 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->dtype() == c.dtype(),
|
||||
"currently bias dtype must match output dtype ", c.dtype());
|
||||
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBias>(
|
||||
c, a, b, a_scales, b_scales, *bias);
|
||||
|
||||
using GroupShape = std::array<int64_t, 2>;
|
||||
|
||||
int M = a.size(0), N = b.size(1), K = a.size(1);
|
||||
|
||||
GroupShape a_scale_group_shape = [&, &s = a_scales]() -> GroupShape {
|
||||
if (s.numel() == 1) return {M, K}; // tensor-wise
|
||||
if (s.dim() == 2)
|
||||
return {ceil_div(a.size(0), s.size(0)), ceil_div(a.size(1), s.size(1))};
|
||||
TORCH_CHECK(false, "Unsupported scale shape for scale_a");
|
||||
}();
|
||||
|
||||
GroupShape b_scale_group_shape = [&, &s = b_scales]() -> GroupShape {
|
||||
if (s.numel() == 1) return {K, N}; // tensor-wise
|
||||
if (s.dim() == 2)
|
||||
return {ceil_div(b.size(0), s.size(0)), ceil_div(b.size(1), s.size(1))};
|
||||
TORCH_CHECK(false, "Unsupported scale shape for scale_b");
|
||||
}();
|
||||
|
||||
if ((a_scale_group_shape == GroupShape{M, K} ||
|
||||
a_scale_group_shape == GroupShape{1, K}) &&
|
||||
(b_scale_group_shape == GroupShape{K, N} ||
|
||||
b_scale_group_shape == GroupShape{K, 1})) {
|
||||
// "standard per-tensor/per-token/per-channel" scaling
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (a.dtype() == torch::kFloat8_e4m3fn) {
|
||||
vllm::cutlass_scaled_mm_sm90_fp8(c, a, b, a_scales, b_scales, bias);
|
||||
} else {
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
vllm::cutlass_scaled_mm_sm90_int8(c, a, b, a_scales, b_scales, bias);
|
||||
}
|
||||
} else if (a_scale_group_shape == GroupShape{1, 128} &&
|
||||
b_scale_group_shape == GroupShape{128, 128}) {
|
||||
// 1x128 per-token group scales for activations
|
||||
// 128x128 blockwise scales for weights
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn &&
|
||||
b.dtype() == torch::kFloat8_e4m3fn,
|
||||
"Currently only FP8 is supported for A group shape 1x128 and "
|
||||
"B group shape 128x128");
|
||||
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
|
||||
|
||||
vllm::cutlass_scaled_mm_blockwise_sm90_fp8(c, a, b, a_scales, b_scales);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogue>(
|
||||
c, a, b, a_scales, b_scales);
|
||||
TORCH_CHECK(false,
|
||||
"Unsupported scale group shapes for CUTLASS 3.x GEMM.\n "
|
||||
"a_scale_group_shape must be [1, 128], got: [",
|
||||
a_scale_group_shape[0], ", ", a_scale_group_shape[1],
|
||||
"]\n"
|
||||
"b_scale_group_shape must be [128, 128], got: [",
|
||||
b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]");
|
||||
}
|
||||
}
|
||||
|
||||
@ -75,13 +78,6 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
if (azp) {
|
||||
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBiasAzpToken>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBiasAzp>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||
}
|
||||
vllm::cutlass_scaled_mm_azp_sm90_int8(out, a, b, a_scales, b_scales, azp_adj,
|
||||
azp, bias);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@ -81,6 +81,19 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
|
||||
// CUTLASS block-quantized FP8 kernels need at least CUDA 12.0
|
||||
// and at least SM90 (Hopper)
|
||||
|
||||
#if defined CUDA_VERSION
|
||||
if (cuda_device_capability >= 90) {
|
||||
return CUDA_VERSION >= 12000;
|
||||
}
|
||||
#endif
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
@ -89,15 +102,12 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||
b.size(1) == c.size(1));
|
||||
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
|
||||
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
|
||||
|
||||
// Check for strides and alignment
|
||||
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
||||
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||
TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
||||
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
|
||||
@ -215,4 +225,4 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
|
||||
"No compiled cutlass_scaled_mm_azp for a compute capability less than "
|
||||
"CUDA device capability: ",
|
||||
version_num);
|
||||
}
|
||||
}
|
||||
|
||||
@ -173,8 +173,8 @@ dequant<half, vllm::kU4B8.id()>(int q) {
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
|
||||
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
|
||||
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
|
||||
// directly into `SUB` and `ADD`.
|
||||
const int SUB = 0x64086408;
|
||||
@ -197,9 +197,9 @@ dequant<nv_bfloat16, vllm::kU4B8.id()>(int q) {
|
||||
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
|
||||
q >>= 4;
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
|
||||
|
||||
typename ScalarType<nv_bfloat16>::FragB frag_b;
|
||||
static constexpr uint32_t MUL = 0x3F803F80;
|
||||
@ -221,8 +221,8 @@ dequant<half, vllm::kU4.id()>(int q) {
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
|
||||
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
|
||||
|
||||
const int SUB = 0x64006400;
|
||||
const int MUL = 0x2c002c00;
|
||||
@ -244,9 +244,9 @@ dequant<nv_bfloat16, vllm::kU4.id()>(int q) {
|
||||
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
|
||||
q >>= 4;
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
|
||||
|
||||
typename ScalarType<nv_bfloat16>::FragB frag_b;
|
||||
static constexpr uint32_t MUL = 0x3F803F80;
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
|
||||
@ -272,6 +272,10 @@ struct MacheteCollectiveMma {
|
||||
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
|
||||
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
|
||||
// One threads per CTA are producers (1 for operand tile)
|
||||
static constexpr int NumProducerThreadEvents = 1;
|
||||
|
||||
using ScaleTileShape = decltype(make_shape(shape<0>(TileShape{}),
|
||||
shape<1>(SmemLayoutAtomScale{})));
|
||||
|
||||
|
||||
@ -96,8 +96,8 @@ __device__ inline FragB dequant(int q) {
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
|
||||
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
|
||||
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
|
||||
// directly into `SUB` and `ADD`.
|
||||
const int SUB = 0x64086408;
|
||||
|
||||
@ -141,8 +141,8 @@ __device__ inline FragB dequant_per_group(int q, FragS_GROUP& frag_s, int i) {
|
||||
static constexpr uint32_t HI = 0x00f000f0;
|
||||
static constexpr uint32_t EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
uint32_t t0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
uint32_t t1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
uint32_t t0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
|
||||
uint32_t t1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
|
||||
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
|
||||
// directly into `SUB` and `ADD`.
|
||||
static constexpr uint32_t SUB = 0x64086408;
|
||||
|
||||
@ -127,8 +127,8 @@ __device__ inline FragB dequant_4bit(int q) {
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
|
||||
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
|
||||
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
|
||||
// directly into `SUB` and `ADD`.
|
||||
const int SUB = 0x64086408;
|
||||
|
||||
@ -907,7 +907,9 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
|
||||
// max_num_partitions, head_size]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int max_num_partitions){UNREACHABLE_CODE}
|
||||
const int max_num_partitions) {
|
||||
UNREACHABLE_CODE
|
||||
}
|
||||
|
||||
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
|
||||
|
||||
|
||||
@ -324,6 +324,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
|
||||
ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
|
||||
|
||||
// Check if cutlass scaled_mm supports block quantization (used by DeepSeekV3)
|
||||
ops.def(
|
||||
"cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> "
|
||||
"bool");
|
||||
ops.impl("cutlass_scaled_mm_supports_block_fp8",
|
||||
&cutlass_scaled_mm_supports_fp8);
|
||||
|
||||
// Check if cutlass sparse scaled_mm is supported for CUDA devices of the
|
||||
// given capability
|
||||
ops.def(
|
||||
@ -463,6 +470,15 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
cache_ops.impl("reshape_and_cache_flash", torch::kCUDA,
|
||||
&reshape_and_cache_flash);
|
||||
|
||||
// Concat kv_c and k_pe and cache them.
|
||||
cache_ops.def(
|
||||
"concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
|
||||
" Tensor! kv_cache,"
|
||||
" Tensor slot_mapping,"
|
||||
" str kv_cache_dtype,"
|
||||
" Tensor scale) -> ()");
|
||||
cache_ops.impl("concat_and_cache_mla", torch::kCUDA, &concat_and_cache_mla);
|
||||
|
||||
// Convert the key and value cache to fp8 data type.
|
||||
cache_ops.def(
|
||||
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
sphinx==6.2.1
|
||||
sphinx-argparse==0.4.0
|
||||
sphinx-book-theme==1.0.1
|
||||
sphinx-copybutton==0.5.2
|
||||
myst-parser==3.0.1
|
||||
sphinx-argparse==0.4.0
|
||||
sphinx-design==0.6.1
|
||||
sphinx-togglebutton==0.3.2
|
||||
myst-parser==3.0.1
|
||||
msgspec
|
||||
cloudpickle
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
// Add RunLLM widget
|
||||
document.addEventListener("DOMContentLoaded", function () {
|
||||
var script = document.createElement("script");
|
||||
script.type = "module";
|
||||
@ -15,4 +16,23 @@ document.addEventListener("DOMContentLoaded", function () {
|
||||
|
||||
script.async = true;
|
||||
document.head.appendChild(script);
|
||||
});
|
||||
});
|
||||
|
||||
// Update URL search params when tab is clicked
|
||||
document.addEventListener("DOMContentLoaded", function () {
|
||||
const tabs = document.querySelectorAll(".sd-tab-label");
|
||||
|
||||
function updateURL(tab) {
|
||||
const syncGroup = tab.getAttribute("data-sync-group");
|
||||
const syncId = tab.getAttribute("data-sync-id");
|
||||
if (syncGroup && syncId) {
|
||||
const url = new URL(window.location);
|
||||
url.searchParams.set(syncGroup, syncId);
|
||||
window.history.replaceState(null, "", url);
|
||||
}
|
||||
}
|
||||
|
||||
tabs.forEach(tab => {
|
||||
tab.addEventListener("click", () => updateURL(tab));
|
||||
});
|
||||
});
|
||||
|
||||
@ -8,10 +8,10 @@
|
||||
.. currentmodule:: vllm.engine
|
||||
```
|
||||
|
||||
```{toctree}
|
||||
:::{toctree}
|
||||
:caption: Engines
|
||||
:maxdepth: 2
|
||||
|
||||
llm_engine
|
||||
async_llm_engine
|
||||
```
|
||||
:::
|
||||
|
||||
@ -2,10 +2,10 @@
|
||||
|
||||
## Submodules
|
||||
|
||||
```{toctree}
|
||||
:::{toctree}
|
||||
:maxdepth: 1
|
||||
|
||||
interfaces_base
|
||||
interfaces
|
||||
adapters
|
||||
```
|
||||
:::
|
||||
|
||||
@ -17,7 +17,7 @@ Looking to add your own multi-modal model? Please follow the instructions listed
|
||||
|
||||
## Submodules
|
||||
|
||||
```{toctree}
|
||||
:::{toctree}
|
||||
:maxdepth: 1
|
||||
|
||||
inputs
|
||||
@ -25,4 +25,4 @@ parse
|
||||
processing
|
||||
profiling
|
||||
registry
|
||||
```
|
||||
:::
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
# Offline Inference
|
||||
|
||||
```{toctree}
|
||||
:::{toctree}
|
||||
:caption: Contents
|
||||
:maxdepth: 1
|
||||
|
||||
llm
|
||||
llm_inputs
|
||||
```
|
||||
:::
|
||||
|
||||
BIN
docs/source/assets/design/v1/prefix_caching/example-time-1.png
Normal file
BIN
docs/source/assets/design/v1/prefix_caching/example-time-1.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 34 KiB |
BIN
docs/source/assets/design/v1/prefix_caching/example-time-3.png
Normal file
BIN
docs/source/assets/design/v1/prefix_caching/example-time-3.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 36 KiB |
BIN
docs/source/assets/design/v1/prefix_caching/example-time-4.png
Normal file
BIN
docs/source/assets/design/v1/prefix_caching/example-time-4.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 41 KiB |
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user