[UX] Speedup DeepGEMM warmup with heuristics (#25619)
Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
24
vllm/envs.py
24
vllm/envs.py
@ -146,7 +146,11 @@ if TYPE_CHECKING:
|
||||
VLLM_TPU_USING_PATHWAYS: bool = False
|
||||
VLLM_USE_DEEP_GEMM: bool = True
|
||||
VLLM_USE_DEEP_GEMM_E8M0: bool = True
|
||||
VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False
|
||||
VLLM_DEEP_GEMM_WARMUP: Literal[
|
||||
"skip",
|
||||
"full",
|
||||
"relax",
|
||||
] = "relax"
|
||||
VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True
|
||||
VLLM_USE_FLASHINFER_MOE_FP16: bool = False
|
||||
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
|
||||
@ -1088,9 +1092,21 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# JIT all the required kernels before model execution so there is no
|
||||
# JIT'ing in the hot-path. However, this warmup increases the engine
|
||||
# startup time by a couple of minutes.
|
||||
# Set `VLLM_SKIP_DEEP_GEMM_WARMUP` to disable the warmup.
|
||||
"VLLM_SKIP_DEEP_GEMM_WARMUP": lambda: bool(
|
||||
int(os.getenv("VLLM_SKIP_DEEP_GEMM_WARMUP", "0"))
|
||||
# Available options:
|
||||
# - "skip" : Skip warmup.
|
||||
# - "full" : Warmup deepgemm by running all possible gemm shapes the
|
||||
# engine could encounter.
|
||||
# - "relax" : Select gemm shapes to run based on some heuristics. The
|
||||
# heuristic aims to have the same effect as running all possible gemm
|
||||
# shapes, but provides no guarantees.
|
||||
"VLLM_DEEP_GEMM_WARMUP": env_with_choices(
|
||||
"VLLM_DEEP_GEMM_WARMUP",
|
||||
"relax",
|
||||
[
|
||||
"skip",
|
||||
"full",
|
||||
"relax",
|
||||
],
|
||||
),
|
||||
# Whether to use fused grouped_topk used for MoE expert selection.
|
||||
"VLLM_USE_FUSED_MOE_GROUPED_TOPK": lambda: bool(
|
||||
|
||||
@ -26,6 +26,55 @@ from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
||||
from vllm.utils.deep_gemm import fp8_gemm_nt, m_grouped_fp8_gemm_nt_contiguous
|
||||
|
||||
|
||||
def _generate_optimal_warmup_m_values(
|
||||
max_tokens: int, n: int, device: torch.device
|
||||
) -> list[int]:
|
||||
"""
|
||||
Generate M values that cover all possible DeepGEMM kernel configurations.
|
||||
Reference: https://github.com/deepseek-ai/DeepGEMM/blob/79f48ee15a82dd5fad5cd9beaa393c1f755e6b55/csrc/jit_kernels/heuristics/common.hpp
|
||||
|
||||
Args:
|
||||
max_tokens: Maximum number of tokens to warmup for
|
||||
n: The actual N dimension from the weight tensor
|
||||
device: The torch device to get properties from.
|
||||
"""
|
||||
|
||||
def ceil_div(a: int, b: int) -> int:
|
||||
return (a + b - 1) // b
|
||||
|
||||
# DeepGEMM's possible block sizes
|
||||
block_ms = [64, 128, 256]
|
||||
block_ns = list(range(16, min(257, n + 1), 16))
|
||||
num_sms = torch.cuda.get_device_properties(device).multi_processor_count
|
||||
|
||||
m_values = set()
|
||||
|
||||
# Always include small cases
|
||||
m_values.update([1, 2, 4] + [i for i in range(8, 65, 8)])
|
||||
|
||||
# Collect M values where different wave patterns occur
|
||||
for block_m in block_ms:
|
||||
for block_n in block_ns:
|
||||
if block_n > n:
|
||||
continue
|
||||
|
||||
# Add key M boundaries for this block combination
|
||||
for wave in range(1, 11): # Up to 10 waves
|
||||
# M where this block config transitions to next wave
|
||||
target_blocks = wave * num_sms
|
||||
m = target_blocks * block_m // ceil_div(n, block_n)
|
||||
if 1 <= m <= max_tokens:
|
||||
m_values.add(m)
|
||||
|
||||
# Add block_m boundaries
|
||||
for multiple in range(1, max_tokens // block_m + 1):
|
||||
m = multiple * block_m
|
||||
if m <= max_tokens:
|
||||
m_values.add(m)
|
||||
|
||||
return sorted(m_values)
|
||||
|
||||
|
||||
def _extract_data_from_linear_base_module(
|
||||
m: torch.nn.Module,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, list[int]]:
|
||||
@ -136,14 +185,27 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, max_tokens:
|
||||
)
|
||||
out = torch.empty((max_tokens, n), device=device, dtype=torch.bfloat16)
|
||||
|
||||
pbar = tqdm(total=max_tokens, desc=f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()})")
|
||||
num_tokens = max_tokens
|
||||
while num_tokens > 0:
|
||||
# Use optimal M values only if VLLM_DEEP_GEMM_WARMUP is set to "relax".
|
||||
# Otherwise warmup all token sizes to avoid JIT compilation in hotpath
|
||||
if envs.VLLM_DEEP_GEMM_WARMUP == "relax":
|
||||
m_values = _generate_optimal_warmup_m_values(max_tokens, n, device)
|
||||
desc = f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()}) [relaxed]"
|
||||
else:
|
||||
assert envs.VLLM_DEEP_GEMM_WARMUP == "full", (
|
||||
"Expected "
|
||||
'VLLM_DEEP_GEMM_WARMUP env to be set to "full" but got '
|
||||
f"{envs.VLLM_DEEP_GEMM_WARMUP}"
|
||||
)
|
||||
m_values = list(range(1, max_tokens + 1))
|
||||
desc = f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()}) [all tokens]"
|
||||
|
||||
pbar = tqdm(total=len(m_values), desc=desc)
|
||||
|
||||
for num_tokens in m_values:
|
||||
fp8_gemm_nt(
|
||||
(a1q[:num_tokens], a1q_scales[:num_tokens]), (w, ws), out[:num_tokens]
|
||||
)
|
||||
pbar.update(1)
|
||||
num_tokens -= 1
|
||||
|
||||
FP8_GEMM_NT_WARMUP_CACHE.add(w.size())
|
||||
|
||||
@ -195,12 +257,16 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
|
||||
)
|
||||
out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16)
|
||||
|
||||
# Generate M values in block_m increments (already optimized for MoE)
|
||||
m_values = list(range(block_m, MAX_M + 1, block_m))
|
||||
|
||||
pbar = tqdm(
|
||||
total=MAX_BLOCKS,
|
||||
desc=f"DeepGemm(m_grouped_fp8_gemm_nt_contiguous) warmup (W={w.size()})",
|
||||
total=len(m_values),
|
||||
desc=f"DeepGemm(m_grouped_fp8_gemm_nt_contiguous) warmup (W={w.size()}) "
|
||||
f"[{len(m_values)} values, block_m={block_m}]",
|
||||
)
|
||||
num_tokens = MAX_M
|
||||
while num_tokens > 0:
|
||||
|
||||
for num_tokens in m_values:
|
||||
m_grouped_fp8_gemm_nt_contiguous(
|
||||
(a1q[:num_tokens], a1q_scales[:num_tokens]),
|
||||
(w, w_scale),
|
||||
@ -208,7 +274,6 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
|
||||
expert_ids[:num_tokens],
|
||||
)
|
||||
pbar.update(1)
|
||||
num_tokens = num_tokens - block_m
|
||||
|
||||
for w, ws in [(w1, w1_scale), (w2, w2_scale)]:
|
||||
if w.size() not in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE:
|
||||
|
||||
@ -29,7 +29,7 @@ def kernel_warmup(worker: "Worker"):
|
||||
do_deep_gemm_warmup = (
|
||||
envs.VLLM_USE_DEEP_GEMM
|
||||
and is_deep_gemm_supported()
|
||||
and not envs.VLLM_SKIP_DEEP_GEMM_WARMUP
|
||||
and envs.VLLM_DEEP_GEMM_WARMUP != "skip"
|
||||
)
|
||||
if do_deep_gemm_warmup:
|
||||
model = worker.get_model()
|
||||
|
||||
Reference in New Issue
Block a user