[Kernels][Bugfix] Use torch op for all kernels in FusedMoE forward. Add additional testing for cudagraphs. (#19717)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@ -29,7 +29,10 @@ MNK_FACTORS = [
|
||||
(224, 1024, 1536),
|
||||
(224, 3072, 1024),
|
||||
(224, 3072, 1536),
|
||||
(1024 * 128, 1024, 1024),
|
||||
(32768, 1024, 1024),
|
||||
# These sizes trigger wrong answers.
|
||||
#(7232, 2048, 5120),
|
||||
#(40000, 2048, 5120),
|
||||
]
|
||||
|
||||
vllm_config = VllmConfig(parallel_config=ParallelConfig(
|
||||
@ -232,8 +235,10 @@ def test_cutlass_moe_8_bit_no_graph(
|
||||
topk: int,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
monkeypatch,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
|
||||
per_out_ch)
|
||||
@ -274,8 +279,10 @@ def test_cutlass_moe_8_bit_cuda_graph(
|
||||
topk: int,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
monkeypatch,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
with set_current_vllm_config(vllm_config):
|
||||
dtype = torch.half
|
||||
|
||||
@ -329,8 +336,10 @@ def test_cutlass_moe_8_bit_EP(
|
||||
per_act_token: bool,
|
||||
per_out_channel: bool,
|
||||
ep_size: int,
|
||||
monkeypatch,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
|
||||
per_out_channel)
|
||||
|
||||
@ -4,6 +4,9 @@
|
||||
|
||||
Run `pytest tests/kernels/test_moe.py`.
|
||||
"""
|
||||
import functools
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
@ -14,6 +17,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_topk, modular_triton_fused_moe)
|
||||
@ -40,7 +44,76 @@ vllm_config.scheduler_config.max_num_seqs = 128
|
||||
vllm_config.scheduler_config.max_model_len = 8192
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128])
|
||||
def run_moe_test(
|
||||
baseline: Union[Callable, torch.Tensor],
|
||||
moe_fn: Callable,
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
score: torch.Tensor,
|
||||
topk: int,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
padding: bool = False,
|
||||
use_compile: bool = False,
|
||||
use_cudagraph: bool = False,
|
||||
atol: float = 2e-2,
|
||||
rtol: float = 0,
|
||||
) -> torch.Tensor:
|
||||
if isinstance(baseline, torch.Tensor):
|
||||
baseline_output = baseline
|
||||
else:
|
||||
baseline_output = baseline(a,
|
||||
w1,
|
||||
w2,
|
||||
score,
|
||||
topk,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
|
||||
# Pad the weight if moe padding is enabled
|
||||
if padding:
|
||||
w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128]
|
||||
w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128]
|
||||
|
||||
if use_compile:
|
||||
moe_fn = torch.compile(moe_fn, backend="inductor", fullgraph=True)
|
||||
torch._dynamo.mark_dynamic(a, 0)
|
||||
torch._dynamo.mark_dynamic(score, 0)
|
||||
|
||||
test_output = moe_fn(a,
|
||||
w1,
|
||||
w2,
|
||||
score,
|
||||
topk,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
|
||||
if use_cudagraph:
|
||||
test_output.fill_(0)
|
||||
stream = torch.cuda.Stream()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, stream=stream):
|
||||
test_output = moe_fn(a,
|
||||
w1,
|
||||
w2,
|
||||
score,
|
||||
topk,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
torch.cuda.synchronize()
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
torch.testing.assert_close(test_output,
|
||||
baseline_output,
|
||||
atol=atol,
|
||||
rtol=rtol)
|
||||
|
||||
return baseline_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1, 33, 64, 222, 32768, 40000])
|
||||
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
||||
@pytest.mark.parametrize("k", [128, 511, 1024])
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@ -48,6 +121,7 @@ vllm_config.scheduler_config.max_model_len = 8192
|
||||
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("padding", [True, False])
|
||||
@pytest.mark.parametrize("chunk_size", [8192])
|
||||
def test_fused_moe(
|
||||
m: int,
|
||||
n: int,
|
||||
@ -57,7 +131,17 @@ def test_fused_moe(
|
||||
ep_size: int,
|
||||
dtype: torch.dtype,
|
||||
padding: bool,
|
||||
chunk_size: int,
|
||||
monkeypatch,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
|
||||
|
||||
#
|
||||
# Setup test data
|
||||
#
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
@ -77,58 +161,70 @@ def test_fused_moe(
|
||||
else:
|
||||
e_map = None
|
||||
|
||||
m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=False,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
per_channel_quant=False,
|
||||
block_shape=None)
|
||||
#
|
||||
# Setup test functions
|
||||
#
|
||||
|
||||
m_fused_moe_fn = modular_triton_fused_moe(use_fp8_w8a8=False,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
per_channel_quant=False,
|
||||
block_shape=None)
|
||||
|
||||
def m_fused_moe(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
score: torch.Tensor,
|
||||
topk: int,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
return m_fused_moe_fn(a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
|
||||
fused_moe_fn = functools.partial(fused_moe, renormalize=False)
|
||||
|
||||
#
|
||||
# Run tests
|
||||
#
|
||||
runner = functools.partial(
|
||||
run_moe_test,
|
||||
a=a,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
score=score,
|
||||
topk=topk,
|
||||
global_num_experts=e,
|
||||
expert_map=e_map,
|
||||
padding=padding,
|
||||
)
|
||||
|
||||
# Note: for now use_compile will error out if the problem size is
|
||||
# large enough to trigger chunking. I'm leaving the flag and
|
||||
# setup code in case we are able to revisit this later.
|
||||
use_compile = False
|
||||
|
||||
use_cudagraph = (n >= 1024 and k >= 1024
|
||||
and current_platform.is_cuda_alike())
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
torch_output = torch_moe(a, w1, w2, score, topk, e_map)
|
||||
iterative_output = iterative_moe(a,
|
||||
w1,
|
||||
w2,
|
||||
score,
|
||||
topk,
|
||||
global_num_experts=e,
|
||||
expert_map=e_map,
|
||||
renormalize=False)
|
||||
|
||||
# Pad the weight if moe padding is enabled
|
||||
if padding:
|
||||
w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128]
|
||||
torch.cuda.empty_cache()
|
||||
w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128]
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
triton_output = fused_moe(a,
|
||||
w1,
|
||||
w2,
|
||||
score,
|
||||
topk,
|
||||
global_num_experts=e,
|
||||
expert_map=e_map,
|
||||
renormalize=False)
|
||||
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
m_triton_output = m_fused_moe(a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts=e,
|
||||
expert_map=e_map)
|
||||
|
||||
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
|
||||
torch.testing.assert_close(m_triton_output,
|
||||
torch_output,
|
||||
atol=2e-2,
|
||||
rtol=0)
|
||||
torch.testing.assert_close(iterative_output,
|
||||
torch_output,
|
||||
atol=2e-2,
|
||||
rtol=0)
|
||||
baseline_output = runner(torch_moe, iterative_moe)
|
||||
runner(baseline_output,
|
||||
fused_moe_fn,
|
||||
use_compile=use_compile,
|
||||
use_cudagraph=use_cudagraph)
|
||||
runner(baseline_output,
|
||||
m_fused_moe,
|
||||
use_compile=use_compile,
|
||||
use_cudagraph=use_cudagraph)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1, 32, 222])
|
||||
@ -238,7 +334,12 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
|
||||
w1_zp=w1_qzeros if has_zp else None,
|
||||
w2_zp=w2_qzeros if has_zp else None,
|
||||
block_shape=[0, group_size])
|
||||
torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, e_map)
|
||||
torch_output = torch_moe(a,
|
||||
w1_ref,
|
||||
w2_ref,
|
||||
score,
|
||||
topk,
|
||||
expert_map=e_map)
|
||||
|
||||
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
|
||||
|
||||
@ -265,45 +366,51 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
|
||||
pytest.skip("AITER ROCm test skip for float32")
|
||||
|
||||
# Instantiate our and huggingface's MoE blocks
|
||||
config = MixtralConfig()
|
||||
hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda")
|
||||
vllm_moe = MixtralMoE(
|
||||
num_experts=config.num_local_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
params_dtype=dtype,
|
||||
tp_size=1,
|
||||
dp_size=1,
|
||||
).cuda()
|
||||
vllm_config.compilation_config.static_forward_context = dict()
|
||||
with (set_current_vllm_config(vllm_config),
|
||||
set_forward_context(None, vllm_config)):
|
||||
config = MixtralConfig()
|
||||
hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda")
|
||||
vllm_moe = MixtralMoE(
|
||||
num_experts=config.num_local_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
params_dtype=dtype,
|
||||
tp_size=1,
|
||||
dp_size=1,
|
||||
).cuda()
|
||||
|
||||
# Load the weights
|
||||
vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data
|
||||
for i in range(config.num_local_experts):
|
||||
weights = (hf_moe.experts[i].w1.weight.data,
|
||||
hf_moe.experts[i].w3.weight.data)
|
||||
vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0)
|
||||
vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data
|
||||
# Load the weights
|
||||
vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data
|
||||
for i in range(config.num_local_experts):
|
||||
weights = (hf_moe.experts[i].w1.weight.data,
|
||||
hf_moe.experts[i].w3.weight.data)
|
||||
vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0)
|
||||
vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data
|
||||
|
||||
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
|
||||
hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")
|
||||
# vLLM uses 1D query [num_tokens, hidden_dim]
|
||||
vllm_inputs = hf_inputs.flatten(0, 1)
|
||||
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
|
||||
hf_inputs = torch.randn(
|
||||
(1, 64, config.hidden_size)).to(dtype).to("cuda")
|
||||
# vLLM uses 1D query [num_tokens, hidden_dim]
|
||||
vllm_inputs = hf_inputs.flatten(0, 1)
|
||||
|
||||
# Pad the weight if moe padding is enabled
|
||||
if padding:
|
||||
vllm_moe.experts.w13_weight = Parameter(F.pad(
|
||||
vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[..., 0:-128],
|
||||
requires_grad=False)
|
||||
torch.cuda.empty_cache()
|
||||
vllm_moe.experts.w2_weight = Parameter(F.pad(
|
||||
vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., 0:-128],
|
||||
requires_grad=False)
|
||||
torch.cuda.empty_cache()
|
||||
# Pad the weight if moe padding is enabled
|
||||
if padding:
|
||||
vllm_moe.experts.w13_weight = Parameter(F.pad(
|
||||
vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[...,
|
||||
0:-128],
|
||||
requires_grad=False)
|
||||
torch.cuda.empty_cache()
|
||||
vllm_moe.experts.w2_weight = Parameter(F.pad(
|
||||
vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[...,
|
||||
0:-128],
|
||||
requires_grad=False)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Run forward passes for both MoE blocks
|
||||
hf_states, _ = hf_moe.forward(hf_inputs)
|
||||
vllm_states = vllm_moe.forward(vllm_inputs)
|
||||
# Run forward passes for both MoE blocks
|
||||
hf_states, _ = hf_moe.forward(hf_inputs)
|
||||
vllm_states = vllm_moe.forward(vllm_inputs)
|
||||
|
||||
mixtral_moe_tol = {
|
||||
torch.float32: 1e-3,
|
||||
@ -546,7 +653,12 @@ def test_fused_marlin_moe(
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map)
|
||||
torch_output = torch_moe(a,
|
||||
w_ref1,
|
||||
w_ref2,
|
||||
score,
|
||||
topk,
|
||||
expert_map=e_map)
|
||||
|
||||
marlin_output = torch.ops.vllm.fused_marlin_moe(
|
||||
a,
|
||||
|
||||
@ -136,7 +136,7 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
|
||||
device=w2.device,
|
||||
block_size=quant_blocksize)
|
||||
|
||||
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk, None)
|
||||
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk)
|
||||
|
||||
torch.testing.assert_close(torch_output,
|
||||
cutlass_output,
|
||||
|
||||
@ -6,9 +6,9 @@ from typing import Optional
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import torch_experts
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
@ -164,22 +164,6 @@ vllm_config.scheduler_config.max_num_seqs = 128
|
||||
vllm_config.scheduler_config.max_model_len = 8192
|
||||
|
||||
|
||||
def torch_moe2(a, w1, w2, topk_weight, topk_ids):
|
||||
M, K = a.shape
|
||||
topk = topk_ids.shape[1]
|
||||
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
|
||||
out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
num_experts = w1.shape[0]
|
||||
for i in range(num_experts):
|
||||
mask = (topk_ids == i).view(-1)
|
||||
if mask.sum():
|
||||
out[mask] = SiluAndMul()(
|
||||
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
|
||||
|
||||
return (out.view(M, -1, w2.shape[1]) *
|
||||
topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
|
||||
|
||||
|
||||
def _pplx_moe(
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
@ -210,8 +194,8 @@ def _pplx_moe(
|
||||
group_name = cpu_group.group_name
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
torch_output = torch_moe2(a_full, w1_full, w2_full, topk_weights,
|
||||
topk_ids)
|
||||
torch_output = torch_experts(a_full, w1_full, w2_full, topk_weights,
|
||||
topk_ids)
|
||||
pplx_output = pplx_cutlass_moe(pgi, dp_size, a, w1, w2, w1_scale,
|
||||
w2_scale, topk_weights, topk_ids,
|
||||
a1_scale, out_dtype, per_act_token,
|
||||
|
||||
@ -18,8 +18,8 @@ try:
|
||||
except ImportError:
|
||||
has_pplx = False
|
||||
|
||||
from tests.kernels.utils import torch_experts
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import override_config
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedExperts, BatchedPrepareAndFinalize, BatchedTritonExperts)
|
||||
@ -163,29 +163,6 @@ def batched_moe(
|
||||
return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts)
|
||||
|
||||
|
||||
# Note: same as torch_moe but with fused_topk factored out.
|
||||
def torch_moe2(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
M, K = a.shape
|
||||
topk = topk_ids.shape[1]
|
||||
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
|
||||
out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
num_experts = w1.shape[0]
|
||||
for i in range(num_experts):
|
||||
mask = (topk_ids == i).view(-1)
|
||||
if mask.sum():
|
||||
out[mask] = SiluAndMul()(
|
||||
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
|
||||
|
||||
return (out.view(M, -1, w2.shape[1]) *
|
||||
topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1, 33, 64, 222])
|
||||
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
||||
@pytest.mark.parametrize("k", [128, 512, 1024])
|
||||
@ -209,7 +186,7 @@ def test_fused_moe_batched_experts(
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
|
||||
baseline_output = torch_experts(a, w1, w2, topk_weight, topk_ids)
|
||||
torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids)
|
||||
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids)
|
||||
|
||||
@ -409,7 +386,7 @@ def pplx_moe(
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
use_compile: bool = True,
|
||||
use_compile: bool = False,
|
||||
use_cudagraphs: bool = True,
|
||||
) -> torch.Tensor:
|
||||
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
||||
@ -470,10 +447,16 @@ def pplx_moe(
|
||||
w1_chunk = chunk_by_rank(w1, rank, world_size).to(device)
|
||||
w2_chunk = chunk_by_rank(w2, rank, world_size).to(device)
|
||||
|
||||
# Note: for now use_compile will error out if the problem size is
|
||||
# large enough to trigger chunking. I'm leaving the flag and
|
||||
# setup code in case we are able to revisit this later.
|
||||
if use_compile:
|
||||
_fused_experts = torch.compile(fused_experts,
|
||||
backend='inductor',
|
||||
fullgraph=True)
|
||||
torch._dynamo.mark_dynamic(a_chunk, 0)
|
||||
torch._dynamo.mark_dynamic(chunk_topk_weight, 0)
|
||||
torch._dynamo.mark_dynamic(chunk_topk_ids, 0)
|
||||
else:
|
||||
_fused_experts = fused_experts
|
||||
|
||||
@ -576,7 +559,7 @@ def _pplx_moe(
|
||||
|
||||
with set_current_vllm_config(vllm_config), override_config(moe_config):
|
||||
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
|
||||
torch_output = torch_experts(a, w1, w2, topk_weight, topk_ids)
|
||||
pplx_output = pplx_moe(group_name, pgi.rank, pgi.world_size, dp_size,
|
||||
a, w1, w2, topk_weight, topk_ids)
|
||||
# TODO (bnell): fix + re-enable
|
||||
|
||||
@ -403,19 +403,24 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
|
||||
itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, SEEDS))
|
||||
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
|
||||
|
||||
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
|
||||
block_size = [block_m, block_m]
|
||||
dtype = torch.bfloat16
|
||||
|
||||
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
|
||||
monkeypatch):
|
||||
if topk > E:
|
||||
pytest.skip(f"Skipping test: topk={topk} > E={E}")
|
||||
|
||||
if not _valid_deep_gemm_shape(M, N, K):
|
||||
pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}")
|
||||
|
||||
chunk_size = 1024
|
||||
|
||||
torch.manual_seed(seed)
|
||||
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
|
||||
|
||||
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
|
||||
block_size = [block_m, block_m]
|
||||
dtype = torch.bfloat16
|
||||
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
@ -451,6 +456,14 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
|
||||
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i])
|
||||
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i])
|
||||
|
||||
# Note: for now use_compile will error out if the problem size is
|
||||
# large enough to trigger chunking. I'm leaving the flag and
|
||||
# setup code in case we are able to revisit this later.
|
||||
use_compile = False
|
||||
|
||||
use_cudagraph = (chunk_size < M and N >= 1024 and K >= 1024
|
||||
and current_platform.is_cuda_alike())
|
||||
|
||||
# Set the context to avoid lots of warning spam.
|
||||
with set_current_vllm_config(vllm_config):
|
||||
if M >= 128:
|
||||
@ -463,7 +476,29 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
a, score.float(), topk, False)
|
||||
|
||||
out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids)
|
||||
if use_compile:
|
||||
deep_gemm_moe_fp8_fn = torch.compile(deep_gemm_moe_fp8,
|
||||
backend="inductor",
|
||||
fullgraph=True)
|
||||
torch._dynamo.mark_dynamic(a, 0)
|
||||
torch._dynamo.mark_dynamic(topk_weights, 0)
|
||||
torch._dynamo.mark_dynamic(topk_ids, 0)
|
||||
else:
|
||||
deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8
|
||||
|
||||
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
|
||||
topk_ids)
|
||||
|
||||
if use_cudagraph:
|
||||
out.fill_(0)
|
||||
stream = torch.cuda.Stream()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, stream=stream):
|
||||
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
|
||||
topk_ids)
|
||||
torch.cuda.synchronize()
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
#print(f"{out.sum()=}")
|
||||
#print(f"{ref_out.sum()=}")
|
||||
|
||||
@ -1054,12 +1054,21 @@ def compute_max_diff(output, output_ref):
|
||||
torch.abs(output_ref))
|
||||
|
||||
|
||||
def torch_moe(a, w1, w2, score, topk, expert_map):
|
||||
def torch_experts(a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
assert (global_num_experts == -1
|
||||
or (global_num_experts == w1.shape[0] and expert_map is None)
|
||||
or (expert_map is not None
|
||||
and global_num_experts == expert_map.shape[0]))
|
||||
topk = topk_ids.shape[1]
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
if expert_map is not None:
|
||||
@ -1073,6 +1082,19 @@ def torch_moe(a, w1, w2, score, topk, expert_map):
|
||||
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
||||
|
||||
|
||||
def torch_moe(a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
score: torch.Tensor,
|
||||
topk: int,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
return torch_experts(a, w1, w2, topk_weight, topk_ids, global_num_experts,
|
||||
expert_map)
|
||||
|
||||
|
||||
def torch_moe_single(a, w, score, topk):
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
|
||||
@ -981,6 +981,7 @@ def compute_hash() -> str:
|
||||
"VLLM_DP_RANK",
|
||||
"VLLM_DP_SIZE",
|
||||
"VLLM_USE_STANDALONE_COMPILE",
|
||||
"VLLM_FUSED_MOE_CHUNK_SIZE",
|
||||
]
|
||||
for key in environment_variables_to_hash:
|
||||
if key in environment_variables:
|
||||
|
||||
@ -41,24 +41,24 @@ def run_cutlass_moe_fp8(
|
||||
assert w1.dtype == torch.float8_e4m3fn
|
||||
assert w2.dtype == torch.float8_e4m3fn
|
||||
if expert_num_tokens is None:
|
||||
assert a1q.shape[1] == w1.shape[2], "Hidden size mismatch w1"
|
||||
assert a1q.size(1) == w1.size(2), "Hidden size mismatch w1"
|
||||
else:
|
||||
assert a1q.shape[2] == w1.shape[2], "Hidden size mismatch w1"
|
||||
assert w1.shape[1] == w2.shape[2] * 2, "Hidden size mismatch w2"
|
||||
assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[
|
||||
1] == w1.shape[1], "W1 scale shape mismatch"
|
||||
assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[
|
||||
1] == w2.shape[1], "W2 scale shape mismatch"
|
||||
assert w1.shape[0] == w2.shape[0], "Expert number mismatch"
|
||||
assert a1q_scale is None or a1q_scale.dim(
|
||||
) == 0 or a1q_scale.shape[0] == 1 or a1q_scale.shape[0] == a1q.shape[
|
||||
0], "Input scale shape mismatch"
|
||||
assert w1.shape[0] == w2.shape[0], "Weights expert number mismatch"
|
||||
assert w1.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
|
||||
assert w1.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
|
||||
assert a2_scale is None or a2_scale.dim(
|
||||
) == 0 or a2_scale.shape[0] == 1 or a2_scale.shape[0] == a1q.shape[
|
||||
0], "Intermediate scale shape mismatch"
|
||||
assert a1q.size(2) == w1.size(2), "Hidden size mismatch w1"
|
||||
assert w1.size(1) == w2.size(2) * 2, "Hidden size mismatch w2"
|
||||
assert w1_scale.dim() == 1 or w1_scale.size(
|
||||
1) == 1 or w1_scale.shape[1] == w1.size(1), "W1 scale shape mismatch"
|
||||
assert w2_scale.dim() == 1 or w2_scale.size(
|
||||
1) == 1 or w2_scale.shape[1] == w2.size(1), "W2 scale shape mismatch"
|
||||
assert w1.size(0) == w2.size(0), "Expert number mismatch"
|
||||
assert a1q_scale is None or a1q_scale.dim() == 0 or a1q_scale.size(
|
||||
0) == 1 or a1q_scale.size(
|
||||
0) == a1q.shape[0], "Input scale shape mismatch"
|
||||
assert w1.size(0) == w2.size(0), "Weights expert number mismatch"
|
||||
assert w1.size(0) == w1_scale.size(0), "w1 scales expert number mismatch"
|
||||
assert w1.size(0) == w2_scale.size(0), "w2 scales expert number mismatch"
|
||||
assert a2_scale is None or a2_scale.dim() == 0 or a2_scale.size(
|
||||
0) == 1 or a2_scale.size(
|
||||
0) == a1q.shape[0], "Intermediate scale shape mismatch"
|
||||
assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype"
|
||||
if expert_map is not None:
|
||||
assert expert_num_tokens is None
|
||||
@ -75,12 +75,12 @@ def run_cutlass_moe_fp8(
|
||||
# their tokens are already contiguous for each expert as a result of
|
||||
# the dispatch function.
|
||||
|
||||
M = a1q.shape[0] # non batched expert M
|
||||
padded_M = a1q.shape[1] # batched expert M
|
||||
M = a1q.size(0) # non batched expert M
|
||||
padded_M = a1q.size(1) # batched expert M
|
||||
_, K, N = w2.shape
|
||||
device = a1q.device
|
||||
|
||||
assert w1.shape[2] == K
|
||||
assert w1.size(2) == K
|
||||
assert global_num_experts != -1
|
||||
assert a1q_scale is not None
|
||||
|
||||
@ -91,8 +91,8 @@ def run_cutlass_moe_fp8(
|
||||
else:
|
||||
local_topk_ids = topk_ids
|
||||
|
||||
topk = local_topk_ids.shape[1]
|
||||
local_E = w1.shape[0]
|
||||
topk = local_topk_ids.size(1)
|
||||
local_E = w1.size(0)
|
||||
|
||||
if use_batched_format:
|
||||
assert expert_num_tokens is not None
|
||||
@ -111,10 +111,10 @@ def run_cutlass_moe_fp8(
|
||||
problem_sizes2, expert_num_tokens,
|
||||
local_E, padded_M, N, K)
|
||||
|
||||
w1_scale = w1_scale.reshape(w1_scale.shape[0], -1)
|
||||
w2_scale = w2_scale.reshape(w2_scale.shape[0], -1)
|
||||
a1q = a1q.reshape(-1, a1q.shape[2])
|
||||
a1q_scale = a1q_scale.reshape(-1, a1q_scale.shape[2]).contiguous()
|
||||
w1_scale = w1_scale.reshape(w1_scale.size(0), -1)
|
||||
w2_scale = w2_scale.reshape(w2_scale.size(0), -1)
|
||||
a1q = a1q.reshape(-1, a1q.size(2))
|
||||
a1q_scale = a1q_scale.reshape(-1, a1q_scale.size(2)).contiguous()
|
||||
|
||||
else:
|
||||
expert_offsets = torch.empty((global_num_experts + 1),
|
||||
@ -151,19 +151,19 @@ def run_cutlass_moe_fp8(
|
||||
a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale
|
||||
expert_offsets = expert_offsets[:-1]
|
||||
|
||||
ab_strides1 = torch.full((w1.shape[0], ),
|
||||
ab_strides1 = torch.full((w1.size(0), ),
|
||||
K,
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
c_strides1 = torch.full((w1.shape[0], ),
|
||||
c_strides1 = torch.full((w1.size(0), ),
|
||||
2 * N,
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
ab_strides2 = torch.full((w1.shape[0], ),
|
||||
ab_strides2 = torch.full((w1.size(0), ),
|
||||
N,
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
c_strides2 = torch.full((w1.shape[0], ),
|
||||
c_strides2 = torch.full((w1.size(0), ),
|
||||
K,
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
@ -237,7 +237,7 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
workspace2: tuple[int, ...] = ()
|
||||
output: tuple[int, ...] = ()
|
||||
if self.use_batched_format:
|
||||
padded_M = aq.shape[1]
|
||||
padded_M = aq.size(1)
|
||||
workspace1 = (self.max_experts_per_worker, padded_M, max(N, K))
|
||||
workspace2 = (self.max_experts_per_worker, padded_M, (N // 2))
|
||||
output = (self.max_experts_per_worker, padded_M, K)
|
||||
@ -332,7 +332,7 @@ def cutlass_moe_fp8(
|
||||
"""
|
||||
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
|
||||
a2_scale.numel() != 1 if a2_scale is not None else False)
|
||||
per_out_ch = w1_scale.numel() != w1_q.shape[0]
|
||||
per_out_ch = w1_scale.numel() != w1_q.size(0)
|
||||
|
||||
out_dtype = a.dtype
|
||||
|
||||
@ -425,11 +425,11 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
|
||||
assert (m == m_a), "input shape mismatch"
|
||||
assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1"
|
||||
assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype"
|
||||
assert (topk_weights.shape[0] == m and topk_ids.shape[0]
|
||||
assert (topk_weights.size(0) == m and topk_ids.size(0)
|
||||
== m), ("topk must be provided for each row of a")
|
||||
|
||||
out_dtype = a.dtype
|
||||
num_topk = topk_ids.shape[1]
|
||||
num_topk = topk_ids.size(1)
|
||||
|
||||
expert_offsets = torch.empty((e + 1), dtype=torch.int32, device=device)
|
||||
blockscale_offsets = torch.empty((e + 1), dtype=torch.int32, device=device)
|
||||
@ -463,7 +463,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
|
||||
out_dtype, device)
|
||||
del rep_a_fp4, rep_a_blockscale
|
||||
# hidden size dimension is split to one halfpytho sized tensor.
|
||||
intermediate = torch.empty((m * num_topk, w1_fp4.shape[1] // 2),
|
||||
intermediate = torch.empty((m * num_topk, w1_fp4.size(1) // 2),
|
||||
device=device,
|
||||
dtype=out_dtype)
|
||||
|
||||
|
||||
@ -48,7 +48,7 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor,
|
||||
M = hidden_states.size(0)
|
||||
_, K, N = w2.size()
|
||||
if not _valid_deep_gemm_shape(M, N, K):
|
||||
logger.debug("DeepGemm disabled: unalinged problem size.")
|
||||
logger.debug("DeepGemm disabled: unaligned problem size.")
|
||||
return False
|
||||
|
||||
if (w1.dtype != torch.float8_e4m3fn or w2.dtype != torch.float8_e4m3fn):
|
||||
|
||||
@ -25,7 +25,7 @@ def dequant_fp8(expert_x_fp8: torch.Tensor,
|
||||
expert_x_fp32 = expert_x_fp8.to(torch.float32).view(
|
||||
num_experts, -1, DEEPEP_QUANT_BLOCK_SIZE)
|
||||
expert_x_scales = expert_x_scales.view(num_experts, -1, 1)
|
||||
return (expert_x_fp32 * expert_x_scales).view(expert_x_fp8.shape)
|
||||
return (expert_x_fp32 * expert_x_scales).view(expert_x_fp8.size())
|
||||
|
||||
|
||||
class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
|
||||
@ -488,10 +488,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
|
||||
if use_fp8_w8a8 or use_int8_w8a8:
|
||||
assert B_scale is not None
|
||||
assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0])
|
||||
== B_scale.shape[-2])
|
||||
assert (block_shape is None or triton.cdiv(B.shape[-1], block_shape[1])
|
||||
== B_scale.shape[-1])
|
||||
assert (block_shape is None
|
||||
or triton.cdiv(B.size(-2), block_shape[0]) == B_scale.size(-2))
|
||||
assert (block_shape is None
|
||||
or triton.cdiv(B.size(-1), block_shape[1]) == B_scale.size(-1))
|
||||
|
||||
elif use_int8_w8a16 or use_int4_w4a16:
|
||||
assert B_scale is not None
|
||||
@ -500,19 +500,19 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
assert A_scale is None
|
||||
assert B_scale is None
|
||||
|
||||
M = A.shape[0]
|
||||
M = A.size(0)
|
||||
num_tokens = M * top_k
|
||||
|
||||
EM = sorted_token_ids.shape[0]
|
||||
if A.shape[0] < config["BLOCK_SIZE_M"]:
|
||||
EM = sorted_token_ids.size(0)
|
||||
if A.size(0) < config["BLOCK_SIZE_M"]:
|
||||
# optimize for small batch_size.
|
||||
# We assume that top_ids of each token is unique, so
|
||||
# so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
|
||||
# and we can skip some invalid blocks.
|
||||
EM = min(sorted_token_ids.shape[0],
|
||||
A.shape[0] * top_k * config['BLOCK_SIZE_M'])
|
||||
EM = min(sorted_token_ids.size(0),
|
||||
A.size(0) * top_k * config['BLOCK_SIZE_M'])
|
||||
grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv(
|
||||
B.shape[1], META['BLOCK_SIZE_N']), )
|
||||
B.size(1), META['BLOCK_SIZE_N']), )
|
||||
|
||||
if (use_int8_w8a16 or use_int4_w4a16) and \
|
||||
block_shape is not None and block_shape[1] > 0:
|
||||
@ -522,16 +522,16 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
use_moe_wna16_cuda = should_moe_wna16_use_cuda(
|
||||
num_valid_tokens=num_tokens,
|
||||
group_size=block_shape[1],
|
||||
num_experts=B.shape[0],
|
||||
num_experts=B.size(0),
|
||||
bit=4 if use_int4_w4a16 else 8)
|
||||
config = config.copy()
|
||||
config.update(
|
||||
get_moe_wna16_block_config(config=config,
|
||||
use_moe_wna16_cuda=use_moe_wna16_cuda,
|
||||
num_valid_tokens=num_tokens,
|
||||
size_k=A.shape[1],
|
||||
size_n=B.shape[1],
|
||||
num_experts=B.shape[1],
|
||||
size_k=A.size(1),
|
||||
size_n=B.size(1),
|
||||
num_experts=B.size(1),
|
||||
group_size=block_shape[1],
|
||||
real_top_k=top_k,
|
||||
block_size_m=config["BLOCK_SIZE_M"]))
|
||||
@ -556,8 +556,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
B.shape[1],
|
||||
A.shape[1],
|
||||
B.size(1),
|
||||
A.size(1),
|
||||
EM,
|
||||
num_tokens,
|
||||
A.stride(0),
|
||||
@ -573,7 +573,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
B_zp.stride(0) if B_zp is not None else 0,
|
||||
B_zp.stride(2) if B_zp is not None else 0,
|
||||
B_zp.stride(1) if B_zp is not None else 0,
|
||||
block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0,
|
||||
block_k_diviable=A.size(1) % config["BLOCK_SIZE_K"] == 0,
|
||||
group_size=block_shape[1],
|
||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||
top_k=top_k,
|
||||
@ -599,8 +599,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
B.shape[1],
|
||||
B.shape[2],
|
||||
B.size(1),
|
||||
B.size(2),
|
||||
EM,
|
||||
num_tokens,
|
||||
A.stride(0),
|
||||
@ -818,7 +818,7 @@ def try_get_optimal_moe_config(
|
||||
M: int,
|
||||
is_marlin: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
):
|
||||
) -> dict[str, int]:
|
||||
from vllm.model_executor.layers.fused_moe import get_config
|
||||
override_config = get_config()
|
||||
if override_config:
|
||||
@ -873,10 +873,10 @@ def fused_topk(
|
||||
renormalize: bool,
|
||||
indices_type: Optional[torch.dtype] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], (
|
||||
assert hidden_states.size(0) == gating_output.size(0), (
|
||||
"Number of tokens mismatch")
|
||||
|
||||
M, _ = hidden_states.shape
|
||||
M, _ = hidden_states.size()
|
||||
|
||||
topk_weights = torch.empty(M,
|
||||
topk,
|
||||
@ -915,7 +915,7 @@ def grouped_topk(
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], (
|
||||
assert hidden_states.size(0) == gating_output.size(0), (
|
||||
"Number of tokens mismatch")
|
||||
|
||||
if scoring_func == "softmax":
|
||||
@ -925,7 +925,7 @@ def grouped_topk(
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
||||
|
||||
num_token = scores.shape[0]
|
||||
num_token = scores.size(0)
|
||||
if e_score_correction_bias is not None:
|
||||
# Store original scores before applying correction bias. We use biased
|
||||
# scores for expert selection but original scores for routing weights
|
||||
@ -942,7 +942,7 @@ def grouped_topk(
|
||||
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
||||
score_mask = group_mask.unsqueeze(-1).expand(
|
||||
num_token, num_expert_group,
|
||||
scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e]
|
||||
scores.size(-1) // num_expert_group).reshape(num_token, -1) # [n, e]
|
||||
tmp_scores = scores.masked_fill(~score_mask.bool(),
|
||||
float("-inf")) # [n, e]
|
||||
|
||||
@ -1162,7 +1162,7 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
allow_deep_gemm: bool = False) -> torch.Tensor:
|
||||
# For now, disable DeepGemm for small N (<= 512) until better
|
||||
# permute/unpermute ops are available.
|
||||
N = w1.shape[1]
|
||||
N = w1.size(1)
|
||||
if (allow_deep_gemm and use_fp8_w8a8 and N > 512
|
||||
and _valid_deep_gemm(hidden_states, w1, w2)):
|
||||
assert apply_router_weight_on_input is False
|
||||
@ -1233,13 +1233,13 @@ def fused_experts_impl(
|
||||
) -> torch.Tensor:
|
||||
# Check constraints.
|
||||
if use_int4_w4a16:
|
||||
assert hidden_states.shape[1] // 2 == w1.shape[
|
||||
2], "Hidden size mismatch"
|
||||
assert hidden_states.size(1) // 2 == w1.size(2), (
|
||||
"Hidden size mismatch")
|
||||
else:
|
||||
assert hidden_states.shape[1] == w1.shape[2], (
|
||||
f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[2]}")
|
||||
assert hidden_states.size(1) == w1.size(2), (
|
||||
f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}")
|
||||
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert topk_weights.size() == topk_ids.size(), "topk shape mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
@ -1247,12 +1247,12 @@ def fused_experts_impl(
|
||||
torch.float32, torch.float16, torch.bfloat16
|
||||
]
|
||||
|
||||
num_tokens = hidden_states.shape[0]
|
||||
E, N, _ = w1.shape
|
||||
K = w2.shape[1]
|
||||
num_tokens = hidden_states.size(0)
|
||||
E, N, _ = w1.size()
|
||||
K = w2.size(1)
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
top_k_num = topk_ids.shape[1]
|
||||
top_k_num = topk_ids.size(1)
|
||||
# We execute the fused_moe kernel in chunks to circumvent this issue:
|
||||
# https://github.com/vllm-project/vllm/issues/5938
|
||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
@ -1269,8 +1269,8 @@ def fused_experts_impl(
|
||||
|
||||
get_config_func = functools.partial(
|
||||
try_get_optimal_moe_config,
|
||||
w1.shape,
|
||||
w2.shape,
|
||||
w1.size(),
|
||||
w2.size(),
|
||||
top_k_num,
|
||||
config_dtype,
|
||||
block_shape=block_shape,
|
||||
@ -1310,7 +1310,7 @@ def fused_experts_impl(
|
||||
min((chunk + 1) * CHUNK_SIZE,
|
||||
num_tokens))
|
||||
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
|
||||
tokens_in_chunk, _ = curr_hidden_states.shape
|
||||
tokens_in_chunk, _ = curr_hidden_states.size()
|
||||
|
||||
if tokens_in_chunk == 0:
|
||||
break
|
||||
@ -1322,7 +1322,7 @@ def fused_experts_impl(
|
||||
# do not need to be adjusted.
|
||||
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
|
||||
intermediate_cache2 = intermediate_cache2[:tokens_in_chunk *
|
||||
topk_ids.shape[1]]
|
||||
topk_ids.size(1)]
|
||||
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
|
||||
config = get_config_func(tokens_in_chunk)
|
||||
|
||||
@ -1398,7 +1398,7 @@ def fused_experts_impl(
|
||||
per_channel_quant=per_channel_quant,
|
||||
block_shape=block_shape)
|
||||
|
||||
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
|
||||
out_hidden_states[begin_chunk_idx:end_chunk_idx])
|
||||
|
||||
return out_hidden_states
|
||||
@ -1611,8 +1611,8 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
dtype=hidden_states.dtype)
|
||||
|
||||
config = try_get_optimal_moe_config(
|
||||
w1.shape,
|
||||
w2.shape,
|
||||
w1.size(),
|
||||
w2.size(),
|
||||
top_k_num,
|
||||
config_dtype,
|
||||
num_tokens,
|
||||
|
||||
@ -861,13 +861,11 @@ class FusedMoE(torch.nn.Module):
|
||||
self.global_num_experts = num_experts
|
||||
|
||||
# For smuggling this layer into the fused moe custom op
|
||||
self.use_direct_call = self.dp_size == 1
|
||||
if not self.use_direct_call:
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError("Duplicate layer name: {}".format(prefix))
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
self.layer_name = prefix
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError("Duplicate layer name: {}".format(prefix))
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
self.layer_name = prefix
|
||||
|
||||
# Determine expert maps
|
||||
if self.use_ep:
|
||||
@ -1361,11 +1359,8 @@ class FusedMoE(torch.nn.Module):
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
if self.use_direct_call:
|
||||
return self.forward_impl(hidden_states, router_logits)
|
||||
else:
|
||||
return torch.ops.vllm.moe_forward(hidden_states, router_logits,
|
||||
self.layer_name)
|
||||
return torch.ops.vllm.moe_forward(hidden_states, router_logits,
|
||||
self.layer_name)
|
||||
|
||||
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
|
||||
full_router_logits: torch.Tensor):
|
||||
|
||||
@ -69,7 +69,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
a1 = a1 * rank_topk_weights.to(a1.dtype)
|
||||
|
||||
repeat_cols = 4
|
||||
repeat_rows = 1 if self.per_act_token else a1.shape[0]
|
||||
repeat_rows = 1 if self.per_act_token else a1.size(0)
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1, (None if self.per_act_token else a1_scale), self.quant_dtype,
|
||||
self.per_act_token, self.block_shape)
|
||||
|
||||
Reference in New Issue
Block a user