[Kernels][Bugfix] Use torch op for all kernels in FusedMoE forward. Add additional testing for cudagraphs. (#19717)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2025-06-25 02:22:58 -04:00
committed by GitHub
parent f59fc60fb3
commit 015fab8c2f
14 changed files with 379 additions and 238 deletions

View File

@ -29,7 +29,10 @@ MNK_FACTORS = [
(224, 1024, 1536),
(224, 3072, 1024),
(224, 3072, 1536),
(1024 * 128, 1024, 1024),
(32768, 1024, 1024),
# These sizes trigger wrong answers.
#(7232, 2048, 5120),
#(40000, 2048, 5120),
]
vllm_config = VllmConfig(parallel_config=ParallelConfig(
@ -232,8 +235,10 @@ def test_cutlass_moe_8_bit_no_graph(
topk: int,
per_act_token: bool,
per_out_ch: bool,
monkeypatch,
):
current_platform.seed_everything(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
with set_current_vllm_config(vllm_config):
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
per_out_ch)
@ -274,8 +279,10 @@ def test_cutlass_moe_8_bit_cuda_graph(
topk: int,
per_act_token: bool,
per_out_ch: bool,
monkeypatch,
):
current_platform.seed_everything(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
with set_current_vllm_config(vllm_config):
dtype = torch.half
@ -329,8 +336,10 @@ def test_cutlass_moe_8_bit_EP(
per_act_token: bool,
per_out_channel: bool,
ep_size: int,
monkeypatch,
):
current_platform.seed_everything(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
with set_current_vllm_config(vllm_config):
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
per_out_channel)

View File

@ -4,6 +4,9 @@
Run `pytest tests/kernels/test_moe.py`.
"""
import functools
from typing import Callable, Optional, Union
import pytest
import torch
from torch.nn import Parameter
@ -14,6 +17,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import vllm.model_executor.layers.fused_moe # noqa
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, modular_triton_fused_moe)
@ -40,7 +44,76 @@ vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
@pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128])
def run_moe_test(
baseline: Union[Callable, torch.Tensor],
moe_fn: Callable,
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
score: torch.Tensor,
topk: int,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
padding: bool = False,
use_compile: bool = False,
use_cudagraph: bool = False,
atol: float = 2e-2,
rtol: float = 0,
) -> torch.Tensor:
if isinstance(baseline, torch.Tensor):
baseline_output = baseline
else:
baseline_output = baseline(a,
w1,
w2,
score,
topk,
global_num_experts=global_num_experts,
expert_map=expert_map)
# Pad the weight if moe padding is enabled
if padding:
w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128]
w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128]
if use_compile:
moe_fn = torch.compile(moe_fn, backend="inductor", fullgraph=True)
torch._dynamo.mark_dynamic(a, 0)
torch._dynamo.mark_dynamic(score, 0)
test_output = moe_fn(a,
w1,
w2,
score,
topk,
global_num_experts=global_num_experts,
expert_map=expert_map)
if use_cudagraph:
test_output.fill_(0)
stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
test_output = moe_fn(a,
w1,
w2,
score,
topk,
global_num_experts=global_num_experts,
expert_map=expert_map)
torch.cuda.synchronize()
graph.replay()
torch.cuda.synchronize()
torch.testing.assert_close(test_output,
baseline_output,
atol=atol,
rtol=rtol)
return baseline_output
@pytest.mark.parametrize("m", [1, 33, 64, 222, 32768, 40000])
@pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("k", [128, 511, 1024])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@ -48,6 +121,7 @@ vllm_config.scheduler_config.max_model_len = 8192
@pytest.mark.parametrize("ep_size", EP_SIZE)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("padding", [True, False])
@pytest.mark.parametrize("chunk_size", [8192])
def test_fused_moe(
m: int,
n: int,
@ -57,7 +131,17 @@ def test_fused_moe(
ep_size: int,
dtype: torch.dtype,
padding: bool,
chunk_size: int,
monkeypatch,
):
current_platform.seed_everything(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
#
# Setup test data
#
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
@ -77,58 +161,70 @@ def test_fused_moe(
else:
e_map = None
m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=False,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_channel_quant=False,
block_shape=None)
#
# Setup test functions
#
m_fused_moe_fn = modular_triton_fused_moe(use_fp8_w8a8=False,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_channel_quant=False,
block_shape=None)
def m_fused_moe(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
score: torch.Tensor,
topk: int,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
) -> torch.Tensor:
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
return m_fused_moe_fn(a,
w1,
w2,
topk_weights,
topk_ids,
global_num_experts=global_num_experts,
expert_map=expert_map)
fused_moe_fn = functools.partial(fused_moe, renormalize=False)
#
# Run tests
#
runner = functools.partial(
run_moe_test,
a=a,
w1=w1,
w2=w2,
score=score,
topk=topk,
global_num_experts=e,
expert_map=e_map,
padding=padding,
)
# Note: for now use_compile will error out if the problem size is
# large enough to trigger chunking. I'm leaving the flag and
# setup code in case we are able to revisit this later.
use_compile = False
use_cudagraph = (n >= 1024 and k >= 1024
and current_platform.is_cuda_alike())
with set_current_vllm_config(vllm_config):
torch_output = torch_moe(a, w1, w2, score, topk, e_map)
iterative_output = iterative_moe(a,
w1,
w2,
score,
topk,
global_num_experts=e,
expert_map=e_map,
renormalize=False)
# Pad the weight if moe padding is enabled
if padding:
w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128]
torch.cuda.empty_cache()
w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128]
torch.cuda.empty_cache()
triton_output = fused_moe(a,
w1,
w2,
score,
topk,
global_num_experts=e,
expert_map=e_map,
renormalize=False)
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
m_triton_output = m_fused_moe(a,
w1,
w2,
topk_weights,
topk_ids,
global_num_experts=e,
expert_map=e_map)
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
torch.testing.assert_close(m_triton_output,
torch_output,
atol=2e-2,
rtol=0)
torch.testing.assert_close(iterative_output,
torch_output,
atol=2e-2,
rtol=0)
baseline_output = runner(torch_moe, iterative_moe)
runner(baseline_output,
fused_moe_fn,
use_compile=use_compile,
use_cudagraph=use_cudagraph)
runner(baseline_output,
m_fused_moe,
use_compile=use_compile,
use_cudagraph=use_cudagraph)
@pytest.mark.parametrize("m", [1, 32, 222])
@ -238,7 +334,12 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
w1_zp=w1_qzeros if has_zp else None,
w2_zp=w2_qzeros if has_zp else None,
block_shape=[0, group_size])
torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, e_map)
torch_output = torch_moe(a,
w1_ref,
w2_ref,
score,
topk,
expert_map=e_map)
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
@ -265,45 +366,51 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
pytest.skip("AITER ROCm test skip for float32")
# Instantiate our and huggingface's MoE blocks
config = MixtralConfig()
hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda")
vllm_moe = MixtralMoE(
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
params_dtype=dtype,
tp_size=1,
dp_size=1,
).cuda()
vllm_config.compilation_config.static_forward_context = dict()
with (set_current_vllm_config(vllm_config),
set_forward_context(None, vllm_config)):
config = MixtralConfig()
hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda")
vllm_moe = MixtralMoE(
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
params_dtype=dtype,
tp_size=1,
dp_size=1,
).cuda()
# Load the weights
vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data
for i in range(config.num_local_experts):
weights = (hf_moe.experts[i].w1.weight.data,
hf_moe.experts[i].w3.weight.data)
vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0)
vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data
# Load the weights
vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data
for i in range(config.num_local_experts):
weights = (hf_moe.experts[i].w1.weight.data,
hf_moe.experts[i].w3.weight.data)
vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0)
vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")
# vLLM uses 1D query [num_tokens, hidden_dim]
vllm_inputs = hf_inputs.flatten(0, 1)
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
hf_inputs = torch.randn(
(1, 64, config.hidden_size)).to(dtype).to("cuda")
# vLLM uses 1D query [num_tokens, hidden_dim]
vllm_inputs = hf_inputs.flatten(0, 1)
# Pad the weight if moe padding is enabled
if padding:
vllm_moe.experts.w13_weight = Parameter(F.pad(
vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[..., 0:-128],
requires_grad=False)
torch.cuda.empty_cache()
vllm_moe.experts.w2_weight = Parameter(F.pad(
vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., 0:-128],
requires_grad=False)
torch.cuda.empty_cache()
# Pad the weight if moe padding is enabled
if padding:
vllm_moe.experts.w13_weight = Parameter(F.pad(
vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[...,
0:-128],
requires_grad=False)
torch.cuda.empty_cache()
vllm_moe.experts.w2_weight = Parameter(F.pad(
vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[...,
0:-128],
requires_grad=False)
torch.cuda.empty_cache()
# Run forward passes for both MoE blocks
hf_states, _ = hf_moe.forward(hf_inputs)
vllm_states = vllm_moe.forward(vllm_inputs)
# Run forward passes for both MoE blocks
hf_states, _ = hf_moe.forward(hf_inputs)
vllm_states = vllm_moe.forward(vllm_inputs)
mixtral_moe_tol = {
torch.float32: 1e-3,
@ -546,7 +653,12 @@ def test_fused_marlin_moe(
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
with set_current_vllm_config(vllm_config):
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map)
torch_output = torch_moe(a,
w_ref1,
w_ref2,
score,
topk,
expert_map=e_map)
marlin_output = torch.ops.vllm.fused_marlin_moe(
a,

View File

@ -136,7 +136,7 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
device=w2.device,
block_size=quant_blocksize)
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk, None)
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk)
torch.testing.assert_close(torch_output,
cutlass_output,

View File

@ -6,9 +6,9 @@ from typing import Optional
import pytest
import torch
from tests.kernels.utils import torch_experts
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.modular_kernel import (
@ -164,22 +164,6 @@ vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
def torch_moe2(a, w1, w2, topk_weight, topk_ids):
M, K = a.shape
topk = topk_ids.shape[1]
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)
num_experts = w1.shape[0]
for i in range(num_experts):
mask = (topk_ids == i).view(-1)
if mask.sum():
out[mask] = SiluAndMul()(
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
return (out.view(M, -1, w2.shape[1]) *
topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
def _pplx_moe(
pgi: ProcessGroupInfo,
dp_size: int,
@ -210,8 +194,8 @@ def _pplx_moe(
group_name = cpu_group.group_name
with set_current_vllm_config(vllm_config):
torch_output = torch_moe2(a_full, w1_full, w2_full, topk_weights,
topk_ids)
torch_output = torch_experts(a_full, w1_full, w2_full, topk_weights,
topk_ids)
pplx_output = pplx_cutlass_moe(pgi, dp_size, a, w1, w2, w1_scale,
w2_scale, topk_weights, topk_ids,
a1_scale, out_dtype, per_act_token,

View File

@ -18,8 +18,8 @@ try:
except ImportError:
has_pplx = False
from tests.kernels.utils import torch_experts
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import override_config
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedExperts, BatchedPrepareAndFinalize, BatchedTritonExperts)
@ -163,29 +163,6 @@ def batched_moe(
return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts)
# Note: same as torch_moe but with fused_topk factored out.
def torch_moe2(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor:
M, K = a.shape
topk = topk_ids.shape[1]
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)
num_experts = w1.shape[0]
for i in range(num_experts):
mask = (topk_ids == i).view(-1)
if mask.sum():
out[mask] = SiluAndMul()(
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
return (out.view(M, -1, w2.shape[1]) *
topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
@pytest.mark.parametrize("m", [1, 33, 64, 222])
@pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("k", [128, 512, 1024])
@ -209,7 +186,7 @@ def test_fused_moe_batched_experts(
with set_current_vllm_config(vllm_config):
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
baseline_output = torch_experts(a, w1, w2, topk_weight, topk_ids)
torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids)
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids)
@ -409,7 +386,7 @@ def pplx_moe(
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
use_compile: bool = True,
use_compile: bool = False,
use_cudagraphs: bool = True,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
@ -470,10 +447,16 @@ def pplx_moe(
w1_chunk = chunk_by_rank(w1, rank, world_size).to(device)
w2_chunk = chunk_by_rank(w2, rank, world_size).to(device)
# Note: for now use_compile will error out if the problem size is
# large enough to trigger chunking. I'm leaving the flag and
# setup code in case we are able to revisit this later.
if use_compile:
_fused_experts = torch.compile(fused_experts,
backend='inductor',
fullgraph=True)
torch._dynamo.mark_dynamic(a_chunk, 0)
torch._dynamo.mark_dynamic(chunk_topk_weight, 0)
torch._dynamo.mark_dynamic(chunk_topk_ids, 0)
else:
_fused_experts = fused_experts
@ -576,7 +559,7 @@ def _pplx_moe(
with set_current_vllm_config(vllm_config), override_config(moe_config):
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
torch_output = torch_experts(a, w1, w2, topk_weight, topk_ids)
pplx_output = pplx_moe(group_name, pgi.rank, pgi.world_size, dp_size,
a, w1, w2, topk_weight, topk_ids)
# TODO (bnell): fix + re-enable

View File

@ -403,19 +403,24 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, SEEDS))
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
@torch.inference_mode()
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
block_size = [block_m, block_m]
dtype = torch.bfloat16
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
monkeypatch):
if topk > E:
pytest.skip(f"Skipping test: topk={topk} > E={E}")
if not _valid_deep_gemm_shape(M, N, K):
pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}")
chunk_size = 1024
torch.manual_seed(seed)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
block_size = [block_m, block_m]
dtype = torch.bfloat16
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
@ -451,6 +456,14 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i])
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i])
# Note: for now use_compile will error out if the problem size is
# large enough to trigger chunking. I'm leaving the flag and
# setup code in case we are able to revisit this later.
use_compile = False
use_cudagraph = (chunk_size < M and N >= 1024 and K >= 1024
and current_platform.is_cuda_alike())
# Set the context to avoid lots of warning spam.
with set_current_vllm_config(vllm_config):
if M >= 128:
@ -463,7 +476,29 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
topk_weights, topk_ids, token_expert_indices = fused_topk(
a, score.float(), topk, False)
out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids)
if use_compile:
deep_gemm_moe_fp8_fn = torch.compile(deep_gemm_moe_fp8,
backend="inductor",
fullgraph=True)
torch._dynamo.mark_dynamic(a, 0)
torch._dynamo.mark_dynamic(topk_weights, 0)
torch._dynamo.mark_dynamic(topk_ids, 0)
else:
deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
topk_ids)
if use_cudagraph:
out.fill_(0)
stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
topk_ids)
torch.cuda.synchronize()
graph.replay()
torch.cuda.synchronize()
#print(f"{out.sum()=}")
#print(f"{ref_out.sum()=}")

View File

@ -1054,12 +1054,21 @@ def compute_max_diff(output, output_ref):
torch.abs(output_ref))
def torch_moe(a, w1, w2, score, topk, expert_map):
def torch_experts(a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None) -> torch.Tensor:
assert (global_num_experts == -1
or (global_num_experts == w1.shape[0] and expert_map is None)
or (expert_map is not None
and global_num_experts == expert_map.shape[0]))
topk = topk_ids.shape[1]
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
if expert_map is not None:
@ -1073,6 +1082,19 @@ def torch_moe(a, w1, w2, score, topk, expert_map):
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
def torch_moe(a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
score: torch.Tensor,
topk: int,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None) -> torch.Tensor:
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
return torch_experts(a, w1, w2, topk_weight, topk_ids, global_num_experts,
expert_map)
def torch_moe_single(a, w, score, topk):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)

View File

@ -981,6 +981,7 @@ def compute_hash() -> str:
"VLLM_DP_RANK",
"VLLM_DP_SIZE",
"VLLM_USE_STANDALONE_COMPILE",
"VLLM_FUSED_MOE_CHUNK_SIZE",
]
for key in environment_variables_to_hash:
if key in environment_variables:

View File

@ -41,24 +41,24 @@ def run_cutlass_moe_fp8(
assert w1.dtype == torch.float8_e4m3fn
assert w2.dtype == torch.float8_e4m3fn
if expert_num_tokens is None:
assert a1q.shape[1] == w1.shape[2], "Hidden size mismatch w1"
assert a1q.size(1) == w1.size(2), "Hidden size mismatch w1"
else:
assert a1q.shape[2] == w1.shape[2], "Hidden size mismatch w1"
assert w1.shape[1] == w2.shape[2] * 2, "Hidden size mismatch w2"
assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[
1] == w1.shape[1], "W1 scale shape mismatch"
assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[
1] == w2.shape[1], "W2 scale shape mismatch"
assert w1.shape[0] == w2.shape[0], "Expert number mismatch"
assert a1q_scale is None or a1q_scale.dim(
) == 0 or a1q_scale.shape[0] == 1 or a1q_scale.shape[0] == a1q.shape[
0], "Input scale shape mismatch"
assert w1.shape[0] == w2.shape[0], "Weights expert number mismatch"
assert w1.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
assert w1.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
assert a2_scale is None or a2_scale.dim(
) == 0 or a2_scale.shape[0] == 1 or a2_scale.shape[0] == a1q.shape[
0], "Intermediate scale shape mismatch"
assert a1q.size(2) == w1.size(2), "Hidden size mismatch w1"
assert w1.size(1) == w2.size(2) * 2, "Hidden size mismatch w2"
assert w1_scale.dim() == 1 or w1_scale.size(
1) == 1 or w1_scale.shape[1] == w1.size(1), "W1 scale shape mismatch"
assert w2_scale.dim() == 1 or w2_scale.size(
1) == 1 or w2_scale.shape[1] == w2.size(1), "W2 scale shape mismatch"
assert w1.size(0) == w2.size(0), "Expert number mismatch"
assert a1q_scale is None or a1q_scale.dim() == 0 or a1q_scale.size(
0) == 1 or a1q_scale.size(
0) == a1q.shape[0], "Input scale shape mismatch"
assert w1.size(0) == w2.size(0), "Weights expert number mismatch"
assert w1.size(0) == w1_scale.size(0), "w1 scales expert number mismatch"
assert w1.size(0) == w2_scale.size(0), "w2 scales expert number mismatch"
assert a2_scale is None or a2_scale.dim() == 0 or a2_scale.size(
0) == 1 or a2_scale.size(
0) == a1q.shape[0], "Intermediate scale shape mismatch"
assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype"
if expert_map is not None:
assert expert_num_tokens is None
@ -75,12 +75,12 @@ def run_cutlass_moe_fp8(
# their tokens are already contiguous for each expert as a result of
# the dispatch function.
M = a1q.shape[0] # non batched expert M
padded_M = a1q.shape[1] # batched expert M
M = a1q.size(0) # non batched expert M
padded_M = a1q.size(1) # batched expert M
_, K, N = w2.shape
device = a1q.device
assert w1.shape[2] == K
assert w1.size(2) == K
assert global_num_experts != -1
assert a1q_scale is not None
@ -91,8 +91,8 @@ def run_cutlass_moe_fp8(
else:
local_topk_ids = topk_ids
topk = local_topk_ids.shape[1]
local_E = w1.shape[0]
topk = local_topk_ids.size(1)
local_E = w1.size(0)
if use_batched_format:
assert expert_num_tokens is not None
@ -111,10 +111,10 @@ def run_cutlass_moe_fp8(
problem_sizes2, expert_num_tokens,
local_E, padded_M, N, K)
w1_scale = w1_scale.reshape(w1_scale.shape[0], -1)
w2_scale = w2_scale.reshape(w2_scale.shape[0], -1)
a1q = a1q.reshape(-1, a1q.shape[2])
a1q_scale = a1q_scale.reshape(-1, a1q_scale.shape[2]).contiguous()
w1_scale = w1_scale.reshape(w1_scale.size(0), -1)
w2_scale = w2_scale.reshape(w2_scale.size(0), -1)
a1q = a1q.reshape(-1, a1q.size(2))
a1q_scale = a1q_scale.reshape(-1, a1q_scale.size(2)).contiguous()
else:
expert_offsets = torch.empty((global_num_experts + 1),
@ -151,19 +151,19 @@ def run_cutlass_moe_fp8(
a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale
expert_offsets = expert_offsets[:-1]
ab_strides1 = torch.full((w1.shape[0], ),
ab_strides1 = torch.full((w1.size(0), ),
K,
device=device,
dtype=torch.int64)
c_strides1 = torch.full((w1.shape[0], ),
c_strides1 = torch.full((w1.size(0), ),
2 * N,
device=device,
dtype=torch.int64)
ab_strides2 = torch.full((w1.shape[0], ),
ab_strides2 = torch.full((w1.size(0), ),
N,
device=device,
dtype=torch.int64)
c_strides2 = torch.full((w1.shape[0], ),
c_strides2 = torch.full((w1.size(0), ),
K,
device=device,
dtype=torch.int64)
@ -237,7 +237,7 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
workspace2: tuple[int, ...] = ()
output: tuple[int, ...] = ()
if self.use_batched_format:
padded_M = aq.shape[1]
padded_M = aq.size(1)
workspace1 = (self.max_experts_per_worker, padded_M, max(N, K))
workspace2 = (self.max_experts_per_worker, padded_M, (N // 2))
output = (self.max_experts_per_worker, padded_M, K)
@ -332,7 +332,7 @@ def cutlass_moe_fp8(
"""
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False)
per_out_ch = w1_scale.numel() != w1_q.shape[0]
per_out_ch = w1_scale.numel() != w1_q.size(0)
out_dtype = a.dtype
@ -425,11 +425,11 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
assert (m == m_a), "input shape mismatch"
assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1"
assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype"
assert (topk_weights.shape[0] == m and topk_ids.shape[0]
assert (topk_weights.size(0) == m and topk_ids.size(0)
== m), ("topk must be provided for each row of a")
out_dtype = a.dtype
num_topk = topk_ids.shape[1]
num_topk = topk_ids.size(1)
expert_offsets = torch.empty((e + 1), dtype=torch.int32, device=device)
blockscale_offsets = torch.empty((e + 1), dtype=torch.int32, device=device)
@ -463,7 +463,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
out_dtype, device)
del rep_a_fp4, rep_a_blockscale
# hidden size dimension is split to one halfpytho sized tensor.
intermediate = torch.empty((m * num_topk, w1_fp4.shape[1] // 2),
intermediate = torch.empty((m * num_topk, w1_fp4.size(1) // 2),
device=device,
dtype=out_dtype)

View File

@ -48,7 +48,7 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor,
M = hidden_states.size(0)
_, K, N = w2.size()
if not _valid_deep_gemm_shape(M, N, K):
logger.debug("DeepGemm disabled: unalinged problem size.")
logger.debug("DeepGemm disabled: unaligned problem size.")
return False
if (w1.dtype != torch.float8_e4m3fn or w2.dtype != torch.float8_e4m3fn):

View File

@ -25,7 +25,7 @@ def dequant_fp8(expert_x_fp8: torch.Tensor,
expert_x_fp32 = expert_x_fp8.to(torch.float32).view(
num_experts, -1, DEEPEP_QUANT_BLOCK_SIZE)
expert_x_scales = expert_x_scales.view(num_experts, -1, 1)
return (expert_x_fp32 * expert_x_scales).view(expert_x_fp8.shape)
return (expert_x_fp32 * expert_x_scales).view(expert_x_fp8.size())
class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):

View File

@ -488,10 +488,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
if use_fp8_w8a8 or use_int8_w8a8:
assert B_scale is not None
assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0])
== B_scale.shape[-2])
assert (block_shape is None or triton.cdiv(B.shape[-1], block_shape[1])
== B_scale.shape[-1])
assert (block_shape is None
or triton.cdiv(B.size(-2), block_shape[0]) == B_scale.size(-2))
assert (block_shape is None
or triton.cdiv(B.size(-1), block_shape[1]) == B_scale.size(-1))
elif use_int8_w8a16 or use_int4_w4a16:
assert B_scale is not None
@ -500,19 +500,19 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
assert A_scale is None
assert B_scale is None
M = A.shape[0]
M = A.size(0)
num_tokens = M * top_k
EM = sorted_token_ids.shape[0]
if A.shape[0] < config["BLOCK_SIZE_M"]:
EM = sorted_token_ids.size(0)
if A.size(0) < config["BLOCK_SIZE_M"]:
# optimize for small batch_size.
# We assume that top_ids of each token is unique, so
# so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
# and we can skip some invalid blocks.
EM = min(sorted_token_ids.shape[0],
A.shape[0] * top_k * config['BLOCK_SIZE_M'])
EM = min(sorted_token_ids.size(0),
A.size(0) * top_k * config['BLOCK_SIZE_M'])
grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv(
B.shape[1], META['BLOCK_SIZE_N']), )
B.size(1), META['BLOCK_SIZE_N']), )
if (use_int8_w8a16 or use_int4_w4a16) and \
block_shape is not None and block_shape[1] > 0:
@ -522,16 +522,16 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_moe_wna16_cuda = should_moe_wna16_use_cuda(
num_valid_tokens=num_tokens,
group_size=block_shape[1],
num_experts=B.shape[0],
num_experts=B.size(0),
bit=4 if use_int4_w4a16 else 8)
config = config.copy()
config.update(
get_moe_wna16_block_config(config=config,
use_moe_wna16_cuda=use_moe_wna16_cuda,
num_valid_tokens=num_tokens,
size_k=A.shape[1],
size_n=B.shape[1],
num_experts=B.shape[1],
size_k=A.size(1),
size_n=B.size(1),
num_experts=B.size(1),
group_size=block_shape[1],
real_top_k=top_k,
block_size_m=config["BLOCK_SIZE_M"]))
@ -556,8 +556,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
B.shape[1],
A.shape[1],
B.size(1),
A.size(1),
EM,
num_tokens,
A.stride(0),
@ -573,7 +573,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
B_zp.stride(0) if B_zp is not None else 0,
B_zp.stride(2) if B_zp is not None else 0,
B_zp.stride(1) if B_zp is not None else 0,
block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0,
block_k_diviable=A.size(1) % config["BLOCK_SIZE_K"] == 0,
group_size=block_shape[1],
MUL_ROUTED_WEIGHT=mul_routed_weight,
top_k=top_k,
@ -599,8 +599,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
B.shape[1],
B.shape[2],
B.size(1),
B.size(2),
EM,
num_tokens,
A.stride(0),
@ -818,7 +818,7 @@ def try_get_optimal_moe_config(
M: int,
is_marlin: bool = False,
block_shape: Optional[list[int]] = None,
):
) -> dict[str, int]:
from vllm.model_executor.layers.fused_moe import get_config
override_config = get_config()
if override_config:
@ -873,10 +873,10 @@ def fused_topk(
renormalize: bool,
indices_type: Optional[torch.dtype] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert hidden_states.shape[0] == gating_output.shape[0], (
assert hidden_states.size(0) == gating_output.size(0), (
"Number of tokens mismatch")
M, _ = hidden_states.shape
M, _ = hidden_states.size()
topk_weights = torch.empty(M,
topk,
@ -915,7 +915,7 @@ def grouped_topk(
e_score_correction_bias: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]:
assert hidden_states.shape[0] == gating_output.shape[0], (
assert hidden_states.size(0) == gating_output.size(0), (
"Number of tokens mismatch")
if scoring_func == "softmax":
@ -925,7 +925,7 @@ def grouped_topk(
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
num_token = scores.shape[0]
num_token = scores.size(0)
if e_score_correction_bias is not None:
# Store original scores before applying correction bias. We use biased
# scores for expert selection but original scores for routing weights
@ -942,7 +942,7 @@ def grouped_topk(
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = group_mask.unsqueeze(-1).expand(
num_token, num_expert_group,
scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e]
scores.size(-1) // num_expert_group).reshape(num_token, -1) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(),
float("-inf")) # [n, e]
@ -1162,7 +1162,7 @@ def fused_experts(hidden_states: torch.Tensor,
allow_deep_gemm: bool = False) -> torch.Tensor:
# For now, disable DeepGemm for small N (<= 512) until better
# permute/unpermute ops are available.
N = w1.shape[1]
N = w1.size(1)
if (allow_deep_gemm and use_fp8_w8a8 and N > 512
and _valid_deep_gemm(hidden_states, w1, w2)):
assert apply_router_weight_on_input is False
@ -1233,13 +1233,13 @@ def fused_experts_impl(
) -> torch.Tensor:
# Check constraints.
if use_int4_w4a16:
assert hidden_states.shape[1] // 2 == w1.shape[
2], "Hidden size mismatch"
assert hidden_states.size(1) // 2 == w1.size(2), (
"Hidden size mismatch")
else:
assert hidden_states.shape[1] == w1.shape[2], (
f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[2]}")
assert hidden_states.size(1) == w1.size(2), (
f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}")
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert topk_weights.size() == topk_ids.size(), "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
@ -1247,12 +1247,12 @@ def fused_experts_impl(
torch.float32, torch.float16, torch.bfloat16
]
num_tokens = hidden_states.shape[0]
E, N, _ = w1.shape
K = w2.shape[1]
num_tokens = hidden_states.size(0)
E, N, _ = w1.size()
K = w2.size(1)
if global_num_experts == -1:
global_num_experts = E
top_k_num = topk_ids.shape[1]
top_k_num = topk_ids.size(1)
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
@ -1269,8 +1269,8 @@ def fused_experts_impl(
get_config_func = functools.partial(
try_get_optimal_moe_config,
w1.shape,
w2.shape,
w1.size(),
w2.size(),
top_k_num,
config_dtype,
block_shape=block_shape,
@ -1310,7 +1310,7 @@ def fused_experts_impl(
min((chunk + 1) * CHUNK_SIZE,
num_tokens))
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
tokens_in_chunk, _ = curr_hidden_states.shape
tokens_in_chunk, _ = curr_hidden_states.size()
if tokens_in_chunk == 0:
break
@ -1322,7 +1322,7 @@ def fused_experts_impl(
# do not need to be adjusted.
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
intermediate_cache2 = intermediate_cache2[:tokens_in_chunk *
topk_ids.shape[1]]
topk_ids.size(1)]
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
config = get_config_func(tokens_in_chunk)
@ -1398,7 +1398,7 @@ def fused_experts_impl(
per_channel_quant=per_channel_quant,
block_shape=block_shape)
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape),
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
out_hidden_states[begin_chunk_idx:end_chunk_idx])
return out_hidden_states
@ -1611,8 +1611,8 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
dtype=hidden_states.dtype)
config = try_get_optimal_moe_config(
w1.shape,
w2.shape,
w1.size(),
w2.size(),
top_k_num,
config_dtype,
num_tokens,

View File

@ -861,13 +861,11 @@ class FusedMoE(torch.nn.Module):
self.global_num_experts = num_experts
# For smuggling this layer into the fused moe custom op
self.use_direct_call = self.dp_size == 1
if not self.use_direct_call:
compilation_config = vllm_config.compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError("Duplicate layer name: {}".format(prefix))
compilation_config.static_forward_context[prefix] = self
self.layer_name = prefix
compilation_config = vllm_config.compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError("Duplicate layer name: {}".format(prefix))
compilation_config.static_forward_context[prefix] = self
self.layer_name = prefix
# Determine expert maps
if self.use_ep:
@ -1361,11 +1359,8 @@ class FusedMoE(torch.nn.Module):
def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
if self.use_direct_call:
return self.forward_impl(hidden_states, router_logits)
else:
return torch.ops.vllm.moe_forward(hidden_states, router_logits,
self.layer_name)
return torch.ops.vllm.moe_forward(hidden_states, router_logits,
self.layer_name)
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
full_router_logits: torch.Tensor):

View File

@ -69,7 +69,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
a1 = a1 * rank_topk_weights.to(a1.dtype)
repeat_cols = 4
repeat_rows = 1 if self.per_act_token else a1.shape[0]
repeat_rows = 1 if self.per_act_token else a1.size(0)
a1q, a1q_scale = moe_kernel_quantize_input(
a1, (None if self.per_act_token else a1_scale), self.quant_dtype,
self.per_act_token, self.block_shape)