Silu v2 (#25074)
Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: elvircrn <elvircrn@gmail.com> Signed-off-by: Elvir Crnčević <elvircrn@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com> Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com>
This commit is contained in:
@ -5,7 +5,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
silu_mul_fp8_quant_deep_gemm_cuda,
|
||||
persistent_masked_m_silu_mul_quant,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv
|
||||
@ -50,15 +50,15 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type):
|
||||
# Input tensor of shape (E, T, 2*H)
|
||||
y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda")
|
||||
tokens_per_expert = torch.randint(
|
||||
low=T // 2,
|
||||
low=0,
|
||||
high=T,
|
||||
size=(E,),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
# Run the Triton kernel
|
||||
y_q, y_s = silu_mul_fp8_quant_deep_gemm_cuda(
|
||||
# Run the SiLU V2 kernel
|
||||
y_q, y_s = persistent_masked_m_silu_mul_quant(
|
||||
y, tokens_per_expert, group_size=group_size
|
||||
)
|
||||
|
||||
@ -115,10 +115,11 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type):
|
||||
y_se = y_s[e].float()
|
||||
y_qe = y_q[e].float()
|
||||
|
||||
torch.testing.assert_close(y_se[:nt], ref_s[:nt], atol=1e-4, rtol=1e-2)
|
||||
torch.testing.assert_close(
|
||||
y_qe[:nt].to(torch.float32),
|
||||
ref_q[:nt].to(torch.float32),
|
||||
atol=2,
|
||||
rtol=2e-1,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(y_se[:nt], ref_s[:nt], atol=1e-4, rtol=1e-2)
|
||||
|
||||
Reference in New Issue
Block a user