[UX] Speedup DeepGEMM warmup with heuristics (#25619)

Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
Michael Goin
2025-10-13 10:59:27 -04:00
committed by GitHub
parent 10214b6935
commit 0d21b9b51e
3 changed files with 95 additions and 14 deletions

View File

@ -146,7 +146,11 @@ if TYPE_CHECKING:
VLLM_TPU_USING_PATHWAYS: bool = False
VLLM_USE_DEEP_GEMM: bool = True
VLLM_USE_DEEP_GEMM_E8M0: bool = True
VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False
VLLM_DEEP_GEMM_WARMUP: Literal[
"skip",
"full",
"relax",
] = "relax"
VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True
VLLM_USE_FLASHINFER_MOE_FP16: bool = False
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
@ -1088,9 +1092,21 @@ environment_variables: dict[str, Callable[[], Any]] = {
# JIT all the required kernels before model execution so there is no
# JIT'ing in the hot-path. However, this warmup increases the engine
# startup time by a couple of minutes.
# Set `VLLM_SKIP_DEEP_GEMM_WARMUP` to disable the warmup.
"VLLM_SKIP_DEEP_GEMM_WARMUP": lambda: bool(
int(os.getenv("VLLM_SKIP_DEEP_GEMM_WARMUP", "0"))
# Available options:
# - "skip" : Skip warmup.
# - "full" : Warmup deepgemm by running all possible gemm shapes the
# engine could encounter.
# - "relax" : Select gemm shapes to run based on some heuristics. The
# heuristic aims to have the same effect as running all possible gemm
# shapes, but provides no guarantees.
"VLLM_DEEP_GEMM_WARMUP": env_with_choices(
"VLLM_DEEP_GEMM_WARMUP",
"relax",
[
"skip",
"full",
"relax",
],
),
# Whether to use fused grouped_topk used for MoE expert selection.
"VLLM_USE_FUSED_MOE_GROUPED_TOPK": lambda: bool(

View File

@ -26,6 +26,55 @@ from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from vllm.utils.deep_gemm import fp8_gemm_nt, m_grouped_fp8_gemm_nt_contiguous
def _generate_optimal_warmup_m_values(
max_tokens: int, n: int, device: torch.device
) -> list[int]:
"""
Generate M values that cover all possible DeepGEMM kernel configurations.
Reference: https://github.com/deepseek-ai/DeepGEMM/blob/79f48ee15a82dd5fad5cd9beaa393c1f755e6b55/csrc/jit_kernels/heuristics/common.hpp
Args:
max_tokens: Maximum number of tokens to warmup for
n: The actual N dimension from the weight tensor
device: The torch device to get properties from.
"""
def ceil_div(a: int, b: int) -> int:
return (a + b - 1) // b
# DeepGEMM's possible block sizes
block_ms = [64, 128, 256]
block_ns = list(range(16, min(257, n + 1), 16))
num_sms = torch.cuda.get_device_properties(device).multi_processor_count
m_values = set()
# Always include small cases
m_values.update([1, 2, 4] + [i for i in range(8, 65, 8)])
# Collect M values where different wave patterns occur
for block_m in block_ms:
for block_n in block_ns:
if block_n > n:
continue
# Add key M boundaries for this block combination
for wave in range(1, 11): # Up to 10 waves
# M where this block config transitions to next wave
target_blocks = wave * num_sms
m = target_blocks * block_m // ceil_div(n, block_n)
if 1 <= m <= max_tokens:
m_values.add(m)
# Add block_m boundaries
for multiple in range(1, max_tokens // block_m + 1):
m = multiple * block_m
if m <= max_tokens:
m_values.add(m)
return sorted(m_values)
def _extract_data_from_linear_base_module(
m: torch.nn.Module,
) -> tuple[torch.Tensor, torch.Tensor, list[int]]:
@ -136,14 +185,27 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, max_tokens:
)
out = torch.empty((max_tokens, n), device=device, dtype=torch.bfloat16)
pbar = tqdm(total=max_tokens, desc=f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()})")
num_tokens = max_tokens
while num_tokens > 0:
# Use optimal M values only if VLLM_DEEP_GEMM_WARMUP is set to "relax".
# Otherwise warmup all token sizes to avoid JIT compilation in hotpath
if envs.VLLM_DEEP_GEMM_WARMUP == "relax":
m_values = _generate_optimal_warmup_m_values(max_tokens, n, device)
desc = f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()}) [relaxed]"
else:
assert envs.VLLM_DEEP_GEMM_WARMUP == "full", (
"Expected "
'VLLM_DEEP_GEMM_WARMUP env to be set to "full" but got '
f"{envs.VLLM_DEEP_GEMM_WARMUP}"
)
m_values = list(range(1, max_tokens + 1))
desc = f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()}) [all tokens]"
pbar = tqdm(total=len(m_values), desc=desc)
for num_tokens in m_values:
fp8_gemm_nt(
(a1q[:num_tokens], a1q_scales[:num_tokens]), (w, ws), out[:num_tokens]
)
pbar.update(1)
num_tokens -= 1
FP8_GEMM_NT_WARMUP_CACHE.add(w.size())
@ -195,12 +257,16 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
)
out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16)
# Generate M values in block_m increments (already optimized for MoE)
m_values = list(range(block_m, MAX_M + 1, block_m))
pbar = tqdm(
total=MAX_BLOCKS,
desc=f"DeepGemm(m_grouped_fp8_gemm_nt_contiguous) warmup (W={w.size()})",
total=len(m_values),
desc=f"DeepGemm(m_grouped_fp8_gemm_nt_contiguous) warmup (W={w.size()}) "
f"[{len(m_values)} values, block_m={block_m}]",
)
num_tokens = MAX_M
while num_tokens > 0:
for num_tokens in m_values:
m_grouped_fp8_gemm_nt_contiguous(
(a1q[:num_tokens], a1q_scales[:num_tokens]),
(w, w_scale),
@ -208,7 +274,6 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
expert_ids[:num_tokens],
)
pbar.update(1)
num_tokens = num_tokens - block_m
for w, ws in [(w1, w1_scale), (w2, w2_scale)]:
if w.size() not in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE:

View File

@ -29,7 +29,7 @@ def kernel_warmup(worker: "Worker"):
do_deep_gemm_warmup = (
envs.VLLM_USE_DEEP_GEMM
and is_deep_gemm_supported()
and not envs.VLLM_SKIP_DEEP_GEMM_WARMUP
and envs.VLLM_DEEP_GEMM_WARMUP != "skip"
)
if do_deep_gemm_warmup:
model = worker.get_model()