Silu v2 (#25074)
Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: elvircrn <elvircrn@gmail.com> Signed-off-by: Elvir Crnčević <elvircrn@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com> Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com>
This commit is contained in:
@ -1,5 +1,19 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""
|
||||
Comprehensive 3-way SiLU Benchmark Suite
|
||||
|
||||
This benchmark compares three SiLU implementations:
|
||||
1. SiLU V2 (CUDA) - Optimized CUDA kernel implementation
|
||||
2. Triton Kernel - Triton-based implementation
|
||||
|
||||
The suite generates detailed performance comparisons including:
|
||||
- Memory bandwidth utilization
|
||||
- Speedup ratios (baseline vs optimized implementations)
|
||||
- Performance across different expert configurations and token distributions
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
@ -7,7 +21,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
silu_mul_fp8_quant_deep_gemm_cuda,
|
||||
persistent_masked_m_silu_mul_quant,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
@ -94,6 +108,7 @@ def silu_mul_fp8_quant_deep_gemm_triton(
|
||||
num_parallel_tokens,
|
||||
group_size: int = 128,
|
||||
eps: float = 1e-10,
|
||||
expert_offsets: torch.Tensor = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
|
||||
|
||||
@ -174,7 +189,7 @@ def silu_mul_fp8_quant_deep_gemm_triton(
|
||||
|
||||
|
||||
# Parse generation strategies
|
||||
strategies = ["uniform", "max_t", "first_t"]
|
||||
strategies = ["random_imbalanced", "uniform", "max_t"]
|
||||
|
||||
|
||||
def benchmark(
|
||||
@ -195,15 +210,27 @@ def benchmark(
|
||||
current_platform.seed_everything(42 + seed_offset)
|
||||
y = torch.rand((E, T, 2 * H), dtype=torch.bfloat16, device="cuda").contiguous()
|
||||
|
||||
if gen_strategy == "uniform":
|
||||
r = torch.rand(size=(E,), device="cuda")
|
||||
if gen_strategy == "random_imbalanced":
|
||||
|
||||
def generate_expert_loads(n_e, total_tokens, ratio, device="cuda"):
|
||||
mean = total_tokens // n_e
|
||||
min_max = mean // ratio
|
||||
e = torch.ones(size=(E,), dtype=torch.int64, device=device) * mean
|
||||
e[0] = min_max
|
||||
r = torch.rand(size=(E - 1,))
|
||||
r /= r.sum()
|
||||
r *= total_tokens - min_max
|
||||
r = r.round().long()
|
||||
e[1:] = r.to(device=device)
|
||||
return e
|
||||
|
||||
tokens_per_expert = generate_expert_loads(E, total_tokens, 0.7, "cuda")
|
||||
elif gen_strategy == "uniform":
|
||||
r = torch.rand(size=(E,))
|
||||
r /= r.sum()
|
||||
r *= total_tokens
|
||||
tokens_per_expert = r.int()
|
||||
tokens_per_expert = torch.minimum(
|
||||
tokens_per_expert,
|
||||
torch.ones((E,), device=r.device, dtype=torch.int) * T,
|
||||
)
|
||||
r = r.round().long()
|
||||
tokens_per_expert = r
|
||||
elif gen_strategy == "max_t":
|
||||
tokens_per_expert = torch.empty(size=(E,), dtype=torch.int32, device="cuda")
|
||||
tokens_per_expert.fill_(total_tokens / E)
|
||||
@ -281,40 +308,34 @@ def benchmark(
|
||||
|
||||
|
||||
def create_comparison_plot(
|
||||
ratio, cuda_times, baseline_times, config_labels, strategy_name, id
|
||||
ratios, silu_v2_times, triton_times, config_labels, strategy_name, id
|
||||
):
|
||||
"""Create a comparison plot for a specific generation strategy"""
|
||||
fig, ax = plt.subplots(1, 1, figsize=(16, 6))
|
||||
fig, ax = plt.subplots(1, 1, figsize=(18, 6))
|
||||
|
||||
# Configure x-axis positions
|
||||
x = np.arange(len(config_labels))
|
||||
width = 0.35
|
||||
width = 0.25
|
||||
|
||||
# Execution Time plot (lower is better)
|
||||
ax.bar(x, silu_v2_times, width, label="SiLU V2 (CUDA)", alpha=0.8, color="blue")
|
||||
ax.bar(
|
||||
x - width / 2, cuda_times, width, label="CUDA Kernel", alpha=0.8, color="blue"
|
||||
)
|
||||
ax.bar(
|
||||
x + width / 2,
|
||||
baseline_times,
|
||||
width,
|
||||
label="Baseline",
|
||||
alpha=0.8,
|
||||
color="orange",
|
||||
x + width, triton_times, width, label="Triton Kernel", alpha=0.8, color="green"
|
||||
)
|
||||
|
||||
# Add speedup labels over each bar pair
|
||||
# Add speedup labels over each bar trio
|
||||
for i in range(len(x)):
|
||||
speedup = ratio[i]
|
||||
max_height = max(cuda_times[i], baseline_times[i])
|
||||
triton_v2_speedup = ratios[i][1] # triton/v2
|
||||
max_height = max(silu_v2_times[i], triton_times[i])
|
||||
|
||||
# Triton/V2 speedup
|
||||
ax.text(
|
||||
x[i],
|
||||
x[i] + width / 2,
|
||||
max_height + max_height * 0.02,
|
||||
f"{speedup:.2f}x",
|
||||
f"{triton_v2_speedup:.2f}x",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontweight="bold",
|
||||
fontsize=9,
|
||||
fontsize=8,
|
||||
)
|
||||
|
||||
ax.set_xlabel("Configuration")
|
||||
@ -332,56 +353,75 @@ def create_comparison_plot(
|
||||
|
||||
|
||||
def create_combined_plot(all_results):
|
||||
"""Create a combined plot with all strategies in one PNG"""
|
||||
num_strategies = len(all_results)
|
||||
fig, axes = plt.subplots(num_strategies, 1, figsize=(20, 6 * num_strategies))
|
||||
fig, axes = plt.subplots(num_strategies, 1, figsize=(22, 7 * num_strategies))
|
||||
|
||||
if num_strategies == 1:
|
||||
axes = [axes]
|
||||
|
||||
for idx, (
|
||||
strategy_name,
|
||||
ratio,
|
||||
cuda_times,
|
||||
baseline_times,
|
||||
all_ratios,
|
||||
all_silu_v2_results,
|
||||
all_triton_results,
|
||||
config_labels,
|
||||
config_x_axis,
|
||||
) in enumerate(all_results):
|
||||
ax = axes[idx]
|
||||
|
||||
# Flatten the nested results to get bandwidth percentages for plotting
|
||||
silu_v2_bandwidths = []
|
||||
triton_bandwidths = []
|
||||
flat_ratios = []
|
||||
|
||||
for config_results in all_silu_v2_results:
|
||||
for result in config_results:
|
||||
silu_v2_bandwidths.append(result[3]) # bandwidth percentage
|
||||
|
||||
for config_results in all_triton_results:
|
||||
for result in config_results:
|
||||
triton_bandwidths.append(result[3]) # bandwidth percentage
|
||||
|
||||
for config_ratios in all_ratios:
|
||||
for ratio in config_ratios:
|
||||
flat_ratios.append(ratio)
|
||||
|
||||
# Configure x-axis positions
|
||||
x = np.arange(len(config_labels))
|
||||
width = 0.35
|
||||
width = 0.25
|
||||
|
||||
# Execution Time plot (lower is better)
|
||||
# Bandwidth utilization plot (higher is better)
|
||||
ax.bar(
|
||||
x - width / 2,
|
||||
cuda_times,
|
||||
x,
|
||||
silu_v2_bandwidths,
|
||||
width,
|
||||
label="CUDA Kernel",
|
||||
label="SiLU V2 (CUDA)",
|
||||
alpha=0.8,
|
||||
color="blue",
|
||||
)
|
||||
ax.bar(
|
||||
x + width / 2,
|
||||
baseline_times,
|
||||
x + width,
|
||||
triton_bandwidths,
|
||||
width,
|
||||
label="Baseline",
|
||||
label="Triton Kernel",
|
||||
alpha=0.8,
|
||||
color="orange",
|
||||
color="green",
|
||||
)
|
||||
|
||||
# Add speedup labels over each bar pair
|
||||
# Add speedup labels over each bar trio
|
||||
for i in range(len(x)):
|
||||
speedup = ratio[i]
|
||||
max_height = max(cuda_times[i], baseline_times[i])
|
||||
triton_v2_speedup = flat_ratios[i] # triton/v2
|
||||
max_height = max(silu_v2_bandwidths[i], triton_bandwidths[i])
|
||||
|
||||
# Triton/V2 speedup
|
||||
ax.text(
|
||||
x[i],
|
||||
x[i] + width / 2,
|
||||
max_height + max_height * 0.02,
|
||||
f"{speedup:.2f}x",
|
||||
f"{triton_v2_speedup:.2f}x",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontweight="bold",
|
||||
fontsize=9,
|
||||
fontsize=8,
|
||||
)
|
||||
|
||||
ax.set_xlabel("Configuration")
|
||||
@ -395,7 +435,7 @@ def create_combined_plot(all_results):
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
filename = "../../silu_bench/silu_benchmark_combined.png"
|
||||
filename = "silu_benchmark_combined_3way.png"
|
||||
plt.savefig(filename, dpi=300, bbox_inches="tight")
|
||||
plt.show()
|
||||
|
||||
@ -405,7 +445,9 @@ def create_combined_plot(all_results):
|
||||
outer_dim = 7168
|
||||
configs = [
|
||||
# DeepSeekV3 Configs
|
||||
# (1, 56, 7168),
|
||||
(8, 1024, 7168),
|
||||
# (32, 56, 7168),
|
||||
# DeepSeekV3 Configs
|
||||
(32, 1024, 7168),
|
||||
# DeepSeekV3 Configs
|
||||
@ -417,6 +459,7 @@ num_warmups = 20
|
||||
|
||||
strategy_descriptions = {
|
||||
"uniform": "Uniform Random",
|
||||
"random_imbalanced": "Imbalanced Random",
|
||||
"max_t": "Even Assignment",
|
||||
"first_t": "experts[0] = T, experts[1:] = 0",
|
||||
}
|
||||
@ -433,28 +476,31 @@ for id, strategy in enumerate(strategies):
|
||||
print(f"Testing strategy: {strategy_descriptions[strategy]}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
# Collect benchmark data for both algorithms
|
||||
# Collect benchmark data for all three algorithms
|
||||
config_labels = []
|
||||
config_x_axis = []
|
||||
all_cuda_results = []
|
||||
all_baseline_results = []
|
||||
all_silu_v2_results = []
|
||||
all_triton_results = []
|
||||
all_ratios = []
|
||||
|
||||
for E, T, H in configs:
|
||||
total_tokens_config = [8 * E, 16 * E, 32 * E, 64 * E, 128 * E, 256 * E]
|
||||
total_tokens_config = []
|
||||
for i in [8, 16, 32, 64, 128, 256, 512]:
|
||||
if i <= T:
|
||||
total_tokens_config.append(i * E)
|
||||
config_x_axis.append(total_tokens_config)
|
||||
|
||||
cuda_results = []
|
||||
baseline_results = []
|
||||
silu_v2_results = []
|
||||
triton_results = []
|
||||
ratios = []
|
||||
|
||||
for total_tokens in total_tokens_config:
|
||||
config_label = f"E={E},T={T},H={H},TT={total_tokens}"
|
||||
config_labels.append(config_label)
|
||||
|
||||
# CUDA kernel results
|
||||
time_ms_cuda, gflops, gbps, perc = benchmark(
|
||||
silu_mul_fp8_quant_deep_gemm_cuda,
|
||||
# SiLU V2 (CUDA kernel) results
|
||||
time_ms_silu_v2, gflops, gbps, perc = benchmark(
|
||||
persistent_masked_m_silu_mul_quant,
|
||||
E,
|
||||
T,
|
||||
H,
|
||||
@ -463,9 +509,9 @@ for id, strategy in enumerate(strategies):
|
||||
num_warmups=num_warmups,
|
||||
gen_strategy=strategy,
|
||||
)
|
||||
cuda_results.append((time_ms_cuda, gflops, gbps, perc))
|
||||
silu_v2_results.append((time_ms_silu_v2, gflops, gbps, perc))
|
||||
|
||||
# Baseline results
|
||||
# Triton kernel results
|
||||
time_ms_triton, gflops, gbps, perc = benchmark(
|
||||
silu_mul_fp8_quant_deep_gemm_triton,
|
||||
E,
|
||||
@ -476,12 +522,20 @@ for id, strategy in enumerate(strategies):
|
||||
num_warmups=num_warmups,
|
||||
gen_strategy=strategy,
|
||||
)
|
||||
baseline_results.append((time_ms_triton, gflops, gbps, perc))
|
||||
ratios.append(time_ms_triton / time_ms_cuda)
|
||||
triton_results.append((time_ms_triton, gflops, gbps, perc))
|
||||
|
||||
print(f"Completed: {config_label}")
|
||||
all_cuda_results.append(cuda_results)
|
||||
all_baseline_results.append(baseline_results)
|
||||
# Calculate speedup ratios (triton baseline / implementation)
|
||||
triton_v2_ratio = time_ms_triton / time_ms_silu_v2
|
||||
ratios.append(triton_v2_ratio)
|
||||
|
||||
print(
|
||||
f"Completed: {config_label}:"
|
||||
f" V2: {time_ms_silu_v2:.3f}ms,"
|
||||
f" Triton: {time_ms_triton:.3f}ms"
|
||||
)
|
||||
|
||||
all_silu_v2_results.append(silu_v2_results)
|
||||
all_triton_results.append(triton_results)
|
||||
all_ratios.append(ratios)
|
||||
|
||||
# Store results for combined plotting
|
||||
@ -489,8 +543,8 @@ for id, strategy in enumerate(strategies):
|
||||
(
|
||||
strategy_descriptions[strategy],
|
||||
all_ratios,
|
||||
all_cuda_results,
|
||||
all_baseline_results,
|
||||
all_silu_v2_results,
|
||||
all_triton_results,
|
||||
config_labels,
|
||||
config_x_axis,
|
||||
)
|
||||
@ -498,15 +552,18 @@ for id, strategy in enumerate(strategies):
|
||||
|
||||
# Print summary table for this strategy
|
||||
print(f"\nSummary Table - {strategy_descriptions[strategy]}:")
|
||||
print(f"{'Config':<20} {'CUDA Time(ms)':<12} {'Base Time(ms)':<12} {'Speedup':<8}")
|
||||
print("-" * 60)
|
||||
print(f" {'V2 Time(ms)':<12} {'Triton Time(ms)':<14} {'Triton/V2':<10}")
|
||||
print("-" * 90)
|
||||
|
||||
for i, (E, T, H) in enumerate(configs):
|
||||
speedup = baseline_results[i][0] / cuda_results[i][0]
|
||||
# Get the first result for each config (simplifying for summary)
|
||||
v2_time = silu_v2_results[i][0]
|
||||
triton_time = triton_results[i][0]
|
||||
triton_v2_speedup = triton_time / v2_time
|
||||
config_label = f"E={E:3d},T={T:4d},H={H:4d}"
|
||||
print(
|
||||
f"{config_label:<20} {cuda_results[i][0]:8.5f} "
|
||||
f"{baseline_results[i][0]:8.5f} {speedup:6.2f}x"
|
||||
f"{config_label:<20} {v2_time:8.5f} {triton_time:10.5f} "
|
||||
f"{triton_v2_speedup:8.2f}x"
|
||||
)
|
||||
|
||||
|
||||
@ -514,15 +571,14 @@ def create_total_tokens_plot(all_results):
|
||||
num_strategies = len(all_results)
|
||||
num_configs = len(configs)
|
||||
|
||||
# Create side-by-side subplots: 2 columns for speedup and bandwidth percentage
|
||||
fig, axs = plt.subplots(
|
||||
num_strategies, num_configs * 2, figsize=(28, 6 * num_strategies)
|
||||
num_strategies, num_configs * 2, figsize=(32, 8 * num_strategies)
|
||||
)
|
||||
|
||||
# Add main title to the entire figure
|
||||
fig.suptitle(
|
||||
"Performance Analysis: Speedup vs Bandwidth Utilization (Triton & CUDA)",
|
||||
fontsize=16,
|
||||
"Performance Analysis: Speedup vs Bandwidth Utilization (SiLU V2, and Triton)",
|
||||
fontsize=18,
|
||||
fontweight="bold",
|
||||
y=0.98,
|
||||
)
|
||||
@ -539,8 +595,8 @@ def create_total_tokens_plot(all_results):
|
||||
(
|
||||
strategy_name,
|
||||
all_ratios,
|
||||
all_cuda_results,
|
||||
all_baseline_results,
|
||||
all_silu_v2_results,
|
||||
all_triton_results,
|
||||
config_labels,
|
||||
config_x_axis,
|
||||
) = result
|
||||
@ -555,42 +611,54 @@ def create_total_tokens_plot(all_results):
|
||||
ratios = all_ratios[config_idx]
|
||||
total_tokens_values = config_x_axis[config_idx]
|
||||
|
||||
# Extract CUDA and Triton bandwidth percentages
|
||||
cuda_bandwidth_percentages = [
|
||||
result[3] for result in all_cuda_results[config_idx]
|
||||
# Extract speedup ratios
|
||||
triton_v2_ratios = [ratio for ratio in ratios]
|
||||
|
||||
# Extract bandwidth percentages for all implementations
|
||||
v2_bandwidth_percentages = [
|
||||
result[3] for result in all_silu_v2_results[config_idx]
|
||||
]
|
||||
triton_bandwidth_percentages = [
|
||||
result[3] for result in all_baseline_results[config_idx]
|
||||
result[3] for result in all_triton_results[config_idx]
|
||||
]
|
||||
|
||||
# Plot speedup ratios vs total tokens (left plot)
|
||||
ax_speedup.plot(
|
||||
total_tokens_values, ratios, "bo-", linewidth=3, markersize=8
|
||||
total_tokens_values,
|
||||
triton_v2_ratios,
|
||||
"go-",
|
||||
linewidth=3,
|
||||
markersize=8,
|
||||
label="Triton/V2 Speedup",
|
||||
)
|
||||
ax_speedup.set_title(
|
||||
f"{strategy_name}\nSpeedup (CUDA/Triton)\nE={E}, T={T}, H={H}",
|
||||
f"{strategy_name}\nSpeedup vs Baseline (Triton)\nE={E}, T={T}, H={H}",
|
||||
fontsize=12,
|
||||
fontweight="bold",
|
||||
)
|
||||
ax_speedup.set_xlabel("Total Tokens", fontweight="bold", fontsize=11)
|
||||
ax_speedup.set_ylabel("Speedup Ratio", fontweight="bold", fontsize=11)
|
||||
ax_speedup.legend(prop={"weight": "bold"})
|
||||
ax_speedup.grid(True, alpha=0.3)
|
||||
|
||||
# Plot bandwidth utilization (right plot)
|
||||
ax_bandwidth.plot(
|
||||
total_tokens_values,
|
||||
cuda_bandwidth_percentages,
|
||||
"ro-",
|
||||
v2_bandwidth_percentages,
|
||||
"o-",
|
||||
linewidth=3,
|
||||
markersize=8,
|
||||
label="CUDA",
|
||||
label="SiLU V2",
|
||||
color="blue",
|
||||
)
|
||||
ax_bandwidth.plot(
|
||||
total_tokens_values,
|
||||
triton_bandwidth_percentages,
|
||||
"go-",
|
||||
"o-",
|
||||
linewidth=3,
|
||||
markersize=8,
|
||||
label="Triton",
|
||||
color="green",
|
||||
)
|
||||
ax_bandwidth.set_title(
|
||||
f"{strategy_name}\nBandwidth Utilization (Hopper)\nE={E}, T={T}, H={H}",
|
||||
@ -618,38 +686,12 @@ def create_total_tokens_plot(all_results):
|
||||
for label in ax.get_xticklabels() + ax.get_yticklabels():
|
||||
label.set_fontweight("bold")
|
||||
|
||||
# Add value labels on speedup points
|
||||
for x, y in zip(total_tokens_values, ratios):
|
||||
# Add value labels on Triton/V2 speedup points
|
||||
for x, y in zip(total_tokens_values, triton_v2_ratios):
|
||||
ax_speedup.annotate(
|
||||
f"{y:.2f}x",
|
||||
(x, y),
|
||||
textcoords="offset points",
|
||||
xytext=(0, 12),
|
||||
ha="center",
|
||||
fontsize=10,
|
||||
fontweight="bold",
|
||||
bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.7),
|
||||
)
|
||||
|
||||
# Add value labels on CUDA bandwidth points
|
||||
for x, y in zip(total_tokens_values, cuda_bandwidth_percentages):
|
||||
ax_bandwidth.annotate(
|
||||
f"{y:.1f}%",
|
||||
(x, y),
|
||||
textcoords="offset points",
|
||||
xytext=(0, 12),
|
||||
ha="center",
|
||||
fontsize=9,
|
||||
fontweight="bold",
|
||||
bbox=dict(boxstyle="round,pad=0.2", facecolor="red", alpha=0.3),
|
||||
)
|
||||
|
||||
# Add value labels on Triton bandwidth points
|
||||
for x, y in zip(total_tokens_values, triton_bandwidth_percentages):
|
||||
ax_bandwidth.annotate(
|
||||
f"{y:.1f}%",
|
||||
(x, y),
|
||||
textcoords="offset points",
|
||||
xytext=(0, -15),
|
||||
ha="center",
|
||||
fontsize=9,
|
||||
@ -659,17 +701,20 @@ def create_total_tokens_plot(all_results):
|
||||
|
||||
plt.tight_layout()
|
||||
plt.subplots_adjust(top=0.93) # Make room for main title
|
||||
filename = "silu_benchmark_total_tokens.png"
|
||||
filename = "silu_benchmark_total_tokens_3way.png"
|
||||
plt.savefig(filename, dpi=300, bbox_inches="tight")
|
||||
plt.show()
|
||||
|
||||
return filename
|
||||
|
||||
|
||||
# Create combined plot with all strategies
|
||||
combined_plot_filename = create_total_tokens_plot(all_results)
|
||||
# Create comprehensive 3-way comparison plots
|
||||
combined_plot_filename = create_combined_plot(all_results)
|
||||
total_tokens_plot_filename = create_total_tokens_plot(all_results)
|
||||
|
||||
print(f"\n{'=' * 60}")
|
||||
print("Benchmark Complete!")
|
||||
print(f"Generated combined plot: {combined_plot_filename}")
|
||||
print(f"{'=' * 60}")
|
||||
print(f"\n{'=' * 80}")
|
||||
print("3-Way Benchmark Suite Complete!")
|
||||
print(f"Generated combined comparison plot: {combined_plot_filename}")
|
||||
print(f"Generated total tokens analysis plot: {total_tokens_plot_filename}")
|
||||
print("Compared: SiLU V2 (CUDA), and Triton implementations")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
Reference in New Issue
Block a user