[Kernels] Clean up FusedMoeMethodBase and modular kernel setup. Remove extra arguments from modular kernel methods. (#22035)
Signed-off-by: Bill Nell <bnell@redhat.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@ -7,41 +7,22 @@ import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from tests.kernels.moe.utils import make_test_weights, per_token_cast_to_fp8
|
||||
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
|
||||
FLOAT8_E4M3_MAX,
|
||||
dequantize_nvfp4_to_dtype)
|
||||
from tests.kernels.utils import torch_experts
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size
|
||||
# Fused experts and PrepareFinalize imports
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
|
||||
BatchedTritonOrDeepGemmExperts)
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig)
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedTritonExperts, NaiveBatchedExperts)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.layer import (FusedMoEMethodBase,
|
||||
TritonExperts)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP)
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts)
|
||||
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
|
||||
|
||||
from .mk_objects import (expert_info, make_fused_experts,
|
||||
make_prepare_finalize, prepare_finalize_info)
|
||||
from .parallel_utils import ProcessGroupInfo
|
||||
from .utils import (make_block_quant_fp8_weights, make_non_quant_weights,
|
||||
make_quant_fp8_weights, per_token_cast_to_fp8)
|
||||
|
||||
if has_pplx():
|
||||
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
||||
PplxPrepareAndFinalize)
|
||||
if has_deep_ep():
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
||||
DeepEPHTPrepareAndFinalize)
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
||||
DeepEPLLPrepareAndFinalize)
|
||||
|
||||
|
||||
def _describe_tensor(t: Optional[torch.Tensor], name: str) -> str:
|
||||
@ -69,24 +50,31 @@ class Config:
|
||||
|
||||
torch_trace_dir_path: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.quant_config is None:
|
||||
self.quant_config = FusedMoEQuantConfig()
|
||||
|
||||
def describe(self) -> str:
|
||||
s = ""
|
||||
s += "== Config: \n"
|
||||
s += f" world_size={self.world_size} \n"
|
||||
s += f" PF={self.prepare_finalize_type.__name__} \n"
|
||||
s += f" FE={self.fused_experts_type.__name__} \n"
|
||||
s += f" topk={self.topks} \n"
|
||||
s += f" dtype={self.dtype} \n"
|
||||
s += f" fused_moe_chunk_size={self.fused_moe_chunk_size} \n"
|
||||
s += " Quant: \n"
|
||||
s += f" fused_moe_chunk_size={self.fused_moe_chunk_size} \n "
|
||||
s += "== Config:\n"
|
||||
s += f" world_size={self.world_size}\n"
|
||||
s += f" PF={self.prepare_finalize_type.__name__}\n"
|
||||
s += f" FE={self.fused_experts_type.__name__}\n"
|
||||
s += f" E={self.E}\n"
|
||||
s += f" Ms={self.Ms}\n"
|
||||
s += f" N={self.N}\n"
|
||||
s += f" K={self.K}\n"
|
||||
s += f" topk={self.topks}\n"
|
||||
s += f" dtype={self.dtype}\n"
|
||||
s += f" fused_moe_chunk_size={self.fused_moe_chunk_size}\n"
|
||||
s += " Quant:\n"
|
||||
if self.quant_config is not None:
|
||||
s += f" q_dtype={self.quant_dtype} \n"
|
||||
s += f" q_block_shape={self.quant_block_shape} \n"
|
||||
s += f" q_per_out_ch_quant={self.is_per_out_ch_quant} \n"
|
||||
s += f" q_per_act_token={self.is_per_act_token_quant} \n"
|
||||
s += f" q_dtype={self.quant_dtype}\n"
|
||||
s += f" q_block_shape={self.quant_block_shape}\n"
|
||||
s += f" q_per_out_ch_quant={self.is_per_out_ch_quant}\n"
|
||||
s += f" q_per_act_token={self.is_per_act_token_quant}\n"
|
||||
else:
|
||||
s += " quant=None \n"
|
||||
s += " quant=None\n"
|
||||
return s
|
||||
|
||||
@property
|
||||
@ -95,34 +83,28 @@ class Config:
|
||||
return self.Ms
|
||||
|
||||
@property
|
||||
def quant_dtype(self) -> Optional[torch.dtype]:
|
||||
if self.quant_config is None:
|
||||
return None
|
||||
def quant_dtype(self) -> Union[torch.dtype, str, None]:
|
||||
assert self.quant_config is not None
|
||||
return self.quant_config.quant_dtype
|
||||
|
||||
@property
|
||||
def is_per_act_token_quant(self) -> bool:
|
||||
if self.quant_config is None:
|
||||
return False
|
||||
assert self.quant_config is not None
|
||||
return self.quant_config.per_act_token_quant
|
||||
|
||||
@property
|
||||
def is_per_tensor_act_quant(self) -> bool:
|
||||
if self.quant_config is None:
|
||||
return False
|
||||
return (not self.is_per_act_token_quant
|
||||
and self.quant_block_shape is None)
|
||||
|
||||
@property
|
||||
def is_per_out_ch_quant(self) -> bool:
|
||||
if self.quant_config is None:
|
||||
return False
|
||||
assert self.quant_config is not None
|
||||
return self.quant_config.per_out_ch_quant
|
||||
|
||||
@property
|
||||
def quant_block_shape(self) -> Optional[list[int]]:
|
||||
if self.quant_config is None:
|
||||
return None
|
||||
assert self.quant_config is not None
|
||||
return self.quant_config.block_shape
|
||||
|
||||
@property
|
||||
@ -130,36 +112,30 @@ class Config:
|
||||
assert isinstance(self.topks, int)
|
||||
return self.topks
|
||||
|
||||
@property
|
||||
def topk_ids_dtype(self) -> Optional[torch.dtype]:
|
||||
topk_ids_dtype = None
|
||||
if self.prepare_finalize_type == PplxPrepareAndFinalize:
|
||||
topk_ids_dtype = torch.uint32
|
||||
elif self.prepare_finalize_type in [
|
||||
DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize
|
||||
]:
|
||||
topk_ids_dtype = torch.int64
|
||||
return topk_ids_dtype
|
||||
|
||||
@property
|
||||
def num_local_experts(self) -> int:
|
||||
return self.E // self.world_size
|
||||
|
||||
def make_env_data(self) -> tuple[VllmConfig, dict[Any, Any]]:
|
||||
"""
|
||||
make env data for vllm launch.
|
||||
make env data for vllm launch.
|
||||
"""
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.parallel_config.data_parallel_size = self.world_size
|
||||
vllm_config.parallel_config.enable_expert_parallel = True
|
||||
|
||||
env_dict = {
|
||||
"VLLM_ALL2ALL_BACKEND": self.all2all_backend(),
|
||||
"VLLM_USE_DEEP_GEMM": str(int(self.needs_deep_gemm())),
|
||||
}
|
||||
|
||||
backend = self.all2all_backend()
|
||||
if backend is not None:
|
||||
env_dict.update({"VLLM_ALL2ALL_BACKEND": backend})
|
||||
|
||||
if self.fused_moe_chunk_size is not None:
|
||||
env_dict.update(
|
||||
{"VLLM_FUSED_MOE_CHUNK_SIZE": str(self.fused_moe_chunk_size)})
|
||||
|
||||
return vllm_config, env_dict
|
||||
|
||||
def is_fp8_block_quantized(self):
|
||||
@ -167,85 +143,59 @@ class Config:
|
||||
and self.quant_block_shape is not None)
|
||||
|
||||
def is_batched_prepare_finalize(self):
|
||||
return self.prepare_finalize_type in [
|
||||
PplxPrepareAndFinalize, DeepEPLLPrepareAndFinalize
|
||||
]
|
||||
info = prepare_finalize_info(self.prepare_finalize_type)
|
||||
return (mk.FusedMoEActivationFormat.BatchedExperts ==
|
||||
info.activation_format)
|
||||
|
||||
def is_batched_fused_experts(self):
|
||||
return self.fused_experts_type in [
|
||||
CutlassExpertsFp8, BatchedDeepGemmExperts, BatchedTritonExperts,
|
||||
NaiveBatchedExperts, BatchedTritonOrDeepGemmExperts
|
||||
]
|
||||
info = expert_info(self.fused_experts_type)
|
||||
return (mk.FusedMoEActivationFormat.BatchedExperts ==
|
||||
info.activation_format)
|
||||
|
||||
def is_standard_fused_experts(self):
|
||||
return self.fused_experts_type in [
|
||||
CutlassExpertsFp8, DeepGemmExperts, TritonOrDeepGemmExperts,
|
||||
TritonExperts
|
||||
]
|
||||
info = expert_info(self.fused_experts_type)
|
||||
return mk.FusedMoEActivationFormat.Standard == info.activation_format
|
||||
|
||||
def is_fe_16bit_supported(self):
|
||||
return self.fused_experts_type in [
|
||||
BatchedTritonExperts, BatchedTritonOrDeepGemmExperts,
|
||||
NaiveBatchedExperts, TritonExperts
|
||||
]
|
||||
def fe_supported_types(self):
|
||||
info = expert_info(self.fused_experts_type)
|
||||
return info.supported_dtypes
|
||||
|
||||
def is_fe_fp8_supported(self):
|
||||
return self.fused_experts_type in [
|
||||
BatchedDeepGemmExperts,
|
||||
BatchedTritonExperts,
|
||||
BatchedTritonOrDeepGemmExperts,
|
||||
CutlassExpertsFp8,
|
||||
DeepGemmExperts,
|
||||
TritonExperts,
|
||||
TritonOrDeepGemmExperts,
|
||||
NaiveBatchedExperts,
|
||||
]
|
||||
def pf_supported_types(self):
|
||||
info = prepare_finalize_info(self.prepare_finalize_type)
|
||||
return info.supported_dtypes
|
||||
|
||||
def is_fe_block_fp8_supported(self):
|
||||
return self.fused_experts_type in [
|
||||
BatchedDeepGemmExperts,
|
||||
BatchedTritonOrDeepGemmExperts,
|
||||
DeepGemmExperts,
|
||||
TritonExperts,
|
||||
TritonOrDeepGemmExperts,
|
||||
BatchedTritonExperts,
|
||||
NaiveBatchedExperts,
|
||||
]
|
||||
def is_block_quant_supported(self):
|
||||
info = expert_info(self.fused_experts_type)
|
||||
return info.blocked_quantization_support
|
||||
|
||||
def is_fe_supports_chunking(self):
|
||||
return self.fused_experts_type in [
|
||||
CutlassExpertsFp8, DeepGemmExperts, TritonOrDeepGemmExperts,
|
||||
TritonExperts
|
||||
]
|
||||
info = expert_info(self.fused_experts_type)
|
||||
return info.supports_chunking
|
||||
|
||||
def supports_expert_map(self):
|
||||
info = expert_info(self.fused_experts_type)
|
||||
return info.supports_expert_map
|
||||
|
||||
def supports_apply_weight_on_input(self):
|
||||
info = prepare_finalize_info(self.prepare_finalize_type)
|
||||
return info.supports_apply_weight_on_input
|
||||
|
||||
def needs_deep_gemm(self):
|
||||
return self.fused_experts_type in [
|
||||
BatchedDeepGemmExperts,
|
||||
DeepGemmExperts,
|
||||
]
|
||||
info = expert_info(self.fused_experts_type)
|
||||
return info.needs_deep_gemm
|
||||
|
||||
def needs_pplx(self):
|
||||
return self.prepare_finalize_type in [PplxPrepareAndFinalize]
|
||||
info = prepare_finalize_info(self.prepare_finalize_type)
|
||||
return info.backend == "pplx"
|
||||
|
||||
def needs_deep_ep(self):
|
||||
return self.prepare_finalize_type in [
|
||||
DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize
|
||||
]
|
||||
info = prepare_finalize_info(self.prepare_finalize_type)
|
||||
return (info.backend == "deepep_high_throughput"
|
||||
or info.backend == "deepep_low_latency")
|
||||
|
||||
def all2all_backend(self):
|
||||
if self.needs_pplx():
|
||||
return "pplx"
|
||||
if self.prepare_finalize_type == DeepEPHTPrepareAndFinalize:
|
||||
return "deepep_high_throughput"
|
||||
if self.prepare_finalize_type == DeepEPLLPrepareAndFinalize:
|
||||
return "deepep_low_latency"
|
||||
return "naive"
|
||||
|
||||
def needs_all2all(self):
|
||||
return self.prepare_finalize_type in [
|
||||
PplxPrepareAndFinalize, DeepEPHTPrepareAndFinalize,
|
||||
DeepEPLLPrepareAndFinalize
|
||||
]
|
||||
info = prepare_finalize_info(self.prepare_finalize_type)
|
||||
return info.backend
|
||||
|
||||
def is_valid(self):
|
||||
# Check prepare-finalize and fused-experts compatibility
|
||||
@ -267,28 +217,28 @@ class Config:
|
||||
# invalid quant config
|
||||
return False
|
||||
|
||||
# check bf16 / fp16 support
|
||||
is_16bit = (self.dtype.itemsize == 2 and self.quant_dtype is None)
|
||||
if is_16bit and not self.is_fe_16bit_supported():
|
||||
return False
|
||||
# check type support
|
||||
if self.quant_dtype is None:
|
||||
if (self.dtype not in self.pf_supported_types()
|
||||
or self.dtype not in self.fe_supported_types()):
|
||||
return False
|
||||
else:
|
||||
if (self.quant_dtype not in self.pf_supported_types()
|
||||
or self.quant_dtype not in self.fe_supported_types()):
|
||||
return False
|
||||
|
||||
# Check fp8 support
|
||||
is_fp8 = self.quant_dtype == torch.float8_e4m3fn
|
||||
if is_fp8 and not self.is_fe_fp8_supported():
|
||||
return False
|
||||
|
||||
# Check fp8 block quanization support
|
||||
# Check block quanization support
|
||||
is_block_quatized = self.quant_block_shape is not None
|
||||
if is_block_quatized and not is_fp8:
|
||||
if is_block_quatized and self.quant_dtype is None:
|
||||
return False
|
||||
if is_block_quatized and not self.is_fe_block_fp8_supported():
|
||||
if is_block_quatized and not self.is_block_quant_supported():
|
||||
return False
|
||||
|
||||
# deep_gemm only works with block-quantized
|
||||
if self.needs_deep_gemm() and not is_block_quatized:
|
||||
return False
|
||||
|
||||
# Check dependencies
|
||||
# Check dependencies (turn into asserts?)
|
||||
if self.needs_deep_ep() and not has_deep_ep():
|
||||
return False
|
||||
if self.needs_deep_gemm() and not has_deep_gemm():
|
||||
@ -305,6 +255,8 @@ class WeightTensors:
|
||||
w2: torch.Tensor
|
||||
w1_scale: Optional[torch.Tensor]
|
||||
w2_scale: Optional[torch.Tensor]
|
||||
w1_gs: Optional[torch.Tensor] = None
|
||||
w2_gs: Optional[torch.Tensor] = None
|
||||
|
||||
def describe(self):
|
||||
s = ""
|
||||
@ -313,13 +265,20 @@ class WeightTensors:
|
||||
s += f' - {_describe_tensor(self.w2, "w2")} \n'
|
||||
s += f' - {_describe_tensor(self.w1_scale, "w1_scale")} \n'
|
||||
s += f' - {_describe_tensor(self.w2_scale, "w2_scale")} \n'
|
||||
s += f' - {_describe_tensor(self.w1_gs, "w1_gs")} \n'
|
||||
s += f' - {_describe_tensor(self.w2_gs, "w2_gs")} \n'
|
||||
return s
|
||||
|
||||
def is_quantized(self) -> bool:
|
||||
# or w1_scale is not None?
|
||||
return (self.w1.dtype == torch.float8_e4m3fn
|
||||
or self.w1.dtype == torch.uint8 or self.w1.dtype == torch.int8)
|
||||
|
||||
def to_current_device(self):
|
||||
self.w1 = self.w1.to(device=torch.cuda.current_device())
|
||||
self.w2 = self.w2.to(device=torch.cuda.current_device())
|
||||
is_quantized = self.w1.dtype == torch.float8_e4m3fn
|
||||
if is_quantized:
|
||||
|
||||
if self.is_quantized():
|
||||
assert self.w1_scale is not None
|
||||
assert self.w2_scale is not None
|
||||
self.w1_scale = self.w1_scale.to(
|
||||
@ -327,56 +286,51 @@ class WeightTensors:
|
||||
self.w2_scale = self.w2_scale.to(
|
||||
device=torch.cuda.current_device())
|
||||
|
||||
if self.w1_gs is not None:
|
||||
assert self.w2_gs is not None
|
||||
self.w1_gs = self.w1_gs.to(device=torch.cuda.current_device())
|
||||
self.w2_gs = self.w2_gs.to(device=torch.cuda.current_device())
|
||||
|
||||
def slice_weights(self, rank: int,
|
||||
num_local_experts: int) -> "WeightTensors":
|
||||
s = rank * num_local_experts
|
||||
e = s + num_local_experts
|
||||
w1 = self.w1[s:e, :, :]
|
||||
w2 = self.w2[s:e, :, :]
|
||||
is_quantized = self.w1.dtype == torch.float8_e4m3fn
|
||||
|
||||
w1_scale, w2_scale = (None, None)
|
||||
if is_quantized:
|
||||
if self.is_quantized():
|
||||
assert self.w1_scale is not None
|
||||
assert self.w2_scale is not None
|
||||
w1_scale = self.w1_scale[s:e, :, :]
|
||||
w2_scale = self.w2_scale[s:e, :, :]
|
||||
return WeightTensors(w1, w2, w1_scale, w2_scale)
|
||||
|
||||
w1_gs = self.w1_gs
|
||||
w2_gs = self.w2_gs
|
||||
if w1_gs is not None:
|
||||
assert w2_gs is not None
|
||||
w1_gs = w1_gs[s:e]
|
||||
w2_gs = w2_gs[s:e]
|
||||
|
||||
return WeightTensors(w1, w2, w1_scale, w2_scale, w1_gs, w2_gs)
|
||||
|
||||
@staticmethod
|
||||
def make(config: Config) -> "WeightTensors":
|
||||
|
||||
if config.quant_dtype is None:
|
||||
# just make normal dtype weights
|
||||
w1, w2 = make_non_quant_weights(e=config.E,
|
||||
n=config.N,
|
||||
k=config.K,
|
||||
dtype=config.dtype)
|
||||
return WeightTensors(w1=w1, w2=w2, w1_scale=None, w2_scale=None)
|
||||
|
||||
assert config.quant_dtype == torch.float8_e4m3fn
|
||||
if not config.is_fp8_block_quantized():
|
||||
w1, w2, w1_scale, w2_scale = make_quant_fp8_weights(
|
||||
e=config.E,
|
||||
n=config.N,
|
||||
k=config.K,
|
||||
per_out_channel_quant=config.is_per_out_ch_quant,
|
||||
)
|
||||
return WeightTensors(w1=w1,
|
||||
w2=w2,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale)
|
||||
|
||||
assert config.quant_block_shape is not None
|
||||
w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
|
||||
(_, w1, w1_scale, w1_gs), (_, w2, w2_scale, w2_gs) = make_test_weights(
|
||||
e=config.E,
|
||||
n=config.N,
|
||||
k=config.K,
|
||||
block_size=config.quant_block_shape,
|
||||
in_dtype=config.dtype,
|
||||
quant_dtype=config.quant_dtype,
|
||||
block_shape=config.quant_block_shape,
|
||||
per_act_token_quant=config.is_per_out_ch_quant,
|
||||
)
|
||||
return WeightTensors(w1=w1,
|
||||
w2=w2,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale)
|
||||
w2_scale=w2_scale,
|
||||
w1_gs=w1_gs,
|
||||
w2_gs=w2_gs)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -449,7 +403,6 @@ class RankTensors:
|
||||
dtype=dtype)
|
||||
topk_weights, topk_ids, _ = fused_topk(hidden_states, score, topk,
|
||||
False)
|
||||
topk_ids = topk_ids.to(config.topk_ids_dtype)
|
||||
|
||||
# distribute topk_ids evenly
|
||||
for mi in range(m):
|
||||
@ -457,7 +410,7 @@ class RankTensors:
|
||||
topk_ids = topk_ids.to(device=torch.cuda.current_device())
|
||||
|
||||
expert_map = None
|
||||
if config.world_size > 1:
|
||||
if config.world_size > 1 and config.supports_expert_map():
|
||||
expert_map = torch.full((global_num_experts, ),
|
||||
fill_value=-1,
|
||||
dtype=torch.int32)
|
||||
@ -480,92 +433,100 @@ class RankTensors:
|
||||
def reference_moe_impl(config: Config, weights: WeightTensors,
|
||||
rank_tensors: RankTensors) -> torch.Tensor:
|
||||
|
||||
return torch_experts(a=rank_tensors.hidden_states,
|
||||
w1=weights.w1,
|
||||
w2=weights.w2,
|
||||
if config.quant_dtype == "nvfp4":
|
||||
quant_blocksize = 16
|
||||
dtype = config.dtype
|
||||
|
||||
w1_q = weights.w1
|
||||
w1_blockscale = weights.w1_scale
|
||||
w1_gs = weights.w1_gs
|
||||
|
||||
w2_q = weights.w2
|
||||
w2_blockscale = weights.w2_scale
|
||||
w2_gs = weights.w2_gs
|
||||
|
||||
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(
|
||||
rank_tensors.hidden_states.flatten(), dim=-1)).to(torch.float32)
|
||||
|
||||
assert w1_gs is not None
|
||||
assert w2_gs is not None
|
||||
assert w1_blockscale is not None
|
||||
assert w2_blockscale is not None
|
||||
|
||||
assert w1_blockscale.shape[1] % 128 == 0
|
||||
assert w1_blockscale.shape[2] % 4 == 0
|
||||
assert w2_blockscale.shape[1] % 128 == 0
|
||||
assert w2_blockscale.shape[2] % 4 == 0
|
||||
|
||||
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(
|
||||
rank_tensors.hidden_states, a_global_scale)
|
||||
|
||||
a = dequantize_nvfp4_to_dtype(a_fp4,
|
||||
a_scale_interleaved,
|
||||
a_global_scale,
|
||||
dtype=dtype,
|
||||
device=a_fp4.device,
|
||||
block_size=quant_blocksize)
|
||||
|
||||
e = w1_q.shape[0]
|
||||
n = w1_q.shape[1] // 2
|
||||
k = w2_q.shape[1]
|
||||
|
||||
w1 = torch.zeros((e, 2 * n, k), device="cuda", dtype=dtype)
|
||||
w2 = torch.zeros((e, k, n), device="cuda", dtype=dtype)
|
||||
|
||||
for idx in range(0, e):
|
||||
w1[idx] = dequantize_nvfp4_to_dtype(w1_q[idx],
|
||||
w1_blockscale[idx],
|
||||
w1_gs[idx],
|
||||
dtype=dtype,
|
||||
device=w1_q.device,
|
||||
block_size=quant_blocksize)
|
||||
w2[idx] = dequantize_nvfp4_to_dtype(w2_q[idx],
|
||||
w2_blockscale[idx],
|
||||
w2_gs[idx],
|
||||
dtype=dtype,
|
||||
device=w2_q.device,
|
||||
block_size=quant_blocksize)
|
||||
a_scale = None
|
||||
w1_scale = None
|
||||
w2_scale = None
|
||||
quant_dtype = None
|
||||
per_act_token_quant = False
|
||||
block_shape = None
|
||||
else:
|
||||
a = rank_tensors.hidden_states
|
||||
a_scale = rank_tensors.hidden_states_scale
|
||||
w1 = weights.w1
|
||||
w1_scale = weights.w1_scale
|
||||
w2 = weights.w2
|
||||
w2_scale = weights.w2_scale
|
||||
quant_dtype = config.quant_dtype
|
||||
per_act_token_quant = config.is_per_act_token_quant
|
||||
block_shape = config.quant_block_shape
|
||||
|
||||
return torch_experts(a=a,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weight=rank_tensors.topk_weights,
|
||||
topk_ids=rank_tensors.topk_ids,
|
||||
global_num_experts=config.E,
|
||||
expert_map=None,
|
||||
w1_scale=weights.w1_scale,
|
||||
w2_scale=weights.w2_scale,
|
||||
a1_scale=rank_tensors.hidden_states_scale,
|
||||
quant_dtype=config.quant_dtype,
|
||||
per_act_token_quant=config.is_per_act_token_quant,
|
||||
block_shape=config.quant_block_shape,
|
||||
apply_router_weights_on_input=config.topk == 1)
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a_scale,
|
||||
quant_dtype=quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
apply_router_weights_on_input=config.topk == 1
|
||||
and config.supports_apply_weight_on_input())
|
||||
|
||||
|
||||
def make_fused_experts(
|
||||
config: Config, moe: FusedMoEConfig,
|
||||
num_dispatchers: int) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
|
||||
use_fp8 = config.quant_dtype == torch.float8_e4m3fn
|
||||
batch_kwargs = {
|
||||
"max_num_tokens": moe.max_num_tokens,
|
||||
"num_dispatchers": num_dispatchers,
|
||||
}
|
||||
quant_kwargs = {
|
||||
"use_fp8_w8a8": use_fp8,
|
||||
"use_int8_w8a8": False,
|
||||
"use_int8_w8a16": False,
|
||||
"use_int4_w4a16": False,
|
||||
"block_shape": config.quant_block_shape,
|
||||
"per_act_token_quant": config.is_per_act_token_quant,
|
||||
}
|
||||
deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()}
|
||||
|
||||
if config.fused_experts_type == BatchedDeepGemmExperts:
|
||||
kwargs = batch_kwargs | {
|
||||
"block_shape": config.quant_block_shape,
|
||||
"per_act_token_quant": config.is_per_act_token_quant,
|
||||
}
|
||||
print(f"Making BatchedDeepGemmExperts {kwargs} ...")
|
||||
experts = BatchedDeepGemmExperts(**kwargs)
|
||||
elif config.fused_experts_type == BatchedTritonExperts:
|
||||
kwargs = batch_kwargs | quant_kwargs
|
||||
print(f"Making BatchedTritonExperts {kwargs} ...")
|
||||
experts = BatchedTritonExperts(**kwargs)
|
||||
elif config.fused_experts_type == BatchedTritonOrDeepGemmExperts:
|
||||
kwargs = batch_kwargs | quant_kwargs | deepgemm_kwargs
|
||||
print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...")
|
||||
experts = BatchedTritonOrDeepGemmExperts(**kwargs)
|
||||
elif config.fused_experts_type == DeepGemmExperts:
|
||||
print("Making DeepGemmExperts () ...")
|
||||
experts = DeepGemmExperts()
|
||||
elif config.fused_experts_type == TritonExperts:
|
||||
kwargs = quant_kwargs
|
||||
print(f"Making TritonExperts {kwargs} ...")
|
||||
experts = TritonExperts(**kwargs)
|
||||
elif config.fused_experts_type == TritonOrDeepGemmExperts:
|
||||
kwargs = quant_kwargs | deepgemm_kwargs
|
||||
print(f"Making TritonOrDeepGemmExperts {kwargs} ...")
|
||||
experts = TritonOrDeepGemmExperts(**kwargs)
|
||||
elif config.fused_experts_type == NaiveBatchedExperts:
|
||||
kwargs = batch_kwargs | quant_kwargs
|
||||
print(f"Making NaiveBatchedExperts {kwargs} ...")
|
||||
experts = NaiveBatchedExperts(**kwargs)
|
||||
elif config.fused_experts_type == CutlassExpertsFp8:
|
||||
use_batched_format = config.is_batched_prepare_finalize()
|
||||
num_experts = (moe.num_local_experts
|
||||
if use_batched_format else moe.num_experts)
|
||||
kwargs = {
|
||||
"max_experts_per_worker": num_experts,
|
||||
"out_dtype": moe.in_dtype,
|
||||
"per_act_token_quant": config.is_per_act_token_quant,
|
||||
"per_out_ch_quant": config.is_per_out_ch_quant,
|
||||
"block_shape": config.quant_block_shape,
|
||||
"num_dispatchers": num_dispatchers,
|
||||
"use_batched_format": use_batched_format
|
||||
}
|
||||
print(f"Making CutlassExpertsFp8 {kwargs} ...")
|
||||
experts = CutlassExpertsFp8(**kwargs)
|
||||
|
||||
return experts
|
||||
|
||||
|
||||
def make_modular_kernel(config: Config,
|
||||
vllm_config: VllmConfig) -> mk.FusedMoEModularKernel:
|
||||
def make_modular_kernel(
|
||||
config: Config,
|
||||
vllm_config: VllmConfig,
|
||||
weights: WeightTensors,
|
||||
) -> mk.FusedMoEModularKernel:
|
||||
|
||||
def next_power_of_2(x):
|
||||
import math
|
||||
@ -579,6 +540,7 @@ def make_modular_kernel(config: Config,
|
||||
dp_size_=get_dp_group().world_size,
|
||||
vllm_parallel_config=vllm_config.parallel_config,
|
||||
)
|
||||
|
||||
moe = FusedMoEConfig(
|
||||
num_experts=config.E,
|
||||
experts_per_token=config.topk,
|
||||
@ -591,15 +553,16 @@ def make_modular_kernel(config: Config,
|
||||
)
|
||||
|
||||
# make modular kernel
|
||||
prepare_finalize = None
|
||||
if config.needs_all2all():
|
||||
prepare_finalize = FusedMoEMethodBase.maybe_make_prepare_finalize(moe)
|
||||
assert prepare_finalize is not None
|
||||
else:
|
||||
prepare_finalize = MoEPrepareAndFinalizeNoEP()
|
||||
prepare_finalize = make_prepare_finalize(config.prepare_finalize_type,
|
||||
config.all2all_backend(), moe)
|
||||
|
||||
fused_experts = make_fused_experts(config, moe,
|
||||
prepare_finalize.num_dispatchers())
|
||||
fused_experts = make_fused_experts(
|
||||
config.fused_experts_type,
|
||||
moe,
|
||||
prepare_finalize.num_dispatchers(),
|
||||
weights.w1_gs,
|
||||
weights.w2_gs,
|
||||
)
|
||||
|
||||
modular_kernel = mk.FusedMoEModularKernel(
|
||||
prepare_finalize=prepare_finalize, fused_experts=fused_experts)
|
||||
@ -620,22 +583,45 @@ def run_modular_kernel(
|
||||
# weights for rank
|
||||
rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts)
|
||||
|
||||
mk = make_modular_kernel(config, vllm_config)
|
||||
mk = make_modular_kernel(config, vllm_config, weights)
|
||||
|
||||
mk_kwargs = {
|
||||
"hidden_states": rank_tensors.hidden_states.clone(
|
||||
"hidden_states":
|
||||
rank_tensors.hidden_states.clone(
|
||||
), # impls might update the tensor in place
|
||||
"w1": rank_weights.w1,
|
||||
"w2": rank_weights.w2,
|
||||
"topk_weights": rank_tensors.topk_weights,
|
||||
"topk_ids": rank_tensors.topk_ids,
|
||||
"expert_map": rank_tensors.expert_map,
|
||||
"w1_scale": rank_weights.w1_scale,
|
||||
"w2_scale": rank_weights.w2_scale,
|
||||
"a1_scale": rank_tensors.hidden_states_scale,
|
||||
"global_num_experts": config.E,
|
||||
"apply_router_weight_on_input": config.topk == 1,
|
||||
"w1":
|
||||
rank_weights.w1,
|
||||
"w2":
|
||||
rank_weights.w2,
|
||||
"topk_weights":
|
||||
rank_tensors.topk_weights,
|
||||
"topk_ids":
|
||||
rank_tensors.topk_ids.to(mk.prepare_finalize.topk_indices_dtype()),
|
||||
"expert_map":
|
||||
rank_tensors.expert_map,
|
||||
"w1_scale":
|
||||
rank_weights.w1_scale,
|
||||
"w2_scale":
|
||||
rank_weights.w2_scale,
|
||||
"a1_scale":
|
||||
rank_tensors.hidden_states_scale,
|
||||
"global_num_experts":
|
||||
config.E,
|
||||
"apply_router_weight_on_input":
|
||||
config.topk == 1 and config.supports_apply_weight_on_input(),
|
||||
}
|
||||
out = mk.forward(**mk_kwargs)
|
||||
|
||||
num_tokens = rank_tensors.hidden_states.shape[0]
|
||||
num_tokens_across_dp = torch.tensor([num_tokens] * config.world_size,
|
||||
device="cuda",
|
||||
dtype=torch.int)
|
||||
|
||||
with set_forward_context(
|
||||
None,
|
||||
vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
):
|
||||
out = mk.forward(**mk_kwargs)
|
||||
|
||||
return out
|
||||
|
||||
@ -1,58 +1,316 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
# Fused experts and PrepareFinalize imports
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
|
||||
BatchedTritonOrDeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
|
||||
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
|
||||
FusedMoEQuantConfig)
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedTritonExperts, NaiveBatchedExperts)
|
||||
from vllm.model_executor.layers.fused_moe.layer import TritonExperts
|
||||
from vllm.model_executor.layers.fused_moe.layer import (FusedMoEMethodBase,
|
||||
TritonExperts)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP)
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts)
|
||||
from vllm.utils import has_deep_ep, has_pplx
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
cutlass_fp4_supported)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
cutlass_fp8_supported)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_supported
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
|
||||
if has_deep_ep():
|
||||
|
||||
@dataclass
|
||||
class PrepareFinalizeInfo:
|
||||
activation_format: mk.FusedMoEActivationFormat
|
||||
supported_dtypes: list[Union[torch.dtype, str]]
|
||||
blocked_quantization_support: bool
|
||||
backend: Optional[str]
|
||||
supports_apply_weight_on_input: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExpertInfo:
|
||||
activation_format: mk.FusedMoEActivationFormat
|
||||
supported_dtypes: list[Union[torch.dtype, str]]
|
||||
blocked_quantization_support: bool
|
||||
supports_chunking: bool
|
||||
supports_expert_map: bool
|
||||
needs_matching_quant: bool = False
|
||||
needs_deep_gemm: bool = False
|
||||
|
||||
|
||||
PREPARE_FINALIZE_INFO: dict[mk.FusedMoEPrepareAndFinalize,
|
||||
PrepareFinalizeInfo] = {}
|
||||
EXPERT_INFO: dict[mk.FusedMoEPermuteExpertsUnpermute, ExpertInfo] = {}
|
||||
MK_ALL_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
|
||||
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
|
||||
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
|
||||
MK_FUSED_EXPERT_TYPES: list[mk.FusedMoEPermuteExpertsUnpermute] = []
|
||||
|
||||
standard_format = mk.FusedMoEActivationFormat.Standard
|
||||
batched_format = mk.FusedMoEActivationFormat.BatchedExperts
|
||||
common_float_types: list[Union[torch.dtype, str]] = [
|
||||
torch.float8_e4m3fn, torch.bfloat16, torch.float16, torch.float32
|
||||
]
|
||||
common_float_and_int_types = common_float_types + [torch.int8]
|
||||
nv_fp4_types = ["nvfp4"]
|
||||
fp8_types = [torch.float8_e4m3fn]
|
||||
|
||||
|
||||
def register_prepare_and_finalize(
|
||||
kind,
|
||||
activation_format: mk.FusedMoEActivationFormat,
|
||||
supported_dtypes: list[Union[torch.dtype, str]],
|
||||
blocked_quantization_support: bool,
|
||||
backend: Optional[str],
|
||||
force_multigpu: bool = False,
|
||||
supports_apply_weight_on_input: bool = True,
|
||||
):
|
||||
global PREPARE_FINALIZE_INFO
|
||||
global MK_ALL_PREPARE_FINALIZE_TYPES
|
||||
global MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
|
||||
global MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES
|
||||
assert kind not in PREPARE_FINALIZE_INFO
|
||||
|
||||
PREPARE_FINALIZE_INFO[kind] = PrepareFinalizeInfo(
|
||||
activation_format,
|
||||
supported_dtypes,
|
||||
blocked_quantization_support,
|
||||
backend,
|
||||
supports_apply_weight_on_input,
|
||||
)
|
||||
MK_ALL_PREPARE_FINALIZE_TYPES.append(kind)
|
||||
if backend is not None or force_multigpu:
|
||||
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES.append(kind)
|
||||
else:
|
||||
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES.append(kind)
|
||||
|
||||
|
||||
def register_experts(
|
||||
kind,
|
||||
activation_format: mk.FusedMoEActivationFormat,
|
||||
supported_dtypes: list[Union[torch.dtype, str]],
|
||||
blocked_quantization_support: bool,
|
||||
supports_chunking: bool,
|
||||
supports_expert_map: bool,
|
||||
needs_matching_quant: bool = False,
|
||||
needs_deep_gemm: bool = False,
|
||||
):
|
||||
global EXPERT_INFO
|
||||
global MK_FUSED_EXPERT_TYPES
|
||||
assert kind not in EXPERT_INFO
|
||||
|
||||
EXPERT_INFO[kind] = ExpertInfo(
|
||||
activation_format,
|
||||
supported_dtypes,
|
||||
blocked_quantization_support,
|
||||
supports_chunking,
|
||||
supports_expert_map,
|
||||
needs_matching_quant,
|
||||
needs_deep_gemm,
|
||||
)
|
||||
|
||||
MK_FUSED_EXPERT_TYPES.append(kind)
|
||||
|
||||
|
||||
def prepare_finalize_info(kind) -> PrepareFinalizeInfo:
|
||||
info = PREPARE_FINALIZE_INFO.get(kind)
|
||||
assert info is not None
|
||||
return info
|
||||
|
||||
|
||||
def expert_info(kind) -> ExpertInfo:
|
||||
info = EXPERT_INFO.get(kind)
|
||||
assert info is not None
|
||||
return info
|
||||
|
||||
|
||||
register_prepare_and_finalize(
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
standard_format,
|
||||
common_float_types,
|
||||
blocked_quantization_support=True,
|
||||
backend=None,
|
||||
)
|
||||
|
||||
register_experts(
|
||||
BatchedTritonExperts,
|
||||
batched_format,
|
||||
common_float_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=False,
|
||||
supports_expert_map=False,
|
||||
needs_matching_quant=True,
|
||||
)
|
||||
|
||||
register_experts(
|
||||
TritonExperts,
|
||||
standard_format,
|
||||
common_float_and_int_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=True,
|
||||
supports_expert_map=True,
|
||||
needs_matching_quant=True,
|
||||
)
|
||||
|
||||
register_experts(
|
||||
NaiveBatchedExperts,
|
||||
batched_format,
|
||||
common_float_and_int_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=False,
|
||||
supports_expert_map=True,
|
||||
)
|
||||
|
||||
# Disable on blackwell for now
|
||||
if has_deep_ep() and not current_platform.has_device_capability(100):
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
||||
DeepEPHTPrepareAndFinalize)
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
||||
DeepEPLLPrepareAndFinalize)
|
||||
|
||||
register_prepare_and_finalize(
|
||||
DeepEPHTPrepareAndFinalize,
|
||||
standard_format,
|
||||
common_float_types,
|
||||
blocked_quantization_support=True,
|
||||
backend="deepep_high_throughput",
|
||||
)
|
||||
|
||||
register_prepare_and_finalize(
|
||||
DeepEPLLPrepareAndFinalize,
|
||||
batched_format,
|
||||
common_float_types,
|
||||
blocked_quantization_support=True,
|
||||
backend="deepep_low_latency",
|
||||
)
|
||||
|
||||
if has_pplx():
|
||||
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
||||
PplxPrepareAndFinalize)
|
||||
register_prepare_and_finalize(
|
||||
PplxPrepareAndFinalize,
|
||||
batched_format,
|
||||
common_float_and_int_types,
|
||||
blocked_quantization_support=True,
|
||||
backend="pplx",
|
||||
)
|
||||
|
||||
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES = []
|
||||
if has_pplx():
|
||||
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES += [PplxPrepareAndFinalize]
|
||||
if has_deep_ep():
|
||||
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES += [
|
||||
DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize
|
||||
]
|
||||
if (has_flashinfer_cutlass_fused_moe()
|
||||
and current_platform.has_device_capability(100)):
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
|
||||
FlashInferExperts)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
||||
FlashInferCutlassMoEPrepareAndFinalize)
|
||||
|
||||
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES = [MoEPrepareAndFinalizeNoEP]
|
||||
register_prepare_and_finalize(
|
||||
FlashInferCutlassMoEPrepareAndFinalize,
|
||||
standard_format,
|
||||
nv_fp4_types,
|
||||
blocked_quantization_support=True,
|
||||
backend=None,
|
||||
force_multigpu=True,
|
||||
supports_apply_weight_on_input=False,
|
||||
)
|
||||
|
||||
MK_ALL_PREPARE_FINALIZE_TYPES = (MK_MULTI_GPU_PREPARE_FINALIZE_TYPES +
|
||||
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES)
|
||||
register_experts(
|
||||
FlashInferExperts,
|
||||
standard_format,
|
||||
nv_fp4_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=True,
|
||||
# Note: this is a hack to get it to run for now
|
||||
supports_expert_map=True,
|
||||
)
|
||||
else:
|
||||
FlashInferCutlassMoEPrepareAndFinalize = None
|
||||
|
||||
MK_FUSED_EXPERT_TYPES = [
|
||||
BatchedDeepGemmExperts,
|
||||
BatchedTritonExperts,
|
||||
NaiveBatchedExperts,
|
||||
BatchedTritonOrDeepGemmExperts,
|
||||
CutlassExpertsFp8,
|
||||
DeepGemmExperts,
|
||||
TritonOrDeepGemmExperts,
|
||||
TritonExperts,
|
||||
]
|
||||
if has_deep_gemm() and is_deep_gemm_supported():
|
||||
register_experts(
|
||||
BatchedDeepGemmExperts,
|
||||
batched_format,
|
||||
fp8_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=False,
|
||||
supports_expert_map=False,
|
||||
needs_matching_quant=False,
|
||||
needs_deep_gemm=True,
|
||||
)
|
||||
register_experts(
|
||||
DeepGemmExperts,
|
||||
standard_format,
|
||||
fp8_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=True,
|
||||
supports_expert_map=True,
|
||||
needs_matching_quant=False,
|
||||
needs_deep_gemm=True,
|
||||
),
|
||||
register_experts(
|
||||
BatchedTritonOrDeepGemmExperts,
|
||||
batched_format,
|
||||
common_float_and_int_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=False,
|
||||
supports_expert_map=False,
|
||||
needs_matching_quant=True,
|
||||
needs_deep_gemm=True,
|
||||
)
|
||||
register_experts(
|
||||
TritonOrDeepGemmExperts,
|
||||
standard_format,
|
||||
common_float_and_int_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=True,
|
||||
supports_expert_map=True,
|
||||
needs_matching_quant=True,
|
||||
needs_deep_gemm=True,
|
||||
)
|
||||
|
||||
if cutlass_fp8_supported():
|
||||
from vllm.model_executor.layers.fused_moe import (CutlassBatchedExpertsFp8,
|
||||
CutlassExpertsFp8)
|
||||
register_experts(
|
||||
CutlassExpertsFp8,
|
||||
standard_format,
|
||||
fp8_types,
|
||||
blocked_quantization_support=False,
|
||||
supports_chunking=True,
|
||||
supports_expert_map=False,
|
||||
)
|
||||
register_experts(
|
||||
CutlassBatchedExpertsFp8,
|
||||
batched_format,
|
||||
fp8_types,
|
||||
blocked_quantization_support=False,
|
||||
supports_chunking=False,
|
||||
supports_expert_map=False,
|
||||
)
|
||||
|
||||
if cutlass_fp4_supported():
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
CutlassExpertsFp4)
|
||||
register_experts(
|
||||
CutlassExpertsFp4,
|
||||
standard_format,
|
||||
nv_fp4_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=True,
|
||||
supports_expert_map=False,
|
||||
)
|
||||
|
||||
MK_QUANT_CONFIGS = [
|
||||
None,
|
||||
@ -85,3 +343,156 @@ MK_QUANT_CONFIGS = [
|
||||
# block-quantized weights and per-token activations
|
||||
# block-quantized weights and per-tensor activations
|
||||
]
|
||||
|
||||
if cutlass_fp4_supported() or has_flashinfer_cutlass_fused_moe():
|
||||
MK_QUANT_CONFIGS += [
|
||||
FusedMoEQuantConfig(quant_dtype="nvfp4",
|
||||
per_out_ch_quant=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=None),
|
||||
]
|
||||
|
||||
|
||||
def _make_gscale(num_experts: int) -> torch.Tensor:
|
||||
return torch.ones((num_experts, ),
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=torch.float32)
|
||||
|
||||
|
||||
def make_prepare_finalize(
|
||||
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
|
||||
backend: Optional[str],
|
||||
moe: FusedMoEConfig,
|
||||
) -> mk.FusedMoEPrepareAndFinalize:
|
||||
if backend != "naive" and backend is not None:
|
||||
prepare_finalize = FusedMoEMethodBase._maybe_make_prepare_finalize(moe)
|
||||
assert prepare_finalize is not None
|
||||
return prepare_finalize
|
||||
elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize:
|
||||
return FlashInferCutlassMoEPrepareAndFinalize(
|
||||
use_dp=moe.moe_parallel_config.dp_size > 1,
|
||||
a1_gscale=_make_gscale(moe.num_local_experts),
|
||||
)
|
||||
else:
|
||||
return MoEPrepareAndFinalizeNoEP()
|
||||
|
||||
|
||||
def _slice(rank: int, num_local_experts: int, t: torch.Tensor) -> torch.Tensor:
|
||||
s = rank * num_local_experts
|
||||
e = s + num_local_experts
|
||||
return t[s:e]
|
||||
|
||||
|
||||
def make_fused_experts(
|
||||
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
|
||||
moe: FusedMoEConfig,
|
||||
num_dispatchers: int,
|
||||
w1_gs: Optional[torch.Tensor],
|
||||
w2_gs: Optional[torch.Tensor],
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
|
||||
use_fp8 = moe.quant_dtype == torch.float8_e4m3fn
|
||||
batch_kwargs = {
|
||||
"max_num_tokens": moe.max_num_tokens,
|
||||
"num_dispatchers": num_dispatchers,
|
||||
}
|
||||
quant_kwargs = {
|
||||
"use_fp8_w8a8": use_fp8,
|
||||
"use_int8_w8a8": False,
|
||||
"use_int8_w8a16": False,
|
||||
"use_int4_w4a16": False,
|
||||
"block_shape": moe.block_shape,
|
||||
"per_act_token_quant": moe.per_act_token_quant,
|
||||
}
|
||||
deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()}
|
||||
|
||||
if fused_experts_type == BatchedDeepGemmExperts:
|
||||
kwargs = batch_kwargs | {
|
||||
"block_shape": moe.block_shape,
|
||||
"per_act_token_quant": moe.per_act_token_quant,
|
||||
}
|
||||
print(f"Making BatchedDeepGemmExperts {kwargs} ...")
|
||||
experts = BatchedDeepGemmExperts(**kwargs)
|
||||
elif fused_experts_type == BatchedTritonExperts:
|
||||
kwargs = batch_kwargs | quant_kwargs
|
||||
print(f"Making BatchedTritonExperts {kwargs} ...")
|
||||
experts = BatchedTritonExperts(**kwargs)
|
||||
elif fused_experts_type == BatchedTritonOrDeepGemmExperts:
|
||||
kwargs = batch_kwargs | quant_kwargs | deepgemm_kwargs
|
||||
print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...")
|
||||
experts = BatchedTritonOrDeepGemmExperts(**kwargs)
|
||||
elif fused_experts_type == DeepGemmExperts:
|
||||
print("Making DeepGemmExperts () ...")
|
||||
experts = DeepGemmExperts()
|
||||
elif fused_experts_type == TritonExperts:
|
||||
kwargs = quant_kwargs
|
||||
print(f"Making TritonExperts {kwargs} ...")
|
||||
experts = TritonExperts(**kwargs)
|
||||
elif fused_experts_type == TritonOrDeepGemmExperts:
|
||||
kwargs = quant_kwargs | deepgemm_kwargs
|
||||
print(f"Making TritonOrDeepGemmExperts {kwargs} ...")
|
||||
experts = TritonOrDeepGemmExperts(**kwargs)
|
||||
elif fused_experts_type == NaiveBatchedExperts:
|
||||
kwargs = batch_kwargs | quant_kwargs
|
||||
print(f"Making NaiveBatchedExperts {kwargs} ...")
|
||||
experts = NaiveBatchedExperts(**kwargs)
|
||||
elif fused_experts_type == CutlassExpertsFp8:
|
||||
kwargs = {
|
||||
"out_dtype": moe.in_dtype,
|
||||
"per_act_token_quant": moe.per_act_token_quant,
|
||||
"per_out_ch_quant": moe.per_out_ch_quant,
|
||||
"block_shape": moe.block_shape,
|
||||
}
|
||||
print(f"Making CutlassExpertsFp8 {kwargs} ...")
|
||||
experts = CutlassExpertsFp8(**kwargs)
|
||||
elif fused_experts_type == CutlassBatchedExpertsFp8:
|
||||
kwargs = {
|
||||
"max_experts_per_worker": moe.num_local_experts,
|
||||
"num_dispatchers": num_dispatchers,
|
||||
"out_dtype": moe.in_dtype,
|
||||
"per_act_token_quant": moe.per_act_token_quant,
|
||||
"per_out_ch_quant": moe.per_out_ch_quant,
|
||||
"block_shape": moe.block_shape,
|
||||
}
|
||||
print(f"Making CutlassBatchedExpertsFp8 {kwargs} ...")
|
||||
experts = CutlassBatchedExpertsFp8(**kwargs)
|
||||
elif fused_experts_type == CutlassExpertsFp4:
|
||||
assert w1_gs is not None and w2_gs is not None
|
||||
num_experts = moe.num_local_experts
|
||||
rank = moe.moe_parallel_config.dp_rank
|
||||
kwargs = {
|
||||
"g1_alphas": _slice(rank, num_experts, (1 / w1_gs)),
|
||||
"g2_alphas": _slice(rank, num_experts, (1 / w2_gs)),
|
||||
"a1_gscale": _make_gscale(num_experts),
|
||||
"a2_gscale": _make_gscale(num_experts),
|
||||
"max_experts_per_worker": num_experts,
|
||||
"out_dtype": moe.in_dtype,
|
||||
"per_act_token_quant": moe.per_act_token_quant,
|
||||
"per_out_ch_quant": moe.per_out_ch_quant,
|
||||
"block_shape": moe.block_shape,
|
||||
"num_dispatchers": num_dispatchers,
|
||||
}
|
||||
print(f"Making CutlassExpertsFp4 {kwargs} ...")
|
||||
experts = CutlassExpertsFp4(**kwargs)
|
||||
elif fused_experts_type == FlashInferExperts:
|
||||
assert w1_gs is not None and w2_gs is not None
|
||||
num_experts = moe.num_local_experts
|
||||
rank = moe.moe_parallel_config.dp_rank
|
||||
kwargs = {
|
||||
"g1_alphas": _slice(rank, num_experts, (1 / w1_gs)),
|
||||
"g2_alphas": _slice(rank, num_experts, (1 / w2_gs)),
|
||||
"a1_gscale": _make_gscale(num_experts),
|
||||
"a2_gscale": _make_gscale(num_experts),
|
||||
"out_dtype": moe.in_dtype,
|
||||
"quant_dtype": "nvfp4",
|
||||
"ep_rank": moe.ep_rank,
|
||||
"ep_size": moe.ep_size,
|
||||
"tp_rank": moe.tp_rank,
|
||||
"tp_size": moe.tp_size,
|
||||
}
|
||||
print(f"Making FlashInferExperts {kwargs} ...")
|
||||
experts = FlashInferExperts(**kwargs)
|
||||
else:
|
||||
raise RuntimeError(f"Unknown fused experts type: {fused_experts_type}")
|
||||
|
||||
return experts
|
||||
|
||||
@ -52,7 +52,7 @@ def profile_modular_kernel(
|
||||
rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts)
|
||||
|
||||
# make modular kernel
|
||||
mk = make_modular_kernel(config, vllm_config)
|
||||
mk = make_modular_kernel(config, vllm_config, weights)
|
||||
|
||||
mk_kwargs = {
|
||||
"hidden_states": rank_tensors.hidden_states,
|
||||
@ -83,7 +83,7 @@ def rank_worker(
|
||||
# sanity check
|
||||
from vllm import envs
|
||||
if config.fused_moe_chunk_size is not None:
|
||||
assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE)
|
||||
assert config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
|
||||
# get weights to this device
|
||||
weights.to_current_device()
|
||||
|
||||
@ -1,117 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.utils.deep_gemm import per_block_cast_to_fp8
|
||||
|
||||
|
||||
def per_token_cast_to_fp8(
|
||||
x: torch.Tensor, block_size: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
pad_size = (block_size - (n % block_size)) % block_size
|
||||
x = torch.nn.functional.pad(x,
|
||||
(0, pad_size), value=0) if pad_size > 0 else x
|
||||
x_view = x.view(m, -1, block_size)
|
||||
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
||||
fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
|
||||
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
|
||||
|
||||
|
||||
def make_non_quant_weights(
|
||||
e: int,
|
||||
n: int,
|
||||
k: int,
|
||||
dtype: torch.dtype,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Return weights w1, w2
|
||||
"""
|
||||
device = torch.cuda.current_device()
|
||||
w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 15
|
||||
w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 15
|
||||
return w1, w2
|
||||
|
||||
|
||||
def make_block_quant_fp8_weights(
|
||||
e: int,
|
||||
n: int,
|
||||
k: int,
|
||||
block_size: list[int],
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Return weights w1, w2, w1_scale, w2_scale
|
||||
"""
|
||||
dtype = torch.bfloat16
|
||||
device = torch.cuda.current_device()
|
||||
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
w1_bf16, w2_bf16 = make_non_quant_weights(e, n, k, dtype)
|
||||
w1_bf16 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
|
||||
w2_bf16 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
|
||||
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
n_tiles_w1 = ((2 * n) + block_n - 1) // block_n
|
||||
k_tiles_w1 = (k + block_k - 1) // block_k
|
||||
n_tiles_w2 = (k + block_n - 1) // block_n
|
||||
k_tiles_w2 = (n + block_k - 1) // block_k
|
||||
|
||||
w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn, device=device)
|
||||
w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn, device=device)
|
||||
|
||||
w1_s = torch.empty((e, n_tiles_w1, k_tiles_w1),
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
w2_s = torch.empty((e, n_tiles_w2, k_tiles_w2),
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
|
||||
assert w1_s.shape == (e, (2 * n + (block_n - 1)) // block_n,
|
||||
(k + (block_k - 1)) // block_k)
|
||||
assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2]
|
||||
|
||||
for i in range(e):
|
||||
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i],
|
||||
block_size=[block_k, block_n])
|
||||
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i],
|
||||
block_size=[block_k, block_n])
|
||||
|
||||
return w1, w2, w1_s, w2_s
|
||||
|
||||
|
||||
def make_quant_fp8_weights(
|
||||
e: int,
|
||||
n: int,
|
||||
k: int,
|
||||
per_out_channel_quant: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Return w1, w2, w1_scale, w2_scale
|
||||
"""
|
||||
q_dtype = torch.float8_e4m3fn
|
||||
|
||||
w1, w2 = make_non_quant_weights(e, n, k, dtype=torch.bfloat16)
|
||||
|
||||
# w1 -> w1_q, w2 -> w2_q
|
||||
w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype)
|
||||
w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype)
|
||||
|
||||
n_b_scales = 2 * n if per_out_channel_quant else 1
|
||||
k_b_scales = k if per_out_channel_quant else 1
|
||||
w1_scale = torch.empty((e, n_b_scales, 1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
w2_scale = torch.empty((e, k_b_scales, 1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
|
||||
for expert in range(e):
|
||||
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
|
||||
w1[expert], use_per_token_if_dynamic=per_out_channel_quant)
|
||||
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
|
||||
w2[expert], use_per_token_if_dynamic=per_out_channel_quant)
|
||||
return w1_q, w2_q, w1_scale, w2_scale
|
||||
@ -133,7 +133,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
)
|
||||
|
||||
B, B_q, B_scale, _, _, _ = make_test_weights(
|
||||
(B, B_q, B_scale, _), _ = make_test_weights(
|
||||
num_experts,
|
||||
N // 2,
|
||||
K,
|
||||
@ -243,7 +243,7 @@ def test_fused_moe_batched_experts(
|
||||
act_dtype = dtype
|
||||
quant_dtype = None
|
||||
|
||||
w1_16, w1, w1_s, w2_16, w2, w2_s = make_test_weights(
|
||||
(w1_16, w1, w1_s, _), (w2_16, w2, w2_s, _) = make_test_weights(
|
||||
e,
|
||||
n,
|
||||
k,
|
||||
|
||||
@ -161,18 +161,20 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
_, w1, w1_s, _, w2, w2_s = make_test_weights(E,
|
||||
N,
|
||||
K,
|
||||
dtype,
|
||||
torch.float8_e4m3fn,
|
||||
per_act_token_quant=False,
|
||||
block_shape=block_size)
|
||||
(_, w1, w1_s, _), (_, w2, w2_s,
|
||||
_) = make_test_weights(E,
|
||||
N,
|
||||
K,
|
||||
dtype,
|
||||
torch.float8_e4m3fn,
|
||||
per_act_token_quant=False,
|
||||
block_shape=block_size)
|
||||
|
||||
m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
use_mxfp4_w4a4=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=block_size)
|
||||
|
||||
@ -247,13 +249,14 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
_, w1, w1_s, _, w2, w2_s = make_test_weights(E,
|
||||
N,
|
||||
K,
|
||||
dtype,
|
||||
torch.float8_e4m3fn,
|
||||
per_act_token_quant=False,
|
||||
block_shape=block_size)
|
||||
(_, w1, w1_s, _), (_, w2, w2_s,
|
||||
_) = make_test_weights(E,
|
||||
N,
|
||||
K,
|
||||
dtype,
|
||||
torch.float8_e4m3fn,
|
||||
per_act_token_quant=False,
|
||||
block_shape=block_size)
|
||||
|
||||
# Note: for now use_compile will error out if the problem size is
|
||||
# large enough to trigger chunking. I'm leaving the flag and
|
||||
|
||||
@ -118,13 +118,14 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
_, w1, w1_s, _, w2, w2_s = make_test_weights(E,
|
||||
N,
|
||||
K,
|
||||
dtype,
|
||||
torch.int8,
|
||||
per_act_token_quant=False,
|
||||
block_shape=block_size)
|
||||
(_, w1, w1_s, _), (_, w2, w2_s,
|
||||
_) = make_test_weights(E,
|
||||
N,
|
||||
K,
|
||||
dtype,
|
||||
torch.int8,
|
||||
per_act_token_quant=False,
|
||||
block_shape=block_size)
|
||||
|
||||
# Set the context to avoid lots of warning spam.
|
||||
with set_current_vllm_config(vllm_config):
|
||||
|
||||
@ -9,6 +9,7 @@ import random
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import per_token_cast_to_fp8
|
||||
from tests.kernels.utils import baseline_scaled_mm
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
@ -16,20 +17,6 @@ from vllm.utils import cdiv
|
||||
from vllm.utils.deep_gemm import per_block_cast_to_fp8
|
||||
|
||||
|
||||
def per_token_cast_to_fp8(
|
||||
x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
pad_size = (128 - (n % 128)) % 128
|
||||
x = torch.nn.functional.pad(x,
|
||||
(0, pad_size), value=0) if pad_size > 0 else x
|
||||
x_view = x.view(m, -1, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
||||
fp8_data = (x_view *
|
||||
(448.0 / x_amax.unsqueeze(2))).to(dtype=torch.float8_e4m3fn)
|
||||
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_groups, expected_m_per_group, k, n", [
|
||||
(4, 8192, 7168, 4096),
|
||||
(4, 8192, 2048, 7168),
|
||||
@ -76,7 +63,7 @@ def test_cutlass_grouped_gemm(
|
||||
device=device,
|
||||
dtype=torch.float))
|
||||
for i in range(num_groups):
|
||||
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
|
||||
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i], [128, 128])
|
||||
|
||||
for i in range(num_groups):
|
||||
a = x_fp8[0][ep_offset[i]:ep_offset[i + 1]]
|
||||
|
||||
@ -70,8 +70,10 @@ def make_block_quant_fp8_weights(
|
||||
"""
|
||||
Return weights w1q, w2q, w1_scale, w2_scale
|
||||
"""
|
||||
w1, w1q, w1_scale, w2, w2q, w2_scale = make_test_weights(
|
||||
e, n, k, torch.bfloat16, torch.float8_e4m3fn, block_size)
|
||||
(_, w1q, w1_scale, _), (_, w2q, w2_scale,
|
||||
_) = make_test_weights(e, n, k, torch.bfloat16,
|
||||
torch.float8_e4m3fn,
|
||||
block_size)
|
||||
return w1q, w2q, w1_scale, w2_scale
|
||||
|
||||
|
||||
|
||||
@ -132,9 +132,9 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
|
||||
# Note: W1 has shape (E, 2N, K), so N = 512
|
||||
# can trigger the deepgemm path.
|
||||
MNKs = [
|
||||
(1024, 512, 128),
|
||||
(1024, 512, 512),
|
||||
(2048, 512, 512),
|
||||
(1024, 768, 128),
|
||||
(1024, 768, 512),
|
||||
(2048, 768, 512),
|
||||
(512, 1024, 1024),
|
||||
(512, 2048, 2048),
|
||||
(4096, 4096, 1024),
|
||||
|
||||
147
tests/kernels/moe/test_flashinfer_moe.py
Normal file
147
tests/kernels/moe/test_flashinfer_moe.py
Normal file
@ -0,0 +1,147 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import make_test_weights
|
||||
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
|
||||
FLOAT8_E4M3_MAX,
|
||||
dequantize_nvfp4_to_dtype)
|
||||
from tests.kernels.utils import torch_moe
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
FlashInferExperts, is_valid_flashinfer_cutlass_fused_moe)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
|
||||
if not has_flashinfer_cutlass_fused_moe(
|
||||
) or not current_platform.has_device_capability(100):
|
||||
pytest.skip("Requires flashinfer_cutlass_fused_moe and nvfp4 support",
|
||||
allow_module_level=True)
|
||||
|
||||
MNK_FACTORS = [
|
||||
(2, 1024, 1024),
|
||||
(2, 1024, 1536),
|
||||
(2, 3072, 1024),
|
||||
(2, 3072, 1536),
|
||||
(64, 1024, 1024),
|
||||
(64, 1024, 1536),
|
||||
(64, 3072, 1024),
|
||||
(64, 2048, 1536),
|
||||
(224, 1024, 1024),
|
||||
(224, 1024, 1536),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", [40, 64, 256])
|
||||
#@pytest.mark.parametrize("e", [128, 256])
|
||||
@pytest.mark.parametrize("topk", [1, 6, 8])
|
||||
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
|
||||
@torch.inference_mode()
|
||||
def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
|
||||
dtype: torch.dtype):
|
||||
current_platform.seed_everything(7)
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(
|
||||
pipeline_parallel_size=1))):
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
|
||||
quant_blocksize = 16
|
||||
|
||||
(_, w1_q, w1_blockscale,
|
||||
w1_gs), (_, w2_q, w2_blockscale, w2_gs) = make_test_weights(
|
||||
e,
|
||||
n,
|
||||
k,
|
||||
in_dtype=dtype,
|
||||
quant_dtype="nvfp4",
|
||||
block_shape=None, # use quant_blocksize?
|
||||
per_act_token_quant=False,
|
||||
)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
topk_weights, topk_ids, _ = fused_topk(a,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False)
|
||||
|
||||
a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
|
||||
a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
|
||||
|
||||
assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q)
|
||||
|
||||
assert w1_gs is not None
|
||||
assert w2_gs is not None
|
||||
assert w1_blockscale is not None
|
||||
assert w2_blockscale is not None
|
||||
|
||||
flashinfer_experts = FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
FlashInferExperts(
|
||||
a1_gscale=a1_gs,
|
||||
g1_alphas=(1 / w1_gs),
|
||||
a2_gscale=a2_gs,
|
||||
g2_alphas=(1 / w2_gs),
|
||||
out_dtype=dtype,
|
||||
quant_dtype="nvfp4",
|
||||
))
|
||||
|
||||
flashinfer_output = flashinfer_experts(
|
||||
hidden_states=a,
|
||||
w1=w1_q,
|
||||
w1_scale=w1_blockscale,
|
||||
w2=w2_q,
|
||||
w2_scale=w2_blockscale,
|
||||
a1_scale=a1_gs,
|
||||
a2_scale=a2_gs,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
)
|
||||
|
||||
# Reference check:
|
||||
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
|
||||
torch.amax(a.flatten(), dim=-1)).to(torch.float32)
|
||||
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale)
|
||||
_, m_k = a_fp4.shape
|
||||
a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4,
|
||||
a_scale_interleaved,
|
||||
a_global_scale,
|
||||
dtype=a.dtype,
|
||||
device=a.device,
|
||||
block_size=quant_blocksize)
|
||||
|
||||
w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype)
|
||||
w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype)
|
||||
|
||||
for idx in range(0, e):
|
||||
w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx],
|
||||
w1_blockscale[idx],
|
||||
w1_gs[idx],
|
||||
dtype=dtype,
|
||||
device=w1_q.device,
|
||||
block_size=quant_blocksize)
|
||||
w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx],
|
||||
w2_blockscale[idx],
|
||||
w2_gs[idx],
|
||||
dtype=dtype,
|
||||
device=w2_q.device,
|
||||
block_size=quant_blocksize)
|
||||
|
||||
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk)
|
||||
|
||||
torch.testing.assert_close(torch_output,
|
||||
flashinfer_output,
|
||||
atol=1e-1,
|
||||
rtol=1e-1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_flashinfer_fp4_moe_no_graph((2, 1024, 1024), 40, 1, torch.half)
|
||||
@ -2,6 +2,8 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import copy
|
||||
import textwrap
|
||||
import traceback
|
||||
from itertools import product
|
||||
from typing import Optional
|
||||
|
||||
@ -10,41 +12,51 @@ import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.config import VllmConfig, current_platform, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
|
||||
BatchedTritonOrDeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedTritonExperts)
|
||||
from vllm.model_executor.layers.fused_moe.layer import TritonExperts
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts)
|
||||
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
|
||||
from .modular_kernel_tools.common import (Config, RankTensors, WeightTensors,
|
||||
reference_moe_impl,
|
||||
run_modular_kernel)
|
||||
from .modular_kernel_tools.mk_objects import (
|
||||
MK_FUSED_EXPERT_TYPES, MK_MULTI_GPU_PREPARE_FINALIZE_TYPES,
|
||||
MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES)
|
||||
MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, expert_info)
|
||||
from .modular_kernel_tools.parallel_utils import (ProcessGroupInfo,
|
||||
parallel_launch_with_config)
|
||||
|
||||
# TODO (varun): These requirements are very strict and could be relaxed.
|
||||
has_all_packages = (has_deep_ep() and has_deep_gemm() and has_pplx())
|
||||
has_any_multi_gpu_package = (has_deep_ep() or has_deep_gemm() or has_pplx()
|
||||
or has_flashinfer_cutlass_fused_moe())
|
||||
|
||||
meets_package_requirements = pytest.mark.skipif(
|
||||
not has_all_packages,
|
||||
reason="Requires deep_ep & deep_gemm & pplx packages",
|
||||
meets_multi_gpu_requirements = pytest.mark.skipif(
|
||||
not has_any_multi_gpu_package,
|
||||
reason="Requires deep_ep or deep_gemm or pplx or flashinfer packages",
|
||||
)
|
||||
|
||||
|
||||
def format_result(verbose, msg, ex=None):
|
||||
if ex is not None:
|
||||
x = str(ex)
|
||||
newx = x.strip(" \n\t")[:16]
|
||||
if len(newx) < len(x):
|
||||
newx = newx + " ..."
|
||||
|
||||
prefix = "E\t"
|
||||
print(f"{textwrap.indent(traceback.format_exc(), prefix)}")
|
||||
print(f"FAILED {msg} - {newx}\n")
|
||||
elif verbose:
|
||||
print(f"PASSED {msg}")
|
||||
else:
|
||||
print(".", end="")
|
||||
|
||||
|
||||
def rank_worker(
|
||||
pgi: ProcessGroupInfo,
|
||||
vllm_config: VllmConfig,
|
||||
cpu_group,
|
||||
config: Config,
|
||||
weights: WeightTensors,
|
||||
verbose: bool,
|
||||
):
|
||||
current_platform.seed_everything(pgi.rank)
|
||||
|
||||
@ -61,39 +73,64 @@ def rank_worker(
|
||||
TOPKs = config.topks
|
||||
assert isinstance(TOPKs, list)
|
||||
|
||||
exceptions = []
|
||||
count = 0
|
||||
|
||||
for m, topk in product(Ms, TOPKs):
|
||||
print(f"Running m={m}, topk={topk} ...")
|
||||
# override m and topk
|
||||
cfgx = copy.deepcopy(config)
|
||||
cfgx.Ms = m
|
||||
cfgx.topks = topk
|
||||
try:
|
||||
print(f"Running[{pgi.rank}]: m={m}, topk={topk} ...")
|
||||
count = count + 1
|
||||
# override m and topk
|
||||
cfgx = copy.deepcopy(config)
|
||||
cfgx.Ms = m
|
||||
cfgx.topks = topk
|
||||
|
||||
# inputs for rank
|
||||
rank_tensors = RankTensors.make(cfgx, pgi)
|
||||
# inputs for rank
|
||||
rank_tensors = RankTensors.make(cfgx, pgi)
|
||||
|
||||
# modular kernel out
|
||||
mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights,
|
||||
rank_tensors)
|
||||
# modular kernel out
|
||||
mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights,
|
||||
rank_tensors)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
ref_out = reference_moe_impl(cfgx, weights, rank_tensors)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
ref_out = reference_moe_impl(cfgx, weights, rank_tensors)
|
||||
|
||||
torch.testing.assert_close(ref_out, mk_out, atol=3e-2, rtol=3e-2)
|
||||
if config.quant_dtype == "nvfp4":
|
||||
atol = 1e-1
|
||||
rtol = 1e-1
|
||||
else:
|
||||
atol = 3e-2
|
||||
rtol = 3e-2
|
||||
|
||||
torch.testing.assert_close(ref_out, mk_out, atol=atol, rtol=rtol)
|
||||
format_result(verbose, config.describe())
|
||||
except Exception as ex:
|
||||
format_result(verbose, config.describe(), ex)
|
||||
exceptions.append(ex)
|
||||
|
||||
if len(exceptions) > 0:
|
||||
raise RuntimeError(
|
||||
f"{len(exceptions)} of {count} tests failed in child process, "
|
||||
f"rank={pgi.rank}.")
|
||||
else:
|
||||
print(f"{count} of {count} tests passed in child process, "
|
||||
f"rank={pgi.rank}.")
|
||||
|
||||
|
||||
def run(config: Config):
|
||||
def run(config: Config, verbose: bool):
|
||||
assert config.is_valid()
|
||||
print(f"Testing config \n{config.describe()} ...")
|
||||
|
||||
weights: WeightTensors = WeightTensors.make(config)
|
||||
|
||||
vllm_config, env_dict = config.make_env_data()
|
||||
parallel_launch_with_config(config.world_size, rank_worker, vllm_config,
|
||||
env_dict, config, weights)
|
||||
env_dict, config, weights, verbose)
|
||||
|
||||
|
||||
Ms = [32, 64]
|
||||
Ks = [7168] # hidden sizes
|
||||
# hidden sizes, making this too large will cause fp4 tests to fail.
|
||||
# Also needs to be a multiple of 1024 for deep_gemm.
|
||||
Ks = [2048]
|
||||
Ns = [2048]
|
||||
TOPKs = [4, 1]
|
||||
Es = [32]
|
||||
@ -103,19 +140,16 @@ FUSED_MOE_CHUNK_SIZEs = [None, 16]
|
||||
|
||||
def is_nyi_config(config: Config) -> bool:
|
||||
# We know these configs to be legitimate. but still fail.
|
||||
info = expert_info(config.fused_experts_type)
|
||||
|
||||
if (config.fused_experts_type in [
|
||||
BatchedTritonExperts, BatchedTritonOrDeepGemmExperts,
|
||||
TritonExperts, TritonOrDeepGemmExperts
|
||||
]):
|
||||
if info.needs_matching_quant:
|
||||
# The triton kernels expect both per-act-token-quant and
|
||||
# per-out-ch-quant or neither.
|
||||
unsupported_quant_config = ((config.is_per_act_token_quant +
|
||||
config.is_per_out_ch_quant) == 1)
|
||||
return unsupported_quant_config
|
||||
|
||||
# cutlass kernels dont support expert_maps yet.
|
||||
return config.fused_experts_type == CutlassExpertsFp8
|
||||
return not info.supports_expert_map
|
||||
|
||||
|
||||
@pytest.mark.parametrize("k", Ks)
|
||||
@ -128,13 +162,13 @@ def is_nyi_config(config: Config) -> bool:
|
||||
product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES))
|
||||
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
|
||||
@pytest.mark.parametrize("world_size", [2])
|
||||
@meets_package_requirements
|
||||
@meets_multi_gpu_requirements
|
||||
def test_modular_kernel_combinations_multigpu(
|
||||
k: int, n: int, e: int, dtype: torch.dtype,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
quant_config: Optional[FusedMoEQuantConfig],
|
||||
combination: tuple[mk.FusedMoEPrepareAndFinalize,
|
||||
mk.FusedMoEPermuteExpertsUnpermute],
|
||||
fused_moe_chunk_size: Optional[int], world_size: int):
|
||||
fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig):
|
||||
|
||||
config = Config(
|
||||
Ms=Ms,
|
||||
@ -149,14 +183,15 @@ def test_modular_kernel_combinations_multigpu(
|
||||
fused_moe_chunk_size=fused_moe_chunk_size,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
if not config.is_valid():
|
||||
pytest.skip(f"Tests config {config} is not valid. Skipping ...")
|
||||
|
||||
if is_nyi_config(config):
|
||||
pytest.skip(f"Tests config {config} is nyi. Skipping ...")
|
||||
|
||||
print(f"{config.describe()}")
|
||||
run(config)
|
||||
verbosity = pytestconfig.getoption('verbose')
|
||||
run(config, verbosity > 0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("k", Ks)
|
||||
@ -169,13 +204,12 @@ def test_modular_kernel_combinations_multigpu(
|
||||
product(MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES))
|
||||
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
|
||||
@pytest.mark.parametrize("world_size", [1])
|
||||
@meets_package_requirements
|
||||
def test_modular_kernel_combinations_singlegpu(
|
||||
k: int, n: int, e: int, dtype: torch.dtype,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
quant_config: Optional[FusedMoEQuantConfig],
|
||||
combination: tuple[mk.FusedMoEPrepareAndFinalize,
|
||||
mk.FusedMoEPermuteExpertsUnpermute],
|
||||
fused_moe_chunk_size: Optional[int], world_size: int):
|
||||
fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig):
|
||||
config = Config(
|
||||
Ms=Ms,
|
||||
K=k,
|
||||
@ -196,7 +230,8 @@ def test_modular_kernel_combinations_singlegpu(
|
||||
if is_nyi_config(config):
|
||||
pytest.skip(f"Tests config {config} is nyi. Skipping ...")
|
||||
|
||||
run(config)
|
||||
verbosity = pytestconfig.getoption('verbose')
|
||||
run(config, verbosity > 0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@ -211,4 +246,4 @@ if __name__ == '__main__':
|
||||
args = parser.parse_args()
|
||||
config = make_config(args)
|
||||
|
||||
run(config)
|
||||
run(config, True)
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import make_test_weights
|
||||
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
|
||||
FLOAT8_E4M3_MAX,
|
||||
dequantize_nvfp4_to_dtype)
|
||||
@ -43,41 +44,20 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
|
||||
VllmConfig(parallel_config=ParallelConfig(
|
||||
pipeline_parallel_size=1))):
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
quant_blocksize = 16
|
||||
round_up = lambda x, y: (x + y - 1) // y * y
|
||||
sf_w1_2n = round_up(2 * n, 128)
|
||||
sf_w1_k = round_up(k // quant_blocksize, 4)
|
||||
w1_blockscale = torch.empty((e, sf_w1_2n, sf_w1_k),
|
||||
device="cuda",
|
||||
dtype=torch.float8_e4m3fn)
|
||||
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
sf_w2_k = round_up(k, 128)
|
||||
sf_w2_n = round_up(n // quant_blocksize, 4)
|
||||
w2_blockscale = torch.empty((e, sf_w2_k, sf_w2_n),
|
||||
device="cuda",
|
||||
dtype=torch.float8_e4m3fn)
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
|
||||
w1_q = torch.empty((e, 2 * n, k // 2),
|
||||
device="cuda",
|
||||
dtype=torch.uint8)
|
||||
w2_q = torch.empty((e, k, n // 2), device="cuda", dtype=torch.uint8)
|
||||
w1_gs = torch.empty((e, ), device="cuda", dtype=torch.float32)
|
||||
w2_gs = torch.empty((e, ), device="cuda", dtype=torch.float32)
|
||||
|
||||
for expert in range(e):
|
||||
w1_amax = torch.abs(w1).max().to(torch.float32)
|
||||
w2_amax = torch.abs(w2).max().to(torch.float32)
|
||||
w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax
|
||||
w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax
|
||||
|
||||
w1_q[expert], w1_blockscale[expert] = ops.scaled_fp4_quant(
|
||||
w1[expert], w1_gs[expert])
|
||||
|
||||
w2_q[expert], w2_blockscale[expert] = ops.scaled_fp4_quant(
|
||||
w2[expert], w2_gs[expert])
|
||||
(_, w1_q, w1_blockscale,
|
||||
w1_gs), (_, w2_q, w2_blockscale, w2_gs) = make_test_weights(
|
||||
e,
|
||||
n,
|
||||
k,
|
||||
in_dtype=dtype,
|
||||
quant_dtype="nvfp4",
|
||||
block_shape=None, # use quant_blocksize?
|
||||
per_act_token_quant=False,
|
||||
)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
topk_weights, topk_ids, _ = fused_topk(a,
|
||||
@ -88,6 +68,11 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
|
||||
a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
|
||||
a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
|
||||
|
||||
assert w1_gs is not None
|
||||
assert w2_gs is not None
|
||||
assert w1_blockscale is not None
|
||||
assert w2_blockscale is not None
|
||||
|
||||
cutlass_output = cutlass_moe_fp4(
|
||||
a=a,
|
||||
a1_gscale=a1_gs,
|
||||
@ -104,14 +89,13 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
|
||||
n=n,
|
||||
k=k,
|
||||
e=e,
|
||||
device=a.device,
|
||||
)
|
||||
|
||||
# Reference check:
|
||||
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
|
||||
torch.amax(a.flatten(), dim=-1)).to(torch.float32)
|
||||
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale)
|
||||
_, m_k = a_fp4.shape
|
||||
|
||||
a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4,
|
||||
a_scale_interleaved,
|
||||
a_global_scale,
|
||||
@ -126,14 +110,14 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
|
||||
w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx],
|
||||
w1_blockscale[idx],
|
||||
w1_gs[idx],
|
||||
dtype=w1.dtype,
|
||||
device=w1.device,
|
||||
dtype=dtype,
|
||||
device=w1_q.device,
|
||||
block_size=quant_blocksize)
|
||||
w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx],
|
||||
w2_blockscale[idx],
|
||||
w2_gs[idx],
|
||||
dtype=w2.dtype,
|
||||
device=w2.device,
|
||||
dtype=dtype,
|
||||
device=w2_q.device,
|
||||
block_size=quant_blocksize)
|
||||
|
||||
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk)
|
||||
|
||||
@ -9,7 +9,8 @@ import torch
|
||||
from tests.kernels.utils import torch_experts
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
CutlassBatchedExpertsFp8)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel)
|
||||
@ -123,12 +124,8 @@ def pplx_cutlass_moe(
|
||||
num_local_experts=num_local_experts,
|
||||
num_dispatchers=num_dispatchers)
|
||||
|
||||
experts = CutlassExpertsFp8(num_local_experts,
|
||||
out_dtype,
|
||||
per_act_token,
|
||||
per_out_ch,
|
||||
num_dispatchers=num_dispatchers,
|
||||
use_batched_format=True)
|
||||
experts = CutlassBatchedExpertsFp8(num_local_experts, num_dispatchers,
|
||||
out_dtype, per_act_token, per_out_ch)
|
||||
|
||||
fused_cutlass_experts = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
|
||||
@ -770,7 +770,7 @@ def test_pplx_moe_slow(
|
||||
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
|
||||
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
_, w1, w1_s, _, w2, w2_s = make_test_weights(
|
||||
(_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
|
||||
e,
|
||||
n,
|
||||
k,
|
||||
@ -836,7 +836,7 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
|
||||
|
||||
args = dict()
|
||||
if make_weights:
|
||||
_, w1, w1_s, _, w2, w2_s = make_test_weights(
|
||||
(_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
|
||||
e,
|
||||
n,
|
||||
k,
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from tests.kernels.quant_utils import per_block_cast_to_int8
|
||||
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
|
||||
FLOAT8_E4M3_MAX)
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
|
||||
@ -169,28 +171,41 @@ def make_quantized_test_activations(
|
||||
def moe_quantize_weights(
|
||||
w: torch.Tensor,
|
||||
w_s: Optional[torch.Tensor],
|
||||
quant_dtype: Optional[torch.dtype],
|
||||
quant_dtype: Union[torch.dtype, str, None],
|
||||
per_token_quant: bool,
|
||||
block_shape: Optional[list[int]],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
assert (quant_dtype == torch.float8_e4m3fn
|
||||
or quant_dtype == torch.int8), "only fp8/int8 supported"
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
assert (quant_dtype == torch.float8_e4m3fn or quant_dtype == torch.int8
|
||||
or quant_dtype == "nvfp4"), "only fp8/int8/nvfp4 supported"
|
||||
|
||||
w_gs = None
|
||||
|
||||
if block_shape is not None:
|
||||
assert not per_token_quant
|
||||
if quant_dtype == torch.int8:
|
||||
w, w_s = per_block_cast_to_int8(w, block_shape)
|
||||
else:
|
||||
elif quant_dtype == torch.float8_e4m3fn:
|
||||
w, w_s = per_block_cast_to_fp8(w, block_shape)
|
||||
elif quant_dtype == "nvfp4":
|
||||
raise RuntimeError("blocked quantization not supported for nvfp4")
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported quant type {quant_dtype}")
|
||||
else:
|
||||
if quant_dtype == torch.int8:
|
||||
w, w_s = ops.scaled_int8_quant(
|
||||
w, w_s, use_per_token_if_dynamic=per_token_quant)
|
||||
else:
|
||||
elif quant_dtype == torch.float8_e4m3fn:
|
||||
w, w_s = ops.scaled_fp8_quant(
|
||||
w, w_s, use_per_token_if_dynamic=per_token_quant)
|
||||
elif quant_dtype == "nvfp4":
|
||||
assert not per_token_quant
|
||||
w_amax = torch.abs(w).max().to(torch.float32)
|
||||
w_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w_amax
|
||||
w, w_s = ops.scaled_fp4_quant(w, w_gs)
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported quant type {quant_dtype}")
|
||||
|
||||
return w, w_s
|
||||
return w, w_s, w_gs
|
||||
|
||||
|
||||
def make_test_weight(
|
||||
@ -198,21 +213,26 @@ def make_test_weight(
|
||||
rows: int,
|
||||
cols: int,
|
||||
in_dtype: torch.dtype = torch.bfloat16,
|
||||
quant_dtype: Optional[torch.dtype] = None,
|
||||
quant_dtype: Union[torch.dtype, str, None] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
per_act_token_quant: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15
|
||||
w_gs = None
|
||||
|
||||
if quant_dtype is not None:
|
||||
w_l = [None] * e
|
||||
w_s_l = [None] * e
|
||||
w_gs_l = [None] * e
|
||||
for idx in range(e):
|
||||
w_l[idx], w_s_l[idx] = moe_quantize_weights(
|
||||
w_l[idx], w_s_l[idx], w_gs_l[idx] = moe_quantize_weights(
|
||||
w_16[idx], None, quant_dtype, per_act_token_quant, block_shape)
|
||||
|
||||
w = torch.stack(w_l)
|
||||
w_s = torch.stack(w_s_l)
|
||||
if e > 0 and w_gs_l[0] is not None:
|
||||
w_gs = torch.stack(w_gs_l)
|
||||
if w_s.ndim == 2:
|
||||
assert w_s.shape[-1] == 1
|
||||
w_s = w_s.view(-1, 1, 1)
|
||||
@ -225,8 +245,9 @@ def make_test_weight(
|
||||
else:
|
||||
w = w_16
|
||||
w_s = None
|
||||
w_gs = None
|
||||
|
||||
return w_16, w, w_s
|
||||
return w_16, w, w_s, w_gs
|
||||
|
||||
|
||||
def make_test_weights(
|
||||
@ -234,14 +255,30 @@ def make_test_weights(
|
||||
n: int,
|
||||
k: int,
|
||||
in_dtype: torch.dtype = torch.bfloat16,
|
||||
quant_dtype: Optional[torch.dtype] = None,
|
||||
quant_dtype: Union[torch.dtype, str, None] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
per_act_token_quant: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor,
|
||||
torch.Tensor, Optional[torch.Tensor]]:
|
||||
) -> tuple[tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]],
|
||||
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]]:
|
||||
return (
|
||||
*make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape,
|
||||
per_act_token_quant),
|
||||
*make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape,
|
||||
per_act_token_quant),
|
||||
make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape,
|
||||
per_act_token_quant),
|
||||
make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape,
|
||||
per_act_token_quant),
|
||||
)
|
||||
|
||||
|
||||
def per_token_cast_to_fp8(
|
||||
x: torch.Tensor,
|
||||
block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
pad_size = (block_size - (n % block_size)) % block_size
|
||||
x = torch.nn.functional.pad(x,
|
||||
(0, pad_size), value=0) if pad_size > 0 else x
|
||||
x_view = x.view(m, -1, block_size)
|
||||
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
||||
fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
|
||||
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
|
||||
|
||||
Reference in New Issue
Block a user