[fix]: disable cutlass block scaled group gemm for EP (#20781)
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
This commit is contained in:
@ -201,11 +201,10 @@ void run_blockwise_scaled_group_mm(
|
||||
reinterpret_cast<typename ScheduleConfig::LayoutSFB*>(
|
||||
layout_sfb.data_ptr())};
|
||||
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = a_ptrs.get_device();
|
||||
hw_info.sm_count =
|
||||
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
|
||||
hw_info.device_id);
|
||||
int device_id = a_ptrs.device().index();
|
||||
static const cutlass::KernelHardwareInfo hw_info{
|
||||
device_id, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
|
||||
device_id)};
|
||||
|
||||
// Epilogue Arguments
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
|
||||
@ -553,8 +553,10 @@ def cutlass_moe_fp4(a: torch.Tensor,
|
||||
return out.to(dtype=out_dtype)
|
||||
|
||||
|
||||
def _valid_cutlass_block_scaled_grouped_gemm(w1: torch.Tensor,
|
||||
w2: torch.Tensor) -> bool:
|
||||
def _valid_cutlass_block_scaled_grouped_gemm(
|
||||
w1: torch.Tensor, w2: torch.Tensor, inplace: bool, activation: str,
|
||||
apply_router_weight_on_input: bool,
|
||||
expert_map: Optional[torch.Tensor]) -> bool:
|
||||
|
||||
def _valid_cutlass_block_scaled_grouped_gemm_shape(N: int, K: int):
|
||||
return N % 128 == 0 and K % 128 == 0
|
||||
@ -570,6 +572,29 @@ def _valid_cutlass_block_scaled_grouped_gemm(w1: torch.Tensor,
|
||||
"CutlassBlockScaledGroupedGemm disabled: invalid weight dtype(s).")
|
||||
return False
|
||||
|
||||
if expert_map is not None:
|
||||
logger.debug(
|
||||
"CutlassBlockScaledGroupedGemm disabled: expert_parallel is"
|
||||
" not supported.")
|
||||
return False
|
||||
|
||||
if activation != "silu":
|
||||
logger.debug(
|
||||
"CutlassBlockScaledGroupedGemm disabled: only activation silu is"
|
||||
" supported.")
|
||||
return False
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
logger.debug("CutlassBlockScaledGroupedGemm disabled:"
|
||||
" apply_router_weight_on_input is not supported.")
|
||||
return False
|
||||
|
||||
if inplace:
|
||||
logger.debug(
|
||||
"CutlassBlockScaledGroupedGemm disabled: inplace is not supported."
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
|
||||
@ -1192,8 +1192,9 @@ def fused_experts(
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
elif (allow_cutlass_block_scaled_grouped_gemm and use_fp8_w8a8
|
||||
and _valid_cutlass_block_scaled_grouped_gemm(w1, w2)):
|
||||
assert apply_router_weight_on_input is False
|
||||
and _valid_cutlass_block_scaled_grouped_gemm(
|
||||
w1, w2, inplace, activation, apply_router_weight_on_input,
|
||||
expert_map)):
|
||||
return run_cutlass_block_scaled_fused_experts(
|
||||
a=hidden_states,
|
||||
w1=w1,
|
||||
|
||||
Reference in New Issue
Block a user