deep_ep + use_fp8_dispatch

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath
2025-06-18 07:32:15 -07:00
parent 4c8f64faa7
commit 8de2fd39fc

View File

@ -45,7 +45,8 @@ if current_platform.is_cuda_alike():
from .pplx_prepare_finalize import PplxPrepareAndFinalize
if has_deepep:
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
from .deepep_ll_prepare_finalize import DeepEPLLPrepareAndFinalize
from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SIZE,
DeepEPLLPrepareAndFinalize)
else:
fused_experts = None # type: ignore
FusedMoEPermuteExpertsUnpermute = None # type: ignore
@ -377,6 +378,12 @@ class FusedMoEMethodBase(QuantizeMethodBase):
all2all_manager.world_size)
handle = all2all_manager.get_handle(all_to_all_args)
# Note : We may want to use FP8 dispatch even otherwise just to
# reduce datamovement
use_fp8_dispatch = (quant_dtype == current_platform.fp8_dtype()
and act_quant_block_size
== DEEPEP_QUANT_BLOCK_SIZE)
# Note (varun): Whether to use FP8 dispatch or not needs some
# profiling. Turning it off for now.
prepare_finalize = DeepEPLLPrepareAndFinalize(
@ -386,7 +393,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
max_tokens_per_rank=moe.max_num_tokens,
quant_dtype=quant_dtype,
block_shape=act_quant_block_size,
use_fp8_dispatch=False,
use_fp8_dispatch=use_fp8_dispatch,
)
self.topk_indices_dtype = None