diff --git a/docker/Dockerfile.rocm_base b/docker/Dockerfile.rocm_base index b8523fbc2a..05192eb69b 100644 --- a/docker/Dockerfile.rocm_base +++ b/docker/Dockerfile.rocm_base @@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git" ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git" ARG FA_BRANCH="1a7f4dfa" ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git" -ARG AITER_BRANCH="8970b25b" +ARG AITER_BRANCH="5a77249" ARG AITER_REPO="https://github.com/ROCm/aiter.git" FROM ${BASE_IMAGE} AS base diff --git a/vllm/envs.py b/vllm/envs.py index ac60899778..0a7067b8a6 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -77,7 +77,6 @@ if TYPE_CHECKING: VLLM_ROCM_USE_AITER: bool = False VLLM_ROCM_USE_AITER_LINEAR: bool = True VLLM_ROCM_USE_AITER_MOE: bool = True - VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE: bool = False VLLM_ROCM_USE_AITER_RMSNORM: bool = True VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True @@ -546,13 +545,6 @@ environment_variables: dict[str, Callable[[], Any]] = { lambda: (os.getenv("VLLM_ROCM_USE_AITER_MOE", "True").lower() in ("true", "1")), - # Whether to use aiter block scaled moe kernel. - # By default this is disabled. - "VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE": - lambda: - (os.getenv("VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE", "false").lower() in - ("true", "1")), - # use aiter rms norm op if aiter ops are enabled. "VLLM_ROCM_USE_AITER_RMSNORM": lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 2a988b8644..a209715ede 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -23,9 +23,7 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import ( from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op -from .rocm_aiter_fused_moe import (is_rocm_aiter_moe_enabled, - rocm_aiter_fused_experts, - rocm_aiter_topk_softmax) +from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled logger = init_logger(__name__) @@ -846,6 +844,7 @@ def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor, def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]: if is_rocm_aiter_moe_enabled(): + from .rocm_aiter_fused_moe import rocm_aiter_topk_softmax return rocm_aiter_topk_softmax return vllm_topk_softmax @@ -1102,6 +1101,7 @@ def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor: def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]: if is_rocm_aiter_moe_enabled(): + from .rocm_aiter_fused_moe import rocm_aiter_fused_experts return rocm_aiter_fused_experts if inplace: return torch_vllm_inplace_fused_experts diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 4214e89448..1315dcead4 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -10,28 +10,68 @@ from vllm.platforms import current_platform def is_rocm_aiter_moe_enabled() -> bool: return current_platform.is_rocm() \ and envs.VLLM_ROCM_USE_AITER_MOE \ - and envs.VLLM_ROCM_USE_AITER \ + and envs.VLLM_ROCM_USE_AITER -def is_rocm_aiter_block_scaled_moe_enabled() -> bool: - return is_rocm_aiter_moe_enabled() and \ - envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE +def rocm_aiter_asm_moe_tkw1(hidden_states, + w1, + w2, + topk_weight, + topk_ids, + fc1_scale=None, + fc2_scale=None, + fc1_smooth_scale=None, + fc2_smooth_scale=None, + a16=False, + per_tensor_quant_scale=None, + expert_mask=None, + activation_str: str = "silu") -> None: + + from aiter import ActivationType + from aiter.fused_moe_bf16_asm import asm_moe_tkw1 + + activation = \ + ActivationType.Gelu if activation_str == "gelu" else ActivationType.Silu + + return asm_moe_tkw1(hidden_states, + w1, + w2, + topk_weight, + topk_ids, + fc1_scale=fc1_scale, + fc2_scale=fc2_scale, + fc1_smooth_scale=fc1_smooth_scale, + fc2_smooth_scale=fc2_smooth_scale, + a16=a16, + per_tensor_quant_scale=per_tensor_quant_scale, + expert_mask=expert_mask, + activation=activation) def rocm_aiter_fused_experts( - *, - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - use_fp8_w8a8: bool = False, - apply_router_weight_on_input: bool = False, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None, - expert_mask: Optional[torch.Tensor] = None, - **kwagrs # Ignore additional keyword arguments + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, + allow_deep_gemm: bool = False, ) -> torch.Tensor: import aiter as rocm_aiter @@ -40,25 +80,21 @@ def rocm_aiter_fused_experts( from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8) - if apply_router_weight_on_input: - assert (topk_weights.dim() == 2 - ), "`topk_weights` should be in shape (num_tokens, topk)" - _, topk = topk_weights.shape - assert ( - topk == 1 - ), "Only support topk=1 when `apply_router_weight_on_input` is True" + # All AITER Fused MoE kernels are expecting the following datatypes + topk_weights = topk_weights.to(torch.float32) + topk_ids = topk_ids.to(torch.int32) - hidden_states = hidden_states * topk_weights.to(hidden_states.dtype) - topk_ids = topk_ids.to(torch.int32) - topk_weights = torch.ones_like(topk_weights, dtype=torch.float32) + if (block_shape is not None) and use_fp8_w8a8: + assert not apply_router_weight_on_input, ( + "apply_router_weight_on_input is not supported for block scaled moe" + ) - if envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE and use_fp8_w8a8: assert w1_scale is not None assert w2_scale is not None local_E = E = w1.shape[0] - if expert_mask is not None: - E = expert_mask.numel() + if expert_map is not None: + E = expert_map.numel() topk = topk_ids.shape[1] model_dim = w1.shape[-1] @@ -80,7 +116,7 @@ def rocm_aiter_fused_experts( E, model_dim, dtype, - expert_mask=expert_mask) + expert_mask=expert_map) a1, a1_scale = per_token_group_quant_fp8(hidden_states, scale_blk_k) rocm_aiter.fmoe_fp8_blockscale_g1u1( @@ -102,7 +138,33 @@ def rocm_aiter_fused_experts( ) return out_asm + elif per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8: + # AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input` + # This applies topk_weights on the GEMM output of the first FC layer + # rather than the second FC. + assert (topk_weights.dim() == 2 + ), "`topk_weights` should be in shape (num_tokens, topk)" + assert topk_weights.shape[-1] == 1, ( + "Only support topk=1 when" + " `apply_router_weight_on_input` is True") + + return rocm_aiter_asm_moe_tkw1(hidden_states, + w1, + w2, + topk_weights, + topk_ids, + fc1_scale=w1_scale, + fc2_scale=w2_scale, + fc1_smooth_scale=None, + fc2_smooth_scale=None, + a16=False, + per_tensor_quant_scale=None, + expert_mask=expert_map, + activation_str=activation) + elif use_fp8_w8a8: + assert not apply_router_weight_on_input, ( + "apply_router_weight_on_input is not supported for fp8_w8a8") return rocm_aiter_asm_fmoe.asm_moe(hidden_states=hidden_states, w1=w1, w2=w2, @@ -114,6 +176,18 @@ def rocm_aiter_fused_experts( fc2_smooth_scale=None, a16=False) + if apply_router_weight_on_input: + assert (topk_weights.dim() == 2 + ), "`topk_weights` should be in shape (num_tokens, topk)" + _, topk = topk_weights.shape + assert ( + topk == 1 + ), "Only support topk=1 when `apply_router_weight_on_input` is True" + + hidden_states = hidden_states * topk_weights.to(hidden_states.dtype) + topk_ids = topk_ids.to(torch.int32) + topk_weights = torch.ones_like(topk_weights, dtype=torch.float32) + return rocm_aiter.ck_moe(hidden_states=hidden_states, w1=w1, w2=w2, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 628724c5b7..4e01b298d0 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -250,6 +250,28 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False) + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + is_rocm_aiter_moe_enabled) + + # Property to determine if AITER is used + if is_rocm_aiter_moe_enabled(): + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501 + rocm_aiter_fused_experts, shuffle_weights) + + # reshaping weights is required for aiter moe kernel. + shuffled_w13, shuffled_w2 = shuffle_weights( + layer.w13_weight.data, layer.w2_weight.data) + + layer.w13_weight = torch.nn.Parameter(shuffled_w13, + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(shuffled_w2, + requires_grad=False) + + self.fused_experts_func = rocm_aiter_fused_experts + else: + from vllm.model_executor.layers.fused_moe import fused_experts + self.fused_experts_func = fused_experts + def apply( self, layer: torch.nn.Module, @@ -268,7 +290,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: - from vllm.model_executor.layers.fused_moe import fused_experts topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -282,7 +303,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias) - return fused_experts( + return self.fused_experts_func( x, layer.w13_weight, layer.w2_weight, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index b7327f4773..be76785bac 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -575,8 +575,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): def process_weights_after_loading(self, layer: Module) -> None: # Lazy import to avoid importing triton too early. from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - expand_weights, is_rocm_aiter_block_scaled_moe_enabled, - is_rocm_aiter_moe_enabled, shuffle_weights) + expand_weights, is_rocm_aiter_moe_enabled, shuffle_weights) # TODO (rob): refactor block quant into separate class. if self.block_quant: @@ -603,7 +602,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): layer.w2_weight = Parameter(w2_weight, requires_grad=False) layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv, requires_grad=False) - if is_rocm_aiter_block_scaled_moe_enabled(): + if is_rocm_aiter_moe_enabled(): # reshaping weights is required for aiter moe kernel. shuffled_w13, shuffled_w2 = shuffle_weights( layer.w13_weight.data, layer.w2_weight.data)