Compare commits
12 Commits
v0.7.1
...
mla_cuda_g
| Author | SHA1 | Date | |
|---|---|---|---|
| 0a02744dc8 | |||
| 984ffddda6 | |||
| 135c404fbb | |||
| 7241acbd64 | |||
| 2b140debbb | |||
| 2326814c11 | |||
| 534cd0006d | |||
| aa19f297d2 | |||
| 4880a43d20 | |||
| 3895bba85a | |||
| f23d126a07 | |||
| ec8c1cf732 |
@ -56,11 +56,6 @@ steps:
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
- input: "Provide Release version here"
|
||||
fields:
|
||||
- text: "What is the release version?"
|
||||
key: "release-version"
|
||||
|
||||
- block: "Build CPU release image"
|
||||
key: block-cpu-release-image-build
|
||||
depends_on: ~
|
||||
@ -71,7 +66,7 @@ steps:
|
||||
queue: cpu_queue_postmerge
|
||||
commands:
|
||||
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --progress plain -f Dockerfile.cpu ."
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version)"
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$RELEASE_VERSION --progress plain -f Dockerfile.cpu ."
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$RELEASE_VERSION"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
37
.github/mergify.yml
vendored
@ -35,43 +35,6 @@ pull_request_rules:
|
||||
add:
|
||||
- frontend
|
||||
|
||||
- name: label-structured-output
|
||||
description: Automatically apply structured-output label
|
||||
conditions:
|
||||
- or:
|
||||
- files~=^vllm/model_executor/guided_decoding/
|
||||
- files=tests/model_executor/test_guided_processors.py
|
||||
- files=tests/entrypoints/llm/test_guided_generate.py
|
||||
- files=benchmarks/benchmark_serving_guided.py
|
||||
- files=benchmarks/benchmark_guided.py
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
- structured-output
|
||||
|
||||
- name: label-speculative-decoding
|
||||
description: Automatically apply speculative-decoding label
|
||||
conditions:
|
||||
- or:
|
||||
- files~=^vllm/spec_decode/
|
||||
- files=vllm/model_executor/layers/spec_decode_base_sampler.py
|
||||
- files~=^tests/spec_decode/
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
- speculative-decoding
|
||||
|
||||
- name: label-v1
|
||||
description: Automatically apply v1 label
|
||||
conditions:
|
||||
- or:
|
||||
- files~=^vllm/v1/
|
||||
- files~=^tests/v1/
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
- v1
|
||||
|
||||
- name: ping author on conflicts and add 'needs-rebase' label
|
||||
conditions:
|
||||
- conflict
|
||||
|
||||
@ -85,22 +85,9 @@ repos:
|
||||
entry: tools/png-lint.sh
|
||||
language: script
|
||||
types: [png]
|
||||
- id: signoff-commit
|
||||
name: Sign-off Commit
|
||||
entry: bash
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
if ! grep -q "^Signed-off-by: $(git config user.name) <$(git config user.email)>" .git/COMMIT_EDITMSG; then
|
||||
printf "\nSigned-off-by: $(git config user.name) <$(git config user.email)>\n" >> .git/COMMIT_EDITMSG
|
||||
fi
|
||||
language: system
|
||||
verbose: true
|
||||
stages: [commit-msg]
|
||||
- id: suggestion
|
||||
name: Suggestion
|
||||
entry: bash -c 'echo "To bypass pre-commit hooks, add --no-verify to git commit."'
|
||||
language: system
|
||||
verbose: true
|
||||
pass_filenames: false
|
||||
|
||||
|
||||
@ -245,7 +245,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
FetchContent_Declare(
|
||||
cutlass
|
||||
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
|
||||
GIT_TAG v3.7.0
|
||||
GIT_TAG v3.6.0
|
||||
GIT_PROGRESS TRUE
|
||||
|
||||
# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
|
||||
@ -299,12 +299,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# CUDA 12.0 or later (and only work on Hopper, 9.0a for now).
|
||||
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu")
|
||||
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
|
||||
|
||||
@ -4,12 +4,12 @@ USER root
|
||||
|
||||
ENV PATH="/usr/local/cargo/bin:$PATH:/opt/conda/bin/"
|
||||
|
||||
RUN apt-get update -y && apt-get install -y git wget kmod curl vim libnuma-dev libsndfile-dev libprotobuf-dev build-essential ffmpeg libsm6 libxext6 libgl1 libssl-dev
|
||||
RUN apt-get update -y && apt-get install -y git wget curl vim libnuma-dev libsndfile-dev libprotobuf-dev build-essential ffmpeg libsm6 libxext6 libgl1 libssl-dev
|
||||
|
||||
# Some packages in requirements-cpu are installed here
|
||||
# IBM provides optimized packages for ppc64le processors in the open-ce project for mamba
|
||||
# Currently these may not be available for venv or pip directly
|
||||
RUN micromamba install -y -n base -c https://ftp.osuosl.org/pub/open-ce/1.11.0-p10/ -c defaults python=3.10 rust && micromamba clean --all --yes
|
||||
RUN micromamba install -y -n base -c https://ftp.osuosl.org/pub/open-ce/1.11.0-p10/ -c defaults python=3.10 torchvision-cpu=0.16.2 rust && micromamba clean --all --yes
|
||||
|
||||
COPY ./ /workspace/vllm
|
||||
|
||||
@ -21,6 +21,7 @@ RUN --mount=type=bind,source=.git,target=.git \
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
RUSTFLAGS='-L /opt/conda/lib' pip install -v --prefer-binary --extra-index-url https://repo.fury.io/mgiessing \
|
||||
'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \
|
||||
torch==2.3.1 \
|
||||
-r requirements-cpu.txt \
|
||||
xformers uvloop==0.20.0
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ import copy
|
||||
import itertools
|
||||
import pickle as pkl
|
||||
import time
|
||||
from typing import Callable, Iterable, List, Optional, Tuple
|
||||
from typing import Callable, Iterable, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.utils.benchmark as TBenchmark
|
||||
@ -12,8 +12,6 @@ from utils import make_rand_tensors
|
||||
from weight_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
w8a8_block_fp8_matmul)
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
|
||||
@ -40,15 +38,8 @@ def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
|
||||
|
||||
def bench_int8(
|
||||
dtype: torch.dtype,
|
||||
m: int,
|
||||
k: int,
|
||||
n: int,
|
||||
label: str,
|
||||
sub_label: str,
|
||||
bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]:
|
||||
"""Benchmark INT8-based kernels."""
|
||||
def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
||||
sub_label: str) -> Iterable[TMeasurement]:
|
||||
assert dtype == torch.int8
|
||||
a, b = make_rand_tensors(torch.int8, m, n, k)
|
||||
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
@ -57,132 +48,155 @@ def bench_int8(
|
||||
azp = torch.zeros((m, ), device="cuda", dtype=torch.int32)
|
||||
azp_adj = torch.zeros((n, ), device="cuda", dtype=torch.int32)
|
||||
|
||||
bench_fns = {
|
||||
"pytorch_bf16_bf16_bf16_matmul-no-scales":
|
||||
lambda: torch.mm(a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
|
||||
),
|
||||
"pytorch_fp16_fp16_fp16_matmul-no-scales":
|
||||
lambda: torch.mm(a.to(dtype=torch.float16), b.to(dtype=torch.float16)),
|
||||
"cutlass_i8_i8_bf16_scaled_mm":
|
||||
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16),
|
||||
"cutlass_i8_i8_bf16_scaled_mm_bias":
|
||||
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16,
|
||||
bias),
|
||||
"cutlass_i8_i8_bf16_scaled_mm_azp":
|
||||
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
|
||||
bfloat16, azp_adj),
|
||||
"cutlass_i8_i8_bf16_scaled_mm_azp_bias":
|
||||
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
|
||||
bfloat16, azp_adj, None, bias),
|
||||
"cutlass_i8_i8_bf16_scaled_mm_azp_pt":
|
||||
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
|
||||
bfloat16, azp_adj, azp),
|
||||
"cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias":
|
||||
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
|
||||
bfloat16, azp_adj, azp, bias),
|
||||
}
|
||||
|
||||
timers = []
|
||||
for name, fn in bench_fns.items():
|
||||
# If bench_kernels is None, run all. Otherwise, run only exact matches.
|
||||
if bench_kernels is None or name in bench_kernels:
|
||||
print(f"Running {name}")
|
||||
timers.append(bench_fn(label, sub_label, name, fn))
|
||||
# pytorch impl - bfloat16
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
|
||||
torch.mm, a.to(dtype=torch.bfloat16),
|
||||
b.to(dtype=torch.bfloat16)))
|
||||
|
||||
# pytorch impl - float16
|
||||
timers.append(
|
||||
bench_fn(label, sub_label,
|
||||
"pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm,
|
||||
a.to(dtype=torch.float16), b.to(dtype=torch.float16)))
|
||||
|
||||
# cutlass impl
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm",
|
||||
ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
|
||||
torch.bfloat16))
|
||||
|
||||
# cutlass with bias
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias",
|
||||
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16,
|
||||
bias))
|
||||
|
||||
# cutlass with azp per-tensor
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp",
|
||||
ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
|
||||
torch.bfloat16, azp_adj))
|
||||
|
||||
# cutlass with azp per-tensor + bias
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_bias",
|
||||
ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
|
||||
torch.bfloat16, azp_adj, None, bias))
|
||||
|
||||
# cutlass with azp per-token
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_pt",
|
||||
ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
|
||||
torch.bfloat16, azp_adj, azp))
|
||||
|
||||
# cutlass with azp per-token + bias
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias",
|
||||
ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
|
||||
torch.bfloat16, azp_adj, azp, bias))
|
||||
|
||||
return timers
|
||||
|
||||
|
||||
def bench_fp8(
|
||||
dtype: torch.dtype,
|
||||
m: int,
|
||||
k: int,
|
||||
n: int,
|
||||
label: str,
|
||||
sub_label: str,
|
||||
bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]:
|
||||
"""Benchmark FP8-based kernels."""
|
||||
def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
||||
sub_label: str) -> Iterable[TMeasurement]:
|
||||
assert dtype == torch.float8_e4m3fn
|
||||
a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k)
|
||||
a_cont = a.contiguous()
|
||||
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
block_scale_a = torch.rand((m, k // 128),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
block_scale_b = torch.rand((k // 128, n // 128),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
block_scale_a_M_major = block_scale_a.t().contiguous().t()
|
||||
block_scale_b_K_major = block_scale_b.t().contiguous().t()
|
||||
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
print(m, k, n)
|
||||
|
||||
bench_fns = {
|
||||
"pytorch_bf16_bf16_bf16_matmul-no-scales":
|
||||
lambda: torch.mm(a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
|
||||
),
|
||||
"pytorch_fp16_fp16_fp16_matmul-no-scales":
|
||||
lambda: torch.mm(a.to(dtype=torch.float16), b.to(dtype=torch.float16)),
|
||||
"pytorch_fp8_fp8_fp16_scaled_mm":
|
||||
lambda: torch._scaled_mm(
|
||||
a, b, scale_a, scale_b, out_dtype=torch.float16),
|
||||
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum":
|
||||
lambda: torch._scaled_mm(a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=torch.float16,
|
||||
use_fast_accum=True),
|
||||
"pytorch_fp8_fp8_bf16_scaled_mm":
|
||||
lambda: torch._scaled_mm(
|
||||
a, b, scale_a, scale_b, out_dtype=torch.bfloat16),
|
||||
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum":
|
||||
lambda: torch._scaled_mm(a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=torch.bfloat16,
|
||||
use_fast_accum=True),
|
||||
"cutlass_fp8_fp8_bf16_scaled_mm":
|
||||
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16),
|
||||
"cutlass_fp8_fp8_fp16_scaled_mm":
|
||||
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.float16),
|
||||
"cutlass_fp8_fp8_bf16_scaled_mm_bias":
|
||||
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16,
|
||||
bias),
|
||||
"cutlass_fp8_fp8_fp16_scaled_mm_bias":
|
||||
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.float16,
|
||||
bias.to(dtype=torch.float16)),
|
||||
"triton_fp8_fp8_fp16_scaled_mm_blockwise":
|
||||
lambda: w8a8_block_fp8_matmul(a_cont, b.t(), block_scale_a,
|
||||
block_scale_b.t(), (128, 128)),
|
||||
"cutlass_fp8_fp8_fp16_scaled_mm_blockwise":
|
||||
lambda: ops.cutlass_scaled_mm(a, b, block_scale_a_M_major,
|
||||
block_scale_b_K_major, torch.float16),
|
||||
}
|
||||
|
||||
timers = []
|
||||
for name, fn in bench_fns.items():
|
||||
# If bench_kernels is None, run all. Otherwise, run only exact matches.
|
||||
if bench_kernels is None or name in bench_kernels:
|
||||
print(f"Running {name}")
|
||||
timers.append(bench_fn(label, sub_label, name, fn))
|
||||
|
||||
# pytorch impl w. bf16
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
|
||||
torch.mm, a.to(dtype=torch.bfloat16, device="cuda"),
|
||||
b.to(dtype=torch.bfloat16, device="cuda")))
|
||||
|
||||
# pytorch impl: bf16 output, without fp8 fast accum
|
||||
timers.append(
|
||||
bench_fn(label,
|
||||
sub_label,
|
||||
"pytorch_fp8_fp8_bf16_scaled_mm",
|
||||
torch._scaled_mm,
|
||||
a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=torch.bfloat16))
|
||||
|
||||
# pytorch impl: bf16 output, with fp8 fast accum
|
||||
timers.append(
|
||||
bench_fn(label,
|
||||
sub_label,
|
||||
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum",
|
||||
torch._scaled_mm,
|
||||
a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=torch.bfloat16,
|
||||
use_fast_accum=True))
|
||||
|
||||
# pytorch impl: fp16 output, without fp8 fast accum
|
||||
timers.append(
|
||||
bench_fn(label,
|
||||
sub_label,
|
||||
"pytorch_fp8_fp8_fp16_scaled_mm",
|
||||
torch._scaled_mm,
|
||||
a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=torch.float16))
|
||||
|
||||
# pytorch impl: fp16 output, with fp8 fast accum
|
||||
timers.append(
|
||||
bench_fn(label,
|
||||
sub_label,
|
||||
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum",
|
||||
torch._scaled_mm,
|
||||
a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=torch.float16,
|
||||
use_fast_accum=True))
|
||||
|
||||
# cutlass impl: bf16 output
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm",
|
||||
ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
|
||||
torch.bfloat16))
|
||||
# cutlass impl: fp16 output
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_mm",
|
||||
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.float16))
|
||||
|
||||
# cutlass impl: bf16 output, with bias
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm_bias",
|
||||
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16,
|
||||
bias))
|
||||
|
||||
# cutlass impl: fp16 output, with bias
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_mm_bias",
|
||||
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.float16,
|
||||
bias.to(dtype=torch.float16)))
|
||||
|
||||
return timers
|
||||
|
||||
|
||||
def bench(dtype: torch.dtype,
|
||||
m: int,
|
||||
k: int,
|
||||
n: int,
|
||||
label: str,
|
||||
sub_label: str,
|
||||
bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]:
|
||||
def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
||||
sub_label: str) -> Iterable[TMeasurement]:
|
||||
if dtype == torch.int8:
|
||||
return bench_int8(dtype, m, k, n, label, sub_label, bench_kernels)
|
||||
return bench_int8(dtype, m, k, n, label, sub_label)
|
||||
if dtype == torch.float8_e4m3fn:
|
||||
return bench_fp8(dtype, m, k, n, label, sub_label, bench_kernels)
|
||||
return bench_fp8(dtype, m, k, n, label, sub_label)
|
||||
raise ValueError("unsupported type")
|
||||
|
||||
|
||||
@ -193,22 +207,18 @@ def print_timers(timers: Iterable[TMeasurement]):
|
||||
|
||||
|
||||
def run(dtype: torch.dtype,
|
||||
MKNs: Iterable[Tuple[int, int, int]],
|
||||
bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]:
|
||||
MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
|
||||
results = []
|
||||
for m, k, n in MKNs:
|
||||
timers = bench(dtype,
|
||||
m,
|
||||
k,
|
||||
n,
|
||||
f"scaled-{dtype}-gemm",
|
||||
f"MKN=({m}x{k}x{n})",
|
||||
bench_kernels=bench_kernels)
|
||||
timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm",
|
||||
f"MKN=({m}x{k}x{n})")
|
||||
print_timers(timers)
|
||||
results.extend(timers)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# output makers
|
||||
def make_output(data: Iterable[TMeasurement],
|
||||
MKNs: Iterable[Tuple[int, int, int]],
|
||||
base_description: str,
|
||||
@ -222,11 +232,15 @@ def make_output(data: Iterable[TMeasurement],
|
||||
pkl.dump(data, f)
|
||||
|
||||
|
||||
# argparse runners
|
||||
|
||||
|
||||
def run_square_bench(args):
|
||||
dim_sizes = list(
|
||||
range(args.dim_start, args.dim_end + 1, args.dim_increment))
|
||||
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
|
||||
data = run(args.dtype, MKNs, bench_kernels=args.kernels)
|
||||
data = run(args.dtype, MKNs)
|
||||
|
||||
make_output(data, MKNs, f"square_bench-{args.dtype}")
|
||||
|
||||
|
||||
@ -237,7 +251,8 @@ def run_range_bench(args):
|
||||
Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes
|
||||
Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes
|
||||
MKNs = list(zip(Ms, Ks, Ns))
|
||||
data = run(args.dtype, MKNs, bench_kernels=args.kernels)
|
||||
data = run(args.dtype, MKNs)
|
||||
|
||||
make_output(data, MKNs, f"range_bench-{args.dtype}")
|
||||
|
||||
|
||||
@ -263,7 +278,7 @@ def run_model_bench(args):
|
||||
for k, n in KNs:
|
||||
MKNs.append((m, k, n))
|
||||
|
||||
data = run(args.dtype, MKNs, bench_kernels=args.kernels)
|
||||
data = run(args.dtype, MKNs)
|
||||
model_bench_data.append(data)
|
||||
|
||||
# Print all results
|
||||
@ -313,15 +328,6 @@ Benchmark Cutlass GEMM.
|
||||
type=to_torch_dtype,
|
||||
required=True,
|
||||
help="Available options are ['int8', 'fp8']")
|
||||
parser.add_argument(
|
||||
"--kernels",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=None,
|
||||
help=
|
||||
"Exact names of the kernels to benchmark. If not set, runs all kernels."
|
||||
)
|
||||
|
||||
subparsers = parser.add_subparsers(dest="cmd")
|
||||
|
||||
square_parser = subparsers.add_parser("square_bench")
|
||||
@ -356,4 +362,4 @@ Benchmark Cutlass GEMM.
|
||||
model_parser.set_defaults(func=run_model_bench)
|
||||
|
||||
args = parser.parse_args()
|
||||
args.func(args)
|
||||
args.func(args)
|
||||
@ -343,13 +343,9 @@ class BenchmarkWorker:
|
||||
op_config = get_moe_configs(num_experts, shard_intermediate_size // 2,
|
||||
dtype_str)
|
||||
if op_config is None:
|
||||
config = get_default_config(num_tokens,
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype_str,
|
||||
is_marlin=False)
|
||||
config = get_default_config(num_tokens, num_experts,
|
||||
shard_intermediate_size, hidden_size,
|
||||
topk, dtype_str)
|
||||
else:
|
||||
config = op_config[min(op_config.keys(),
|
||||
key=lambda x: abs(x - num_tokens))]
|
||||
@ -540,11 +536,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--model",
|
||||
type=str,
|
||||
default="mistralai/Mixtral-8x7B-Instruct-v0.1")
|
||||
parser.add_argument("--tp-size",
|
||||
"-tp",
|
||||
"--tensor-parallel-size",
|
||||
type=int,
|
||||
default=2)
|
||||
parser.add_argument("--tp-size", "-tp", type=int, default=2)
|
||||
parser.add_argument("--dtype",
|
||||
type=str,
|
||||
choices=["auto", "fp8_w8a8", "int8_w8a16"],
|
||||
|
||||
@ -1,14 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <climits>
|
||||
#include <iostream>
|
||||
|
||||
inline constexpr uint32_t next_pow_2(uint32_t const num) {
|
||||
inline uint32_t next_pow_2(uint32_t const num) {
|
||||
if (num <= 1) return num;
|
||||
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline constexpr std::enable_if_t<std::is_integral_v<T>, T> ceil_div(T a, T b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
@ -32,20 +32,3 @@ inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
|
||||
}
|
||||
|
||||
int32_t get_sm_version_num();
|
||||
|
||||
/**
|
||||
* A wrapper for a kernel that is used to guard against compilation on
|
||||
* architectures that will never use the kernel. The purpose of this is to
|
||||
* reduce the size of the compiled binary.
|
||||
* __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
|
||||
* into code that will be executed on the device where it is defined.
|
||||
*/
|
||||
template <typename Kernel>
|
||||
struct enable_sm90_or_later : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE void operator()(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
|
||||
Kernel::operator()(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
@ -1,123 +0,0 @@
|
||||
// Modified from: cutlass/gemm/collective/builders/sm90_gmma_builder.inl
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/collective/builders/sm90_gmma_builder.inl"
|
||||
|
||||
#include "cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp"
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// GMMA_TMA_WS_SS (BlockScaled Builders)
|
||||
template <
|
||||
class ElementA,
|
||||
class GmemLayoutATag,
|
||||
int AlignmentA,
|
||||
class ElementB,
|
||||
class GmemLayoutBTag,
|
||||
int AlignmentB,
|
||||
class ElementAccumulator,
|
||||
class TileShape_MNK,
|
||||
class ClusterShape_MNK,
|
||||
class StageCountType,
|
||||
int ScaleGranularityM
|
||||
>
|
||||
struct CollectiveBuilder<
|
||||
arch::Sm90,
|
||||
arch::OpClassTensorOp,
|
||||
ElementA,
|
||||
GmemLayoutATag,
|
||||
AlignmentA,
|
||||
ElementB,
|
||||
GmemLayoutBTag,
|
||||
AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape_MNK,
|
||||
ClusterShape_MNK,
|
||||
StageCountType,
|
||||
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>,
|
||||
cute::enable_if_t<
|
||||
not detail::is_use_rmem_A<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>()>
|
||||
> {
|
||||
using KernelScheduleType = KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>;
|
||||
|
||||
static_assert(is_static<TileShape_MNK>::value);
|
||||
static_assert(is_static<ClusterShape_MNK>::value);
|
||||
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
|
||||
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
|
||||
#endif
|
||||
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
|
||||
"Should meet TMA alignment requirement\n");
|
||||
|
||||
static constexpr bool IsArrayOfPointersGemm = (cute::is_any_of_v<KernelScheduleType,
|
||||
KernelPtrArrayTmaWarpSpecializedCooperative,
|
||||
KernelPtrArrayTmaWarpSpecializedPingpong>);
|
||||
static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
|
||||
static_assert((!IsFP8Input || !IsArrayOfPointersGemm),
|
||||
"KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum is only compatible with FP8 Blocked Scaled version right now.");
|
||||
|
||||
// For fp32 types, map to tf32 MMA value type
|
||||
using ElementAMma = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
|
||||
using ElementBMma = cute::conditional_t<cute::is_same_v<ElementB, float>, tfloat32_t, ElementB>;
|
||||
|
||||
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementAMma, GmemLayoutATag>();
|
||||
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementBMma, GmemLayoutBTag>();
|
||||
|
||||
static constexpr bool IsCooperative = cute::is_any_of_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecializedCooperative,
|
||||
KernelPtrArrayTmaWarpSpecializedCooperative,
|
||||
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>>;
|
||||
using AtomLayoutMNK = cute::conditional_t<IsCooperative,
|
||||
Layout<Shape<_2,_1,_1>>, Layout<Shape<_1,_1,_1>>>;
|
||||
|
||||
using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector<
|
||||
ElementAMma, ElementBMma, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{}));
|
||||
|
||||
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
|
||||
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
|
||||
|
||||
using SmemLayoutAtomA = decltype(detail::ss_smem_selector<
|
||||
GmmaMajorA, ElementAMma, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutAtomB = decltype(detail::ss_smem_selector<
|
||||
GmmaMajorB, ElementBMma, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
|
||||
static constexpr size_t TensorMapStorage = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0;
|
||||
static constexpr int KernelSmemCarveout = static_cast<int>(TensorMapStorage);
|
||||
|
||||
static constexpr int PipelineStages = detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes - KernelSmemCarveout,
|
||||
ElementAMma, ElementBMma, TileShape_MNK>(StageCountType{});
|
||||
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType, ScaleGranularityM>;
|
||||
|
||||
using SmemCopyAtomA = void;
|
||||
using SmemCopyAtomB = void;
|
||||
|
||||
using CollectiveOp = CollectiveMma<
|
||||
DispatchPolicy,
|
||||
TileShape_MNK,
|
||||
ElementA,
|
||||
TagToStrideA_t<GmemLayoutATag>,
|
||||
ElementB,
|
||||
TagToStrideB_t<GmemLayoutBTag>,
|
||||
TiledMma,
|
||||
GmemTiledCopyA,
|
||||
SmemLayoutAtomA,
|
||||
SmemCopyAtomA,
|
||||
cute::identity,
|
||||
GmemTiledCopyB,
|
||||
SmemLayoutAtomB,
|
||||
SmemCopyAtomB,
|
||||
cute::identity
|
||||
>;
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,183 +0,0 @@
|
||||
// clang-format off
|
||||
// adapted from: https://github.com/soundOfDestiny/cutlass/blob/a4208aa6958864923505cade9c63eb2a6daf16e5/include/cutlass/gemm/collective/fp8_accumulation.hpp
|
||||
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cute/algorithm/clear.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
///////////////////////////////////FP8 Accumulation///////////////////////////
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
/// This class provides API to promote (add) or scale (multiply_add) the results
|
||||
/// from the tensor core accumulators to the main accumulators when the number
|
||||
/// of MMAs reaches the max number of MMA interval specified by user, after that
|
||||
/// the tensor core accumulators are zeroed.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
|
||||
template <
|
||||
class EngineAccum,
|
||||
class LayoutAccum>
|
||||
struct GmmaFP8AccumulationWithScale {
|
||||
using TensorAccum = cute::Tensor<EngineAccum, LayoutAccum>;
|
||||
using ElementAccumulator = typename EngineAccum::value_type;
|
||||
|
||||
static_assert(is_static<LayoutAccum>::value, "Accumulator Layout should be static");
|
||||
static_assert(is_rmem<TensorAccum>::value , "Accumulator tensor must be rmem resident.");
|
||||
|
||||
private:
|
||||
TensorAccum& accum_;
|
||||
TensorAccum accum_temp_;
|
||||
|
||||
uint32_t accum_promotion_interval_; // defines the max num of executed MMAs after which accum should be promoted.
|
||||
uint32_t mma_count_per_mainloop_iteration_; // num of MMAs per k_tile of mainloop
|
||||
uint32_t mma_count_; // current executed MMAs
|
||||
uint32_t reset_accum_flag_; // accum needs to be zeroed or not.
|
||||
|
||||
// promote or `add` the partial accumulators to main accumulator (FADD).
|
||||
CUTLASS_DEVICE
|
||||
void promote_core() {
|
||||
warpgroup_wait<0>();
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(accum_); ++i) {
|
||||
accum_(i) += accum_temp_(i);
|
||||
}
|
||||
}
|
||||
|
||||
// `multiply` scale the partial accumulators and `add` to main accumulator (FFMA).
|
||||
template <
|
||||
class EngineScale,
|
||||
class LayoutScale>
|
||||
CUTLASS_DEVICE
|
||||
void scale_core(const cute::Tensor<EngineScale, LayoutScale> &scale) {
|
||||
using TensorScale = cute::Tensor<EngineScale, LayoutScale>;
|
||||
|
||||
static_assert(is_static<LayoutScale>::value, "Scale Layout should be static");
|
||||
static_assert(is_rmem<TensorScale>::value , "Scale tensor must be rmem resident.");
|
||||
|
||||
static_assert(LayoutAccum{}.shape() == LayoutScale{}.shape(), "Accumulator and scale must have same shape.");
|
||||
|
||||
warpgroup_wait<0>();
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(accum_); ++i) {
|
||||
accum_(i) += accum_temp_(i) * scale(i);
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
GmmaFP8AccumulationWithScale(
|
||||
TensorAccum &accum,
|
||||
uint32_t accum_promotion_interval,
|
||||
uint32_t mma_count_per_mainloop_iteration)
|
||||
: accum_(accum),
|
||||
accum_promotion_interval_(accum_promotion_interval),
|
||||
mma_count_per_mainloop_iteration_(mma_count_per_mainloop_iteration),
|
||||
mma_count_(0),
|
||||
reset_accum_flag_(0)
|
||||
{
|
||||
accum_temp_ = cute::make_fragment_like(accum);
|
||||
}
|
||||
|
||||
//
|
||||
// Methods (Common)
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
TensorAccum& operator()() {
|
||||
return accum_temp_;
|
||||
}
|
||||
|
||||
/// prepare the MMA accumulators when initialization or zeroing is required.
|
||||
CUTLASS_DEVICE
|
||||
bool prepare_if_needed() {
|
||||
return reset_accum_flag_;
|
||||
}
|
||||
|
||||
//
|
||||
// Methods (for FADD version)
|
||||
//
|
||||
|
||||
/// promote (add) the results from the MMA accumulators to main accumulator if needed.
|
||||
CUTLASS_DEVICE
|
||||
void promote_if_needed() {
|
||||
mma_count_ += mma_count_per_mainloop_iteration_;
|
||||
reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0);
|
||||
if (reset_accum_flag_) {
|
||||
promote_core();
|
||||
mma_count_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/// promote (add) the residue results from the MMA accumulators to main accumulator if needed.
|
||||
CUTLASS_DEVICE
|
||||
void promote_residue_if_needed() {
|
||||
if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) {
|
||||
promote_core();
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Methods (for FFMA version)
|
||||
//
|
||||
|
||||
/// scale (multiply_add) the results from the MMA accumulators to main accumulator if needed.
|
||||
template <
|
||||
class EngineScale,
|
||||
class LayoutScale>
|
||||
CUTLASS_DEVICE
|
||||
void scale_if_needed(const cute::Tensor<EngineScale, LayoutScale> &scale) {
|
||||
mma_count_ += mma_count_per_mainloop_iteration_;
|
||||
reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0);
|
||||
if (reset_accum_flag_) {
|
||||
scale_core(scale);
|
||||
mma_count_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/// scale (multiply_add) the residue results from the MMA accumulators to main accumulator if needed.
|
||||
template <
|
||||
class EngineScale,
|
||||
class LayoutScale>
|
||||
CUTLASS_DEVICE
|
||||
void scale_residue_if_needed(const cute::Tensor<EngineScale, LayoutScale> &scale) {
|
||||
if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) {
|
||||
scale_core(scale);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
@ -1,730 +0,0 @@
|
||||
// clang-format off
|
||||
// Adapted (Heavily) from: https://github.com/soundOfDestiny/cutlass/blob/9d997ce0dea4c5fa1a617db6b7ff29aa9235822c/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp
|
||||
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cute/arch/cluster_sm90.hpp"
|
||||
#include "cute/arch/copy_sm80.hpp"
|
||||
#include "cute/arch/copy_sm90.hpp"
|
||||
#include "cute/algorithm/functional.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cute/algorithm/gemm.hpp"
|
||||
#include "cute/tensor_predicate.hpp"
|
||||
#include "cute/numeric/arithmetic_tuple.hpp"
|
||||
|
||||
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass_extensions/gemm/collective/fp8_accumulation.hpp"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
using namespace cute;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// WarpSpecialized Mainloop
|
||||
template <
|
||||
int Stages,
|
||||
class ClusterShape,
|
||||
class KernelSchedule,
|
||||
int ScaleGranularityM_,
|
||||
class TileShape_,
|
||||
class ElementA_,
|
||||
class StrideA_,
|
||||
class ElementB_,
|
||||
class StrideB_,
|
||||
class TiledMma_,
|
||||
class GmemTiledCopyA_,
|
||||
class SmemLayoutAtomA_,
|
||||
class SmemCopyAtomA_,
|
||||
class TransformA_,
|
||||
class GmemTiledCopyB_,
|
||||
class SmemLayoutAtomB_,
|
||||
class SmemCopyAtomB_,
|
||||
class TransformB_>
|
||||
struct CollectiveMma<
|
||||
MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8<Stages, ClusterShape, KernelSchedule, ScaleGranularityM_>,
|
||||
TileShape_,
|
||||
ElementA_,
|
||||
StrideA_,
|
||||
ElementB_,
|
||||
StrideB_,
|
||||
TiledMma_,
|
||||
GmemTiledCopyA_,
|
||||
SmemLayoutAtomA_,
|
||||
SmemCopyAtomA_,
|
||||
TransformA_,
|
||||
GmemTiledCopyB_,
|
||||
SmemLayoutAtomB_,
|
||||
SmemCopyAtomB_,
|
||||
TransformB_>
|
||||
{
|
||||
//
|
||||
// Type Aliases
|
||||
//
|
||||
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8<Stages, ClusterShape, KernelSchedule, ScaleGranularityM_>;
|
||||
using TileShape = TileShape_;
|
||||
using ElementA = ElementA_;
|
||||
using StrideA = StrideA_;
|
||||
using ElementB = ElementB_;
|
||||
using StrideB = StrideB_;
|
||||
using TiledMma = TiledMma_;
|
||||
using ElementAccumulator = typename TiledMma::ValTypeC;
|
||||
using ElementBlockScale = ElementAccumulator;
|
||||
using GmemTiledCopyA = GmemTiledCopyA_;
|
||||
using GmemTiledCopyB = GmemTiledCopyB_;
|
||||
using SmemLayoutAtomA = SmemLayoutAtomA_;
|
||||
using SmemLayoutAtomB = SmemLayoutAtomB_;
|
||||
using SmemCopyAtomA = SmemCopyAtomA_;
|
||||
using SmemCopyAtomB = SmemCopyAtomB_;
|
||||
using TransformA = TransformA_;
|
||||
using TransformB = TransformB_;
|
||||
using ArchTag = typename DispatchPolicy::ArchTag;
|
||||
|
||||
using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{}));
|
||||
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
|
||||
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
|
||||
// Two threads per CTA are producers (1 for operand tile and 32 for scales)
|
||||
static constexpr int NumProducerThreadEvents = 33;
|
||||
|
||||
static constexpr int ScaleGranularityM = ScaleGranularityM_ == 0 ? size<0>(TileShape{}) : ScaleGranularityM_;
|
||||
static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM;
|
||||
|
||||
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
static_assert((size<0>(TileShape{}) % ScaleGranularityM) == 0, "FP8 scaling granularity must evenly divide tile shape along M.");
|
||||
|
||||
// Tile along modes in a way that maximizes the TMA box size.
|
||||
using SmemLayoutA = decltype(tile_to_shape(
|
||||
SmemLayoutAtomA{},
|
||||
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
|
||||
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
|
||||
using SmemLayoutB = decltype(tile_to_shape(
|
||||
SmemLayoutAtomB{},
|
||||
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
|
||||
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
|
||||
|
||||
// Block scaling gmem-to-smem copy atom
|
||||
using SmemBlockScalingCopyAtomA = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>;
|
||||
using SmemBlockScalingCopyAtomB = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>;
|
||||
|
||||
// Block scaling smem layout
|
||||
using SmemLayoutScaleA = Layout<Shape<Int<ScaleMsPerTile>, Int<DispatchPolicy::Stages>>>;
|
||||
using SmemLayoutScaleB = Layout<Shape<Int<DispatchPolicy::Stages>>, Stride<_1>>; // `ScaleNsPerTile` is always 1.
|
||||
|
||||
static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more.");
|
||||
static_assert(cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value &&
|
||||
cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeB>::value,
|
||||
"MMA atom must source both A and B operand from smem_desc for this mainloop.");
|
||||
static_assert(cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>,
|
||||
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
|
||||
static_assert(cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>,
|
||||
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
|
||||
static_assert(cute::is_same_v<ElementAccumulator, ElementBlockScale>,
|
||||
"ElementAccumulator and ElementBlockScale should be same datatype");
|
||||
|
||||
struct SharedStorage
|
||||
{
|
||||
struct TensorStorage : cute::aligned_struct<128> {
|
||||
cute::array_aligned<typename TiledMma::ValTypeA, cute::cosize_v<SmemLayoutA>> smem_A; // mxk
|
||||
cute::array_aligned<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B; // nxk
|
||||
cute::array_aligned<ElementBlockScale, cute::cosize_v<SmemLayoutScaleA>> smem_scale_A; // ScaleMsPerTile x k
|
||||
cute::array_aligned<ElementBlockScale, cute::cosize_v<SmemLayoutScaleB>> smem_scale_B; // 1xk
|
||||
} tensors;
|
||||
|
||||
using PipelineStorage = typename MainloopPipeline::SharedStorage;
|
||||
PipelineStorage pipeline;
|
||||
};
|
||||
using TensorStorage = typename SharedStorage::TensorStorage;
|
||||
using PipelineStorage = typename SharedStorage::PipelineStorage;
|
||||
|
||||
// Host side kernel arguments
|
||||
struct Arguments {
|
||||
ElementA const* ptr_A;
|
||||
StrideA dA;
|
||||
ElementB const* ptr_B;
|
||||
StrideB dB;
|
||||
ElementBlockScale const* ptr_scale_A;
|
||||
ElementBlockScale const* ptr_scale_B;
|
||||
};
|
||||
|
||||
// Device side kernel params
|
||||
struct Params {
|
||||
// Assumption: StrideA is congruent with Problem_MK
|
||||
using TMA_A = decltype(make_tma_copy_A_sm90(
|
||||
GmemTiledCopyA{},
|
||||
make_tensor(static_cast<ElementA const*>(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}),
|
||||
SmemLayoutA{}(_,_,0),
|
||||
TileShape{},
|
||||
ClusterShape{}));
|
||||
// Assumption: StrideB is congruent with Problem_NK
|
||||
using TMA_B = decltype(make_tma_copy_B_sm90(
|
||||
GmemTiledCopyB{},
|
||||
make_tensor(static_cast<ElementB const*>(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}),
|
||||
SmemLayoutB{}(_,_,0),
|
||||
TileShape{},
|
||||
ClusterShape{}));
|
||||
TMA_A tma_load_a;
|
||||
TMA_B tma_load_b;
|
||||
uint32_t tma_transaction_bytes = TmaTransactionBytes;
|
||||
uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK;
|
||||
uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK;
|
||||
// Block scaling factors for A and B
|
||||
ElementBlockScale const* ptr_scale_A;
|
||||
ElementBlockScale const* ptr_scale_B;
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
(void) workspace;
|
||||
|
||||
// Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK)
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
auto [M,N,K,L] = problem_shape_MNKL;
|
||||
|
||||
auto ptr_A = reinterpret_cast<ElementA const*>(args.ptr_A);
|
||||
auto ptr_B = reinterpret_cast<ElementB const*>(args.ptr_B);
|
||||
|
||||
Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA));
|
||||
Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB));
|
||||
typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90(
|
||||
GmemTiledCopyA{},
|
||||
tensor_a,
|
||||
SmemLayoutA{}(_,_,cute::Int<0>{}),
|
||||
TileShape{},
|
||||
ClusterShape{});
|
||||
typename Params::TMA_B tma_load_b = make_tma_copy_B_sm90(
|
||||
GmemTiledCopyB{},
|
||||
tensor_b,
|
||||
SmemLayoutB{}(_,_,cute::Int<0>{}),
|
||||
TileShape{},
|
||||
ClusterShape{});
|
||||
uint32_t transaction_bytes_mk = TmaTransactionBytesMK;
|
||||
uint32_t transaction_bytes_nk = TmaTransactionBytesNK;
|
||||
uint32_t transaction_bytes = transaction_bytes_mk + transaction_bytes_nk;
|
||||
|
||||
return {
|
||||
tma_load_a,
|
||||
tma_load_b,
|
||||
transaction_bytes,
|
||||
transaction_bytes_mk,
|
||||
transaction_bytes_nk,
|
||||
args.ptr_scale_A,
|
||||
args.ptr_scale_B
|
||||
};
|
||||
}
|
||||
|
||||
template<class ProblemShape>
|
||||
static bool
|
||||
can_implement(
|
||||
ProblemShape const& problem_shape,
|
||||
[[maybe_unused]] Arguments const& args) {
|
||||
constexpr int tma_alignment_bits = 128;
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
auto [M,N,K,L] = problem_shape_MNKL;
|
||||
|
||||
bool implementable = true;
|
||||
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
|
||||
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
|
||||
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
|
||||
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), StrideB{});
|
||||
|
||||
if (!implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
|
||||
}
|
||||
return implementable;
|
||||
}
|
||||
|
||||
static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
|
||||
static constexpr int K_PIPE_MMAS = 1;
|
||||
static constexpr uint32_t TmaTransactionBytesMK =
|
||||
cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast<uint32_t>(sizeof_bits<ElementA>::value));
|
||||
static constexpr uint32_t TmaTransactionBytesNK =
|
||||
cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast<uint32_t>(sizeof_bits<ElementB>::value));
|
||||
static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK;
|
||||
|
||||
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& mainloop_params)
|
||||
{
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor());
|
||||
}
|
||||
|
||||
/// Set up the data needed by this collective for load and mma.
|
||||
/// Returns a tuple of tensors. The collective and the kernel layer have the contract
|
||||
/// Returned tuple must contain at least two elements, with the first two elements being:
|
||||
/// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l)
|
||||
/// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l)
|
||||
template <class ProblemShape_MNKL>
|
||||
CUTLASS_DEVICE auto
|
||||
load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const {
|
||||
using X = Underscore;
|
||||
// Separate out problem shape for convenience
|
||||
auto [M,N,K,L] = problem_shape_MNKL;
|
||||
|
||||
// TMA requires special handling of strides to deal with coord codomain mapping
|
||||
// Represent the full tensors -- get these from TMA
|
||||
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l)
|
||||
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l)
|
||||
|
||||
// Make tiled views, defer the slice
|
||||
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
|
||||
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
|
||||
|
||||
constexpr auto scales_m = Int<ScaleMsPerTile>{};
|
||||
auto tM = get<2>(gA_mkl.shape());
|
||||
auto tN = get<2>(gB_nkl.shape());
|
||||
auto tK = get<3>(gA_mkl.shape());
|
||||
|
||||
// Make the tiled views of scale tensors
|
||||
auto scaleA_shape = make_shape(M / ScaleGranularityM, tK, L); // (scale_m,k,l)
|
||||
auto scaleA_layout = make_ordered_layout(scaleA_shape, Step<_0, _1, _2>{});
|
||||
auto scaleB_shape = make_shape(tN, tK, L); // (n,k,l)
|
||||
auto scaleB_layout = make_ordered_layout(scaleB_shape, Step<_1, _0, _2>{});
|
||||
|
||||
// Note that mScaleA_mkl and mScaleB_nkl are already blocked tiled in the `m` host and
|
||||
// gScaleA_mkl and gScaleB_nkl in `g` global memory are same as mScaleA_mkl and mScaleB_nkl.
|
||||
Tensor mScaleA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_A), scaleA_layout); // (scale_m,k,l)
|
||||
Tensor mScaleB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_B), scaleB_layout); // (n,k,l)
|
||||
|
||||
return cute::make_tuple(gA_mkl, gB_nkl, mScaleA_mkl, mScaleB_nkl);
|
||||
}
|
||||
|
||||
/// Perform a collective-scoped matrix multiply-accumulate
|
||||
/// Producer Perspective
|
||||
template <
|
||||
class TensorA, class TensorB,
|
||||
class TensorScaleA, class TensorScaleB,
|
||||
class KTileIterator, class BlockCoord
|
||||
>
|
||||
CUTLASS_DEVICE void
|
||||
load(
|
||||
Params const& mainloop_params,
|
||||
MainloopPipeline pipeline,
|
||||
PipelineState smem_pipe_write,
|
||||
cute::tuple<TensorA, TensorB, TensorScaleA, TensorScaleB> const& load_inputs,
|
||||
BlockCoord const& blk_coord,
|
||||
KTileIterator k_tile_iter, int k_tile_count,
|
||||
int thread_idx,
|
||||
uint32_t block_rank_in_cluster,
|
||||
TensorStorage& shared_tensors) {
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
// Blockscaling: Tma loads for load_input and CpAsync for load_scale
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
Tensor sScaleA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), SmemLayoutScaleA{}); // (ScaleMsPerTile,k)
|
||||
Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k)
|
||||
|
||||
//
|
||||
// Prepare the TMA loads for A and B
|
||||
//
|
||||
|
||||
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
|
||||
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
|
||||
|
||||
Tensor gA_mkl = get<0>(load_inputs);
|
||||
Tensor gB_nkl = get<1>(load_inputs);
|
||||
|
||||
auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
|
||||
auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
|
||||
|
||||
// Partition the inputs based on the current block coordinates.
|
||||
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
|
||||
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
|
||||
Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
|
||||
|
||||
|
||||
// Block scaling: load_scale has scaling tensors in global memory which are not tiled
|
||||
Tensor mScaleA_mkl = get<2>(load_inputs);
|
||||
Tensor mScaleB_nkl = get<3>(load_inputs);
|
||||
auto scales_m = get<0>(mScaleA_mkl.shape());
|
||||
|
||||
Tensor cScaleA_mkl = make_identity_tensor(mScaleA_mkl.shape());
|
||||
|
||||
Tensor gScaleA = local_tile(
|
||||
mScaleA_mkl, make_tile(Int<ScaleMsPerTile>{}),
|
||||
make_coord(m_coord,_,l_coord)); // (ScaleMsPerTile,k,1)
|
||||
Tensor cScaleA = local_tile(
|
||||
cScaleA_mkl, make_tile(Int<ScaleMsPerTile>{}),
|
||||
make_coord(m_coord,_,l_coord));
|
||||
Tensor gScaleB = mScaleB_nkl(n_coord,_,l_coord); // (1,k,1)
|
||||
|
||||
// TODO: test `scale_copy_a` with `ScaleMsPerTile` < 128
|
||||
TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{},
|
||||
Layout<Shape<_32, _1>>{}, Layout<Shape<_4, _1>>{}); // (1,1,1)
|
||||
TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{},
|
||||
Layout<Shape<_1>>{}, Layout<Shape<_1>>{}); // (1,1,1)
|
||||
ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x);
|
||||
ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x);
|
||||
|
||||
Tensor tAgA_ScaleA = thr_scale_copy_a.partition_S(gScaleA);
|
||||
Tensor tAcA_ScaleA = thr_scale_copy_a.partition_S(cScaleA);
|
||||
Tensor tAsA_ScaleA = thr_scale_copy_a.partition_D(sScaleA);
|
||||
|
||||
Tensor tBgB_ScaleB = thr_scale_copy_b.partition_S(gScaleB);
|
||||
Tensor tBsB_ScaleB = thr_scale_copy_b.partition_D(sScaleB);
|
||||
|
||||
// Applies the mapping from block_tma_a
|
||||
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
|
||||
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
|
||||
|
||||
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
|
||||
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
|
||||
|
||||
uint16_t mcast_mask_a = 0;
|
||||
uint16_t mcast_mask_b = 0;
|
||||
|
||||
// Issue TmaLoads for GEMM operands A/B and CpAsync for scale tensors
|
||||
// Maps the tile -> block, value
|
||||
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
|
||||
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
||||
for (int n = 0; n < size<1>(block_layout); ++n) {
|
||||
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
|
||||
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
||||
for (int m = 0; m < size<0>(block_layout); ++m) {
|
||||
mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
// Allocate predicate tensors for a_scales (since we can't guarantee that
|
||||
// all scales are valid, since we could have a partial tiles along M)
|
||||
Tensor tApA_ScaleA = make_tensor<bool>(shape(tAsA_ScaleA(_,_,0)));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(tApA_ScaleA); ++i) {
|
||||
tApA_ScaleA(i) = get<0>(tAcA_ScaleA(i)) < scales_m;
|
||||
}
|
||||
|
||||
// Mainloop
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for ( ; k_tile_count > 0; --k_tile_count) {
|
||||
// LOCK smem_pipe_write for _writing_
|
||||
pipeline.producer_acquire(smem_pipe_write);
|
||||
|
||||
//
|
||||
// Copy gmem to smem for *k_tile_iter
|
||||
//
|
||||
int write_stage = smem_pipe_write.index();
|
||||
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
|
||||
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
|
||||
|
||||
// Copy operands A and B from global memory to shared memory
|
||||
if (lane_predicate) copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
|
||||
if (lane_predicate) copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
|
||||
|
||||
// Copy scale tensors from global memory to shared memory
|
||||
copy_if(scale_copy_a, tApA_ScaleA, tAgA_ScaleA(_,_,*k_tile_iter), tAsA_ScaleA(_,_,write_stage));
|
||||
copy(scale_copy_b, tBgB_ScaleB(_,*k_tile_iter), tBsB_ScaleB(_,write_stage));
|
||||
pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc);
|
||||
|
||||
++k_tile_iter;
|
||||
|
||||
// Advance smem_pipe_write
|
||||
++smem_pipe_write;
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
|
||||
CUTLASS_DEVICE void
|
||||
load_tail(
|
||||
MainloopPipeline pipeline,
|
||||
PipelineState smem_pipe_write) {
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
// Issue the epilogue waits
|
||||
if (lane_predicate) {
|
||||
/* This helps avoid early exit of blocks in Cluster
|
||||
* Waits for all stages to either be released (all
|
||||
* Consumer UNLOCKs), or if the stage was never used
|
||||
* then would just be acquired since the phase was
|
||||
* still inverted from make_producer_start_state
|
||||
*/
|
||||
pipeline.producer_tail(smem_pipe_write);
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a collective-scoped matrix multiply-accumulate
|
||||
/// Consumer Perspective
|
||||
template <
|
||||
class FrgTensorC
|
||||
>
|
||||
CUTLASS_DEVICE void
|
||||
mma(MainloopPipeline pipeline,
|
||||
PipelineState smem_pipe_read,
|
||||
FrgTensorC& accum,
|
||||
int k_tile_count,
|
||||
int thread_idx,
|
||||
TensorStorage& shared_tensors,
|
||||
Params const& mainloop_params) {
|
||||
|
||||
|
||||
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
|
||||
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::is_void_v<SmemCopyAtomA>,
|
||||
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
|
||||
static_assert(cute::is_void_v<SmemCopyAtomB>,
|
||||
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
|
||||
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
|
||||
// Block scaling
|
||||
Tensor sScaleAViewAsC = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()),
|
||||
Layout<
|
||||
Shape<Shape<Int<ScaleGranularityM>, Int<ScaleMsPerTile>>, cute::tuple_element_t<1, TileShape>, Int<DispatchPolicy::Stages>>,
|
||||
Stride<Stride<_0, _1>, _0, Int<ScaleMsPerTile>>
|
||||
>{}); // ((ScaleGranularityM,ScaleMsPerTile),n,k)
|
||||
Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k)
|
||||
|
||||
//
|
||||
// Define C accumulators and A/B partitioning
|
||||
//
|
||||
|
||||
// Layout of warp group to thread mapping
|
||||
|
||||
static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and
|
||||
stride<0>(typename TiledMma::BLayout{}) == 0 and
|
||||
size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and
|
||||
size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup,
|
||||
"Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup");
|
||||
|
||||
constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup;
|
||||
Layout warp_group_thread_layout = make_layout(Int<MmaWarpGroups>{},
|
||||
Int<NumThreadsPerWarpGroup>{});
|
||||
|
||||
int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);
|
||||
|
||||
TiledMma tiled_mma;
|
||||
auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx));
|
||||
|
||||
Tensor tCsScaleAViewAsC = tiled_mma.get_slice(thread_idx).partition_C(sScaleAViewAsC); // (MMA,MMA_M,MMA_N,PIPE), `thread_mma` above is correct when partitioning A and B, but it is not correct when partitioning C.
|
||||
|
||||
Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
|
||||
// Allocate "fragments/descriptors"
|
||||
Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K
|
||||
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sA)); // PIPE
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB)); // PIPE
|
||||
|
||||
//
|
||||
// PIPELINED MAIN LOOP
|
||||
//
|
||||
static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX),
|
||||
"ERROR : Incorrect number of MMAs in flight");
|
||||
|
||||
// We release buffers to producer warps(dma load) with some mmas in flight
|
||||
PipelineState smem_pipe_release = smem_pipe_read;
|
||||
|
||||
// Per block scale values for operand A and B
|
||||
|
||||
using RegLayoutScaleAViewAsC = decltype(make_layout_like(tCsScaleAViewAsC(_, _, _, 0).layout())); // `make_layout_like` makes a compact layout.
|
||||
using RegLayoutScaleAEssential = decltype(filter_zeros(RegLayoutScaleAViewAsC{}.stride(), RegLayoutScaleAViewAsC{}.shape())); // an interface to traverse the underlying storage for the compact layout mentioned above
|
||||
|
||||
Tensor tCrScaleAViewAsC = make_tensor<ElementBlockScale>(RegLayoutScaleAViewAsC{}); // (MMA,MMA_M,MMA_N)
|
||||
ElementBlockScale scale_b;
|
||||
|
||||
// Prologue GMMAs
|
||||
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
||||
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
|
||||
GmmaFP8AccumulationWithScale accumulation(accum, size<2>(TileShape{}) / size<2>(typename TiledMma::AtomShape_MNK{}), size<2>(tCrA));
|
||||
warpgroup_fence_operand(accumulation());
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue)
|
||||
{
|
||||
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
|
||||
if (accumulation.prepare_if_needed()) {
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
}
|
||||
|
||||
int read_stage = smem_pipe_read.index();
|
||||
|
||||
// Load per block scale values from shared memory to registers.
|
||||
scale_b = sScaleB[read_stage];
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
|
||||
tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{}));
|
||||
}
|
||||
if constexpr (ScaleMsPerTile == 1) {
|
||||
static_assert(size(RegLayoutScaleAEssential{}) == 1);
|
||||
tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`.
|
||||
} else {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
|
||||
tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b;
|
||||
}
|
||||
}
|
||||
|
||||
warpgroup_arrive();
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
// (V,M,K) x (V,N,K) => (V,M,N)
|
||||
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation());
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
|
||||
// Block scale the accumulators with reg tensor `tCrScaleAViewAsC`
|
||||
accumulation.scale_if_needed(tCrScaleAViewAsC);
|
||||
|
||||
++smem_pipe_read;
|
||||
}
|
||||
|
||||
warpgroup_fence_operand(accumulation());
|
||||
// Mainloop GMMAs
|
||||
k_tile_count -= prologue_mma_count;
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for ( ; k_tile_count > 0; --k_tile_count)
|
||||
{
|
||||
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
|
||||
//
|
||||
// Compute on k_tile
|
||||
//
|
||||
|
||||
int read_stage = smem_pipe_read.index();
|
||||
|
||||
// Load per block scale values from shared memory to registers (at most twice per block along M and exactly once per block along N)
|
||||
scale_b = sScaleB[read_stage];
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
|
||||
tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{}));
|
||||
}
|
||||
if constexpr (ScaleMsPerTile == 1) {
|
||||
static_assert(size(RegLayoutScaleAEssential{}) == 1);
|
||||
tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`.
|
||||
} else {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
|
||||
tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b;
|
||||
}
|
||||
}
|
||||
|
||||
if (accumulation.prepare_if_needed()) {
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
}
|
||||
|
||||
warpgroup_fence_operand(accumulation());
|
||||
warpgroup_arrive();
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
// (V,M,K) x (V,N,K) => (V,M,N)
|
||||
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation());
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
|
||||
/// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed
|
||||
warpgroup_wait<K_PIPE_MMAS>();
|
||||
warpgroup_fence_operand(accumulation());
|
||||
|
||||
// Block scale the accumulators with reg tensor `tCrScaleAViewAsC`
|
||||
accumulation.scale_if_needed(tCrScaleAViewAsC);
|
||||
|
||||
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
|
||||
|
||||
// Advance smem_pipe_read and smem_pipe_release
|
||||
++smem_pipe_read;
|
||||
++smem_pipe_release;
|
||||
}
|
||||
|
||||
accumulation.scale_residue_if_needed(tCrScaleAViewAsC);
|
||||
|
||||
warpgroup_fence_operand(accumulation());
|
||||
}
|
||||
|
||||
/// Perform a Consumer Epilogue to release all buffers
|
||||
CUTLASS_DEVICE void
|
||||
mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) {
|
||||
// Prologue GMMAs
|
||||
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
||||
k_tile_count -= prologue_mma_count;
|
||||
|
||||
smem_pipe_release.advance(k_tile_count);
|
||||
|
||||
// Wait on all GMMAs to complete
|
||||
warpgroup_wait<0>();
|
||||
|
||||
for (int count = 0; count < prologue_mma_count; ++count) {
|
||||
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
|
||||
++smem_pipe_release;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,39 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
|
||||
namespace cutlass::gemm {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// FP8 related policies (including Blocked Scaled Accumulation)
|
||||
// `ScaleGranularityM` specifies scaling granularity along M, while zero-value
|
||||
// `ScaleGranularityM` indicates that scaling granularity is
|
||||
// `size<0>(TileShape_MNK{})` along M.
|
||||
template <int ScaleGranularityM = 0>
|
||||
struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
|
||||
: KernelTmaWarpSpecializedCooperative {};
|
||||
|
||||
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp
|
||||
// specialized dynamic schedule For FP8 kernels with Block Scaling
|
||||
template <int Stages_, class ClusterShape_ = Shape<_1, _1, _1>,
|
||||
class KernelSchedule = KernelTmaWarpSpecialized,
|
||||
int ScaleGranularityM =
|
||||
0 // `ScaleGranularityM` specifies scaling granularity along M,
|
||||
// while zero-value `ScaleGranularityM` indicates that scaling
|
||||
// granularity is `size<0>(TileShape_MNK{})` along M.
|
||||
>
|
||||
struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8
|
||||
: MainloopSm90TmaGmmaWarpSpecialized<Stages_, ClusterShape_,
|
||||
KernelSchedule> {
|
||||
static_assert(
|
||||
cute::is_same_v<
|
||||
KernelSchedule,
|
||||
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<
|
||||
ScaleGranularityM>>,
|
||||
"KernelSchedule must be one of the warp specialized policies");
|
||||
};
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm
|
||||
@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
using namespace cute;
|
||||
|
||||
@ -153,7 +153,6 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
|
||||
|
||||
#ifndef USE_ROCM
|
||||
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
|
||||
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
|
||||
|
||||
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||
|
||||
@ -1,93 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "core/math.hpp"
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
// clang-format on
|
||||
|
||||
namespace vllm::c3x {
|
||||
|
||||
static inline cute::Shape<int, int, int, int> get_problem_shape(
|
||||
torch::Tensor const& a, torch::Tensor const& b) {
|
||||
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
|
||||
return {m, n, k, 1};
|
||||
}
|
||||
|
||||
template <typename GemmKernel>
|
||||
void cutlass_gemm_caller(torch::Device device,
|
||||
cute::Shape<int, int, int, int> prob_shape,
|
||||
typename GemmKernel::MainloopArguments mainloop_args,
|
||||
typename GemmKernel::EpilogueArguments epilogue_args) {
|
||||
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
prob_shape, mainloop_args, epilogue_args};
|
||||
|
||||
// Launch the CUTLASS GEMM kernel.
|
||||
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
GemmOp gemm_op;
|
||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(device);
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(device.index());
|
||||
|
||||
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
template <typename Gemm, typename... EpilogueArgs>
|
||||
void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_params) {
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
|
||||
int64_t lda = a.stride(0);
|
||||
int64_t ldb = b.stride(1);
|
||||
int64_t ldc = out.stride(0);
|
||||
|
||||
using StrideA = cute::Stride<int64_t, cute::Int<1>, int64_t>;
|
||||
using StrideB = cute::Stride<int64_t, cute::Int<1>, int64_t>;
|
||||
using StrideC = typename Gemm::StrideC;
|
||||
|
||||
StrideA a_stride{lda, cute::Int<1>{}, 0};
|
||||
StrideB b_stride{ldb, cute::Int<1>{}, 0};
|
||||
StrideC c_stride{ldc, cute::Int<1>{}, cute::Int<0>{}};
|
||||
|
||||
typename GemmKernel::ProblemShape prob_shape = get_problem_shape(a, b);
|
||||
|
||||
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
|
||||
typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
|
||||
b_stride};
|
||||
|
||||
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
Gemm::Epilogue::prepare_args(
|
||||
std::forward<EpilogueArgs>(epilogue_params)...),
|
||||
c_ptr, c_stride, c_ptr, c_stride};
|
||||
|
||||
cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
|
||||
epilogue_args);
|
||||
}
|
||||
|
||||
} // namespace vllm::c3x
|
||||
@ -1,24 +0,0 @@
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_sm90_int8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_azp_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
if (azp) {
|
||||
return cutlass_scaled_mm_sm90_int8_epilogue<
|
||||
c3x::ScaledEpilogueBiasAzpToken>(out, a, b, a_scales, b_scales, azp_adj,
|
||||
*azp, bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogueBiasAzp>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@ -1,24 +0,0 @@
|
||||
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_blockwise_sm90_fp8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::bfloat16_t>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::half_t>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@ -1,168 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass_gemm_caller.cuh"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename OutType, int GroupSizeM_, int GroupSizeN_, int GroupSizeK_,
|
||||
int TileSizeM_ = 128, class ClusterShape = Shape<_1, _2, _1>>
|
||||
struct cutlass_3x_gemm_fp8_blockwise {
|
||||
using GroupSizeM = Int<GroupSizeM_>;
|
||||
using GroupSizeN = Int<GroupSizeN_>;
|
||||
using GroupSizeK = Int<GroupSizeK_>;
|
||||
using TileSizeM = Int<TileSizeM_>;
|
||||
|
||||
static_assert(TileSizeM_ % GroupSizeM_ == 0,
|
||||
"TileSizeM must be a multiple of GroupSizeM");
|
||||
|
||||
using ElementAB = cutlass::float_e4m3_t;
|
||||
|
||||
using ElementA = ElementAB;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
|
||||
|
||||
using ElementB = ElementAB;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
||||
|
||||
using ElementD = OutType;
|
||||
using StrideD = Stride<int64_t, Int<1>, Int<0>>;
|
||||
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
using ElementC = void;
|
||||
using StrideC = StrideD;
|
||||
static constexpr int AlignmentC = AlignmentD;
|
||||
|
||||
using ElementAccumulator = float;
|
||||
using ElementBlockScale = float;
|
||||
using ElementCompute = float;
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using TileShape = Shape<TileSizeM, GroupSizeN, GroupSizeK>;
|
||||
|
||||
using KernelSchedule = cutlass::gemm::
|
||||
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<
|
||||
GroupSizeM_>;
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
|
||||
using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT<
|
||||
cutlass::epilogue::fusion::Sm90AccFetch>;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType,
|
||||
ElementAccumulator, ElementCompute, ElementC, StrideC, AlignmentC,
|
||||
ElementD, StrideD, AlignmentD, EpilogueSchedule,
|
||||
StoreEpilogueCompute>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB,
|
||||
LayoutB, AlignmentB, ElementAccumulator, TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
|
||||
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
|
||||
cutlass::gemm::PersistentScheduler>>;
|
||||
|
||||
struct GemmKernel : public KernelType {};
|
||||
|
||||
using StrideA = typename GemmKernel::StrideA;
|
||||
using StrideB = typename GemmKernel::StrideB;
|
||||
};
|
||||
|
||||
template <typename Gemm>
|
||||
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
|
||||
auto prob_shape = c3x::get_problem_shape(a, b);
|
||||
int32_t m = get<0>(prob_shape), n = get<1>(prob_shape),
|
||||
k = get<2>(prob_shape);
|
||||
|
||||
int64_t lda = a.stride(0);
|
||||
int64_t ldb = b.stride(1);
|
||||
int64_t ldc = out.stride(0);
|
||||
|
||||
using StrideA = Stride<int64_t, Int<1>, int64_t>;
|
||||
using StrideB = Stride<int64_t, Int<1>, int64_t>;
|
||||
using StrideC = typename Gemm::StrideC;
|
||||
|
||||
StrideA a_stride{lda, Int<1>{}, 0};
|
||||
StrideB b_stride{ldb, Int<1>{}, 0};
|
||||
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
|
||||
|
||||
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
|
||||
auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr());
|
||||
auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr());
|
||||
|
||||
// Check is the t is contiguous and is 1D or 2D with one of the dimensions
|
||||
// being 1 (i.e. a row or column vector)
|
||||
auto is_contiguous_vector = [](const torch::Tensor& t) {
|
||||
auto t_sizes = t.sizes();
|
||||
return t.is_contiguous() &&
|
||||
(t.dim() == 1 ||
|
||||
(t.dim() == 2 &&
|
||||
*std::min_element(t_sizes.begin(), t_sizes.end()) == 1));
|
||||
};
|
||||
|
||||
// TODO(lucas): lets clean-up the kernel so that we pass in Strides so
|
||||
// we don't have to deal with enforcing implicit layouts
|
||||
TORCH_CHECK(a_scales.size(0) == m / Gemm::GroupSizeM::value);
|
||||
TORCH_CHECK(a_scales.size(1) == k / Gemm::GroupSizeK::value);
|
||||
TORCH_CHECK(a_scales.stride(0) == 1 || is_contiguous_vector(a_scales),
|
||||
"a_scales must be M major");
|
||||
TORCH_CHECK(b_scales.size(0) == k / Gemm::GroupSizeK::value);
|
||||
TORCH_CHECK(b_scales.size(1) == n / Gemm::GroupSizeN::value);
|
||||
TORCH_CHECK(b_scales.stride(0) == 1 || is_contiguous_vector(b_scales),
|
||||
"b_scales must be K major");
|
||||
typename GemmKernel::MainloopArguments mainloop_args{
|
||||
a_ptr, a_stride, b_ptr, b_stride, a_scales_ptr, b_scales_ptr};
|
||||
|
||||
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
{}, c_ptr, c_stride, c_ptr, c_stride};
|
||||
|
||||
c3x::cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
|
||||
epilogue_args);
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
cutlass_gemm_caller_blockwise<
|
||||
cutlass_3x_gemm_fp8_blockwise<OutType, 1, 128, 128>>(out, a, b, a_scales,
|
||||
b_scales);
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@ -1,33 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_sm90_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_azp_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales);
|
||||
|
||||
} // namespace vllm
|
||||
@ -1,24 +0,0 @@
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_sm90_fp8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_sm90_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_mm_sm90_fp8_epilogue<c3x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm90_fp8_epilogue<c3x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@ -1,24 +0,0 @@
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_sm90_int8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@ -1,13 +1,52 @@
|
||||
#include <cudaTypedefs.h>
|
||||
#include "c3x/scaled_mm_kernels.hpp"
|
||||
|
||||
#include "core/math.hpp"
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||
|
||||
#include "scaled_mm_c3x_sm90_fp8_dispatch.cuh"
|
||||
#include "scaled_mm_c3x_sm90_int8_dispatch.cuh"
|
||||
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
using namespace vllm;
|
||||
|
||||
/*
|
||||
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
||||
NVIDIA GPUs with sm90a (Hopper) or later.
|
||||
*/
|
||||
|
||||
template <template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm90_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
if (a.dtype() == torch::kInt8) {
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
|
||||
Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
@ -15,56 +54,14 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
using GroupShape = std::array<int64_t, 2>;
|
||||
|
||||
int M = a.size(0), N = b.size(1), K = a.size(1);
|
||||
|
||||
GroupShape a_scale_group_shape = [&, &s = a_scales]() -> GroupShape {
|
||||
if (s.numel() == 1) return {M, K}; // tensor-wise
|
||||
if (s.dim() == 2)
|
||||
return {ceil_div(a.size(0), s.size(0)), ceil_div(a.size(1), s.size(1))};
|
||||
TORCH_CHECK(false, "Unsupported scale shape for scale_a");
|
||||
}();
|
||||
|
||||
GroupShape b_scale_group_shape = [&, &s = b_scales]() -> GroupShape {
|
||||
if (s.numel() == 1) return {K, N}; // tensor-wise
|
||||
if (s.dim() == 2)
|
||||
return {ceil_div(b.size(0), s.size(0)), ceil_div(b.size(1), s.size(1))};
|
||||
TORCH_CHECK(false, "Unsupported scale shape for scale_b");
|
||||
}();
|
||||
|
||||
if ((a_scale_group_shape == GroupShape{M, K} ||
|
||||
a_scale_group_shape == GroupShape{1, K}) &&
|
||||
(b_scale_group_shape == GroupShape{K, N} ||
|
||||
b_scale_group_shape == GroupShape{K, 1})) {
|
||||
// "standard per-tensor/per-token/per-channel" scaling
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (a.dtype() == torch::kFloat8_e4m3fn) {
|
||||
vllm::cutlass_scaled_mm_sm90_fp8(c, a, b, a_scales, b_scales, bias);
|
||||
} else {
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
vllm::cutlass_scaled_mm_sm90_int8(c, a, b, a_scales, b_scales, bias);
|
||||
}
|
||||
} else if (a_scale_group_shape == GroupShape{1, 128} &&
|
||||
b_scale_group_shape == GroupShape{128, 128}) {
|
||||
// 1x128 per-token group scales for activations
|
||||
// 128x128 blockwise scales for weights
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn &&
|
||||
b.dtype() == torch::kFloat8_e4m3fn,
|
||||
"Currently only FP8 is supported for A group shape 1x128 and "
|
||||
"B group shape 128x128");
|
||||
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
|
||||
|
||||
vllm::cutlass_scaled_mm_blockwise_sm90_fp8(c, a, b, a_scales, b_scales);
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->dtype() == c.dtype(),
|
||||
"currently bias dtype must match output dtype ", c.dtype());
|
||||
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBias>(
|
||||
c, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
TORCH_CHECK(false,
|
||||
"Unsupported scale group shapes for CUTLASS 3.x GEMM.\n "
|
||||
"a_scale_group_shape must be [1, 128], got: [",
|
||||
a_scale_group_shape[0], ", ", a_scale_group_shape[1],
|
||||
"]\n"
|
||||
"b_scale_group_shape must be [128, 128], got: [",
|
||||
b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]");
|
||||
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogue>(
|
||||
c, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
@ -78,6 +75,13 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
vllm::cutlass_scaled_mm_azp_sm90_int8(out, a, b, a_scales, b_scales, azp_adj,
|
||||
azp, bias);
|
||||
if (azp) {
|
||||
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBiasAzpToken>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBiasAzp>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@ -2,6 +2,9 @@
|
||||
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
@ -29,6 +32,21 @@ using namespace cute;
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// A wrapper for the GEMM kernel that is used to guard against compilation on
|
||||
// architectures that will never use the kernel. The purpose of this is to
|
||||
// reduce the size of the compiled binary.
|
||||
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
|
||||
// into code that will be executed on the device where it is defined.
|
||||
template <typename Kernel>
|
||||
struct enable_sm90_or_later : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE void operator()(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
|
||||
Kernel::operator()(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ElementAB_, typename ElementD_,
|
||||
template <typename, typename, typename> typename Epilogue_,
|
||||
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
||||
@ -83,4 +101,60 @@ struct cutlass_3x_gemm {
|
||||
struct GemmKernel : public KernelType {};
|
||||
};
|
||||
|
||||
template <typename Gemm, typename... EpilogueArgs>
|
||||
void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_params) {
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
|
||||
int32_t m = a.size(0);
|
||||
int32_t n = b.size(1);
|
||||
int32_t k = a.size(1);
|
||||
|
||||
int64_t lda = a.stride(0);
|
||||
int64_t ldb = b.stride(1);
|
||||
int64_t ldc = out.stride(0);
|
||||
|
||||
using StrideA = Stride<int64_t, Int<1>, int64_t>;
|
||||
using StrideB = Stride<int64_t, Int<1>, int64_t>;
|
||||
using StrideC = typename Gemm::StrideC;
|
||||
|
||||
StrideA a_stride{lda, Int<1>{}, 0};
|
||||
StrideB b_stride{ldb, Int<1>{}, 0};
|
||||
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
|
||||
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
typename GemmKernel::ProblemShape prob_shape{m, n, k, 1};
|
||||
|
||||
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
|
||||
typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
|
||||
b_stride};
|
||||
|
||||
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
Gemm::Epilogue::prepare_args(
|
||||
std::forward<EpilogueArgs>(epilogue_params)...),
|
||||
c_ptr, c_stride, c_ptr, c_stride};
|
||||
|
||||
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
prob_shape, mainloop_args, epilogue_args};
|
||||
|
||||
// Launch the CUTLASS GEMM kernel.
|
||||
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
GemmOp gemm_op;
|
||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||
|
||||
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@ -1,7 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "scaled_mm.cuh"
|
||||
#include "cutlass_gemm_caller.cuh"
|
||||
#include "scaled_mm_c3x.cuh"
|
||||
|
||||
/**
|
||||
* This file defines Gemm kernel configurations for SM90 (fp8) based on the Gemm
|
||||
@ -10,8 +9,6 @@
|
||||
|
||||
namespace vllm {
|
||||
|
||||
using c3x::cutlass_gemm_caller;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_default {
|
||||
@ -96,25 +93,4 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
|
||||
}
|
||||
}
|
||||
|
||||
template <template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm90_fp8_epilogue(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@ -1,7 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "scaled_mm.cuh"
|
||||
#include "cutlass_gemm_caller.cuh"
|
||||
#include "scaled_mm_c3x.cuh"
|
||||
|
||||
/**
|
||||
* This file defines Gemm kernel configurations for SM90 (int8) based on the
|
||||
@ -10,8 +9,6 @@
|
||||
|
||||
namespace vllm {
|
||||
|
||||
using c3x::cutlass_gemm_caller;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_int8_config_default {
|
||||
@ -140,24 +137,4 @@ inline void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out,
|
||||
}
|
||||
}
|
||||
|
||||
template <template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm90_int8_epilogue(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
|
||||
Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@ -81,19 +81,6 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
|
||||
// CUTLASS block-quantized FP8 kernels need at least CUDA 12.0
|
||||
// and at least SM90 (Hopper)
|
||||
|
||||
#if defined CUDA_VERSION
|
||||
if (cuda_device_capability >= 90) {
|
||||
return CUDA_VERSION >= 12000;
|
||||
}
|
||||
#endif
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
@ -102,12 +89,15 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||
b.size(1) == c.size(1));
|
||||
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
|
||||
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
|
||||
|
||||
// Check for strides and alignment
|
||||
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
||||
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||
TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
||||
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
|
||||
@ -225,4 +215,4 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
|
||||
"No compiled cutlass_scaled_mm_azp for a compute capability less than "
|
||||
"CUDA device capability: ",
|
||||
version_num);
|
||||
}
|
||||
}
|
||||
@ -272,10 +272,6 @@ struct MacheteCollectiveMma {
|
||||
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
|
||||
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
|
||||
// One threads per CTA are producers (1 for operand tile)
|
||||
static constexpr int NumProducerThreadEvents = 1;
|
||||
|
||||
using ScaleTileShape = decltype(make_shape(shape<0>(TileShape{}),
|
||||
shape<1>(SmemLayoutAtomScale{})));
|
||||
|
||||
|
||||
@ -324,13 +324,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
|
||||
ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
|
||||
|
||||
// Check if cutlass scaled_mm supports block quantization (used by DeepSeekV3)
|
||||
ops.def(
|
||||
"cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> "
|
||||
"bool");
|
||||
ops.impl("cutlass_scaled_mm_supports_block_fp8",
|
||||
&cutlass_scaled_mm_supports_fp8);
|
||||
|
||||
// Check if cutlass sparse scaled_mm is supported for CUDA devices of the
|
||||
// given capability
|
||||
ops.def(
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
// Add RunLLM widget
|
||||
document.addEventListener("DOMContentLoaded", function () {
|
||||
var script = document.createElement("script");
|
||||
script.type = "module";
|
||||
@ -16,23 +15,4 @@ document.addEventListener("DOMContentLoaded", function () {
|
||||
|
||||
script.async = true;
|
||||
document.head.appendChild(script);
|
||||
});
|
||||
|
||||
// Update URL search params when tab is clicked
|
||||
document.addEventListener("DOMContentLoaded", function () {
|
||||
const tabs = document.querySelectorAll(".sd-tab-label");
|
||||
|
||||
function updateURL(tab) {
|
||||
const syncGroup = tab.getAttribute("data-sync-group");
|
||||
const syncId = tab.getAttribute("data-sync-id");
|
||||
if (syncGroup && syncId) {
|
||||
const url = new URL(window.location);
|
||||
url.searchParams.set(syncGroup, syncId);
|
||||
window.history.replaceState(null, "", url);
|
||||
}
|
||||
}
|
||||
|
||||
tabs.forEach(tab => {
|
||||
tab.addEventListener("click", () => updateURL(tab));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
Before Width: | Height: | Size: 34 KiB |
|
Before Width: | Height: | Size: 36 KiB |
|
Before Width: | Height: | Size: 41 KiB |
|
Before Width: | Height: | Size: 39 KiB |
|
Before Width: | Height: | Size: 25 KiB |
|
Before Width: | Height: | Size: 32 KiB |
|
Before Width: | Height: | Size: 18 KiB |
|
Before Width: | Height: | Size: 32 KiB |
|
Before Width: | Height: | Size: 17 KiB |
@ -70,7 +70,6 @@ copybutton_prompt_is_regexp = True
|
||||
html_title = project
|
||||
html_theme = 'sphinx_book_theme'
|
||||
html_logo = 'assets/logos/vllm-logo-text-light.png'
|
||||
html_favicon = 'assets/logos/vllm-logo-only-light.ico'
|
||||
html_theme_options = {
|
||||
'path_to_docs': 'docs/source',
|
||||
'repository_url': 'https://github.com/vllm-project/vllm',
|
||||
|
||||
@ -26,7 +26,7 @@ Check out the [building from source](#build-from-source) documentation for detai
|
||||
pip install -r requirements-dev.txt
|
||||
|
||||
# Linting, formatting and static type checking
|
||||
pre-commit install --hook-type pre-commit --hook-type commit-msg
|
||||
pre-commit install
|
||||
|
||||
# You can manually run pre-commit with
|
||||
pre-commit run --all-files
|
||||
|
||||
@ -1,228 +0,0 @@
|
||||
# Automatic Prefix Caching
|
||||
|
||||
Prefix caching kv-cache blocks is a popular optimization in LLM inference to avoid redundant prompt computations. The core idea is simple – we cache the kv-cache blocks of processed requests, and reuse these blocks when a new request comes in with the same prefix as previous requests. Since prefix caching is almost a free lunch and won’t change model outputs, it has been widely used by many public endpoints (e.g., OpenAI, Anthropic, etc) and most open source LLM inference frameworks (e.g., SGLang).
|
||||
|
||||
While there are many ways to implement prefix caching, vLLM chooses a hash-based approach. Specifically, we hash each kv-cache block by the tokens in the block and the tokens in the prefix before the block:
|
||||
|
||||
```text
|
||||
Block 1 Block 2 Block 3
|
||||
[A gentle breeze stirred] [the leaves as children] [laughed in the distance]
|
||||
Block 1: |<--- block tokens ---->|
|
||||
Block 2: |<------- prefix ------>| |<--- block tokens --->|
|
||||
Block 3: |<------------------ prefix -------------------->| |<--- block tokens ---->|
|
||||
```
|
||||
|
||||
In the example above, the KV cache in the first block can be uniquely identified with the token “A gentle breeze stirred”. The third block can be uniquely identified with the tokens in the block “laughed in the distance”, along with the prefix tokens “A gentle breeze stirred the leaves as children”. Therefore, we can build the block hash of `hash(tuple[components])`, where components are:
|
||||
|
||||
* Parent hash value: The hash value of the parent hash block.
|
||||
* Block tokens: A tuple of tokens in this block. The reason to include the exact tokens is to reduce potential hash value collision.
|
||||
* Extra hashes: Other values required to make this block unique, such as LoRA IDs and multi-modality input hashes (see the example below).
|
||||
|
||||
Note 1: We only cache full blocks.
|
||||
|
||||
Note 2: The above hash key structure is not 100% collision free. Theoretically it’s still possible for the different prefix tokens to have the same hash value, but this should be nearly impossible to happen. Of course, contributions are welcome if you have an awesome idea to eliminate collusion entirely.
|
||||
|
||||
**A hashing example with multi-modality inputs**
|
||||
In this example, we illustrate how prefix caching works with multi-modality inputs (e.g., images). Assuming we have a request with the following messages:
|
||||
|
||||
```text
|
||||
messages = [
|
||||
{"role": "user",
|
||||
"content": [
|
||||
{"type": "text",
|
||||
"text": "What's in this image?"
|
||||
},
|
||||
{"type": "image_url",
|
||||
"image_url": {"url": image_url},
|
||||
},
|
||||
]},
|
||||
]
|
||||
```
|
||||
|
||||
It will become the following prompt:
|
||||
|
||||
```text
|
||||
Prompt:
|
||||
<s>[INST]What's in this image?\n[IMG][/INST]
|
||||
|
||||
Tokenized prompt:
|
||||
[1, 3, 7493, 1681, 1294, 1593, 3937, 9551, 10, 4]
|
||||
|
||||
Prompt with placeholders (<P>):
|
||||
[1, 3, 7493, 1681, 1294, 1593, 3937, 9551, <P>, <P>, ..., <P>, 4]
|
||||
```
|
||||
|
||||
As we can see, after the tokenization, the `[IMG]` will be replaced by a sequence of placeholder tokens, and these placeholders will be replaced by image embeddings during prefill. The challenge for prefix caching to support this case is we need to differentiate images from the placeholders. To address this problem, we encode the image hash generated by the frontend image processor. For example, the hash of the blocks in the above prompt would be (assuming block size 16, and we have 41 placeholder tokens):
|
||||
|
||||
```text
|
||||
Block 0
|
||||
Parent hash: None
|
||||
Token IDs: 1, 3, 7493, 1681, 1294, 1593, 3937, 9551, <p>, ..., <p>
|
||||
Extra hash: <image hash>
|
||||
Block 1
|
||||
Parent hash: Block 0 hash
|
||||
Token IDs: <p>, ..., <p>
|
||||
Extra hash: <image hash>
|
||||
Block 2
|
||||
Parent hash: Block 1 hash
|
||||
Token IDs: <p>, ..., <p>
|
||||
Extra hash: <image hash>
|
||||
Block 3
|
||||
Parent hash: Block 2 hash
|
||||
Token IDs: <p>, ..., <p>, 4
|
||||
Extra hash: <image hash>
|
||||
```
|
||||
|
||||
In the rest of this document, we first introduce the data structure used for prefix caching in vLLM v1, followed by the prefix caching workflow of major KV cache operators (e.g., allocate, append, free, eviction). Finally, we use an example to illustrate the end to end prefix caching workflow.
|
||||
|
||||
## Data Structure
|
||||
|
||||
The prefix caching in vLLM v1 is implemented in the KV cache manager. The basic building block is the “Block” data class (simplified):
|
||||
|
||||
```python
|
||||
class KVCacheBlock:
|
||||
# The block ID (immutable)
|
||||
block_id: int
|
||||
# The block hash (will be assigned when the block is full,
|
||||
# and will be reset when the block is evicted).
|
||||
block_hash: BlockHashType
|
||||
# The number of requests using this block now.
|
||||
ref_cnt: int
|
||||
|
||||
# The pointers to form a doubly linked list for the free queue.
|
||||
prev_free_block: Optional["KVCacheBlock"] = None
|
||||
next_free_block: Optional["KVCacheBlock"] = None
|
||||
```
|
||||
|
||||
There are two design points to highlight:
|
||||
|
||||
1. We allocate all KVCacheBlock when initializing the KV cache manager to be a block pool. This avoids Python object creation overheads and can easily track all blocks all the time.
|
||||
2. We introduce doubly linked list pointers directly in the KVCacheBlock, so that we could construct a free queue directly. This gives us two benefits:
|
||||
1. We could have O(1) complexity moving elements in the middle to the tail.
|
||||
2. We could avoid introducing another Python queue (e.g., `deque`) which has a wrapper to the elements.
|
||||
|
||||
As a result, we will have the following components when the KV cache manager is initialized:
|
||||
|
||||
:::{image} /assets/design/v1/prefix_caching/overview.png
|
||||
:alt: Component Overview
|
||||
:::
|
||||
|
||||
* Block Pool: A list of KVCacheBlock.
|
||||
* Free Block Queue: Only store the pointers of head and tail blocks for manipulations.
|
||||
* Cache blocks: Mapping from hash key to block IDs.
|
||||
* Request blocks: Mapping from request ID to allocated block IDs.
|
||||
|
||||
## Operations
|
||||
|
||||
### Block Allocation
|
||||
|
||||
**New request:** Workflow for the scheduler to schedule a new request with KV cache block allocation:
|
||||
|
||||
1. The scheduler calls `kv_cache_manager.get_computed_blocks()` to get a sequence of blocks that have already been computed. This is done by hashing the prompt tokens in the request and looking up Cache Blocks.
|
||||
2. The scheduler calls `kv_cache_manager.allocate_slots()`. It does the following steps:
|
||||
1. Compute the number of new required blocks, and return if there are no sufficient blocks to allocate.
|
||||
2. “Touch” the computed blocks. It increases the reference count of the computed block by one, and removes the block from the free queue if the block wasn’t used by other requests. This is to avoid these computed blocks being evicted. See the example in the next section for illustration.
|
||||
3. Allocate new blocks by popping the heads of the free queue. If the head block is a cached block, this also “evicts” the block so that no other requests can reuse it anymore from now on.
|
||||
4. If an allocated block is already full of tokens, we immediately add it to the Cache Block, so that the block can be reused by other requests in the same batch.
|
||||
|
||||
**Running request:** Workflow for the scheduler to schedule a running request with KV cache block allocation:
|
||||
|
||||
1. The scheduler calls `kv_cache_manager.append_slots()`. It does the following steps:
|
||||
1. Compute the number of new required blocks, and return if there are no sufficient blocks to allocate.
|
||||
2. Allocate new blocks by popping the heads of the free queue. If the head block is a cached block, this also “evicts” the block so that no other requests can reuse it anymore from now on.
|
||||
3. Append token IDs to the slots in existing blocks as well as the new blocks. If a block is full, we add it to the Cache Block to cache it.
|
||||
|
||||
**Duplicated blocks**
|
||||
Assuming block size is 4 and you send a request (Request 1\) with prompt ABCDEF and decoding length 3:
|
||||
|
||||
```text
|
||||
Prompt: [A, B, C, D, E, F]
|
||||
Output: [G, H, I]
|
||||
|
||||
Time 0:
|
||||
Tokens: [A, B, C, D, E, F, G]
|
||||
Block Table: [0 (ABCD), 1 (EFG)]
|
||||
Cache Blocks: 0
|
||||
Time 1:
|
||||
Tokens: [A, B, C, D, E, F, G, H]
|
||||
Block Table: [0 (ABCD), 1 (EFGH)]
|
||||
Cache Blocks: 0, 1
|
||||
Time 2:
|
||||
Tokens: [A, B, C, D, E, F, G, H, I]
|
||||
Block Table: [0 (ABCD), 1 (EFGH), 2 (I)]
|
||||
Cache Blocks: 0, 1
|
||||
```
|
||||
|
||||
Now block 0 and block 1 are cached, and we send the same request again (Request 2\) with greedy sampling, so that it will produce exactly the same outputs as the Request 1:
|
||||
|
||||
```text
|
||||
Prompt: [A, B, C, D, E, F]
|
||||
Output: [G, H, I]
|
||||
|
||||
Time 0:
|
||||
Tokens: [A, B, C, D, E, F, G]
|
||||
Block Table: [0 (ABCD), 3 (EFG)]
|
||||
Cache Blocks: 0, 1
|
||||
Time 1:
|
||||
Tokens: [A, B, C, D, E, F, G, H]
|
||||
Block Table: [0 (ABCD), 3 (EFGH)]
|
||||
Cache Blocks: 0, 1, 3
|
||||
```
|
||||
|
||||
As can be seen, block 3 is a new full block and is cached. However, it is redundant as block 1, meaning that we cached the same block twice. In v0, when detecting block 3 is duplicated, we free block 3 and let Request 2 use block 1 instead, so its block table becomes `[0, 1]` in Time 1. However, the block table in vLLM v1 is append-only, meaning that changing the block table from `[0, 3]` to `[0, 1]` is not allowed. As a result, we will have duplicated blocks for the hash key E-H. This duplication will be eliminated when the request is freed.
|
||||
|
||||
### Free
|
||||
|
||||
When a request is finished, we free all its blocks if no other requests are using them (reference count = 0). In this example, we free request 1 and block 2, 3, 4, 8 associated with it. We can see that the freed blocks are added to the tail of the free queue in the *reverse* order. This is because the last block of a request must hash more tokens and is less likely to be reused by other requests. As a result, it should be evicted first.
|
||||
|
||||
:::{image} /assets/design/v1/prefix_caching/free.png
|
||||
:alt: Free Queue after Free a Request
|
||||
:::
|
||||
|
||||
### Eviction (LRU)
|
||||
|
||||
When the head block (least recently used block) of the free queue is cached, we have to evict the block to prevent it from being used by other requests. Specifically, eviction involves the following steps:
|
||||
|
||||
1. Pop the block from the head of the free queue. This is the LRU black to be evicted.
|
||||
2. Remove the block ID from the Cache Block.
|
||||
3. Remove the block hash.
|
||||
|
||||
## Example
|
||||
|
||||
In this example, we assume the block size is 4 (each block can cache 4 tokens), and we have 10 blocks in the KV-cache manager in total.
|
||||
|
||||
**Time 1: The cache is empty and a new request comes in.** We allocate 4 blocks. 3 of them are already full and cached. The fourth block is partially full with 2 of 4 tokens.
|
||||
|
||||
:::{image} /assets/design/v1/prefix_caching/example-time-1.png
|
||||
:alt: Example Time 1
|
||||
:::
|
||||
|
||||
**Time 3: Request 0 makes the block 3 full and asks for a new block to keep decoding.** We cache block 3 and allocate block 4.
|
||||
|
||||
:::{image} /assets/design/v1/prefix_caching/example-time-3.png
|
||||
:alt: Example Time 3
|
||||
:::
|
||||
|
||||
**Time 4: Request 1 comes in with the 14 prompt tokens, where the first 11 tokens are the same as request 0.** We can see that only 2 blocks (11 tokens) hit the cache, because the 3rd block only matches 3 of 4 tokens.
|
||||
|
||||
:::{image} /assets/design/v1/prefix_caching/example-time-4.png
|
||||
:alt: Example Time 4
|
||||
:::
|
||||
|
||||
**Time 5: Request 0 is finished and free.** Blocks 2, 3 and 4 are added to the free queue in the reverse order (but block 2 and 3 are still cached). Block 0 and 1 are not added to the free queue because they are being used by Request 1.
|
||||
|
||||
:::{image} /assets/design/v1/prefix_caching/example-time-5.png
|
||||
:alt: Example Time 5
|
||||
:::
|
||||
|
||||
**Time 6: Request 1 is finished and free.**
|
||||
|
||||
:::{image} /assets/design/v1/prefix_caching/example-time-6.png
|
||||
:alt: Example Time 6
|
||||
:::
|
||||
|
||||
**Time 7: Request 2 comes in with the 33 prompt tokens, where the first 16 tokens are the same as request 0\.** Note that even the block order in the free queue was `7 - 8 - 9 - 4 - 3 - 2 - 6 - 5 - 1 - 0`, the cache hit blocks (i.e., 0, 1, 2) are touched and removed from the queue before allocation, so the free queue becomes `7 - 8 - 9 - 4 - 3 - 6 - 5`. As a result, the allocated blocks are 0 (cached), 1 (cached), 2 (cached), 7, 8, 9, 4, 3 (evicted).
|
||||
|
||||
:::{image} /assets/design/v1/prefix_caching/example-time-7.png
|
||||
:alt: Example Time 7
|
||||
:::
|
||||
@ -12,7 +12,6 @@ supported_hardware
|
||||
auto_awq
|
||||
bnb
|
||||
gguf
|
||||
int4
|
||||
int8
|
||||
fp8
|
||||
quantized_kvcache
|
||||
|
||||
@ -1,166 +0,0 @@
|
||||
(int4)=
|
||||
|
||||
# INT4 W4A16
|
||||
|
||||
vLLM supports quantizing weights to INT4 for memory savings and inference acceleration. This quantization method is particularly useful for reducing model size and maintaining low latency in workloads with low queries per second (QPS).
|
||||
|
||||
Please visit the HF collection of [quantized INT4 checkpoints of popular LLMs ready to use with vLLM](https://huggingface.co/collections/neuralmagic/int4-llms-for-vllm-668ec34bf3c9fa45f857df2c).
|
||||
|
||||
:::{note}
|
||||
INT4 computation is supported on NVIDIA GPUs with compute capability > 8.0 (Ampere, Ada Lovelace, Hopper, Blackwell).
|
||||
:::
|
||||
|
||||
## Prerequisites
|
||||
|
||||
To use INT4 quantization with vLLM, you'll need to install the [llm-compressor](https://github.com/vllm-project/llm-compressor/) library:
|
||||
|
||||
```console
|
||||
pip install llmcompressor
|
||||
```
|
||||
|
||||
## Quantization Process
|
||||
|
||||
The quantization process involves four main steps:
|
||||
|
||||
1. Loading the model
|
||||
2. Preparing calibration data
|
||||
3. Applying quantization
|
||||
4. Evaluating accuracy in vLLM
|
||||
|
||||
### 1. Loading the Model
|
||||
|
||||
Load your model and tokenizer using the standard `transformers` AutoModel classes:
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL_ID, device_map="auto", torch_dtype="auto",
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
||||
```
|
||||
|
||||
### 2. Preparing Calibration Data
|
||||
|
||||
When quantizing weights to INT4, you need sample data to estimate the weight updates and calibrated scales.
|
||||
It's best to use calibration data that closely matches your deployment data.
|
||||
For a general-purpose instruction-tuned model, you can use a dataset like `ultrachat`:
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
|
||||
NUM_CALIBRATION_SAMPLES = 512
|
||||
MAX_SEQUENCE_LENGTH = 2048
|
||||
|
||||
# Load and preprocess the dataset
|
||||
ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft")
|
||||
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))
|
||||
|
||||
def preprocess(example):
|
||||
return {"text": tokenizer.apply_chat_template(example["messages"], tokenize=False)}
|
||||
ds = ds.map(preprocess)
|
||||
|
||||
def tokenize(sample):
|
||||
return tokenizer(sample["text"], padding=False, max_length=MAX_SEQUENCE_LENGTH, truncation=True, add_special_tokens=False)
|
||||
ds = ds.map(tokenize, remove_columns=ds.column_names)
|
||||
```
|
||||
|
||||
### 3. Applying Quantization
|
||||
|
||||
Now, apply the quantization algorithms:
|
||||
|
||||
```python
|
||||
from llmcompressor.transformers import oneshot
|
||||
from llmcompressor.modifiers.quantization import GPTQModifier
|
||||
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
|
||||
|
||||
# Configure the quantization algorithms
|
||||
recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"])
|
||||
|
||||
# Apply quantization
|
||||
oneshot(
|
||||
model=model,
|
||||
dataset=ds,
|
||||
recipe=recipe,
|
||||
max_seq_length=MAX_SEQUENCE_LENGTH,
|
||||
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
|
||||
)
|
||||
|
||||
# Save the compressed model
|
||||
SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A16-G128"
|
||||
model.save_pretrained(SAVE_DIR, save_compressed=True)
|
||||
tokenizer.save_pretrained(SAVE_DIR)
|
||||
```
|
||||
|
||||
This process creates a W4A16 model with weights quantized to 4-bit integers.
|
||||
|
||||
### 4. Evaluating Accuracy
|
||||
|
||||
After quantization, you can load and run the model in vLLM:
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
model = LLM("./Meta-Llama-3-8B-Instruct-W4A16-G128")
|
||||
```
|
||||
|
||||
To evaluate accuracy, you can use `lm_eval`:
|
||||
|
||||
```console
|
||||
$ lm_eval --model vllm \
|
||||
--model_args pretrained="./Meta-Llama-3-8B-Instruct-W4A16-G128",add_bos_token=true \
|
||||
--tasks gsm8k \
|
||||
--num_fewshot 5 \
|
||||
--limit 250 \
|
||||
--batch_size 'auto'
|
||||
```
|
||||
|
||||
:::{note}
|
||||
Quantized models can be sensitive to the presence of the `bos` token. Make sure to include the `add_bos_token=True` argument when running evaluations.
|
||||
:::
|
||||
|
||||
## Best Practices
|
||||
|
||||
- Start with 512 samples for calibration data, and increase if accuracy drops
|
||||
- Ensure the calibration data contains a high variety of samples to prevent overfitting towards a specific use case
|
||||
- Use a sequence length of 2048 as a starting point
|
||||
- Employ the chat template or instruction template that the model was trained with
|
||||
- If you've fine-tuned a model, consider using a sample of your training data for calibration
|
||||
- Tune key hyperparameters to the quantization algorithm:
|
||||
- `dampening_frac` sets how much influence the GPTQ algorithm has. Lower values can improve accuracy, but can lead to numerical instabilities that cause the algorithm to fail.
|
||||
- `actorder` sets the activation ordering. When compressing the weights of a layer weight, the order in which channels are quantized matters. Setting `actorder="weight"` can improve accuracy without added latency.
|
||||
|
||||
The following is an example of an expanded quantization recipe you can tune to your own use case:
|
||||
|
||||
```python
|
||||
from compressed_tensors.quantization import (
|
||||
QuantizationArgs,
|
||||
QuantizationScheme,
|
||||
QuantizationStrategy,
|
||||
QuantizationType,
|
||||
)
|
||||
recipe = GPTQModifier(
|
||||
targets="Linear",
|
||||
config_groups={
|
||||
"config_group": QuantizationScheme(
|
||||
targets=["Linear"],
|
||||
weights=QuantizationArgs(
|
||||
num_bits=4,
|
||||
type=QuantizationType.INT,
|
||||
strategy=QuantizationStrategy.GROUP,
|
||||
group_size=128,
|
||||
symmetric=True,
|
||||
dynamic=False,
|
||||
actorder="weight",
|
||||
),
|
||||
),
|
||||
},
|
||||
ignore=["lm_head"],
|
||||
update_size=NUM_CALIBRATION_SAMPLES,
|
||||
dampening_frac=0.01
|
||||
)
|
||||
```
|
||||
|
||||
## Troubleshooting and Support
|
||||
|
||||
If you encounter any issues or have feature requests, please open an issue on the [`vllm-project/llm-compressor`](https://github.com/vllm-project/llm-compressor) GitHub repository. The full INT4 quantization example in `llm-compressor` is available [here](https://github.com/vllm-project/llm-compressor/blob/main/examples/quantization_w4a16/llama3_example.py).
|
||||
@ -8,7 +8,7 @@ This quantization method is particularly useful for reducing model size while ma
|
||||
Please visit the HF collection of [quantized INT8 checkpoints of popular LLMs ready to use with vLLM](https://huggingface.co/collections/neuralmagic/int8-llms-for-vllm-668ec32c049dca0369816415).
|
||||
|
||||
:::{note}
|
||||
INT8 computation is supported on NVIDIA GPUs with compute capability > 7.5 (Turing, Ampere, Ada Lovelace, Hopper, Blackwell).
|
||||
INT8 computation is supported on NVIDIA GPUs with compute capability > 7.5 (Turing, Ampere, Ada Lovelace, Hopper).
|
||||
:::
|
||||
|
||||
## Prerequisites
|
||||
@ -132,4 +132,4 @@ Quantized models can be sensitive to the presence of the `bos` token. Make sure
|
||||
|
||||
## Troubleshooting and Support
|
||||
|
||||
If you encounter any issues or have feature requests, please open an issue on the [`vllm-project/llm-compressor`](https://github.com/vllm-project/llm-compressor) GitHub repository.
|
||||
If you encounter any issues or have feature requests, please open an issue on the `vllm-project/llm-compressor` GitHub repository.
|
||||
|
||||
@ -2,10 +2,6 @@
|
||||
|
||||
This tab provides instructions on running vLLM with Intel Gaudi devices.
|
||||
|
||||
:::{attention}
|
||||
There are no pre-built wheels or images for this device, so you must build vLLM from source.
|
||||
:::
|
||||
|
||||
## Requirements
|
||||
|
||||
- OS: Ubuntu 22.04 LTS
|
||||
|
||||
@ -5,8 +5,7 @@ vLLM is a Python library that supports the following AI accelerators. Select you
|
||||
:::::{tab-set}
|
||||
:sync-group: device
|
||||
|
||||
::::{tab-item} Google TPU
|
||||
:selected:
|
||||
::::{tab-item} TPU
|
||||
:sync: tpu
|
||||
|
||||
:::{include} tpu.inc.md
|
||||
@ -26,7 +25,7 @@ vLLM is a Python library that supports the following AI accelerators. Select you
|
||||
|
||||
::::
|
||||
|
||||
::::{tab-item} AWS Neuron
|
||||
::::{tab-item} Neuron
|
||||
:sync: neuron
|
||||
|
||||
:::{include} neuron.inc.md
|
||||
@ -53,7 +52,7 @@ vLLM is a Python library that supports the following AI accelerators. Select you
|
||||
:::::{tab-set}
|
||||
:sync-group: device
|
||||
|
||||
::::{tab-item} Google TPU
|
||||
::::{tab-item} TPU
|
||||
:sync: tpu
|
||||
|
||||
:::{include} tpu.inc.md
|
||||
@ -73,7 +72,7 @@ vLLM is a Python library that supports the following AI accelerators. Select you
|
||||
|
||||
::::
|
||||
|
||||
::::{tab-item} AWS Neuron
|
||||
::::{tab-item} Neuron
|
||||
:sync: neuron
|
||||
|
||||
:::{include} neuron.inc.md
|
||||
@ -100,7 +99,7 @@ vLLM is a Python library that supports the following AI accelerators. Select you
|
||||
:::::{tab-set}
|
||||
:sync-group: device
|
||||
|
||||
::::{tab-item} Google TPU
|
||||
::::{tab-item} TPU
|
||||
:sync: tpu
|
||||
|
||||
:::{include} tpu.inc.md
|
||||
@ -120,7 +119,7 @@ vLLM is a Python library that supports the following AI accelerators. Select you
|
||||
|
||||
::::
|
||||
|
||||
::::{tab-item} AWS Neuron
|
||||
::::{tab-item} Neuron
|
||||
:sync: neuron
|
||||
|
||||
:::{include} neuron.inc.md
|
||||
@ -147,7 +146,7 @@ vLLM is a Python library that supports the following AI accelerators. Select you
|
||||
:::::{tab-set}
|
||||
:sync-group: device
|
||||
|
||||
::::{tab-item} Google TPU
|
||||
::::{tab-item} TPU
|
||||
:sync: tpu
|
||||
|
||||
:::{include} tpu.inc.md
|
||||
@ -167,7 +166,7 @@ vLLM is a Python library that supports the following AI accelerators. Select you
|
||||
|
||||
::::
|
||||
|
||||
::::{tab-item} AWS Neuron
|
||||
::::{tab-item} Neuron
|
||||
:sync: neuron
|
||||
|
||||
:::{include} neuron.inc.md
|
||||
@ -194,7 +193,7 @@ vLLM is a Python library that supports the following AI accelerators. Select you
|
||||
:::::{tab-set}
|
||||
:sync-group: device
|
||||
|
||||
::::{tab-item} Google TPU
|
||||
::::{tab-item} TPU
|
||||
:sync: tpu
|
||||
|
||||
:::{include} tpu.inc.md
|
||||
@ -214,7 +213,7 @@ vLLM is a Python library that supports the following AI accelerators. Select you
|
||||
|
||||
::::
|
||||
|
||||
::::{tab-item} AWS Neuron
|
||||
::::{tab-item} Neuron
|
||||
:sync: neuron
|
||||
|
||||
:::{include} neuron.inc.md
|
||||
@ -243,7 +242,7 @@ vLLM is a Python library that supports the following AI accelerators. Select you
|
||||
:::::{tab-set}
|
||||
:sync-group: device
|
||||
|
||||
::::{tab-item} Google TPU
|
||||
::::{tab-item} TPU
|
||||
:sync: tpu
|
||||
|
||||
:::{include} tpu.inc.md
|
||||
@ -263,7 +262,7 @@ vLLM is a Python library that supports the following AI accelerators. Select you
|
||||
|
||||
::::
|
||||
|
||||
::::{tab-item} AWS Neuron
|
||||
::::{tab-item} Neuron
|
||||
:sync: neuron
|
||||
|
||||
:::{include} neuron.inc.md
|
||||
@ -290,7 +289,7 @@ vLLM is a Python library that supports the following AI accelerators. Select you
|
||||
:::::{tab-set}
|
||||
:sync-group: device
|
||||
|
||||
::::{tab-item} Google TPU
|
||||
::::{tab-item} TPU
|
||||
:sync: tpu
|
||||
|
||||
:::{include} tpu.inc.md
|
||||
@ -310,7 +309,7 @@ vLLM is a Python library that supports the following AI accelerators. Select you
|
||||
|
||||
::::
|
||||
|
||||
::::{tab-item} AWS Neuron
|
||||
::::{tab-item} Neuron
|
||||
:sync: neuron
|
||||
|
||||
:::{include} neuron.inc.md
|
||||
@ -337,7 +336,7 @@ vLLM is a Python library that supports the following AI accelerators. Select you
|
||||
:::::{tab-set}
|
||||
:sync-group: device
|
||||
|
||||
::::{tab-item} Google TPU
|
||||
::::{tab-item} TPU
|
||||
:sync: tpu
|
||||
|
||||
:::{include} tpu.inc.md
|
||||
@ -355,7 +354,7 @@ vLLM is a Python library that supports the following AI accelerators. Select you
|
||||
|
||||
::::
|
||||
|
||||
::::{tab-item} AWS Neuron
|
||||
::::{tab-item} Neuron
|
||||
:sync: neuron
|
||||
|
||||
:::{include} neuron.inc.md
|
||||
|
||||
@ -4,10 +4,6 @@ vLLM 0.3.3 onwards supports model inferencing and serving on AWS Trainium/Infere
|
||||
Paged Attention and Chunked Prefill are currently in development and will be available soon.
|
||||
Data types currently supported in Neuron SDK are FP16 and BF16.
|
||||
|
||||
:::{attention}
|
||||
There are no pre-built wheels or images for this device, so you must build vLLM from source.
|
||||
:::
|
||||
|
||||
## Requirements
|
||||
|
||||
- OS: Linux
|
||||
|
||||
@ -2,10 +2,6 @@
|
||||
|
||||
vLLM powered by OpenVINO supports all LLM models from [vLLM supported models list](#supported-models) and can perform optimal model serving on all x86-64 CPUs with, at least, AVX2 support, as well as on both integrated and discrete Intel® GPUs ([the list of supported GPUs](https://docs.openvino.ai/2024/about-openvino/release-notes-openvino/system-requirements.html#gpu)).
|
||||
|
||||
:::{attention}
|
||||
There are no pre-built wheels or images for this device, so you must build vLLM from source.
|
||||
:::
|
||||
|
||||
## Requirements
|
||||
|
||||
- OS: Linux
|
||||
|
||||
@ -30,10 +30,6 @@ For TPU pricing information, see [Cloud TPU pricing](https://cloud.google.com/tp
|
||||
You may need additional persistent storage for your TPU VMs. For more
|
||||
information, see [Storage options for Cloud TPU data](https://cloud.devsite.corp.google.com/tpu/docs/storage-options).
|
||||
|
||||
:::{attention}
|
||||
There are no pre-built wheels for this device, so you must either use the pre-built Docker image or build vLLM from source.
|
||||
:::
|
||||
|
||||
## Requirements
|
||||
|
||||
- Google Cloud TPU VM
|
||||
|
||||
@ -4,10 +4,6 @@ vLLM has experimental support for macOS with Apple silicon. For now, users shall
|
||||
|
||||
Currently the CPU implementation for macOS supports FP32 and FP16 datatypes.
|
||||
|
||||
:::{attention}
|
||||
There are no pre-built wheels or images for this device, so you must build vLLM from source.
|
||||
:::
|
||||
|
||||
## Requirements
|
||||
|
||||
- OS: `macOS Sonoma` or later
|
||||
|
||||
@ -4,10 +4,6 @@ vLLM has been adapted to work on ARM64 CPUs with NEON support, leveraging the CP
|
||||
|
||||
ARM CPU backend currently supports Float32, FP16 and BFloat16 datatypes.
|
||||
|
||||
:::{attention}
|
||||
There are no pre-built wheels or images for this device, so you must build vLLM from source.
|
||||
:::
|
||||
|
||||
## Requirements
|
||||
|
||||
- OS: Linux
|
||||
|
||||
@ -5,8 +5,7 @@ vLLM is a Python library that supports the following CPU variants. Select your C
|
||||
:::::{tab-set}
|
||||
:sync-group: device
|
||||
|
||||
::::{tab-item} Intel/AMD x86
|
||||
:selected:
|
||||
::::{tab-item} x86
|
||||
:sync: x86
|
||||
|
||||
:::{include} x86.inc.md
|
||||
@ -16,7 +15,7 @@ vLLM is a Python library that supports the following CPU variants. Select your C
|
||||
|
||||
::::
|
||||
|
||||
::::{tab-item} ARM AArch64
|
||||
::::{tab-item} ARM
|
||||
:sync: arm
|
||||
|
||||
:::{include} arm.inc.md
|
||||
@ -45,7 +44,7 @@ vLLM is a Python library that supports the following CPU variants. Select your C
|
||||
:::::{tab-set}
|
||||
:sync-group: device
|
||||
|
||||
::::{tab-item} Intel/AMD x86
|
||||
::::{tab-item} x86
|
||||
:sync: x86
|
||||
|
||||
:::{include} x86.inc.md
|
||||
@ -55,7 +54,7 @@ vLLM is a Python library that supports the following CPU variants. Select your C
|
||||
|
||||
::::
|
||||
|
||||
::::{tab-item} ARM AArch64
|
||||
::::{tab-item} ARM
|
||||
:sync: arm
|
||||
|
||||
:::{include} arm.inc.md
|
||||
@ -93,7 +92,7 @@ Currently, there are no pre-built CPU wheels.
|
||||
:::::{tab-set}
|
||||
:sync-group: device
|
||||
|
||||
::::{tab-item} Intel/AMD x86
|
||||
::::{tab-item} x86
|
||||
:sync: x86
|
||||
|
||||
:::{include} x86.inc.md
|
||||
@ -103,7 +102,7 @@ Currently, there are no pre-built CPU wheels.
|
||||
|
||||
::::
|
||||
|
||||
::::{tab-item} ARM AArch64
|
||||
::::{tab-item} ARM
|
||||
:sync: arm
|
||||
|
||||
:::{include} arm.inc.md
|
||||
|
||||
@ -2,20 +2,12 @@
|
||||
|
||||
vLLM initially supports basic model inferencing and serving on x86 CPU platform, with data types FP32, FP16 and BF16.
|
||||
|
||||
:::{attention}
|
||||
There are no pre-built wheels or images for this device, so you must build vLLM from source.
|
||||
:::
|
||||
|
||||
## Requirements
|
||||
|
||||
- OS: Linux
|
||||
- Compiler: `gcc/g++ >= 12.3.0` (optional, recommended)
|
||||
- Instruction Set Architecture (ISA): AVX512 (optional, recommended)
|
||||
|
||||
:::{tip}
|
||||
[Intel Extension for PyTorch (IPEX)](https://github.com/intel/intel-extension-for-pytorch) extends PyTorch with up-to-date features optimizations for an extra performance boost on Intel hardware.
|
||||
:::
|
||||
|
||||
## Set up using Python
|
||||
|
||||
### Pre-built wheels
|
||||
@ -37,3 +29,7 @@ There are no pre-built wheels or images for this device, so you must build vLLM
|
||||
### Build image from source
|
||||
|
||||
## Extra information
|
||||
|
||||
## Intel Extension for PyTorch
|
||||
|
||||
- [Intel Extension for PyTorch (IPEX)](https://github.com/intel/intel-extension-for-pytorch) extends PyTorch with up-to-date features optimizations for an extra performance boost on Intel hardware.
|
||||
|
||||
@ -5,8 +5,7 @@ vLLM is a Python library that supports the following GPU variants. Select your G
|
||||
:::::{tab-set}
|
||||
:sync-group: device
|
||||
|
||||
::::{tab-item} NVIDIA CUDA
|
||||
:selected:
|
||||
::::{tab-item} CUDA
|
||||
:sync: cuda
|
||||
|
||||
:::{include} cuda.inc.md
|
||||
@ -16,7 +15,7 @@ vLLM is a Python library that supports the following GPU variants. Select your G
|
||||
|
||||
::::
|
||||
|
||||
::::{tab-item} AMD ROCm
|
||||
::::{tab-item} ROCm
|
||||
:sync: rocm
|
||||
|
||||
:::{include} rocm.inc.md
|
||||
@ -26,7 +25,7 @@ vLLM is a Python library that supports the following GPU variants. Select your G
|
||||
|
||||
::::
|
||||
|
||||
::::{tab-item} Intel XPU
|
||||
::::{tab-item} XPU
|
||||
:sync: xpu
|
||||
|
||||
:::{include} xpu.inc.md
|
||||
@ -46,7 +45,7 @@ vLLM is a Python library that supports the following GPU variants. Select your G
|
||||
:::::{tab-set}
|
||||
:sync-group: device
|
||||
|
||||
::::{tab-item} NVIDIA CUDA
|
||||
::::{tab-item} CUDA
|
||||
:sync: cuda
|
||||
|
||||
:::{include} cuda.inc.md
|
||||
@ -56,7 +55,7 @@ vLLM is a Python library that supports the following GPU variants. Select your G
|
||||
|
||||
::::
|
||||
|
||||
::::{tab-item} AMD ROCm
|
||||
::::{tab-item} ROCm
|
||||
:sync: rocm
|
||||
|
||||
:::{include} rocm.inc.md
|
||||
@ -66,7 +65,7 @@ vLLM is a Python library that supports the following GPU variants. Select your G
|
||||
|
||||
::::
|
||||
|
||||
::::{tab-item} Intel XPU
|
||||
::::{tab-item} XPU
|
||||
:sync: xpu
|
||||
|
||||
:::{include} xpu.inc.md
|
||||
@ -88,7 +87,7 @@ vLLM is a Python library that supports the following GPU variants. Select your G
|
||||
:::::{tab-set}
|
||||
:sync-group: device
|
||||
|
||||
::::{tab-item} NVIDIA CUDA
|
||||
::::{tab-item} CUDA
|
||||
:sync: cuda
|
||||
|
||||
:::{include} cuda.inc.md
|
||||
@ -98,14 +97,14 @@ vLLM is a Python library that supports the following GPU variants. Select your G
|
||||
|
||||
::::
|
||||
|
||||
::::{tab-item} AMD ROCm
|
||||
::::{tab-item} ROCm
|
||||
:sync: rocm
|
||||
|
||||
There is no extra information on creating a new Python environment for this device.
|
||||
|
||||
::::
|
||||
|
||||
::::{tab-item} Intel XPU
|
||||
::::{tab-item} XPU
|
||||
:sync: xpu
|
||||
|
||||
There is no extra information on creating a new Python environment for this device.
|
||||
@ -119,7 +118,7 @@ There is no extra information on creating a new Python environment for this devi
|
||||
:::::{tab-set}
|
||||
:sync-group: device
|
||||
|
||||
::::{tab-item} NVIDIA CUDA
|
||||
::::{tab-item} CUDA
|
||||
:sync: cuda
|
||||
|
||||
:::{include} cuda.inc.md
|
||||
@ -129,7 +128,7 @@ There is no extra information on creating a new Python environment for this devi
|
||||
|
||||
::::
|
||||
|
||||
::::{tab-item} AMD ROCm
|
||||
::::{tab-item} ROCm
|
||||
:sync: rocm
|
||||
|
||||
:::{include} rocm.inc.md
|
||||
@ -139,7 +138,7 @@ There is no extra information on creating a new Python environment for this devi
|
||||
|
||||
::::
|
||||
|
||||
::::{tab-item} Intel XPU
|
||||
::::{tab-item} XPU
|
||||
:sync: xpu
|
||||
|
||||
:::{include} xpu.inc.md
|
||||
@ -158,7 +157,7 @@ There is no extra information on creating a new Python environment for this devi
|
||||
:::::{tab-set}
|
||||
:sync-group: device
|
||||
|
||||
::::{tab-item} NVIDIA CUDA
|
||||
::::{tab-item} CUDA
|
||||
:sync: cuda
|
||||
|
||||
:::{include} cuda.inc.md
|
||||
@ -168,7 +167,7 @@ There is no extra information on creating a new Python environment for this devi
|
||||
|
||||
::::
|
||||
|
||||
::::{tab-item} AMD ROCm
|
||||
::::{tab-item} ROCm
|
||||
:sync: rocm
|
||||
|
||||
:::{include} rocm.inc.md
|
||||
@ -178,7 +177,7 @@ There is no extra information on creating a new Python environment for this devi
|
||||
|
||||
::::
|
||||
|
||||
::::{tab-item} Intel XPU
|
||||
::::{tab-item} XPU
|
||||
:sync: xpu
|
||||
|
||||
:::{include} xpu.inc.md
|
||||
@ -197,7 +196,7 @@ There is no extra information on creating a new Python environment for this devi
|
||||
:::::{tab-set}
|
||||
:sync-group: device
|
||||
|
||||
::::{tab-item} NVIDIA CUDA
|
||||
::::{tab-item} CUDA
|
||||
:sync: cuda
|
||||
|
||||
:::{include} cuda.inc.md
|
||||
@ -207,7 +206,7 @@ There is no extra information on creating a new Python environment for this devi
|
||||
|
||||
::::
|
||||
|
||||
::::{tab-item} AMD ROCm
|
||||
::::{tab-item} ROCm
|
||||
:sync: rocm
|
||||
|
||||
:::{include} rocm.inc.md
|
||||
@ -217,7 +216,7 @@ There is no extra information on creating a new Python environment for this devi
|
||||
|
||||
::::
|
||||
|
||||
::::{tab-item} Intel XPU
|
||||
::::{tab-item} XPU
|
||||
:sync: xpu
|
||||
|
||||
:::{include} xpu.inc.md
|
||||
@ -234,7 +233,7 @@ There is no extra information on creating a new Python environment for this devi
|
||||
:::::{tab-set}
|
||||
:sync-group: device
|
||||
|
||||
::::{tab-item} NVIDIA CUDA
|
||||
::::{tab-item} CUDA
|
||||
:sync: cuda
|
||||
|
||||
:::{include} cuda.inc.md
|
||||
@ -244,7 +243,7 @@ There is no extra information on creating a new Python environment for this devi
|
||||
|
||||
::::
|
||||
|
||||
::::{tab-item} AMD ROCm
|
||||
::::{tab-item} ROCm
|
||||
:sync: rocm
|
||||
|
||||
:::{include} rocm.inc.md
|
||||
@ -254,7 +253,7 @@ There is no extra information on creating a new Python environment for this devi
|
||||
|
||||
::::
|
||||
|
||||
::::{tab-item} Intel XPU
|
||||
::::{tab-item} XPU
|
||||
:sync: xpu
|
||||
|
||||
:::{include} xpu.inc.md
|
||||
@ -271,7 +270,7 @@ There is no extra information on creating a new Python environment for this devi
|
||||
:::::{tab-set}
|
||||
:sync-group: device
|
||||
|
||||
::::{tab-item} NVIDIA CUDA
|
||||
::::{tab-item} CUDA
|
||||
:sync: cuda
|
||||
|
||||
:::{include} cuda.inc.md
|
||||
@ -280,7 +279,7 @@ There is no extra information on creating a new Python environment for this devi
|
||||
|
||||
::::
|
||||
|
||||
::::{tab-item} AMD ROCm
|
||||
::::{tab-item} ROCm
|
||||
:sync: rocm
|
||||
|
||||
:::{include} rocm.inc.md
|
||||
@ -289,7 +288,7 @@ There is no extra information on creating a new Python environment for this devi
|
||||
|
||||
::::
|
||||
|
||||
::::{tab-item} Intel XPU
|
||||
::::{tab-item} XPU
|
||||
:sync: xpu
|
||||
|
||||
:::{include} xpu.inc.md
|
||||
|
||||
@ -2,10 +2,6 @@
|
||||
|
||||
vLLM supports AMD GPUs with ROCm 6.2.
|
||||
|
||||
:::{attention}
|
||||
There are no pre-built wheels for this device, so you must either use the pre-built Docker image or build vLLM from source.
|
||||
:::
|
||||
|
||||
## Requirements
|
||||
|
||||
- GPU: MI200s (gfx90a), MI300 (gfx942), Radeon RX 7900 series (gfx1100)
|
||||
@ -17,6 +13,14 @@ There are no pre-built wheels for this device, so you must either use the pre-bu
|
||||
|
||||
Currently, there are no pre-built ROCm wheels.
|
||||
|
||||
However, the [AMD Infinity hub for vLLM](https://hub.docker.com/r/rocm/vllm/tags) offers a prebuilt, optimized
|
||||
docker image designed for validating inference performance on the AMD Instinct™ MI300X accelerator.
|
||||
|
||||
:::{tip}
|
||||
Please check [LLM inference performance validation on AMD Instinct MI300X](https://rocm.docs.amd.com/en/latest/how-to/performance-validation/mi300x/vllm-benchmark.html)
|
||||
for instructions on how to use this prebuilt docker image.
|
||||
:::
|
||||
|
||||
### Build wheel from source
|
||||
|
||||
0. Install prerequisites (skip if you are already in an environment/docker with the following installed):
|
||||
@ -108,13 +112,7 @@ Currently, there are no pre-built ROCm wheels.
|
||||
|
||||
### Pre-built images
|
||||
|
||||
The [AMD Infinity hub for vLLM](https://hub.docker.com/r/rocm/vllm/tags) offers a prebuilt, optimized
|
||||
docker image designed for validating inference performance on the AMD Instinct™ MI300X accelerator.
|
||||
|
||||
:::{tip}
|
||||
Please check [LLM inference performance validation on AMD Instinct MI300X](https://rocm.docs.amd.com/en/latest/how-to/performance-validation/mi300x/vllm-benchmark.html)
|
||||
for instructions on how to use this prebuilt docker image.
|
||||
:::
|
||||
Currently, there are no pre-built ROCm images.
|
||||
|
||||
### Build image from source
|
||||
|
||||
|
||||
@ -2,10 +2,6 @@
|
||||
|
||||
vLLM initially supports basic model inferencing and serving on Intel GPU platform.
|
||||
|
||||
:::{attention}
|
||||
There are no pre-built wheels or images for this device, so you must build vLLM from source.
|
||||
:::
|
||||
|
||||
## Requirements
|
||||
|
||||
- Supported Hardware: Intel Data Center GPU, Intel ARC GPU
|
||||
|
||||
@ -6,23 +6,8 @@ vLLM supports the following hardware platforms:
|
||||
|
||||
:::{toctree}
|
||||
:maxdepth: 1
|
||||
:hidden:
|
||||
|
||||
gpu/index
|
||||
cpu/index
|
||||
ai_accelerator/index
|
||||
:::
|
||||
|
||||
- <project:gpu/index.md>
|
||||
- NVIDIA CUDA
|
||||
- AMD ROCm
|
||||
- Intel XPU
|
||||
- <project:cpu/index.md>
|
||||
- Intel/AMD x86
|
||||
- ARM AArch64
|
||||
- Apple silicon
|
||||
- <project:ai_accelerator/index.md>
|
||||
- Google TPU
|
||||
- Intel Gaudi
|
||||
- AWS Neuron
|
||||
- OpenVINO
|
||||
|
||||
@ -153,13 +153,6 @@ design/automatic_prefix_caching
|
||||
design/multiprocessing
|
||||
:::
|
||||
|
||||
:::{toctree}
|
||||
:caption: V1 Design Documents
|
||||
:maxdepth: 2
|
||||
|
||||
design/v1/prefix_caching
|
||||
:::
|
||||
|
||||
% How to contribute to the vLLM project
|
||||
|
||||
:::{toctree}
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
#!/bin/bash
|
||||
|
||||
echo "vLLM linting system has been moved from format.sh to pre-commit hook."
|
||||
echo "Please run 'pip install -r requirements-lint.txt', followed by"
|
||||
echo "'pre-commit install --hook-type pre-commit --hook-type commit-msg' to install the pre-commit hook."
|
||||
echo "Please run 'pip install -r requirements-lint.txt' and 'pre-commit install' to install the pre-commit hook."
|
||||
echo "Then linters will run automatically before each commit."
|
||||
|
||||
@ -5,7 +5,7 @@ requests >= 2.26.0
|
||||
tqdm
|
||||
blake3
|
||||
py-cpuinfo
|
||||
transformers >= 4.48.2 # Required for Bamba.
|
||||
transformers >= 4.45.2 # Required for Llama 3.2 and Qwen2-VL.
|
||||
tokenizers >= 0.19.1 # Required for Llama 3.
|
||||
protobuf # Required by LlamaTokenizer.
|
||||
fastapi >= 0.107.0, < 0.113.0; python_version < '3.9'
|
||||
|
||||
@ -3,13 +3,7 @@
|
||||
|
||||
# Dependencies for CPUs
|
||||
torch==2.5.1+cpu; platform_machine != "ppc64le" and platform_machine != "aarch64" and platform_system != "Darwin"
|
||||
torch==2.5.1; platform_machine == "ppc64le" or platform_machine == "aarch64" or platform_system == "Darwin"
|
||||
|
||||
# required for the image processor of minicpm-o-2_6, this must be updated alongside torch
|
||||
torchaudio; platform_machine != "ppc64le"
|
||||
torchaudio==2.5.1; platform_machine == "ppc64le"
|
||||
|
||||
# required for the image processor of phi3v, this must be updated alongside torch
|
||||
torchvision; platform_machine != "ppc64le"
|
||||
torchvision==0.20.1; platform_machine == "ppc64le"
|
||||
torch==2.5.1; platform_machine == "aarch64" or platform_system == "Darwin"
|
||||
torchaudio; platform_machine != "ppc64le" # required for the image processor of minicpm-o-2_6, this must be updated alongside torch
|
||||
torchvision; platform_machine != "ppc64le" # required for the image processor of phi3v, this must be updated alongside torch
|
||||
datasets # for benchmark scripts
|
||||
|
||||
@ -27,7 +27,7 @@ matplotlib # required for qwen-vl test
|
||||
mistral_common[opencv] >= 1.5.0 # required for pixtral test
|
||||
datamodel_code_generator # required for minicpm3 test
|
||||
lm-eval[api]==0.4.4 # required for model evaluation test
|
||||
transformers==4.48.2
|
||||
|
||||
# quantization
|
||||
bitsandbytes>=0.45.0
|
||||
buildkite-test-collector==0.1.9
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# This file is autogenerated by pip-compile with Python 3.12
|
||||
# by the following command:
|
||||
#
|
||||
# python3.12 -m piptools compile requirements-test.in -o requirements-test.txt
|
||||
# python3.12 -m piptools compile requirements-test.in -o requirements-test.txt
|
||||
#
|
||||
absl-py==2.1.0
|
||||
# via rouge-score
|
||||
@ -617,9 +617,8 @@ tqdm==4.66.6
|
||||
# transformers
|
||||
tqdm-multiprocess==0.0.11
|
||||
# via lm-eval
|
||||
transformers==4.48.2
|
||||
transformers==4.47.0
|
||||
# via
|
||||
# -r requirements-test.in
|
||||
# genai-perf
|
||||
# lm-eval
|
||||
# peft
|
||||
|
||||
6
setup.py
@ -608,11 +608,7 @@ if _build_custom_ops():
|
||||
ext_modules.append(CMakeExtension(name="vllm._C"))
|
||||
|
||||
package_data = {
|
||||
"vllm": [
|
||||
"py.typed",
|
||||
"model_executor/layers/fused_moe/configs/*.json",
|
||||
"model_executor/layers/quantization/utils/configs/*.json",
|
||||
]
|
||||
"vllm": ["py.typed", "model_executor/layers/fused_moe/configs/*.json"]
|
||||
}
|
||||
|
||||
if _no_device():
|
||||
|
||||
@ -10,7 +10,6 @@ import torch
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv
|
||||
|
||||
from .utils import baseline_scaled_mm, to_fp8, to_int8
|
||||
|
||||
@ -40,11 +39,6 @@ CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
|
||||
# -1 means full extent in that dimension
|
||||
TENSORWISE_GROUP_SHAPE = (-1, -1)
|
||||
PER_TOKEN_GROUP_SHAPE = (1, -1)
|
||||
PER_OUT_CH_GROUP_SHAPE = (-1, 1)
|
||||
|
||||
capability = current_platform.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
|
||||
@ -53,22 +47,11 @@ def rand_int8(shape: tuple, device: str = "cuda"):
|
||||
return to_int8(torch.rand(shape, device=device) * 255 - 128)
|
||||
|
||||
|
||||
def group_scale_helper(shape, group_shape):
|
||||
return [shape[i] if s < 0 else s for i, s in enumerate(group_shape)]
|
||||
|
||||
|
||||
def scale_shape(shape, group_shape):
|
||||
assert len(shape) == len(group_shape)
|
||||
group_shape = group_scale_helper(shape, group_shape)
|
||||
return tuple(
|
||||
cdiv(shape[i], group_shape[i]) for i in range(len(group_shape)))
|
||||
|
||||
|
||||
def cutlass_fp8_gemm_helper(m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
a_scale_group_shape: tuple,
|
||||
b_scale_group_shape: tuple,
|
||||
per_token_act_quant: bool,
|
||||
per_out_channel_weight_quant: bool,
|
||||
use_bias: bool,
|
||||
out_dtype: Type[torch.dtype] = torch.bfloat16,
|
||||
device: str = "cuda"):
|
||||
@ -77,17 +60,13 @@ def cutlass_fp8_gemm_helper(m: int,
|
||||
a = to_fp8(torch.randn((m, k), device=device))
|
||||
b = to_fp8(torch.randn((n, k), device=device).t())
|
||||
|
||||
a_scales_shape = scale_shape(a.shape, a_scale_group_shape)
|
||||
b_scales_shape = scale_shape(b.shape, b_scale_group_shape)
|
||||
|
||||
scale_a = (torch.randn(a_scales_shape, device=device, dtype=torch.float32))
|
||||
scale_b = (torch.randn(b_scales_shape, device=device, dtype=torch.float32))
|
||||
|
||||
# make scales M-major for blockwise quant, doesn't affect 1D scales
|
||||
scale_a = scale_a.t().contiguous().t()
|
||||
# make scales K-major for blockwise quant, doesn't affect 1D scales
|
||||
scale_b = scale_b.t().contiguous().t()
|
||||
m_a_scales = m if per_token_act_quant else 1
|
||||
n_b_scales = n if per_out_channel_weight_quant else 1
|
||||
|
||||
scale_a = (torch.randn((m_a_scales, 1), device=device,
|
||||
dtype=torch.float32))
|
||||
scale_b = (torch.randn((1, n_b_scales), device=device,
|
||||
dtype=torch.float32))
|
||||
if use_bias:
|
||||
bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
|
||||
else:
|
||||
@ -105,8 +84,8 @@ def cutlass_fp8_gemm_helper(m: int,
|
||||
def cutlass_int8_gemm_helper(m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
a_scale_group_shape: tuple,
|
||||
b_scale_group_shape: tuple,
|
||||
per_token_act_quant: bool,
|
||||
per_out_channel_weight_quant: bool,
|
||||
use_bias: bool,
|
||||
out_dtype: Type[torch.dtype] = torch.bfloat16,
|
||||
device: str = "cuda"):
|
||||
@ -115,11 +94,13 @@ def cutlass_int8_gemm_helper(m: int,
|
||||
a = to_int8(torch.randn((m, k), device=device) * 5)
|
||||
b = to_int8(torch.randn((n, k), device=device).t() * 5)
|
||||
|
||||
a_scales_shape = scale_shape(a.shape, a_scale_group_shape)
|
||||
b_scales_shape = scale_shape(b.shape, b_scale_group_shape)
|
||||
m_a_scales = m if per_token_act_quant else 1
|
||||
n_b_scales = n if per_out_channel_weight_quant else 1
|
||||
|
||||
scale_a = (torch.randn(a_scales_shape, device=device, dtype=torch.float32))
|
||||
scale_b = (torch.randn(b_scales_shape, device=device, dtype=torch.float32))
|
||||
scale_a = (torch.randn((m_a_scales, 1), device=device,
|
||||
dtype=torch.float32))
|
||||
scale_b = (torch.randn((1, n_b_scales), device=device,
|
||||
dtype=torch.float32))
|
||||
|
||||
if use_bias:
|
||||
bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
|
||||
@ -136,135 +117,82 @@ def cutlass_int8_gemm_helper(m: int,
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("a_scale_group_shape",
|
||||
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("b_scale_group_shape",
|
||||
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_gemm(m: int, n: int, k: int, a_scale_group_shape,
|
||||
b_scale_group_shape, use_bias: bool):
|
||||
cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape,
|
||||
use_bias)
|
||||
def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool,
|
||||
per_out_ch: bool, use_bias: bool):
|
||||
cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("a_scale_group_shape,b_scale_group_shape",
|
||||
[((1, 128), (128, 128))])
|
||||
@pytest.mark.parametrize("use_bias", [False])
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(90),
|
||||
reason="FP8 blockwise is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_blockwise_scale_gemm(m: int, n: int, k: int,
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape, use_bias: bool):
|
||||
if k % b_scale_group_shape[0] != 0 or n % b_scale_group_shape[1] != 0:
|
||||
return
|
||||
if m % a_scale_group_shape[0] != 0 or k % a_scale_group_shape[1] != 0:
|
||||
return
|
||||
cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape,
|
||||
use_bias)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("a_scale_group_shape",
|
||||
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("b_scale_group_shape",
|
||||
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_cutlass_int8_gemm(m: int, n: int, k: int, a_scale_group_shape,
|
||||
b_scale_group_shape, use_bias: bool):
|
||||
cutlass_int8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape,
|
||||
use_bias)
|
||||
def test_cutlass_int8_gemm(m: int, n: int, k: int, per_act_token: bool,
|
||||
per_out_ch: bool, use_bias: bool):
|
||||
cutlass_int8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("a_scale_group_shape",
|
||||
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("b_scale_group_shape",
|
||||
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
|
||||
out_dtype: Type[torch.dtype],
|
||||
use_bias: bool):
|
||||
cutlass_int8_gemm_helper(512,
|
||||
512,
|
||||
512,
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
per_act_token,
|
||||
per_out_ch,
|
||||
use_bias,
|
||||
out_dtype=out_dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("a_scale_group_shape",
|
||||
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("b_scale_group_shape",
|
||||
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_gemm_output_dtype(a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
|
||||
out_dtype: Type[torch.dtype],
|
||||
use_bias: bool):
|
||||
cutlass_fp8_gemm_helper(512,
|
||||
512,
|
||||
512,
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
per_act_token,
|
||||
per_out_ch,
|
||||
use_bias,
|
||||
out_dtype=out_dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("a_scale_group_shape,b_scale_group_shape",
|
||||
[((1, 128), (128, 128))])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("use_bias", [False])
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(90),
|
||||
reason="FP8 blockwise is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_blockwise_scale_gemm_dtype(a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
out_dtype: Type[torch.dtype],
|
||||
use_bias: bool):
|
||||
cutlass_fp8_gemm_helper(512,
|
||||
512,
|
||||
512,
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
use_bias,
|
||||
out_dtype=out_dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("a_scale_group_shape",
|
||||
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("b_scale_group_shape",
|
||||
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_gemm_devices(a_scale_group_shape, b_scale_group_shape,
|
||||
def test_cutlass_fp8_gemm_devices(per_act_token: bool, per_out_ch: bool,
|
||||
use_bias: bool, device: str):
|
||||
cutlass_fp8_gemm_helper(512, 512, 512, a_scale_group_shape,
|
||||
b_scale_group_shape, use_bias, torch.bfloat16,
|
||||
device)
|
||||
cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, use_bias,
|
||||
torch.bfloat16, device)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("a_scale_group_shape",
|
||||
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("b_scale_group_shape",
|
||||
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_cutlass_int8_gemm_devices(a_scale_group_shape, b_scale_group_shape,
|
||||
def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
|
||||
use_bias: bool, device: str):
|
||||
cutlass_int8_gemm_helper(512,
|
||||
512,
|
||||
512,
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
per_act_token,
|
||||
per_out_ch,
|
||||
use_bias,
|
||||
out_dtype=torch.bfloat16,
|
||||
device=device)
|
||||
@ -275,32 +203,28 @@ def test_cutlass_int8_gemm_devices(a_scale_group_shape, b_scale_group_shape,
|
||||
# of a large power of two. In any case, the kernel will have a naive fallback
|
||||
# when N and K are not divisible by 16. But M is the number of tokens and the
|
||||
# kernel must handle any M thrown at it.
|
||||
@pytest.mark.parametrize("a_scale_group_shape",
|
||||
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("b_scale_group_shape",
|
||||
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_gemm_m_sweep(a_scale_group_shape, b_scale_group_shape,
|
||||
def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
|
||||
use_bias: bool):
|
||||
for nk in range(32, 128, 32):
|
||||
for m in range(1, 128):
|
||||
cutlass_fp8_gemm_helper(m, nk, nk, a_scale_group_shape,
|
||||
b_scale_group_shape, use_bias)
|
||||
cutlass_fp8_gemm_helper(m, nk, nk, per_act_token, per_out_ch,
|
||||
use_bias)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("a_scale_group_shape",
|
||||
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("b_scale_group_shape",
|
||||
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_cutlass_int8_gemm_m_sweep(a_scale_group_shape, b_scale_group_shape,
|
||||
def test_cutlass_int8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
|
||||
use_bias: bool):
|
||||
for nk in range(32, 128, 32):
|
||||
for m in range(1, 128):
|
||||
cutlass_int8_gemm_helper(m, nk, nk, a_scale_group_shape,
|
||||
b_scale_group_shape, use_bias)
|
||||
cutlass_int8_gemm_helper(m, nk, nk, per_act_token, per_out_ch,
|
||||
use_bias)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [32, 64, 128])
|
||||
|
||||
@ -1119,36 +1119,8 @@ def baseline_scaled_mm(a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: Type[torch.dtype],
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
# We treat N-dimensional group scaling as extended numpy-style broadcasting
|
||||
# in numpy simply stretches dimensions with an extent of 1 to match the
|
||||
# the target shape by repeating the data along that dimension (broadcasting)
|
||||
# , we extend these semantics to say if the extent of a dimension in the
|
||||
# source shape is not 1 and does not match the target shape we repeat each
|
||||
# element along that dimension src_shape[dim] // target_shape[dim] times
|
||||
# example if we have:
|
||||
# a = [[1, 2], and target_shape = (2, 4)
|
||||
# [3, 4]]
|
||||
# then we would expand a to:
|
||||
# a = [[1, 1, 2, 2],
|
||||
# [3, 3, 4, 4]]
|
||||
# NOTE this function this function does not explicitly broadcast dimensions
|
||||
# with an extent of 1, since this can be done implicitly by pytorch
|
||||
def group_broadcast(t, shape):
|
||||
for i, s in enumerate(shape):
|
||||
if t.shape[i] != s and t.shape[i] != 1:
|
||||
assert s % t.shape[i] == 0
|
||||
t = t.unsqueeze(i + 1)\
|
||||
.expand(*t.shape[:i+1], s // t.shape[i], *t.shape[i+1:])\
|
||||
.flatten(i, i + 1)
|
||||
return t
|
||||
|
||||
scale_a = group_broadcast(scale_a, a.shape)
|
||||
scale_b = group_broadcast(scale_b, b.shape)
|
||||
|
||||
output = torch.mm((scale_a * a.to(dtype=torch.float32)),
|
||||
(scale_b * b.to(dtype=torch.float32))).to(out_dtype)
|
||||
|
||||
output = (scale_a * (scale_b * (torch.mm(
|
||||
a.to(dtype=torch.float32), b.to(dtype=torch.float32))))).to(out_dtype)
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
|
||||
|
||||
@ -5,5 +5,5 @@ class DummyPlatform(CudaPlatform):
|
||||
device_name = "DummyDevice"
|
||||
|
||||
def get_attn_backend_cls(self, backend_name, head_size, dtype,
|
||||
kv_cache_dtype, block_size, use_v1, use_mla):
|
||||
kv_cache_dtype, block_size, use_v1):
|
||||
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501
|
||||
|
||||
@ -192,7 +192,7 @@ def test_hash_block_tokens():
|
||||
extra_keys)
|
||||
assert isinstance(block_hash, BlockHashType)
|
||||
assert block_hash.hash_value == hash(
|
||||
(parent_block_hash, curr_block_token_ids, extra_keys))
|
||||
(parent_block_hash, *curr_block_token_ids))
|
||||
assert block_hash.token_ids == curr_block_token_ids
|
||||
assert block_hash.extra_keys == extra_keys
|
||||
|
||||
@ -227,38 +227,6 @@ def test_hash_request_tokens():
|
||||
assert block_hashes[1].extra_keys == ("hash2", )
|
||||
|
||||
|
||||
def test_hash_tokens_different_mm_input():
|
||||
request1 = make_request(
|
||||
request_id=0,
|
||||
prompt_token_ids=[_ for _ in range(6)],
|
||||
mm_positions=[{
|
||||
"offset": 0,
|
||||
"length": 3
|
||||
}, {
|
||||
"offset": 3,
|
||||
"length": 3
|
||||
}],
|
||||
mm_hashes=["hash1", "hash2"],
|
||||
)
|
||||
request2 = make_request(
|
||||
request_id=1,
|
||||
prompt_token_ids=[_ for _ in range(6)],
|
||||
mm_positions=[{
|
||||
"offset": 0,
|
||||
"length": 3
|
||||
}, {
|
||||
"offset": 3,
|
||||
"length": 3
|
||||
}],
|
||||
mm_hashes=["hash3", "hash2"],
|
||||
)
|
||||
block_size = 3
|
||||
block_hashes1 = hash_request_tokens(block_size, request1)
|
||||
block_hashes2 = hash_request_tokens(block_size, request2)
|
||||
assert block_hashes1[0] != block_hashes2[0]
|
||||
assert block_hashes1[1] != block_hashes2[1]
|
||||
|
||||
|
||||
def test_hash_request_tokens_no_mm_inputs():
|
||||
request = make_request(
|
||||
request_id=0,
|
||||
|
||||
@ -20,7 +20,7 @@ compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main
|
||||
compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
|
||||
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
|
||||
compressed-tensors, nm-testing/TinyLlama-1.1B-Chat-v1.0-actorder-group, main
|
||||
#compressed-tensors, mgoin/DeepSeek-Coder-V2-Lite-Instruct-FP8, main
|
||||
compressed-tensors, mgoin/DeepSeek-Coder-V2-Lite-Instruct-FP8, main
|
||||
compressed-tensors, nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-FP8-Dynamic-testing, main, 90
|
||||
compressed-tensors, nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-W8A8-testing, main, 90
|
||||
awq, casperhansen/mixtral-instruct-awq, main
|
||||
|
||||
@ -3,7 +3,7 @@ SUCCESS=0
|
||||
|
||||
while getopts "c:" OPT; do
|
||||
case ${OPT} in
|
||||
c )
|
||||
c )
|
||||
CONFIG="$OPTARG"
|
||||
;;
|
||||
\? )
|
||||
@ -18,14 +18,9 @@ IFS=$'\n' read -d '' -r -a MODEL_CONFIGS < "$CONFIG"
|
||||
|
||||
for MODEL_CONFIG in "${MODEL_CONFIGS[@]}"
|
||||
do
|
||||
if [[ $MODEL_CONFIG == \#* ]]; then
|
||||
echo "=== SKIPPING MODEL: $MODEL_CONFIG ==="
|
||||
continue
|
||||
fi
|
||||
|
||||
LOCAL_SUCCESS=0
|
||||
IFS=', ' read -r -a array <<< "$MODEL_CONFIG"
|
||||
|
||||
|
||||
echo "=== RUNNING MODEL: $MODEL_CONFIG ==="
|
||||
|
||||
export QUANTIZATION=${array[0]}
|
||||
|
||||
@ -435,39 +435,12 @@ def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
|
||||
return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
|
||||
|
||||
|
||||
def cutlass_scaled_mm_supports_block_fp8(cuda_device_capability: int) -> bool:
|
||||
return torch.ops._C.cutlass_scaled_mm_supports_block_fp8(
|
||||
cuda_device_capability)
|
||||
|
||||
|
||||
def cutlass_scaled_mm(a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""
|
||||
`cutlass_scaled_mm` implements a fused version of
|
||||
`output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
|
||||
where scale_a * a and scale_b * b are implemented using numpy-style
|
||||
broadcasting.
|
||||
|
||||
In order to support blockwise scaling like found in DeepSeek V3 we also
|
||||
support extended "group" broadcast rules. We extend the numpy-style
|
||||
broadcasting rules with the following rule:
|
||||
"if the extent of a dimension in the source shape is between 1 and
|
||||
corresponding extent in the target shape we repeat each element along
|
||||
that dimension src_shape[dim] // target_shape[dim] times consecutively"
|
||||
example if we have:
|
||||
a = [[1, 2], and target_shape = (2, 4)
|
||||
[3, 4]]
|
||||
then we would expand a to:
|
||||
a = [[1, 1, 2, 2],
|
||||
[3, 3, 4, 4]]
|
||||
currently we only support the case:
|
||||
scale_a.shape * [1, 128] == a.shape
|
||||
scale_b.shape * [128, 128] == b.shape
|
||||
"""
|
||||
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
|
||||
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
|
||||
assert bias is None or bias.shape[0] == b.shape[
|
||||
|
||||
@ -168,7 +168,8 @@ class AttentionState(ABC, Generic[T]):
|
||||
|
||||
@abstractmethod
|
||||
@contextmanager
|
||||
def graph_capture(self, max_batch_size: int):
|
||||
def graph_capture(self, max_batch_size: int,
|
||||
positions: Optional[torch.Tensor]):
|
||||
"""Context manager used when capturing CUDA graphs."""
|
||||
yield
|
||||
|
||||
@ -268,9 +269,9 @@ class AttentionImpl(ABC, Generic[T]):
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
query: torch.Tensor, # For MLA hidden_states_or_cq
|
||||
key: torch.Tensor, # For MLA kv_c_normed
|
||||
value: torch.Tensor, # For MLA k_pe
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: T,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
@ -278,7 +279,7 @@ class AttentionImpl(ABC, Generic[T]):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
|
||||
class MLAAttentionImpl(AttentionImpl):
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
|
||||
@ -213,7 +213,10 @@ class FlashInferState(AttentionState):
|
||||
return self._decode_wrapper
|
||||
|
||||
@contextmanager
|
||||
def graph_capture(self, max_batch_size: int):
|
||||
def graph_capture(self, max_batch_size: int,
|
||||
positions: Optional[torch.Tensor]):
|
||||
assert positions is None
|
||||
|
||||
self._is_graph_capturing = True
|
||||
self._graph_decode_wrapper = None
|
||||
self._graph_slot_mapping = torch.full((max_batch_size, ),
|
||||
|
||||
@ -1,47 +1,35 @@
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Generic, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import (AttentionLayer,
|
||||
AttentionMetadata,
|
||||
MLAAttentionImpl, T)
|
||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
MLAAttentionImpl)
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearBase, RowParallelLinear,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||
CompressedTensorsLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsW8A8Fp8)
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
scaled_dequantize, scaled_quantize)
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLACommonMetadata(AttentionMetadata):
|
||||
# Input positions for rotrary embeddings since for MLA the rotary
|
||||
@dataclass(kw_only=True)
|
||||
class MLAMetadataCommon(AttentionMetadata):
|
||||
# Input positions for rotrary embeddings since for MLA the rotarty
|
||||
# position embeddings are applied inside the attention backend
|
||||
input_positions: torch.Tensor
|
||||
|
||||
|
||||
class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
class MLACommonImpl(MLAAttentionImpl):
|
||||
"""
|
||||
Common class for implementing repeated parts
|
||||
|
||||
Common class for implementing repeated parts
|
||||
|
||||
Main reference: DeepseekV2 paper, and FlashInfer Implementation
|
||||
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
|
||||
|
||||
|
||||
Deepseek's MLA attention works the following way:
|
||||
* Use a single latent vector to represent the entire KV cache.
|
||||
* The attention "simulates" a multi-head attention, while the compute is
|
||||
@ -58,7 +46,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
* V: V head dim.
|
||||
* kv_c: latent/compressed KV
|
||||
* q_c: latent/compressed Q
|
||||
|
||||
|
||||
#
|
||||
# Outside the MLA attention backend
|
||||
#
|
||||
@ -67,21 +55,21 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
kv_c_k_pe (B, Lkv+R).
|
||||
2. The kv_c_k_pe is split into kv_c (B, Lkv) and k_pe (B, R). cq
|
||||
and kv_c are normalized.
|
||||
|
||||
|
||||
#
|
||||
# Inside the MLA attention backend
|
||||
#
|
||||
|
||||
* if prefill:
|
||||
|
||||
3. The q_c is then projected up into the multi-head version.
|
||||
* q_c goes from (B, Lq) to (B, N, (P+R)), which is split into q_nope
|
||||
(B, N, P) and q_pe (B, N, R).
|
||||
|
||||
3. The q_c is then projected up into the multi-head version.
|
||||
* q_c goes from (B, Lq) to (B, N, (P+R)), which is split into q_nope
|
||||
(B, N, P) and q_pe (B, N, R).
|
||||
4. q_pe, k_pe are then passed through rotary embeddings.
|
||||
5. kv_c and k_pe are concatenated and inserted into the cache
|
||||
6. The kv_c is then projected up into the multi-head version.
|
||||
* kv_c goes from (B, Lkv) to (B, N, (P+V)) which has the nope
|
||||
dimensions for K and V, which is split into k_nope (B, N, P)
|
||||
6. The kv_c is then projected up into the multi-head version.
|
||||
* kv_c goes from (B, Lkv) to (B, N, (P+V)) which has the nope
|
||||
dimensions for K and V, which is split into k_nope (B, N, P)
|
||||
and v (B, N, V).
|
||||
7. q (B, N, (P+R)) and k (B, N, (P+R)) matrices are assembled from
|
||||
q_nope, q_pe, k_nope, k_pe.
|
||||
@ -124,7 +112,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
From @tsu-bin's calculation, we only want to use the absorption technique
|
||||
for decode. The prefill algorithm should still use the up-projected MHA
|
||||
for less flops and memory usage.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -150,7 +138,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
# q_proj should be q_b_proj if q_lora_rank is not None, but from an
|
||||
# attention backend perspective we rely on the layer to pass in the
|
||||
# correct matrix
|
||||
q_proj: ColumnParallelLinear,
|
||||
q_proj: Optional[ColumnParallelLinear],
|
||||
kv_b_proj: ColumnParallelLinear,
|
||||
o_proj: RowParallelLinear,
|
||||
) -> None:
|
||||
@ -174,19 +162,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
|
||||
def _v_up_proj_and_o_proj(self, x):
|
||||
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
||||
if is_fp8(self.W_UV_O):
|
||||
output_parallel = apply_fp8_linear_generic(
|
||||
x.flatten(start_dim=1), self.W_UV_O, self.W_UV_O_scales,
|
||||
self.reqaunt_input_group_shape,
|
||||
self.reqaunt_weight_group_shape)
|
||||
else:
|
||||
output_parallel = torch.matmul(x.flatten(start_dim=1),
|
||||
self.W_UV_O)
|
||||
if self.tp_size > 1:
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
return output
|
||||
return self.o_proj_absorbed(
|
||||
x.reshape(-1, self.num_heads * self.kv_lora_rank))[0]
|
||||
else:
|
||||
x = torch.einsum("bnl,lnv->bnv", x, self.W_UV)
|
||||
return self.o_proj(x.reshape(-1,
|
||||
@ -194,12 +171,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
|
||||
def _q_proj_and_k_up_proj(self, x):
|
||||
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
||||
if is_fp8(self.W_Q_UK):
|
||||
return apply_fp8_linear_generic(
|
||||
x, self.W_Q_UK, self.W_Q_UK_scales,
|
||||
self.reqaunt_input_group_shape,
|
||||
self.reqaunt_weight_group_shape).view(
|
||||
-1, self.num_heads, self.kv_lora_rank)
|
||||
return torch.matmul(x, self.W_Q_UK)\
|
||||
.view(-1, self.num_heads, self.kv_lora_rank)
|
||||
else:
|
||||
@ -208,91 +179,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
return torch.einsum("bnp,lnp->bnl", x, self.W_UK)\
|
||||
.view(-1, self.num_heads, self.kv_lora_rank)
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
|
||||
def is_layer_fp8(layer: LinearBase) -> bool:
|
||||
return isinstance(layer.quant_method, Fp8LinearMethod) or\
|
||||
(isinstance(layer.quant_method, CompressedTensorsLinearMethod)\
|
||||
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8))
|
||||
|
||||
def quantization_scheme_supported(layer: LinearBase) -> bool:
|
||||
return isinstance(layer.quant_method, UnquantizedLinearMethod) or \
|
||||
is_layer_fp8(layer)
|
||||
|
||||
# TODO(lucas) This is very gross, we need a more wide scale refactor of
|
||||
# all the FP8 code with a more standard way of
|
||||
# defining schemes/group-shapes, we should also potentially force
|
||||
# quant_methods to support a decompress function
|
||||
#
|
||||
# returns input_group_shape, weight_group_shape
|
||||
def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \
|
||||
Tuple[Tuple[int, int], Tuple[int, int]]:
|
||||
if isinstance(layer.quant_method, Fp8LinearMethod):
|
||||
if layer.quant_method.block_quant is not None:
|
||||
weight_block_size = \
|
||||
layer.quant_method.quant_config.weight_block_size
|
||||
# per-token-group (1, X), block-quantized (X, Y)
|
||||
return (1, weight_block_size[-1]), weight_block_size
|
||||
else:
|
||||
return (-1, -1), (-1, -1) # per-tensor, per-tensor
|
||||
elif isinstance(layer.quant_method, CompressedTensorsLinearMethod)\
|
||||
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
|
||||
# this is hacky but we always assume the for
|
||||
# CompressedTensorsW8A8Fp8 the input is dynamic per-token
|
||||
# we ignore if it is static-per-tensor since we are going to
|
||||
# requantize after later anyways
|
||||
strategy = layer.scheme.strategy
|
||||
if strategy == QuantizationStrategy.TENSOR:
|
||||
return (1, -1), (-1, -1) # per-token, per-tensor
|
||||
elif strategy == QuantizationStrategy.CHANNEL:
|
||||
return (1, -1), (-1, 1) # per-token, per-channel
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"QuantizationStrategy.{strategy} is not supported for "
|
||||
"fp8 MLA, please run with VLLM_MLA_DISABLE=1")
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Can't determine scale group shapes for "
|
||||
f"{layer.quant_method}, please run with VLLM_MLA_DISABLE=1"
|
||||
)
|
||||
|
||||
def get_scales(layer: LinearBase) -> torch.Tensor:
|
||||
if hasattr(layer, "weight_scale_inv"):
|
||||
return layer.weight_scale_inv
|
||||
return layer.weight_scale
|
||||
|
||||
def get_and_maybe_dequant_weights(layer: LinearBase):
|
||||
if is_layer_fp8(layer):
|
||||
if isinstance(layer.quant_method, \
|
||||
CompressedTensorsLinearMethod) and \
|
||||
isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
|
||||
# NOTE(lucas): note sure why but `CompressedTensorsW8A8Fp8`
|
||||
# seems to store weights as (input, output) instead of
|
||||
# (output, input) so we need to transpose
|
||||
weight = layer.weight.T # standardize to (output, input)
|
||||
else:
|
||||
weight = layer.weight
|
||||
_, weight_scale_group_shape = \
|
||||
get_scale_group_shapes_for_fp8(layer)
|
||||
scales = get_scales(layer)
|
||||
|
||||
return scaled_dequantize(weight, scales,
|
||||
weight_scale_group_shape)
|
||||
else:
|
||||
return layer.weight
|
||||
|
||||
if not (quantization_scheme_supported(self.kv_b_proj) and\
|
||||
quantization_scheme_supported(self.q_proj) and\
|
||||
quantization_scheme_supported(self.o_proj)):
|
||||
raise NotImplementedError(
|
||||
"Only FP8 and UnquantizedLinearMethod are supported for MLA"
|
||||
", please run with VLLM_MLA_DISABLE=1")
|
||||
|
||||
weight_dtype = self.kv_b_proj.weight.dtype
|
||||
assert self.o_proj.weight.dtype == weight_dtype
|
||||
assert self.q_proj.weight.dtype == weight_dtype
|
||||
|
||||
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
|
||||
def process_weights_after_loading(self):
|
||||
kv_b_proj_weight = self.kv_b_proj.weight.T
|
||||
assert kv_b_proj_weight.shape == (
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
|
||||
@ -310,35 +198,18 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
W_UK, W_UV = kv_b_proj_weight.split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
|
||||
q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\
|
||||
q_proj = self.q_proj.weight.T\
|
||||
.view(-1, self.num_heads, self.qk_head_dim)
|
||||
|
||||
# can be W_Q or W_UQ depending q_lora_rank, the former if
|
||||
# q_lora_rank is None, the latter otherwise. From the Attention backend
|
||||
# perspective though we call these both W_Q and rely on the layer
|
||||
# to pass in the correct matrix
|
||||
W_Q = q_proj_weight[..., :self.qk_nope_head_dim]
|
||||
self.W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\
|
||||
W_Q = q_proj[..., :self.qk_nope_head_dim]
|
||||
self.W_QR = q_proj[..., self.qk_nope_head_dim:]\
|
||||
.flatten(start_dim=1).contiguous()
|
||||
|
||||
# W_QR is small so for simplicity we dont bother requantizing it
|
||||
self.W_QR = self.W_QR.to(act_dtype)
|
||||
|
||||
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
||||
requantization_enabled = not envs.VLLM_MLA_DISABLE_REQUANTIZATION
|
||||
if is_fp8(weight_dtype) and requantization_enabled:
|
||||
# This assumes it wise to requantize using the same group shapes
|
||||
# (i.e. strategy, per-tensor, per-channel, block etc.) that the
|
||||
# weights were originally quantized
|
||||
requant_input_group_shape, requant_weight_group_shape = \
|
||||
get_scale_group_shapes_for_fp8(self.q_proj)
|
||||
assert (requant_input_group_shape, requant_weight_group_shape)\
|
||||
== get_scale_group_shapes_for_fp8(self.kv_b_proj)
|
||||
assert (requant_input_group_shape, requant_weight_group_shape)\
|
||||
== get_scale_group_shapes_for_fp8(self.o_proj)
|
||||
self.reqaunt_input_group_shape = requant_input_group_shape
|
||||
self.reqaunt_weight_group_shape = requant_weight_group_shape
|
||||
|
||||
#
|
||||
# Perform matrix-absorption following
|
||||
# https://github.com/flashinfer-ai/flashinfer/pull/551
|
||||
@ -352,44 +223,25 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
# latter otherwise
|
||||
# basically if q_lora_rank is none we are absorbing into q_proj
|
||||
# instead of UQ
|
||||
W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\
|
||||
self.W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\
|
||||
.flatten(start_dim=1).contiguous()
|
||||
|
||||
if is_fp8(weight_dtype) and requantization_enabled:
|
||||
W_Q_UK, W_Q_UK_scales = scaled_quantize(
|
||||
W_Q_UK,
|
||||
self.reqaunt_weight_group_shape,
|
||||
quant_dtype=current_platform_fp8_dtype)
|
||||
# For FP8 save the transpose so we can use
|
||||
# `apply_w8a8_block_fp8_linear` directly
|
||||
self.W_Q_UK = W_Q_UK.T.contiguous()
|
||||
self.W_Q_UK_scales = W_Q_UK_scales.T.contiguous()
|
||||
else:
|
||||
self.W_Q_UK = W_Q_UK.to(act_dtype)
|
||||
|
||||
W_O = get_and_maybe_dequant_weights(self.o_proj)\
|
||||
W_O = self.o_proj.weight\
|
||||
.view(-1, self.num_heads, self.v_head_dim)
|
||||
W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\
|
||||
self.W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\
|
||||
.flatten(start_dim=0, end_dim=1).contiguous()
|
||||
|
||||
if is_fp8(weight_dtype) and requantization_enabled:
|
||||
W_UV_O, W_UV_O_scales = scaled_quantize(
|
||||
W_UV_O,
|
||||
self.reqaunt_weight_group_shape,
|
||||
quant_dtype=current_platform_fp8_dtype)
|
||||
# For FP8 save the transpose so we can use
|
||||
# `apply_w8a8_block_fp8_linear` directly
|
||||
self.W_UV_O = W_UV_O.T.contiguous()
|
||||
self.W_UV_O_scales = W_UV_O_scales.T.contiguous()
|
||||
else:
|
||||
self.W_UV_O = W_UV_O.to(act_dtype)
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.o_proj_absorbed = RowParallelLinear(
|
||||
self.W_UV_O.shape[0] * tp_size,
|
||||
self.W_UV_O.shape[1],
|
||||
bias=False,
|
||||
# TODO(lucas) figure out how to properly forward quant_method
|
||||
#quant_config=self.o_proj.quant_method,
|
||||
)
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.o_proj_absorbed.weight = torch.nn.Parameter(self.W_UV_O.T)
|
||||
else:
|
||||
if is_fp8(weight_dtype):
|
||||
raise NotImplementedError(
|
||||
"Currently fp8 requires matrix absorption")
|
||||
|
||||
self.W_UV = W_UV
|
||||
self.W_UK = W_UK
|
||||
self.W_Q = W_Q.flatten(start_dim=1)
|
||||
@ -400,7 +252,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
q: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
attn_metadata: T,
|
||||
attn_metadata: MLAMetadataCommon,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -410,7 +262,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: T,
|
||||
attn_metadata: MLAMetadataCommon,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -421,7 +273,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
k_c_normed: torch.Tensor, # key in unified attn
|
||||
k_pe: torch.Tensor, # value in unified attn
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: T,
|
||||
attn_metadata: MLAMetadataCommon,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if output is not None:
|
||||
@ -437,7 +289,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
|
||||
# Restore head dim (for rotary embedding)
|
||||
k_pe = k_pe.unsqueeze(1)
|
||||
assert hasattr(attn_metadata, "input_positions")
|
||||
|
||||
if is_decode:
|
||||
q_nope = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
|
||||
|
||||
@ -90,17 +90,6 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
seq_lens: Optional[List[int]]
|
||||
# seq_lens stored as a tensor.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||
# requests only.
|
||||
max_prefill_seq_len: int
|
||||
# Maximum sequence length among decode batch. 0 if there are prefill
|
||||
# requests only.
|
||||
max_decode_seq_len: int
|
||||
|
||||
# Whether or not if cuda graph is enabled.
|
||||
# Cuda-graph is currently enabled for decoding only.
|
||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||
use_cuda_graph: bool
|
||||
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
@ -111,18 +100,30 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
# |-- query_len ---|
|
||||
|
||||
# Maximum query length in the batch. None for decoding.
|
||||
max_query_len: Optional[int] = None
|
||||
max_query_len: Optional[int]
|
||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||
# requests only.
|
||||
max_prefill_seq_len: int
|
||||
# Maximum sequence length among decode batch. 0 if there are prefill
|
||||
# requests only.
|
||||
max_decode_seq_len: int
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
query_start_loc: Optional[torch.Tensor] = None
|
||||
query_start_loc: Optional[torch.Tensor]
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
seq_start_loc: Optional[torch.Tensor] = None
|
||||
seq_start_loc: Optional[torch.Tensor]
|
||||
|
||||
# Whether or not if cuda graph is enabled.
|
||||
# Cuda-graph is currently enabled for decoding only.
|
||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||
use_cuda_graph: bool
|
||||
|
||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||
# so far).
|
||||
context_lens_tensor: Optional[torch.Tensor] = None
|
||||
context_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# Max number of query tokens among request in the batch.
|
||||
max_decode_query_len: Optional[int] = None
|
||||
@ -130,23 +131,6 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
_cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None
|
||||
_cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None
|
||||
|
||||
# Begin encoder attn & enc/dec cross-attn fields...
|
||||
|
||||
# Encoder sequence lengths representation
|
||||
encoder_seq_lens: Optional[List[int]] = None
|
||||
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
|
||||
|
||||
# Maximum sequence length among encoder sequences
|
||||
max_encoder_seq_len: Optional[int] = None
|
||||
|
||||
# Number of tokens input to encoder
|
||||
num_encoder_tokens: Optional[int] = None
|
||||
|
||||
# Cross-attention memory-mapping data structures: slot mapping
|
||||
# and block tables
|
||||
cross_slot_mapping: Optional[torch.Tensor] = None
|
||||
cross_block_tables: Optional[torch.Tensor] = None
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
|
||||
if self.num_prefills == 0:
|
||||
@ -157,7 +141,10 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
|
||||
assert self.seq_lens is not None
|
||||
assert self.seq_lens_tensor is not None
|
||||
assert self.query_start_loc is not None
|
||||
assert self.context_lens_tensor is not None
|
||||
assert self.block_tables is not None
|
||||
assert self.seq_start_loc is not None
|
||||
|
||||
self._cached_prefill_metadata = ROCmFlashAttentionMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
@ -172,20 +159,12 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||
max_decode_seq_len=0,
|
||||
query_start_loc=None if self.query_start_loc is None else
|
||||
self.query_start_loc[:self.num_prefills + 1],
|
||||
seq_start_loc=None if self.seq_start_loc is None else
|
||||
self.seq_start_loc[:self.num_prefills + 1],
|
||||
context_lens_tensor=None if self.context_lens_tensor is None else
|
||||
self.context_lens_tensor[:self.num_prefills],
|
||||
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
|
||||
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
|
||||
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
|
||||
block_tables=self.block_tables[:self.num_prefills],
|
||||
use_cuda_graph=False,
|
||||
# Begin encoder & cross attn fields below...
|
||||
encoder_seq_lens=self.encoder_seq_lens,
|
||||
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
||||
max_encoder_seq_len=self.max_encoder_seq_len,
|
||||
cross_slot_mapping=self.cross_slot_mapping,
|
||||
cross_block_tables=self.cross_block_tables)
|
||||
)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
@ -215,12 +194,7 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
context_lens_tensor=None,
|
||||
block_tables=self.block_tables[self.num_prefills:],
|
||||
use_cuda_graph=self.use_cuda_graph,
|
||||
# Begin encoder & cross attn fields below...
|
||||
encoder_seq_lens=self.encoder_seq_lens,
|
||||
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
||||
max_encoder_seq_len=self.max_encoder_seq_len,
|
||||
cross_slot_mapping=self.cross_slot_mapping,
|
||||
cross_block_tables=self.cross_block_tables)
|
||||
)
|
||||
# Batch may be composed of prefill|decodes, adjust query start indices
|
||||
# to refer to the start of decodes when the two are split apart.
|
||||
# E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
|
||||
@ -330,97 +304,6 @@ def _make_alibi_bias(alibi_slopes: torch.Tensor,
|
||||
return attn_biases
|
||||
|
||||
|
||||
def _get_seq_len_block_table_args(
|
||||
attn_metadata: ROCmFlashAttentionMetadata,
|
||||
attn_type: str,
|
||||
) -> tuple:
|
||||
'''
|
||||
The particular choice of sequence-length
|
||||
attributes which should be extracted from attn_metadata is dependent
|
||||
on the type of attention operation.
|
||||
|
||||
Decoder attn -> select entirely decoder self-attention-related fields
|
||||
Encoder/decoder cross-attn -> select encoder sequence lengths
|
||||
Encoder attn -> select encoder sequence lengths fields
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention op
|
||||
* attn_type: encoder attention, decoder self-attention,
|
||||
encoder/decoder cross-attention
|
||||
|
||||
Returns:
|
||||
|
||||
* Appropriate sequence-lengths tensors for query and key
|
||||
* Appropriate max sequence-length scalar
|
||||
'''
|
||||
|
||||
partial_prefix_sum = 0
|
||||
if attn_type == AttentionType.ENCODER:
|
||||
assert attn_metadata.encoder_seq_lens is not None
|
||||
assert attn_metadata.encoder_seq_lens_tensor is not None
|
||||
query_seq_start_loc = torch.tensor(
|
||||
[0] + [
|
||||
partial_prefix_sum := partial_prefix_sum + i
|
||||
for i in attn_metadata.encoder_seq_lens
|
||||
],
|
||||
device=attn_metadata.encoder_seq_lens_tensor.device,
|
||||
dtype=attn_metadata.encoder_seq_lens_tensor.dtype)
|
||||
causal_mask = False
|
||||
|
||||
# No block tables associated with encoder attention
|
||||
return (query_seq_start_loc, attn_metadata.max_encoder_seq_len,
|
||||
query_seq_start_loc, attn_metadata.max_encoder_seq_len,
|
||||
attn_metadata.encoder_seq_lens, causal_mask)
|
||||
elif attn_type == AttentionType.DECODER:
|
||||
# Decoder self-attention
|
||||
# Choose max_seq_len based on whether we are in prompt_run
|
||||
assert attn_metadata.seq_lens is not None
|
||||
assert attn_metadata.seq_lens_tensor is not None
|
||||
query_seq_start_loc = torch.tensor(
|
||||
[0] + [
|
||||
partial_prefix_sum := partial_prefix_sum + i
|
||||
for i in attn_metadata.seq_lens
|
||||
],
|
||||
device=attn_metadata.seq_lens_tensor.device,
|
||||
dtype=attn_metadata.seq_lens_tensor.dtype)
|
||||
max_seq_len = attn_metadata.max_prefill_seq_len
|
||||
causal_mask = True
|
||||
|
||||
return (query_seq_start_loc, max_seq_len, query_seq_start_loc,
|
||||
max_seq_len, attn_metadata.seq_lens, causal_mask)
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
assert attn_metadata.seq_lens is not None
|
||||
assert attn_metadata.encoder_seq_lens_tensor is not None
|
||||
query_start_loc = torch.tensor(
|
||||
[0] + [
|
||||
partial_prefix_sum := partial_prefix_sum + i
|
||||
for i in attn_metadata.seq_lens
|
||||
],
|
||||
device=attn_metadata.encoder_seq_lens_tensor.device,
|
||||
dtype=attn_metadata.encoder_seq_lens_tensor.dtype)
|
||||
|
||||
partial_prefix_sum = 0
|
||||
assert attn_metadata.encoder_seq_lens is not None
|
||||
assert attn_metadata.seq_lens_tensor is not None
|
||||
key_seq_start_loc = torch.tensor(
|
||||
[0] + [
|
||||
partial_prefix_sum := partial_prefix_sum + i
|
||||
for i in attn_metadata.encoder_seq_lens
|
||||
],
|
||||
device=attn_metadata.seq_lens_tensor.device,
|
||||
dtype=attn_metadata.seq_lens_tensor.dtype)
|
||||
causal_mask = False
|
||||
|
||||
# Enc/dec cross-attention KVs match encoder sequence length;
|
||||
# cross-attention utilizes special "cross" block tables
|
||||
return (query_start_loc, attn_metadata.max_prefill_seq_len,
|
||||
key_seq_start_loc, attn_metadata.max_encoder_seq_len,
|
||||
attn_metadata.seq_lens, causal_mask)
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
|
||||
class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
"""
|
||||
If the input tensors contain prompt tokens, the layout is as follows:
|
||||
@ -463,13 +346,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"ROCmFlashAttention does not support blocksparse attention.")
|
||||
|
||||
if logits_soft_cap is None:
|
||||
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
||||
self.logits_soft_cap = 0.0
|
||||
else:
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
self.attn_type = attn_type
|
||||
if logits_soft_cap is not None:
|
||||
raise ValueError(
|
||||
"ROCmFlashAttention does not support attention logits soft "
|
||||
"capping.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
@ -494,14 +374,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
|
||||
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
|
||||
if self.use_triton_flash_attn:
|
||||
if logits_soft_cap is not None:
|
||||
raise ValueError(
|
||||
"ROCm Triton FlashAttention does not support attention"
|
||||
"logits soft capping."
|
||||
" please try using the ROCm CK "
|
||||
"FA backend instead by setting the env var "
|
||||
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
|
||||
|
||||
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
|
||||
triton_attention)
|
||||
self.attn_func = triton_attention
|
||||
@ -526,13 +398,14 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
self.use_naive_attn = True
|
||||
|
||||
if self.use_naive_attn:
|
||||
if logits_soft_cap is not None:
|
||||
raise ValueError(
|
||||
"ROCm Naive FlashAttention does not support"
|
||||
"attention logits soft capping.")
|
||||
|
||||
self.attn_func = _sdpa_attention
|
||||
logger.debug("Using naive (SDPA) attention in ROCmBackend")
|
||||
logger.debug("Using naive attention in ROCmBackend")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"ROCmFlashAttentionImpl")
|
||||
|
||||
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
|
||||
@ -554,37 +427,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention and PagedAttention.
|
||||
|
||||
For decoder-only models: query, key and value must be non-None.
|
||||
|
||||
For encoder/decoder models:
|
||||
* ROCmFlashAttentionImpl.forward() may be invoked for both self- and
|
||||
cross-attention layers.
|
||||
* For self-attention: query, key and value must be non-None.
|
||||
* For cross-attention:
|
||||
* Query must be non-None
|
||||
* During prefill, key and value must be non-None; key and value
|
||||
get cached for use during decode.
|
||||
* During decode, key and value may be None, since:
|
||||
(1) key and value tensors were cached during prefill, and
|
||||
(2) cross-attention key and value tensors do not grow during
|
||||
decode
|
||||
|
||||
A note on how the attn_type (attention type enum) argument impacts
|
||||
attention forward() behavior:
|
||||
|
||||
* DECODER: normal decoder-only behavior;
|
||||
use decoder self-attention block table
|
||||
* ENCODER: no KV caching; pass encoder sequence
|
||||
attributes (encoder_seq_lens/encoder_seq_lens_tensor/
|
||||
max_encoder_seq_len) to kernel, in lieu of decoder
|
||||
sequence attributes (seq_lens/seq_lens_tensor/max_seq_len)
|
||||
* ENCODER_DECODER: cross-attention behavior;
|
||||
use cross-attention block table for caching KVs derived
|
||||
from encoder hidden states; since KV sequence lengths
|
||||
will match encoder sequence lengths, pass encoder sequence
|
||||
attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
|
||||
max_encoder_seq_len)
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
@ -593,80 +435,54 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
NOTE: kv_cache will be an empty tensor with shape [0]
|
||||
for profiling run.
|
||||
attn_metadata: Metadata for attention.
|
||||
attn_type: Select attention type, between encoder attention,
|
||||
decoder self-attention, or encoder/decoder cross-
|
||||
attention. Defaults to decoder self-attention,
|
||||
which is the vLLM default generally
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
# Reminder: Please update docs/source/features/compatibility_matrix.md
|
||||
# If the feature combo become valid
|
||||
num_tokens, hidden_size = query.shape
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
if key is not None:
|
||||
assert value is not None
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
else:
|
||||
assert value is None
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
|
||||
if self.attn_type != AttentionType.ENCODER and kv_cache.numel() > 0:
|
||||
if kv_cache.numel() > 0:
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size)
|
||||
|
||||
if key is not None and value is not None:
|
||||
# Reshape the input keys and values and store them in the
|
||||
# cache. If kv_cache is not provided, the new key and value
|
||||
# tensors are not cached. This happens during the initial
|
||||
# memory profiling run.
|
||||
PagedAttention.write_to_paged_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping
|
||||
if self.attn_type != AttentionType.ENCODER_DECODER else
|
||||
attn_metadata.cross_slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# If kv_cache is not provided, the new key and value tensors are
|
||||
# not cached. This happens during the initial memory profiling run.
|
||||
PagedAttention.write_to_paged_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if self.attn_type != AttentionType.ENCODER:
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
else:
|
||||
assert attn_metadata.num_encoder_tokens is not None
|
||||
num_prefill_tokens = attn_metadata.num_encoder_tokens
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
|
||||
output = torch.empty_like(query)
|
||||
# Query for decode. KV is not needed because it is already cached.
|
||||
decode_query = query[num_prefill_tokens:]
|
||||
# QKV for prefill.
|
||||
query = query[:num_prefill_tokens]
|
||||
key = key[:num_prefill_tokens]
|
||||
value = value[:num_prefill_tokens]
|
||||
|
||||
if key is not None and value is not None \
|
||||
and self.attn_type != AttentionType.ENCODER_DECODER:
|
||||
key = key[:num_prefill_tokens]
|
||||
value = value[:num_prefill_tokens]
|
||||
assert query.shape[0] == num_prefill_tokens
|
||||
assert decode_query.shape[0] == num_decode_tokens
|
||||
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# Prompt run.
|
||||
# normal attention and DECODER
|
||||
if self.attn_type == AttentionType.DECODER and (
|
||||
kv_cache.numel() == 0 or prefill_meta.block_tables is None
|
||||
or prefill_meta.block_tables.numel() == 0):
|
||||
(query_seq_start_loc, query_max_seq_len, key_seq_start_loc,
|
||||
key_max_seq_len, seq_lens,
|
||||
causal_mask) = (prefill_meta.seq_start_loc,
|
||||
prefill_meta.max_prefill_seq_len,
|
||||
prefill_meta.seq_start_loc,
|
||||
prefill_meta.max_prefill_seq_len,
|
||||
attn_metadata.seq_lens, True)
|
||||
# prefix-enabled attention and ENCODER/ENCODER_DECODER
|
||||
else:
|
||||
(query_seq_start_loc, query_max_seq_len, key_seq_start_loc,
|
||||
key_max_seq_len, seq_lens,
|
||||
causal_mask) = _get_seq_len_block_table_args(
|
||||
prefill_meta, self.attn_type)
|
||||
# Prompt run.
|
||||
assert prefill_meta.seq_lens is not None
|
||||
if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
|
||||
# triton attention
|
||||
# When block_tables are not filled, it means q and k are the
|
||||
@ -677,18 +493,18 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
attn_masks = _make_alibi_bias(
|
||||
self.alibi_slopes,
|
||||
query.dtype,
|
||||
seq_lens,
|
||||
attn_metadata.seq_lens,
|
||||
make_attn_mask=False) # type: ignore
|
||||
out, _ = self.attn_func(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
None,
|
||||
query_seq_start_loc,
|
||||
key_seq_start_loc,
|
||||
query_max_seq_len,
|
||||
key_max_seq_len,
|
||||
causal_mask,
|
||||
prefill_meta.seq_start_loc,
|
||||
prefill_meta.seq_start_loc,
|
||||
prefill_meta.max_prefill_seq_len,
|
||||
prefill_meta.max_prefill_seq_len,
|
||||
True,
|
||||
self.scale,
|
||||
attn_masks[0][None]
|
||||
if attn_masks is not None else None,
|
||||
@ -712,12 +528,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
query_seq_start_loc,
|
||||
num_prefill_tokens,
|
||||
prefill_meta.seq_lens,
|
||||
num_tokens,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
self.scale,
|
||||
causal_mask,
|
||||
attn_masks,
|
||||
)
|
||||
else:
|
||||
@ -725,23 +540,19 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
cu_seqlens_q=query_seq_start_loc,
|
||||
cu_seqlens_k=key_seq_start_loc,
|
||||
cu_seqlens_q=prefill_meta.seq_start_loc,
|
||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||
max_seqlen_q=prefill_meta.max_prefill_seq_len,
|
||||
max_seqlen_k=key_max_seq_len,
|
||||
max_seqlen_k=prefill_meta.max_prefill_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
window_size=self.sliding_window,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
softcap=self.logits_soft_cap,
|
||||
)
|
||||
|
||||
# common code for prefill
|
||||
assert output[:num_prefill_tokens].shape == out.shape
|
||||
if output.shape[0] > num_prefill_tokens:
|
||||
output[:num_prefill_tokens] = out
|
||||
else:
|
||||
output = out
|
||||
output[:num_prefill_tokens] = out
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
output[:num_prefill_tokens] = PagedAttention.forward_prefix(
|
||||
@ -772,10 +583,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
decode_query.dtype, head_size, block_size, gqa_ratio,
|
||||
decode_meta.max_decode_seq_len)
|
||||
if use_custom:
|
||||
max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type
|
||||
!= AttentionType.ENCODER_DECODER else
|
||||
decode_meta.max_encoder_seq_len)
|
||||
assert max_seq_len is not None
|
||||
max_seq_len = decode_meta.max_decode_seq_len
|
||||
max_num_partitions = (
|
||||
(max_seq_len + _PARTITION_SIZE_ROCM - 1) //
|
||||
_PARTITION_SIZE_ROCM)
|
||||
@ -791,12 +599,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
device=output.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
if num_prefill_tokens > 0:
|
||||
out = output[num_prefill_tokens:]
|
||||
else:
|
||||
out = output
|
||||
ops.paged_attention_rocm(
|
||||
out,
|
||||
output[num_prefill_tokens:],
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
@ -805,12 +609,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
value_cache,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
decode_meta.block_tables
|
||||
if self.attn_type != AttentionType.ENCODER_DECODER else
|
||||
decode_meta.cross_block_tables,
|
||||
decode_meta.seq_lens_tensor
|
||||
if self.attn_type != AttentionType.ENCODER_DECODER else
|
||||
decode_meta.encoder_seq_lens_tensor,
|
||||
decode_meta.block_tables,
|
||||
decode_meta.seq_lens_tensor,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
self.alibi_slopes,
|
||||
@ -823,15 +623,9 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
decode_query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
decode_meta.block_tables
|
||||
if self.attn_type != AttentionType.ENCODER_DECODER else
|
||||
decode_meta.cross_block_tables,
|
||||
decode_meta.seq_lens_tensor
|
||||
if self.attn_type != AttentionType.ENCODER_DECODER else
|
||||
decode_meta.encoder_seq_lens_tensor,
|
||||
decode_meta.max_decode_seq_len
|
||||
if self.attn_type != AttentionType.ENCODER_DECODER else
|
||||
decode_meta.max_encoder_seq_len,
|
||||
decode_meta.block_tables,
|
||||
decode_meta.seq_lens_tensor,
|
||||
decode_meta.max_decode_seq_len,
|
||||
self.kv_cache_dtype,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
@ -841,7 +635,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
)
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
return output.view(num_tokens, hidden_size)
|
||||
|
||||
|
||||
def _sdpa_attention(
|
||||
|
||||
@ -20,7 +20,7 @@ from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionState, AttentionType)
|
||||
from vllm.attention.backends.mla.utils import MLACommonImpl, MLACommonMetadata
|
||||
from vllm.attention.backends.mla.utils import MLACommonImpl, MLAMetadataCommon
|
||||
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
|
||||
compute_slot_mapping_start_idx,
|
||||
is_block_tables_empty)
|
||||
@ -57,12 +57,14 @@ class TritonMLABackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int, # assumed to be 1 for MLA
|
||||
head_size: int,
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int, # assumed to be 1 for MLA
|
||||
kv_lora_rank: int, # passed via head_size
|
||||
) -> Tuple[int, ...]:
|
||||
return (num_blocks, block_size, head_size)
|
||||
# TODO(lucas): remove hardcoding k_pe size as 1/8th of kv_lora_rank
|
||||
k_pe_size = kv_lora_rank // 8
|
||||
return (num_blocks, block_size, kv_lora_rank + k_pe_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
@ -81,7 +83,7 @@ class TritonMLABackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> List[int]:
|
||||
return [576]
|
||||
return [512]
|
||||
|
||||
|
||||
class TritonMLAState(AttentionState):
|
||||
@ -91,7 +93,8 @@ class TritonMLAState(AttentionState):
|
||||
self._is_graph_capturing = False
|
||||
|
||||
@contextmanager
|
||||
def graph_capture(self, max_batch_size: int):
|
||||
def graph_capture(self, max_batch_size: int,
|
||||
positions: Optional[torch.Tensor]):
|
||||
self._is_graph_capturing = True
|
||||
|
||||
self._graph_slot_mapping = torch.full((max_batch_size, ),
|
||||
@ -104,9 +107,8 @@ class TritonMLAState(AttentionState):
|
||||
self._graph_block_tables = torch.from_numpy(
|
||||
self.runner.graph_block_tables).to(device=self.runner.device)
|
||||
|
||||
self._positions = torch.zeros((max_batch_size, ),
|
||||
dtype=torch.long,
|
||||
device=self.runner.device)
|
||||
assert positions is not None
|
||||
self._positions = positions
|
||||
|
||||
yield
|
||||
|
||||
@ -114,7 +116,6 @@ class TritonMLAState(AttentionState):
|
||||
del self._graph_slot_mapping
|
||||
del self._graph_seq_lens
|
||||
del self._graph_block_tables
|
||||
del self._positions
|
||||
|
||||
def graph_clone(self, batch_size: int):
|
||||
assert self._is_graph_capturing
|
||||
@ -158,7 +159,6 @@ class TritonMLAState(AttentionState):
|
||||
"slot_mapping": attn_metadata.slot_mapping,
|
||||
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
|
||||
"block_tables": attn_metadata.decode_metadata.block_tables,
|
||||
"input_positions": attn_metadata.decode_metadata.input_positions,
|
||||
}
|
||||
if is_encoder_decoder_model:
|
||||
raise NotImplementedError(
|
||||
@ -170,16 +170,10 @@ class TritonMLAState(AttentionState):
|
||||
input_buffers,
|
||||
attn_metadata,
|
||||
is_encoder_decoder_model: bool = False):
|
||||
input_positions = attn_metadata.input_positions
|
||||
num_positions = input_positions.shape[0]
|
||||
input_buffers["seq_lens_tensor"].copy_(
|
||||
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
|
||||
input_buffers["block_tables"].copy_(
|
||||
attn_metadata.decode_metadata.block_tables, non_blocking=True)
|
||||
# CUDA graph buffer is padded so only perform a partial copy based on
|
||||
# num_positions
|
||||
input_buffers["input_positions"][:num_positions].copy_(
|
||||
input_positions, non_blocking=True)
|
||||
if is_encoder_decoder_model:
|
||||
raise NotImplementedError(
|
||||
"TritonMLAState does not support encoder/decoder yet")
|
||||
@ -188,8 +182,8 @@ class TritonMLAState(AttentionState):
|
||||
return
|
||||
|
||||
|
||||
@dataclass
|
||||
class TritonMLAMetadata(MLACommonMetadata):
|
||||
@dataclass(kw_only=True)
|
||||
class TritonMLAMetadata(MLAMetadataCommon):
|
||||
"""Metadata for TritonMLAMetadata.
|
||||
|
||||
NOTE: Any python object stored here is not updated when it is
|
||||
@ -278,8 +272,10 @@ class TritonMLAMetadata(MLACommonMetadata):
|
||||
if self._cached_prefill_metadata is not None:
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
assert self.seq_lens is not None
|
||||
assert self.seq_lens_tensor is not None
|
||||
assert ((self.seq_lens is not None)
|
||||
or (self.encoder_seq_lens is not None))
|
||||
assert ((self.seq_lens_tensor is not None)
|
||||
or (self.encoder_seq_lens_tensor is not None))
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
query_start_loc = (None if self.query_start_loc is None else
|
||||
@ -307,7 +303,6 @@ class TritonMLAMetadata(MLACommonMetadata):
|
||||
multi_modal_placeholder_index_maps=self.
|
||||
multi_modal_placeholder_index_maps,
|
||||
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
|
||||
input_positions=input_positions,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=self.max_query_len,
|
||||
@ -319,6 +314,7 @@ class TritonMLAMetadata(MLACommonMetadata):
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=False,
|
||||
input_positions=input_positions,
|
||||
head_dim=self.head_dim)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@ -329,7 +325,8 @@ class TritonMLAMetadata(MLACommonMetadata):
|
||||
|
||||
if self._cached_decode_metadata is not None:
|
||||
return self._cached_decode_metadata
|
||||
assert self.seq_lens_tensor is not None
|
||||
assert ((self.seq_lens_tensor is not None)
|
||||
or (self.encoder_seq_lens_tensor is not None))
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
slot_mapping = (None if self.slot_mapping is None else
|
||||
@ -622,6 +619,8 @@ class TritonMLAMetadataBuilder(AttentionMetadataBuilder[TritonMLAMetadata]):
|
||||
self.multimodal_placeholder_maps.items()
|
||||
}
|
||||
|
||||
num_kv_splits = 8
|
||||
|
||||
return TritonMLAMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
@ -630,7 +629,6 @@ class TritonMLAMetadataBuilder(AttentionMetadataBuilder[TritonMLAMetadata]):
|
||||
seq_lens=seq_lens,
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
enable_kv_scales_calculation=True,
|
||||
input_positions=input_positions,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=max_query_len,
|
||||
max_decode_query_len=max_decode_query_len,
|
||||
@ -641,12 +639,13 @@ class TritonMLAMetadataBuilder(AttentionMetadataBuilder[TritonMLAMetadata]):
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=use_captured_graph,
|
||||
num_kv_splits=4, # TODO(lucas) add heuristic
|
||||
input_positions=input_positions,
|
||||
num_kv_splits=num_kv_splits,
|
||||
head_dim=self.runner.model_config.get_head_size(),
|
||||
)
|
||||
|
||||
|
||||
class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):
|
||||
class TritonMLAImpl(MLACommonImpl):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -689,7 +688,6 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):
|
||||
k_pe: torch.Tensor,
|
||||
attn_metadata: TritonMLAMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert isinstance(attn_metadata, TritonMLAMetadata)
|
||||
return self._forward_prefill_flash(q, kv_c_normed, k_pe,
|
||||
attn_metadata.seq_start_loc,
|
||||
attn_metadata.max_prefill_seq_len)
|
||||
@ -706,7 +704,6 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):
|
||||
raise NotImplementedError("FP8 Triton MLA not yet supported")
|
||||
|
||||
decode_meta = attn_metadata.decode_metadata
|
||||
assert decode_meta is not None
|
||||
B = q_nope.shape[0]
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from itertools import accumulate
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union, Optional)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -288,10 +288,11 @@ class CommonAttentionState(AttentionState):
|
||||
self._is_graph_capturing = False
|
||||
|
||||
@contextmanager
|
||||
def graph_capture(self, max_batch_size: int):
|
||||
def graph_capture(self, max_batch_size: int, positions: Optional[torch.Tensor]):
|
||||
assert positions is None
|
||||
|
||||
self._is_graph_capturing = True
|
||||
|
||||
|
||||
self._graph_slot_mapping = torch.full((max_batch_size, ),
|
||||
PAD_SLOT_ID,
|
||||
dtype=torch.long,
|
||||
@ -301,7 +302,7 @@ class CommonAttentionState(AttentionState):
|
||||
device=self.runner.device)
|
||||
self._graph_block_tables = torch.from_numpy(
|
||||
self.runner.graph_block_tables).to(device=self.runner.device)
|
||||
|
||||
|
||||
yield
|
||||
|
||||
self._is_graph_capturing = False
|
||||
|
||||
@ -200,9 +200,9 @@ class Attention(nn.Module):
|
||||
s += f", backend={self.impl.__class__.__name__}"
|
||||
return s
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
def process_weights_after_loading(self):
|
||||
if hasattr(self.impl, "process_weights_after_loading"):
|
||||
self.impl.process_weights_after_loading(act_dtype)
|
||||
self.impl.process_weights_after_loading()
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
|
||||
@ -738,20 +738,18 @@ class ModelConfig:
|
||||
|
||||
@property
|
||||
def is_deepseek_mla(self) -> bool:
|
||||
# TODO add deepseek_v3
|
||||
return (hasattr(self.hf_text_config, "model_type")) \
|
||||
and (self.hf_text_config.model_type in \
|
||||
('deepseek_v2', 'deepseek_v3'))\
|
||||
and (self.hf_text_config.kv_lora_rank is not None)
|
||||
return hasattr(self.hf_text_config,
|
||||
"model_type") and (self.hf_text_config.model_type
|
||||
in ('deepseek_v2', 'deepseek_v3'))
|
||||
|
||||
def get_head_size(self) -> int:
|
||||
# TODO remove hard code
|
||||
if self.is_deepseek_mla:
|
||||
qk_rope_head_dim = getattr(self.hf_text_config, "qk_rope_head_dim",
|
||||
0)
|
||||
if self.use_mla:
|
||||
return self.hf_text_config.kv_lora_rank + qk_rope_head_dim
|
||||
return self.hf_text_config.kv_lora_rank
|
||||
else:
|
||||
qk_rope_head_dim = getattr(self.hf_text_config,
|
||||
"qk_rope_head_dim", 0)
|
||||
qk_nope_head_dim = getattr(self.hf_text_config,
|
||||
"qk_nope_head_dim", 0)
|
||||
if qk_rope_head_dim and qk_nope_head_dim:
|
||||
@ -970,36 +968,9 @@ class ModelConfig:
|
||||
|
||||
@property
|
||||
def use_mla(self) -> bool:
|
||||
if self.quantization is not None and self.quantization not in [\
|
||||
"fp8", "compressed-tensors"]:
|
||||
logger.warning(
|
||||
"MLA is not supported with %s quantization. "
|
||||
"Disabling MLA.", self.quantization)
|
||||
return False
|
||||
|
||||
# If using a "compressed-tensors" checkpoint, check that all groups
|
||||
# have fp8 for both weights and activations.
|
||||
if self.quantization == "compressed-tensors":
|
||||
quant_config = self._parse_quant_hf_config()
|
||||
for group_name, cfg in quant_config.get("config_groups",
|
||||
("", {})).items():
|
||||
act_cfg = cfg.get("input_activations", {})
|
||||
act_type = None if act_cfg is None else act_cfg.get("type", "")
|
||||
w_cfg = cfg.get("weights", {})
|
||||
w_type = None if w_cfg is None else w_cfg.get("type", "")
|
||||
if act_type != "fp8" or w_type != "fp8":
|
||||
logger.warning(
|
||||
"compressed-tensors MLA support requires fp8 "
|
||||
"activations and weights in group '%s', but got "
|
||||
"activations type '%s' and weights type '%s'.\n "
|
||||
"Full config: %s", group_name, act_type, w_type,
|
||||
quant_config)
|
||||
return False
|
||||
|
||||
use_mla = (self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE)
|
||||
return use_mla
|
||||
|
||||
@property
|
||||
def supported_runner_types(self) -> Set[RunnerType]:
|
||||
return {_TASK_RUNNER[task] for task in self.supported_tasks}
|
||||
|
||||
@ -3252,16 +3223,6 @@ class VllmConfig:
|
||||
|
||||
current_platform.check_and_update_config(self)
|
||||
|
||||
# If MLA is enabled, force disable chunked prefill and prefix caching
|
||||
if self.model_config and self.model_config.use_mla:
|
||||
logger.info("MLA is enabled; forcing chunked prefill and prefix "
|
||||
"caching to be disabled.")
|
||||
self.scheduler_config.enable_chunked_prefill = False
|
||||
self.scheduler_config.chunked_prefill_enabled = False
|
||||
|
||||
if self.cache_config is not None:
|
||||
self.cache_config.enable_prefix_caching = False
|
||||
|
||||
if not self.instance_id:
|
||||
self.instance_id = random_uuid()[:5]
|
||||
|
||||
|
||||
@ -203,7 +203,7 @@ class OpenAIServingModels:
|
||||
for lora_request in self.lora_requests):
|
||||
return create_error_response(
|
||||
message=
|
||||
f"The lora adapter '{request.lora_name}' has already been "
|
||||
f"The lora adapter '{request.lora_name}' has already been"
|
||||
"loaded.",
|
||||
err_type="InvalidUserInput",
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
12
vllm/envs.py
@ -79,7 +79,6 @@ if TYPE_CHECKING:
|
||||
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128
|
||||
VLLM_MLA_DISABLE: bool = False
|
||||
VLLM_MLA_PERFORM_MATRIX_ABSORPTION: bool = True
|
||||
VLLM_MLA_DISABLE_REQUANTIZATION: bool = False
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@ -520,16 +519,7 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
||||
# storing more weights, W_Q_UK and W_UV_O, so can increase memory usage,
|
||||
# the is enabled by default
|
||||
"VLLM_MLA_PERFORM_MATRIX_ABSORPTION":
|
||||
lambda: bool(int(os.getenv("VLLM_MLA_PERFORM_MATRIX_ABSORPTION", "1"))),
|
||||
|
||||
# When running MLA with matrix-absorption enabled and fp8 quantized weights
|
||||
# we perform the matrix-absorption in float32 precision, after the matrices
|
||||
# are absorbed we requantize the weights back to fp8, this flag can be used
|
||||
# to disable the requantization step, and instead convert the absorbed
|
||||
# matrices to match the activation type. This can lead to higher memory and
|
||||
# compute usage but better preserves the accuracy of the original model.
|
||||
"VLLM_MLA_DISABLE_REQUANTIZATION":
|
||||
lambda: bool(int(os.getenv("VLLM_MLA_DISABLE_REQUANTIZATION", "0")))
|
||||
lambda: bool(int(os.getenv("VLLM_MLA_PERFORM_MATRIX_ABSORPTION", "1")))
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
||||
@ -307,8 +307,8 @@ class XGrammarLogitsProcessor:
|
||||
# Note: In this method, if the tensors have different dimensions
|
||||
# on CPU device fails, but on GPU it runs without error. Hence the
|
||||
# unsqueeze above for scores, to match the token bitmask shape
|
||||
xgr.apply_token_bitmask_inplace(
|
||||
scores, self.token_bitmask.to(scores.device, non_blocking=True))
|
||||
xgr.apply_token_bitmask_inplace(scores,
|
||||
self.token_bitmask.to(scores.device))
|
||||
if device_type != "cuda":
|
||||
scores = scores.to(dtype).to(device_type).squeeze()
|
||||
|
||||
|
||||
@ -1,146 +0,0 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@ -1,146 +0,0 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@ -598,27 +598,15 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
)
|
||||
|
||||
|
||||
# Adapted from: https://github.com/sgl-project/sglang/pull/2628
|
||||
def get_config_file_name(E: int,
|
||||
N: int,
|
||||
dtype: Optional[str],
|
||||
block_shape: Optional[List[int]] = None) -> str:
|
||||
def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
|
||||
device_name = current_platform.get_device_name().replace(" ", "_")
|
||||
dtype_selector = "" if not dtype else f",dtype={dtype}"
|
||||
block_shape_selector = ("" if not block_shape or not all(block_shape) else
|
||||
f",block_shape={block_shape}")
|
||||
return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501
|
||||
return f"E={E},N={N},device_name={device_name}{dtype_selector}.json"
|
||||
|
||||
|
||||
# Adapted from: https://github.com/sgl-project/sglang/pull/2628
|
||||
@functools.lru_cache
|
||||
def get_moe_configs(
|
||||
E: int,
|
||||
N: int,
|
||||
dtype: Optional[str],
|
||||
block_n: Optional[int] = None,
|
||||
block_k: Optional[int] = None,
|
||||
) -> Optional[Dict[int, Any]]:
|
||||
def get_moe_configs(E: int, N: int,
|
||||
dtype: Optional[str]) -> Optional[Dict[int, Any]]:
|
||||
"""
|
||||
Return optimized configurations for the fused MoE kernel.
|
||||
|
||||
@ -630,8 +618,7 @@ def get_moe_configs(
|
||||
|
||||
# First look up if an optimized configuration is available in the configs
|
||||
# directory
|
||||
block_shape = [block_n, block_k] if block_n and block_k else None
|
||||
json_file_name = get_config_file_name(E, N, dtype, block_shape)
|
||||
json_file_name = get_config_file_name(E, N, dtype)
|
||||
|
||||
config_file_path = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
|
||||
@ -658,34 +645,21 @@ def get_default_config(
|
||||
topk: int,
|
||||
dtype: Optional[str],
|
||||
is_marlin: bool,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
) -> Dict[str, int]:
|
||||
if dtype == "fp8_w8a8" and block_shape is not None:
|
||||
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
|
||||
# BLOCK_SIZE_K must be divisible by block_shape[1]
|
||||
config = {
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
}
|
||||
# A heuristic: fused marlin works faster with this config for small M
|
||||
if M <= E or (is_marlin and M <= 32):
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": block_shape[0],
|
||||
"BLOCK_SIZE_K": block_shape[1],
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3,
|
||||
'BLOCK_SIZE_M': 16,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 64,
|
||||
'GROUP_SIZE_M': 1
|
||||
}
|
||||
else:
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
}
|
||||
# A heuristic: fused marlin works faster with this config for small M
|
||||
if M <= E or (is_marlin and M <= 32):
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
@ -705,9 +679,7 @@ def try_get_optimal_moe_config(
|
||||
else:
|
||||
# First try to load optimal config from the file
|
||||
E, _, N = w2_shape
|
||||
block_n = block_shape[0] if block_shape else 0
|
||||
block_k = block_shape[1] if block_shape else 0
|
||||
configs = get_moe_configs(E, N, dtype, block_n, block_k)
|
||||
configs = get_moe_configs(E, N, dtype)
|
||||
|
||||
if configs:
|
||||
# If an optimal configuration map has been found, look up the
|
||||
@ -716,7 +688,13 @@ def try_get_optimal_moe_config(
|
||||
else:
|
||||
# Else use the default config
|
||||
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype,
|
||||
is_marlin, block_shape)
|
||||
is_marlin)
|
||||
# NOTE: For block-wise quant,
|
||||
# BLOCK_K must be divisible by block_shape[1]
|
||||
# BLOCK_N and BLOCK_M has no requirements
|
||||
if block_shape is not None and block_shape[0] != 0:
|
||||
config["BLOCK_SIZE_N"] = block_shape[0]
|
||||
config["BLOCK_SIZE_K"] = block_shape[1]
|
||||
return config
|
||||
|
||||
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from contextlib import suppress
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, cast
|
||||
from typing import Any, Dict, List, Literal, Optional, cast
|
||||
|
||||
import torch
|
||||
from compressed_tensors.config import (CompressionFormat,
|
||||
@ -45,7 +44,6 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
ignore: List[str],
|
||||
quant_format: str,
|
||||
sparsity_scheme_map: Dict[str, SparsityCompressionConfig],
|
||||
sparsity_ignore_list: List[str],
|
||||
kv_cache_scheme: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
@ -56,7 +54,6 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
self.target_scheme_map = target_scheme_map
|
||||
self.kv_cache_scheme = kv_cache_scheme
|
||||
self.sparsity_scheme_map = sparsity_scheme_map
|
||||
self.sparsity_ignore_list = sparsity_ignore_list
|
||||
self.config = config
|
||||
|
||||
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
|
||||
@ -101,7 +98,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
quant_format = cast(str, config.get("format"))
|
||||
target_scheme_map = cls._quantization_scheme_map_from_config(
|
||||
config=config)
|
||||
sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config(
|
||||
sparsity_scheme_map = cls._sparsity_scheme_map_from_config(
|
||||
config=config)
|
||||
|
||||
return cls(
|
||||
@ -109,23 +106,20 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
ignore=ignore,
|
||||
quant_format=quant_format,
|
||||
sparsity_scheme_map=sparsity_scheme_map,
|
||||
sparsity_ignore_list=sparsity_ignore_list,
|
||||
config=config,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _parse_sparsity_config(
|
||||
cls, config: Dict[str, Any]
|
||||
) -> Tuple[Dict[str, SparsityCompressionConfig], List[str]]:
|
||||
def _sparsity_scheme_map_from_config(
|
||||
cls, config: Dict[str,
|
||||
Any]) -> Dict[str, SparsityCompressionConfig]:
|
||||
"""
|
||||
:param config: The `quantization_config` dictionary from config.json
|
||||
:return: A tuple with two elements
|
||||
1. A dictionary mapping target layer names to their corresponding
|
||||
sparsity_config
|
||||
2. A list of layer names to ignore for sparsity
|
||||
:return: A dictionary mapping target layer names to their corresponding
|
||||
sparsity compression configurations
|
||||
"""
|
||||
if not (sparsity_config := config.get(SPARSITY_CONFIG_NAME)):
|
||||
return dict(), []
|
||||
return dict()
|
||||
|
||||
sparsity_config = SparsityCompressionConfig.model_validate(
|
||||
sparsity_config)
|
||||
@ -133,8 +127,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
target: sparsity_config
|
||||
for target in sparsity_config.targets or list()
|
||||
}
|
||||
sparsity_ignore_list = sparsity_config.ignore or list()
|
||||
return sparse_scheme_map, sparsity_ignore_list
|
||||
return sparse_scheme_map
|
||||
|
||||
@classmethod
|
||||
def _quantization_scheme_map_from_config(
|
||||
@ -359,6 +352,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
"""
|
||||
compressed-tensors supports non uniform in the following way:
|
||||
|
||||
ignore: List of layer_names or nn.Module names to be ignored.
|
||||
targets of config_groups: There can be N config_groups which each
|
||||
have a quantization scheme. Each config_group has a list of targets
|
||||
which can be a full layer_name, a regex for a layer_name, or
|
||||
@ -376,8 +370,6 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
# need to make accelerate optional in ct to do this
|
||||
|
||||
# Will be empty for models with only sparsity
|
||||
weight_quant = input_quant = None
|
||||
sparsity_scheme: Optional[SparsityCompressionConfig] = None
|
||||
if self.target_scheme_map:
|
||||
matched_target = find_matched_target(
|
||||
layer_name=layer_name,
|
||||
@ -387,24 +379,19 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
scheme_dict = self.target_scheme_map[matched_target]
|
||||
weight_quant = scheme_dict.get("weights")
|
||||
input_quant = scheme_dict.get("input_activations")
|
||||
elif self.sparsity_scheme_map:
|
||||
matched_target = find_matched_target(
|
||||
layer_name=layer_name,
|
||||
module=layer,
|
||||
targets=self.sparsity_scheme_map.keys())
|
||||
weight_quant = None
|
||||
input_quant = None
|
||||
|
||||
if self.sparsity_scheme_map:
|
||||
is_ignored = False
|
||||
with suppress(ValueError):
|
||||
is_ignored = find_matched_target(
|
||||
layer_name=layer_name,
|
||||
module=layer,
|
||||
targets=self.sparsity_ignore_list)
|
||||
|
||||
# if the layer is in the sparsity ignore list,
|
||||
# we should not apply any sparsity scheme
|
||||
|
||||
if not is_ignored:
|
||||
matched_target = find_matched_target(
|
||||
layer_name=layer_name,
|
||||
module=layer,
|
||||
targets=self.sparsity_scheme_map.keys())
|
||||
sparsity_scheme = self.sparsity_scheme_map.get(matched_target)
|
||||
# For models with sparsity, assumes that the sparse layers are also
|
||||
# quantized for cutlass 2:4 support
|
||||
sparsity_scheme: Optional[
|
||||
SparsityCompressionConfig] = self.sparsity_scheme_map.get(
|
||||
matched_target)
|
||||
|
||||
if self.supports_cutlass_24(weight_quant=weight_quant,
|
||||
input_quant=input_quant,
|
||||
@ -432,8 +419,6 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
# Raise error if device does not support the scheme
|
||||
# (e.g. fp8 needs ada lovelace)
|
||||
self._check_scheme_supported(scheme.get_min_capability())
|
||||
logger.debug("Using scheme: %s for %s", scheme.__class__.__name__,
|
||||
layer_name)
|
||||
return scheme
|
||||
|
||||
def get_cache_scale(self, name: str) -> Optional[str]:
|
||||
|
||||
@ -12,7 +12,7 @@ def is_activation_quantization_format(format: str) -> bool:
|
||||
_ACTIVATION_QUANTIZATION_FORMATS = [
|
||||
CompressionFormat.naive_quantized.value,
|
||||
CompressionFormat.int_quantized.value,
|
||||
CompressionFormat.float_quantized.value,
|
||||
CompressionFormat.float_quantized.value
|
||||
]
|
||||
return format in _ACTIVATION_QUANTIZATION_FORMATS
|
||||
|
||||
@ -68,7 +68,7 @@ def should_ignore_layer(layer_name: Optional[str],
|
||||
def check_equal_or_regex_match(layer_name: str,
|
||||
targets: Iterable[str]) -> bool:
|
||||
"""
|
||||
Checks whether a layer_name is exactly equal or a regex match for
|
||||
Checks whether a layer_name is exactly equal or a regex match for
|
||||
if target starts with 're:' to any target in list.
|
||||
"""
|
||||
for target in targets:
|
||||
@ -77,64 +77,17 @@ def check_equal_or_regex_match(layer_name: str,
|
||||
return False
|
||||
|
||||
|
||||
def _handle_fused_layers(func):
|
||||
"""
|
||||
Decorator to handle fused layers by mapping vllm fused layer names
|
||||
to their corresponding unfused layer names for quantization/pruning schemes.
|
||||
"""
|
||||
# fused_layer_name -> unfused_layer_name
|
||||
fused_layer_map = {
|
||||
"qkv_proj": "q_proj",
|
||||
"gate_up_proj": "up_proj",
|
||||
}
|
||||
|
||||
def fused_layer_handler(layer_name: Optional[str], module: Module,
|
||||
targets: Iterable[str]) -> Optional[str]:
|
||||
"""
|
||||
Wrapper function specifically designed to support the
|
||||
find_matched_target function.
|
||||
|
||||
It handles cases where the provided layer name corresponds to a
|
||||
fused layer in vllm, mapping it to its equivalent unfused layer name
|
||||
based on the predefined fused_layer_map. If the original layer name
|
||||
raises a ValueError in the wrapped function, this handler
|
||||
will attempt to resolve the issue by substituting with unfused
|
||||
layer name.
|
||||
|
||||
:param layer_name: Name of the layer, which may be fused.
|
||||
:param module: An instance of torch.nn.Module.
|
||||
:param targets: A list of target names or patterns to match.
|
||||
:return: The result of the wrapped find_matched_target function with
|
||||
the resolved layer name.
|
||||
:raises ValueError: If the layer name cannot be resolved to a
|
||||
valid target.
|
||||
"""
|
||||
try:
|
||||
return func(layer_name, module, targets)
|
||||
except ValueError:
|
||||
if layer_name is None:
|
||||
layer_name = ""
|
||||
parent_name, fused_proj_name = layer_name.rsplit(".", 1)
|
||||
unfused_proj_name = fused_layer_map.get(fused_proj_name,
|
||||
fused_proj_name)
|
||||
new_layer_name = f"{parent_name}.{unfused_proj_name}"
|
||||
return func(new_layer_name, module, targets)
|
||||
|
||||
return fused_layer_handler
|
||||
|
||||
|
||||
@_handle_fused_layers
|
||||
def find_matched_target(layer_name: Optional[str], module: Module,
|
||||
targets: Iterable[str]) -> str:
|
||||
"""
|
||||
Helper function to look up which "target" in the compressed-tensors
|
||||
config that a layer corresponds to.
|
||||
|
||||
Recall that a compressed-tensors configs has a concept of
|
||||
Recall that a compressed-tensors configs has a concept of
|
||||
config_groups, where each layer can be quantized with with a different
|
||||
scheme.
|
||||
|
||||
targets in each config_group will be a list of either layer names
|
||||
targets in each config_group will be a list of either layer names
|
||||
(or regexes corresponding to layer names) or names of torch Modules.
|
||||
|
||||
First, we try to match the layer_name with a target
|
||||
@ -150,13 +103,11 @@ def find_matched_target(layer_name: Optional[str], module: Module,
|
||||
|
||||
matched_target = (_find_first_match(layer_name, targets)
|
||||
or _find_first_match(module.__class__.__name__, targets,
|
||||
True)
|
||||
or _match_fused_layer(layer_name, targets))
|
||||
True))
|
||||
|
||||
if matched_target is None:
|
||||
raise ValueError(
|
||||
f"Unable to find matching target for {layer_name} in the "
|
||||
"compressed-tensors config.")
|
||||
raise ValueError(f"Unable to find matching target for {module} in the "
|
||||
"compressed-tensors config.")
|
||||
|
||||
return matched_target
|
||||
|
||||
@ -201,41 +152,3 @@ def _is_equal_or_regex_match(value: str,
|
||||
elif target == value:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _match_fused_layer(layer_name: str,
|
||||
target_layers: Iterable[str]) -> Optional[str]:
|
||||
"""
|
||||
Match a fused layer name to its corresponding individual layer in
|
||||
target_layers.
|
||||
|
||||
Examples:
|
||||
layer_name = "model.layers.0.self_attn.qkv_proj"
|
||||
target_layers = ["model.layers.0.self_attn.q_proj",
|
||||
"model.layers.0.self_attn.k_proj",
|
||||
"model.layers.0.self_attn.v_proj"]
|
||||
"""
|
||||
# Split into parent path and layer type
|
||||
# e.g., "model.layers.0.self_attn" and "qkv_proj"
|
||||
parent_path = ".".join(layer_name.split(".")[:-1])
|
||||
layer_type = layer_name.split(".")[-1]
|
||||
|
||||
if layer_type not in FUSED_LAYER_NAME_MAPPING:
|
||||
return None
|
||||
|
||||
possible_layer_types = FUSED_LAYER_NAME_MAPPING[layer_type]
|
||||
|
||||
# Look for a target layer that:
|
||||
# 1. Has the same parent path
|
||||
# 2. Ends with one of the possible individual layer types
|
||||
for target in target_layers:
|
||||
is_same_parent = parent_path in target
|
||||
is_matching_type = any(type_suffix in target
|
||||
for type_suffix in possible_layer_types)
|
||||
|
||||
if is_same_parent and is_matching_type and all(
|
||||
'.'.join([parent_path, type_suffix])
|
||||
for type_suffix in possible_layer_types):
|
||||
return target
|
||||
|
||||
return None
|
||||
|
||||
@ -21,8 +21,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
all_close_1d, apply_fp8_linear, convert_to_channelwise,
|
||||
cutlass_block_fp8_supported, cutlass_fp8_supported,
|
||||
normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
|
||||
cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
|
||||
requantize_with_max_scale)
|
||||
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
@ -134,7 +133,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
def __init__(self, quant_config: Fp8Config):
|
||||
self.quant_config = quant_config
|
||||
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
||||
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
|
||||
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||
# kernel for fast weight-only FP8 quantization
|
||||
@ -247,24 +245,20 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
layer.register_parameter("input_scale", None)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
# TODO(rob): refactor block quant into separate class.
|
||||
# Block quant doesn't need to process weights after loading
|
||||
if self.block_quant:
|
||||
assert self.quant_config.activation_scheme == "dynamic"
|
||||
if current_platform.is_rocm():
|
||||
weight, weight_scale_inv, _ = \
|
||||
weight, weight_scale, _ = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale_inv)
|
||||
else:
|
||||
weight = layer.weight.data
|
||||
weight_scale_inv = layer.weight_scale_inv.data
|
||||
|
||||
# Torch.compile cannot use Parameter subclasses.
|
||||
layer.weight = Parameter(weight, requires_grad=False)
|
||||
layer.weight_scale_inv = Parameter(weight_scale_inv,
|
||||
requires_grad=False)
|
||||
weight_scale=layer.weight_scale_inv,
|
||||
input_scale=layer.input_scale)
|
||||
layer.weight = Parameter(weight, requires_grad=False)
|
||||
layer.weight_scale_inv = Parameter(weight_scale,
|
||||
requires_grad=False)
|
||||
return
|
||||
|
||||
layer.weight = torch.nn.Parameter(layer.weight.data,
|
||||
requires_grad=False)
|
||||
# If checkpoint not serialized fp8, quantize the weights.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
|
||||
@ -361,7 +355,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
weight_scale=layer.weight_scale_inv,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
|
||||
)
|
||||
|
||||
return apply_fp8_linear(
|
||||
@ -514,9 +507,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_input_scale = None
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
# TODO (rob): refactor block quant into separate class.
|
||||
# Block quant doesn't need to process weights after loading
|
||||
if self.block_quant:
|
||||
assert self.quant_config.activation_scheme == "dynamic"
|
||||
if current_platform.is_rocm():
|
||||
w13_weight, w13_weight_scale_inv, w13_input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
@ -526,21 +518,22 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
layer.w2_weight, layer.w2_weight_scale_inv,
|
||||
layer.w2_input_scale)
|
||||
else:
|
||||
w13_weight = layer.w13_weight.data
|
||||
w13_weight_scale_inv = layer.w13_weight_scale_inv.data
|
||||
w2_weight = layer.w2_weight
|
||||
w2_weight_scale_inv = layer.w2_weight_scale_inv
|
||||
|
||||
# torch.compile() cannot use Parameter subclasses.
|
||||
layer.w13_weight = Parameter(w13_weight, requires_grad=False)
|
||||
layer.w13_weight_scale_inv = Parameter(w13_weight_scale_inv,
|
||||
requires_grad=False)
|
||||
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
|
||||
layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
|
||||
requires_grad=False)
|
||||
# Reset the parameter
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight,
|
||||
requires_grad=False)
|
||||
layer.w13_weight_scale_inv = torch.nn.Parameter(
|
||||
w13_weight_scale_inv, requires_grad=False)
|
||||
if w13_input_scale is not None:
|
||||
layer.w13_input_scale = torch.nn.Parameter(
|
||||
w13_input_scale, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight,
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale_inv = torch.nn.Parameter(
|
||||
w2_weight_scale_inv, requires_grad=False)
|
||||
if w2_input_scale is not None:
|
||||
layer.w2_input_scale = torch.nn.Parameter(
|
||||
w2_input_scale, requires_grad=False)
|
||||
return
|
||||
|
||||
# If checkpoint is fp16, quantize in place.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
# If rocm, use float8_e4m3fnuz as dtype
|
||||
|
||||
@ -1,146 +0,0 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@ -1,146 +0,0 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@ -1,146 +0,0 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@ -1,146 +0,0 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@ -1,146 +0,0 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@ -1,146 +0,0 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@ -1,146 +0,0 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@ -1,146 +0,0 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@ -1,146 +0,0 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@ -1,146 +0,0 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
@ -1,146 +0,0 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||