[ROCm] Add aiter tkw1 kernel for Llama4 fp8 (#16727)
Signed-off-by: kliuae <kuanfu.liu@embeddedllm.com> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
|
||||
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
|
||||
ARG FA_BRANCH="1a7f4dfa"
|
||||
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
|
||||
ARG AITER_BRANCH="8970b25b"
|
||||
ARG AITER_BRANCH="5a77249"
|
||||
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
|
||||
|
||||
FROM ${BASE_IMAGE} AS base
|
||||
|
||||
@ -77,7 +77,6 @@ if TYPE_CHECKING:
|
||||
VLLM_ROCM_USE_AITER: bool = False
|
||||
VLLM_ROCM_USE_AITER_LINEAR: bool = True
|
||||
VLLM_ROCM_USE_AITER_MOE: bool = True
|
||||
VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE: bool = False
|
||||
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
|
||||
VLLM_ROCM_FP8_PADDING: bool = True
|
||||
VLLM_ROCM_MOE_PADDING: bool = True
|
||||
@ -546,13 +545,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MOE", "True").lower() in
|
||||
("true", "1")),
|
||||
|
||||
# Whether to use aiter block scaled moe kernel.
|
||||
# By default this is disabled.
|
||||
"VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE":
|
||||
lambda:
|
||||
(os.getenv("VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE", "false").lower() in
|
||||
("true", "1")),
|
||||
|
||||
# use aiter rms norm op if aiter ops are enabled.
|
||||
"VLLM_ROCM_USE_AITER_RMSNORM":
|
||||
lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in
|
||||
|
||||
@ -23,9 +23,7 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
from .rocm_aiter_fused_moe import (is_rocm_aiter_moe_enabled,
|
||||
rocm_aiter_fused_experts,
|
||||
rocm_aiter_topk_softmax)
|
||||
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -846,6 +844,7 @@ def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor,
|
||||
|
||||
def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]:
|
||||
if is_rocm_aiter_moe_enabled():
|
||||
from .rocm_aiter_fused_moe import rocm_aiter_topk_softmax
|
||||
return rocm_aiter_topk_softmax
|
||||
return vllm_topk_softmax
|
||||
|
||||
@ -1102,6 +1101,7 @@ def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor:
|
||||
|
||||
def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
|
||||
if is_rocm_aiter_moe_enabled():
|
||||
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
|
||||
return rocm_aiter_fused_experts
|
||||
if inplace:
|
||||
return torch_vllm_inplace_fused_experts
|
||||
|
||||
@ -10,28 +10,68 @@ from vllm.platforms import current_platform
|
||||
def is_rocm_aiter_moe_enabled() -> bool:
|
||||
return current_platform.is_rocm() \
|
||||
and envs.VLLM_ROCM_USE_AITER_MOE \
|
||||
and envs.VLLM_ROCM_USE_AITER \
|
||||
and envs.VLLM_ROCM_USE_AITER
|
||||
|
||||
|
||||
def is_rocm_aiter_block_scaled_moe_enabled() -> bool:
|
||||
return is_rocm_aiter_moe_enabled() and \
|
||||
envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE
|
||||
def rocm_aiter_asm_moe_tkw1(hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
fc1_scale=None,
|
||||
fc2_scale=None,
|
||||
fc1_smooth_scale=None,
|
||||
fc2_smooth_scale=None,
|
||||
a16=False,
|
||||
per_tensor_quant_scale=None,
|
||||
expert_mask=None,
|
||||
activation_str: str = "silu") -> None:
|
||||
|
||||
from aiter import ActivationType
|
||||
from aiter.fused_moe_bf16_asm import asm_moe_tkw1
|
||||
|
||||
activation = \
|
||||
ActivationType.Gelu if activation_str == "gelu" else ActivationType.Silu
|
||||
|
||||
return asm_moe_tkw1(hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
fc1_scale=fc1_scale,
|
||||
fc2_scale=fc2_scale,
|
||||
fc1_smooth_scale=fc1_smooth_scale,
|
||||
fc2_smooth_scale=fc2_smooth_scale,
|
||||
a16=a16,
|
||||
per_tensor_quant_scale=per_tensor_quant_scale,
|
||||
expert_mask=expert_mask,
|
||||
activation=activation)
|
||||
|
||||
|
||||
def rocm_aiter_fused_experts(
|
||||
*,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
use_fp8_w8a8: bool = False,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
expert_mask: Optional[torch.Tensor] = None,
|
||||
**kwagrs # Ignore additional keyword arguments
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
allow_deep_gemm: bool = False,
|
||||
) -> torch.Tensor:
|
||||
|
||||
import aiter as rocm_aiter
|
||||
@ -40,25 +80,21 @@ def rocm_aiter_fused_experts(
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
assert (topk_weights.dim() == 2
|
||||
), "`topk_weights` should be in shape (num_tokens, topk)"
|
||||
_, topk = topk_weights.shape
|
||||
assert (
|
||||
topk == 1
|
||||
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||
# All AITER Fused MoE kernels are expecting the following datatypes
|
||||
topk_weights = topk_weights.to(torch.float32)
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
|
||||
hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
topk_weights = torch.ones_like(topk_weights, dtype=torch.float32)
|
||||
if (block_shape is not None) and use_fp8_w8a8:
|
||||
assert not apply_router_weight_on_input, (
|
||||
"apply_router_weight_on_input is not supported for block scaled moe"
|
||||
)
|
||||
|
||||
if envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE and use_fp8_w8a8:
|
||||
assert w1_scale is not None
|
||||
assert w2_scale is not None
|
||||
|
||||
local_E = E = w1.shape[0]
|
||||
if expert_mask is not None:
|
||||
E = expert_mask.numel()
|
||||
if expert_map is not None:
|
||||
E = expert_map.numel()
|
||||
|
||||
topk = topk_ids.shape[1]
|
||||
model_dim = w1.shape[-1]
|
||||
@ -80,7 +116,7 @@ def rocm_aiter_fused_experts(
|
||||
E,
|
||||
model_dim,
|
||||
dtype,
|
||||
expert_mask=expert_mask)
|
||||
expert_mask=expert_map)
|
||||
|
||||
a1, a1_scale = per_token_group_quant_fp8(hidden_states, scale_blk_k)
|
||||
rocm_aiter.fmoe_fp8_blockscale_g1u1(
|
||||
@ -102,7 +138,33 @@ def rocm_aiter_fused_experts(
|
||||
)
|
||||
return out_asm
|
||||
|
||||
elif per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8:
|
||||
# AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input`
|
||||
# This applies topk_weights on the GEMM output of the first FC layer
|
||||
# rather than the second FC.
|
||||
assert (topk_weights.dim() == 2
|
||||
), "`topk_weights` should be in shape (num_tokens, topk)"
|
||||
assert topk_weights.shape[-1] == 1, (
|
||||
"Only support topk=1 when"
|
||||
" `apply_router_weight_on_input` is True")
|
||||
|
||||
return rocm_aiter_asm_moe_tkw1(hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
fc1_scale=w1_scale,
|
||||
fc2_scale=w2_scale,
|
||||
fc1_smooth_scale=None,
|
||||
fc2_smooth_scale=None,
|
||||
a16=False,
|
||||
per_tensor_quant_scale=None,
|
||||
expert_mask=expert_map,
|
||||
activation_str=activation)
|
||||
|
||||
elif use_fp8_w8a8:
|
||||
assert not apply_router_weight_on_input, (
|
||||
"apply_router_weight_on_input is not supported for fp8_w8a8")
|
||||
return rocm_aiter_asm_fmoe.asm_moe(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
@ -114,6 +176,18 @@ def rocm_aiter_fused_experts(
|
||||
fc2_smooth_scale=None,
|
||||
a16=False)
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
assert (topk_weights.dim() == 2
|
||||
), "`topk_weights` should be in shape (num_tokens, topk)"
|
||||
_, topk = topk_weights.shape
|
||||
assert (
|
||||
topk == 1
|
||||
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||
|
||||
hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
topk_weights = torch.ones_like(topk_weights, dtype=torch.float32)
|
||||
|
||||
return rocm_aiter.ck_moe(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
|
||||
@ -250,6 +250,28 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
||||
requires_grad=False)
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
is_rocm_aiter_moe_enabled)
|
||||
|
||||
# Property to determine if AITER is used
|
||||
if is_rocm_aiter_moe_enabled():
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501
|
||||
rocm_aiter_fused_experts, shuffle_weights)
|
||||
|
||||
# reshaping weights is required for aiter moe kernel.
|
||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||
layer.w13_weight.data, layer.w2_weight.data)
|
||||
|
||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
|
||||
requires_grad=False)
|
||||
|
||||
self.fused_experts_func = rocm_aiter_fused_experts
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
self.fused_experts_func = fused_experts
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@ -268,7 +290,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
@ -282,7 +303,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
return fused_experts(
|
||||
return self.fused_experts_func(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
|
||||
@ -575,8 +575,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
# Lazy import to avoid importing triton too early.
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
expand_weights, is_rocm_aiter_block_scaled_moe_enabled,
|
||||
is_rocm_aiter_moe_enabled, shuffle_weights)
|
||||
expand_weights, is_rocm_aiter_moe_enabled, shuffle_weights)
|
||||
|
||||
# TODO (rob): refactor block quant into separate class.
|
||||
if self.block_quant:
|
||||
@ -603,7 +602,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
|
||||
layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
|
||||
requires_grad=False)
|
||||
if is_rocm_aiter_block_scaled_moe_enabled():
|
||||
if is_rocm_aiter_moe_enabled():
|
||||
# reshaping weights is required for aiter moe kernel.
|
||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||
layer.w13_weight.data, layer.w2_weight.data)
|
||||
|
||||
Reference in New Issue
Block a user