[Kernel] [Quantization] Add MXFP4 and bias support for marlin kernel (#22428)
Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io> Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com> Signed-off-by: Huzaifa Sidhpurwala <huzaifas@redhat.com> Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Animesh Jain <anijain@umich.edu> Signed-off-by: Rui Qiao <ruisearch42@gmail.com> Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com> Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: kf <kuanfu.liu@embeddedllm.com> Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Dipika Sikka <dipikasikka1@gmail.com> Signed-off-by: Sage Moore <sage@neuralmagic.com> Signed-off-by: tjtanaavllm <tunjian.tan@amd.com> Signed-off-by: Yong Hoon Shin <yhshin@meta.com> Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Signed-off-by: Roger Wang <hey@rogerw.me> Signed-off-by: Vadim Gimpelson <vadim.gimpelson@centml.ai> Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: zRzRzRzRzRzRzR <2448370773@qq.com> Signed-off-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: yan <yan.ma@intel.com> Signed-off-by: Yan Ma <yan.ma@intel.com> Signed-off-by: Xiao Liu <xiszishu@gmail.com> Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com> Signed-off-by: LopezCastroRoberto <roberto.lopez.castro@udc.es> Signed-off-by: Andy Xie <andy.xning@gmail.com> Signed-off-by: Haibin Lin <haibin.lin@bytedance.com> Signed-off-by: David Ben-David <davidb@pliops.com> Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: jiang1.li <jiang1.li@intel.com> Signed-off-by: Seiji Eicher <seiji@anyscale.com> Signed-off-by: zitian.zhao <zitian.zhao@tencentmusic.com> Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com> Signed-off-by: Abirdcfly <fp544037857@gmail.com> Signed-off-by: Giancarlo Delfin <gdelfin@meta.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Signed-off-by: huangweixiao <huangweixiao@msh.team> Signed-off-by: alyosha-swamy <raghav@arcee.ai> Signed-off-by: Eric Hanley <ericehanley@google.com> Signed-off-by: Abatom <abzhonghua@gmail.com> Signed-off-by: CLFutureX <775523362@qq.com> Signed-off-by: Linkun Chen <github@lkchen.net> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: tlipoca9 <tlipoca9@gmail.com> Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> Signed-off-by: zitian zhao <zitian.zhao@tencentmusic.com> Signed-off-by: mgoin <michael@neuralmagic.com> Signed-off-by: wang.yuqi <noooop@126.com> Signed-off-by: Benji Beck <benjibeck@meta.com> Signed-off-by: Siyuan Liu <lsiyuan@google.com> Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai> Signed-off-by: isotr0py <2037008807@qq.com> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Signed-off-by: simon-mo <xmo@berkeley.edu> Signed-off-by: LucasWilkinson <lwilkinson@neuralmagic.com> Signed-off-by: Zhang Jason <ning.zhang2@amd.com> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Signed-off-by: asafg <asafg@ai21.com> Signed-off-by: Siyuan Fu <siyuanf@nvidia.com> Signed-off-by: Lain <fusiyuan2000@hotmail.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Kunshang Ji <kunshang.ji@intel.com> Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: QscQ <qscqesze@gmail.com> Signed-off-by: qingjun <qingjun@minimaxi.com> Signed-off-by: Syed Muhammad Bin Asif <syedmba7@connect.hku.hk> Signed-off-by: Lionel Villard <villard@us.ibm.com> Signed-off-by: ycyaw66 <497410282@qq.com> Signed-off-by: David Chen <530634352@qq.com> Signed-off-by: Linkun <github@lkchen.net> Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: Adrian Garcia <adrian.garcia@inceptionai.ai> Signed-off-by: shaojunqi <shaojunqi.sjq@alibaba-inc.com> Signed-off-by: Ricardo Decal <rdecal@anyscale.com> Signed-off-by: Andrew Chan <andrewkchan.akc@gmail.com> Signed-off-by: Felix Marty <Felix.Marty@amd.com> Signed-off-by: Andrew Sansom <andrew@protopia.ai> Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com> Signed-off-by: Shu Wang <shuw@nvidia.com> Signed-off-by: Po-Han Huang <pohanh@nvidia.com> Signed-off-by: Shu Wang. <shuw@nvidia.com> Signed-off-by: XIn Li <xinli@nvidia.com> Signed-off-by: Junhao Li <junhao@ubicloud.com> Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com> Signed-off-by: iAmir97 <Amir.balwel@embeddedllm.com> Signed-off-by: iAmir97 <71513472+iAmir97@users.noreply.github.com> Signed-off-by: <zyy1102000@gmail.com> Signed-off-by: Guy Stone <guys@spotify.com> Signed-off-by: <yyweiss@gmail.com> Signed-off-by: yyw <yyweiss@gmail.com> Signed-off-by: Russell Bryant <rbryant@redhat.com> Signed-off-by: Pradyun Ramadorai <pradyunr@amazon.com> Signed-off-by: Pradyun92 <142861237+Pradyun92@users.noreply.github.com> Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com> Co-authored-by: rongfu.leng <rongfu.leng@daocloud.io> Co-authored-by: Huzaifa Sidhpurwala <huzaifas@redhat.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Animesh Jain <jainanimesh2305@yahoo.com> Co-authored-by: Rui Qiao <161574667+ruisearch42@users.noreply.github.com> Co-authored-by: XiongfeiWei <isaacwxf23@gmail.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: JartX <sagformas@gmail.com> Co-authored-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: kf <kuanfu.liu@embeddedllm.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com> Co-authored-by: Dipika Sikka <dipikasikka1@gmail.com> Co-authored-by: Sage Moore <sage@neuralmagic.com> Co-authored-by: tjtanaavllm <tunjian.tan@amd.com> Co-authored-by: Yong Hoon Shin <48474650+sarckk@users.noreply.github.com> Co-authored-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: Roger Wang <hey@rogerw.me> Co-authored-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com> Co-authored-by: Yuxuan Zhang <2448370773@qq.com> Co-authored-by: Isotr0py <2037008807@qq.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Yan Ma <yan.ma@intel.com> Co-authored-by: Xiao <xiszishu@gmail.com> Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com> Co-authored-by: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com> Co-authored-by: Ning Xie <andy.xning@gmail.com> Co-authored-by: H <linhaibin.eric@gmail.com> Co-authored-by: David Ben-David <sdavidbd@gmail.com> Co-authored-by: David Ben-David <davidb@pliops.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Li, Jiang <jiang1.li@intel.com> Co-authored-by: TankNee <nee@tanknee.cn> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: Seiji Eicher <58963096+eicherseiji@users.noreply.github.com> Co-authored-by: ZiTian.Zhao <zitian.zhao@tencentmusic.com> Co-authored-by: 22quinn <33176974+22quinn@users.noreply.github.com> Co-authored-by: Abirdcfly <fp544037857@gmail.com> Co-authored-by: Giancarlo Delfin <32987265+TheEpicDolphin@users.noreply.github.com> Co-authored-by: Chenxi Yang <cxyang@cs.utexas.edu> Co-authored-by: Chenxi Yang <cxyang@meta.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Weixiao Huang <hwx.simle@gmail.com> Co-authored-by: Raghav Ravishankar <113712354+alyosha-swamy@users.noreply.github.com> Co-authored-by: ericehanley <ericehanley@google.com> Co-authored-by: Zhonghua Deng <abzhonghua@gmail.com> Co-authored-by: Po-Han Huang (NVIDIA) <53919306+nvpohanh@users.noreply.github.com> Co-authored-by: PiteXChen <44110731+CLFutureX@users.noreply.github.com> Co-authored-by: lkchen <github@lkchen.net> Co-authored-by: TJian <tunjian.tan@embeddedllm.com> Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Co-authored-by: tlipoca9 <160737620+tlipoca9@users.noreply.github.com> Co-authored-by: elvischenv <219235043+elvischenv@users.noreply.github.com> Co-authored-by: wang.yuqi <noooop@126.com> Co-authored-by: Benji Beck <benjibeck@meta.com> Co-authored-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Siyuan Liu <lsiyuan@google.com> Co-authored-by: Benjamin Chislett <chislett.ben@gmail.com> Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com> Co-authored-by: simon-mo <xmo@berkeley.edu> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com> Co-authored-by: Minseok Lee <47620120+minseokl@users.noreply.github.com> Co-authored-by: Yongye Zhu <zyy1102000@gmail.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Zhang Jason <ning.zhang2@amd.com> Co-authored-by: Asaf Joseph Gardin <39553475+Josephasafg@users.noreply.github.com> Co-authored-by: asafg <asafg@ai21.com> Co-authored-by: Lain <siyuanf@nvidia.com> Co-authored-by: tc-mb <157115220+tc-mb@users.noreply.github.com> Co-authored-by: imning3 <hbning@pku.edu.cn> Co-authored-by: Maximilien de Bayser <mbayser@br.ibm.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com> Co-authored-by: Tao He <linzhu.ht@alibaba-inc.com> Co-authored-by: qscqesze <qingjun@minimaxi.com> Co-authored-by: Syed Muhammad Bin Asif <92625830+syedmba@users.noreply.github.com> Co-authored-by: Lionel Villard <villard@us.ibm.com> Co-authored-by: WeiQing Chen <40507679+david6666666@users.noreply.github.com> Co-authored-by: ycyaw66 <497410282@qq.com> Co-authored-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> Co-authored-by: Ming Yang <minos.future@gmail.com> Co-authored-by: Adrián García García <adrigarvk8@gmail.com> Co-authored-by: Michael Goin <mgoin@redhat.com> Co-authored-by: JaceyShao <65159281+JaceyShao@users.noreply.github.com> Co-authored-by: shaojunqi <shaojunqi.sjq@alibaba-inc.com> Co-authored-by: Ricardo Decal <crypdick@users.noreply.github.com> Co-authored-by: Andrew Chan <andrewkchan.akc@gmail.com> Co-authored-by: fxmarty-amd <felmarty@amd.com> Co-authored-by: Andrew Sansom <andrew@protopia.ai> Co-authored-by: Zhiyu <zhiyuc@nvidia.com> Co-authored-by: Shu Wang <shuw@nvidia.com> Co-authored-by: XIn Li <xinli@nvidia.com> Co-authored-by: Junhao Li <streaver91@gmail.com> Co-authored-by: Chauncey <chaunceyjiang@gmail.com> Co-authored-by: iAmir97 <71513472+iAmir97@users.noreply.github.com> Co-authored-by: iAmir97 <Amir.balwel@embeddedllm.com> Co-authored-by: Hong Hanh <hanh.usth@gmail.com> Co-authored-by: Daniel Serebrenik <74646983+pliops-daniels@users.noreply.github.com> Co-authored-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: Guy Stone <guys@spotify.com> Co-authored-by: yyweiss <70619747+yyweiss@users.noreply.github.com> Co-authored-by: Pradyun92 <142861237+Pradyun92@users.noreply.github.com> Co-authored-by: Pradyun Ramadorai <pradyunr@amazon.com> Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com>
This commit is contained in:
@ -24,8 +24,10 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_topk, modular_triton_fused_moe)
|
||||
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
|
||||
fused_moe as iterative_moe)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
marlin_permute_bias)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
rand_marlin_weight_fp4_like)
|
||||
rand_marlin_weight_mxfp4_like, rand_marlin_weight_nvfp4_like)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
marlin_quant_fp8_torch)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
@ -476,8 +478,11 @@ def marlin_moe_generate_valid_test_cases():
|
||||
if quant_type == scalar_types.float8_e4m3fn and \
|
||||
group_size not in [-1, 128]:
|
||||
return False
|
||||
if quant_type == scalar_types.float4_e2m1f and group_size != 16:
|
||||
return False
|
||||
if quant_type == scalar_types.float4_e2m1f:
|
||||
if group_size not in [16, 32]:
|
||||
return False
|
||||
if dtype == torch.float16 and group_size == 32:
|
||||
return False
|
||||
if quant_type != scalar_types.float4_e2m1f and group_size == 16:
|
||||
return False
|
||||
|
||||
@ -520,31 +525,6 @@ def test_fused_marlin_moe(
|
||||
torch.cuda.manual_seed(0)
|
||||
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
|
||||
|
||||
if quant_type == scalar_types.float8_e4m3fn:
|
||||
if group_size not in [-1, 128]:
|
||||
return
|
||||
if act_order:
|
||||
return
|
||||
|
||||
# Filter act_order
|
||||
if act_order:
|
||||
if quant_type == scalar_types.float8_e4m3fn:
|
||||
return
|
||||
if group_size == -1:
|
||||
return
|
||||
if group_size in (k, n):
|
||||
return
|
||||
if has_zp:
|
||||
return
|
||||
else:
|
||||
if not is_k_full:
|
||||
return
|
||||
|
||||
if quant_type == scalar_types.float4_e2m1f and group_size != 16:
|
||||
return
|
||||
if quant_type != scalar_types.float4_e2m1f and group_size == 16:
|
||||
return
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20
|
||||
@ -569,13 +549,19 @@ def test_fused_marlin_moe(
|
||||
|
||||
for i in range(w1.shape[0]):
|
||||
if quant_type == scalar_types.float4_e2m1f:
|
||||
w_ref1, qweight1, scales1, global_scale1 = \
|
||||
rand_marlin_weight_fp4_like(w1[i], group_size)
|
||||
if group_size == 16:
|
||||
w_ref1, qweight1, scales1, global_scale1 = \
|
||||
rand_marlin_weight_nvfp4_like(w1[i], group_size)
|
||||
else:
|
||||
w_ref1, qweight1, scales1 = \
|
||||
rand_marlin_weight_mxfp4_like(w1[i], group_size)
|
||||
global_scale1 = None
|
||||
|
||||
w_ref1_l.append(w_ref1.T)
|
||||
qweight1_l.append(qweight1)
|
||||
scales1_l.append(scales1)
|
||||
global_scale1_l.append(global_scale1)
|
||||
if global_scale1 is not None:
|
||||
global_scale1_l.append(global_scale1)
|
||||
elif quant_type == scalar_types.float8_e4m3fn:
|
||||
w_ref1, qweight1, scales1 = marlin_quant_fp8_torch(
|
||||
w1[i], group_size)
|
||||
@ -620,13 +606,19 @@ def test_fused_marlin_moe(
|
||||
|
||||
for i in range(w2.shape[0]):
|
||||
if quant_type == scalar_types.float4_e2m1f:
|
||||
w_ref2, qweight2, scales2, global_scale2 = \
|
||||
rand_marlin_weight_fp4_like(w2[i], group_size)
|
||||
if group_size == 16:
|
||||
w_ref2, qweight2, scales2, global_scale2 = \
|
||||
rand_marlin_weight_nvfp4_like(w2[i], group_size)
|
||||
else:
|
||||
w_ref2, qweight2, scales2 = \
|
||||
rand_marlin_weight_mxfp4_like(w2[i], group_size)
|
||||
global_scale2 = None
|
||||
|
||||
w_ref2_l.append(w_ref2.T)
|
||||
qweight2_l.append(qweight2)
|
||||
scales2_l.append(scales2)
|
||||
global_scale2_l.append(global_scale2)
|
||||
if global_scale2 is not None:
|
||||
global_scale2_l.append(global_scale2)
|
||||
elif quant_type == scalar_types.float8_e4m3fn:
|
||||
w_ref2, qweight2, scales2 = marlin_quant_fp8_torch(
|
||||
w2[i], group_size)
|
||||
@ -677,6 +669,8 @@ def test_fused_marlin_moe(
|
||||
a,
|
||||
qweight1,
|
||||
qweight2,
|
||||
None,
|
||||
None,
|
||||
scales1,
|
||||
scales2,
|
||||
score,
|
||||
@ -698,6 +692,119 @@ def test_fused_marlin_moe(
|
||||
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
|
||||
|
||||
|
||||
@pytest.mark.flaky(reruns=2)
|
||||
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
|
||||
@pytest.mark.parametrize("m", [1, 256])
|
||||
def test_fused_marlin_moe_with_bias(m):
|
||||
torch.cuda.manual_seed(0)
|
||||
|
||||
e, topk = 32, 4
|
||||
n, k = 2048, 2048
|
||||
group_size = 128
|
||||
act_order = False
|
||||
is_k_full = True
|
||||
quant_type = scalar_types.uint4b8
|
||||
dtype = torch.half
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
b_bias1 = torch.randn((e, 2 * n), device="cuda", dtype=dtype) / 10
|
||||
b_bias2 = torch.randn((e, k), device="cuda", dtype=dtype) / 10
|
||||
|
||||
b_bias1_l = []
|
||||
w_ref1_l = []
|
||||
qweight1_l = []
|
||||
scales1_l = []
|
||||
g_idx1_l = []
|
||||
sort_indices1_l = []
|
||||
|
||||
for i in range(w1.shape[0]):
|
||||
test_perm = torch.randperm(k)
|
||||
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = \
|
||||
marlin_quantize(w1[i].transpose(1, 0), quant_type,
|
||||
group_size, act_order, test_perm)
|
||||
|
||||
w_ref1_l.append(w_ref1.T)
|
||||
qweight1_l.append(qweight1)
|
||||
scales1_l.append(scales1)
|
||||
g_idx1_l.append(g_idx1)
|
||||
sort_indices1_l.append(sort_indices1)
|
||||
b_bias1_l.append(marlin_permute_bias(b_bias1[i]))
|
||||
|
||||
w_ref1 = stack_and_dev(w_ref1_l)
|
||||
qweight1 = stack_and_dev(qweight1_l).contiguous()
|
||||
scales1 = stack_and_dev(scales1_l)
|
||||
global_scale1 = None
|
||||
g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
|
||||
zeros1 = None
|
||||
sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
|
||||
marlin_bias1 = stack_and_dev(b_bias1_l) if b_bias1_l else None
|
||||
|
||||
b_bias2_l = []
|
||||
w_ref2_l = []
|
||||
qweight2_l = []
|
||||
scales2_l = []
|
||||
g_idx2_l = []
|
||||
sort_indices2_l = []
|
||||
|
||||
for i in range(w2.shape[0]):
|
||||
test_perm = torch.randperm(n)
|
||||
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = \
|
||||
marlin_quantize(w2[i].transpose(1, 0), quant_type,
|
||||
group_size, act_order, test_perm)
|
||||
|
||||
w_ref2_l.append(w_ref2.T)
|
||||
qweight2_l.append(qweight2)
|
||||
scales2_l.append(scales2)
|
||||
g_idx2_l.append(g_idx2)
|
||||
sort_indices2_l.append(sort_indices2)
|
||||
b_bias2_l.append(marlin_permute_bias(b_bias2[i]))
|
||||
|
||||
w_ref2 = stack_and_dev(w_ref2_l)
|
||||
qweight2 = stack_and_dev(qweight2_l).contiguous()
|
||||
scales2 = stack_and_dev(scales2_l)
|
||||
global_scale2 = None
|
||||
g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
|
||||
zeros2 = None
|
||||
sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
|
||||
marlin_bias2 = stack_and_dev(b_bias2_l) if b_bias2_l else None
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, b_bias1,
|
||||
b_bias2)
|
||||
|
||||
marlin_output = torch.ops.vllm.fused_marlin_moe(
|
||||
a,
|
||||
qweight1,
|
||||
qweight2,
|
||||
marlin_bias1,
|
||||
marlin_bias2,
|
||||
scales1,
|
||||
scales2,
|
||||
score,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts=e,
|
||||
expert_map=None,
|
||||
global_scale1=global_scale1,
|
||||
global_scale2=global_scale2,
|
||||
g_idx1=g_idx1,
|
||||
g_idx2=g_idx2,
|
||||
sort_indices1=sort_indices1,
|
||||
sort_indices2=sort_indices2,
|
||||
w1_zeros=zeros1,
|
||||
w2_zeros=zeros2,
|
||||
quant_type_id=quant_type.id,
|
||||
is_k_full=is_k_full)
|
||||
|
||||
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
|
||||
|
||||
|
||||
def test_moe_align_block_size_opcheck():
|
||||
num_experts = 4
|
||||
block_size = 4
|
||||
|
||||
@ -19,10 +19,11 @@ from vllm.model_executor.layers.quantization.qqq import (
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx,
|
||||
marlin_make_workspace_new, marlin_permute_scales,
|
||||
marlin_make_workspace_new, marlin_permute_bias, marlin_permute_scales,
|
||||
query_marlin_supported_quant_types)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
FP4_MARLIN_SUPPORTED_GROUP_SIZES, rand_marlin_weight_fp4_like)
|
||||
FP4_MARLIN_SUPPORTED_GROUP_SIZES, rand_marlin_weight_mxfp4_like,
|
||||
rand_marlin_weight_nvfp4_like)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
marlin_quant_fp8_torch)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
@ -39,7 +40,7 @@ from vllm.scalar_type import scalar_types
|
||||
ACT_ORDER_OPTS = [False, True]
|
||||
K_FULL_OPTS = [False, True]
|
||||
USE_ATOMIC_ADD_OPTS = [False, True]
|
||||
USE_FP32_REDUCE_OPTS = [False, True]
|
||||
USE_FP32_REDUCE_OPTS = [True]
|
||||
|
||||
MARLIN_K_CHUNKS = [128]
|
||||
MARLIN_N_CHUNKS = [64, 256]
|
||||
@ -202,17 +203,10 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
|
||||
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
|
||||
@pytest.mark.parametrize("use_atomic_add", USE_ATOMIC_ADD_OPTS)
|
||||
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
|
||||
def test_gptq_marlin_gemm(
|
||||
k_chunk,
|
||||
n_chunk,
|
||||
quant_type,
|
||||
group_size,
|
||||
mnk_factors,
|
||||
act_order,
|
||||
is_k_full,
|
||||
use_atomic_add,
|
||||
use_fp32_reduce,
|
||||
):
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size,
|
||||
mnk_factors, act_order, is_k_full, use_atomic_add,
|
||||
use_fp32_reduce, dtype):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
|
||||
|
||||
@ -231,14 +225,23 @@ def test_gptq_marlin_gemm(
|
||||
if size_k % group_size != 0:
|
||||
return
|
||||
|
||||
a_input = rand_data((size_m, size_k))
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
a_input = rand_data((size_m, size_k), dtype)
|
||||
b_weight = rand_data((size_k, size_n), dtype)
|
||||
|
||||
if quant_type == scalar_types.float4_e2m1f:
|
||||
if group_size != 16 or act_order:
|
||||
if group_size not in [16, 32] or act_order:
|
||||
return
|
||||
w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_fp4_like(
|
||||
b_weight.T, group_size)
|
||||
if group_size == 32 and dtype == torch.float16:
|
||||
return
|
||||
|
||||
if group_size == 16:
|
||||
w_ref, marlin_q_w, marlin_s, marlin_s2 = \
|
||||
rand_marlin_weight_nvfp4_like(b_weight.T, group_size)
|
||||
else:
|
||||
w_ref, marlin_q_w, marlin_s = \
|
||||
rand_marlin_weight_mxfp4_like(b_weight.T, group_size)
|
||||
marlin_s2 = None
|
||||
|
||||
g_idx = None
|
||||
sort_indices = None
|
||||
marlin_zp = None
|
||||
@ -272,8 +275,8 @@ def test_gptq_marlin_gemm(
|
||||
workspace = marlin_make_workspace_new(w_ref.device)
|
||||
|
||||
opcheck(torch.ops._C.gptq_marlin_gemm,
|
||||
(a_input, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, g_idx,
|
||||
sort_indices, workspace, quant_type.id, a_input.shape[0],
|
||||
(a_input, None, marlin_q_w, None, marlin_s, marlin_s2, marlin_zp,
|
||||
g_idx, sort_indices, workspace, quant_type.id, a_input.shape[0],
|
||||
b_weight.shape[1], a_input.shape[1], is_k_full, use_atomic_add,
|
||||
use_fp32_reduce, False),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
|
||||
@ -282,6 +285,7 @@ def test_gptq_marlin_gemm(
|
||||
a_input,
|
||||
None,
|
||||
marlin_q_w,
|
||||
None,
|
||||
marlin_s,
|
||||
marlin_s2,
|
||||
marlin_zp,
|
||||
@ -418,6 +422,7 @@ def test_hqq_marlin_gemm(
|
||||
a_input,
|
||||
None,
|
||||
marlin_w_q,
|
||||
None,
|
||||
marlin_s,
|
||||
None,
|
||||
marlin_zp,
|
||||
@ -531,6 +536,7 @@ def test_marlin_gemm_subset_input():
|
||||
a_input,
|
||||
None,
|
||||
marlin_q_w,
|
||||
None,
|
||||
marlin_s,
|
||||
None,
|
||||
marlin_zp,
|
||||
@ -555,6 +561,53 @@ def test_marlin_gemm_subset_input():
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
@pytest.mark.parametrize("size_m", [1, 256])
|
||||
def test_marlin_gemm_with_bias(size_m):
|
||||
quant_type = scalar_types.uint4b8
|
||||
group_size = 128
|
||||
|
||||
size_k, size_n = 1024, 2048
|
||||
a_input = rand_data((size_m, size_k))
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
b_bias = rand_data((size_n, )) * 10
|
||||
|
||||
marlin_bias = marlin_permute_bias(b_bias)
|
||||
|
||||
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
||||
b_weight, quant_type, group_size, False)
|
||||
|
||||
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
|
||||
workspace = marlin_make_workspace_new(a_input.device)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
a_input,
|
||||
None,
|
||||
marlin_q_w,
|
||||
marlin_bias,
|
||||
marlin_s,
|
||||
None,
|
||||
marlin_zp,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
workspace,
|
||||
quant_type,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
is_k_full=True,
|
||||
use_atomic_add=False,
|
||||
use_fp32_reduce=True,
|
||||
is_zp_float=False,
|
||||
)
|
||||
output_ref = torch.matmul(a_input, w_ref) + b_bias.view(1, -1)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
max_diff = compute_max_diff(output, output_ref)
|
||||
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
def test_marlin_gemm_opcheck():
|
||||
size_m = 2048
|
||||
size_n = 4096
|
||||
|
||||
@ -1064,6 +1064,8 @@ def torch_experts(
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
global_num_experts: int = -1,
|
||||
b_bias1: Optional[torch.Tensor] = None,
|
||||
b_bias2: Optional[torch.Tensor] = None,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
@ -1108,8 +1110,13 @@ def torch_experts(
|
||||
if mask.sum():
|
||||
if quant_dtype is None:
|
||||
tmp1 = a[mask] @ w1[i].transpose(0, 1)
|
||||
if b_bias1 is not None:
|
||||
tmp1 = tmp1 + b_bias1[i].view(1, -1).to(tmp1.dtype)
|
||||
tmp2 = SiluAndMul()(tmp1)
|
||||
out[mask] = tmp2 @ w2[i].transpose(0, 1)
|
||||
if b_bias2 is not None:
|
||||
out[mask] = out[mask] + b_bias2[i].view(1, -1).to(
|
||||
tmp1.dtype)
|
||||
elif block_shape is not None:
|
||||
# block quantized
|
||||
assert (a_scale is not None and w1_scale is not None
|
||||
@ -1117,6 +1124,8 @@ def torch_experts(
|
||||
tmp1 = native_w8a8_block_matmul(a[mask], w1[i], a_scale[mask],
|
||||
w1_scale[i], block_shape,
|
||||
out.dtype)
|
||||
if b_bias1 is not None:
|
||||
tmp1 = tmp1 + b_bias1[i].view(1, -1).to(tmp1.dtype)
|
||||
tmp2 = SiluAndMul()(tmp1)
|
||||
tmp2, b_scale = moe_kernel_quantize_input(
|
||||
tmp2, a2_scale, quant_dtype, per_act_token_quant,
|
||||
@ -1125,6 +1134,9 @@ def torch_experts(
|
||||
out[mask] = native_w8a8_block_matmul(tmp2, w2[i], b_scale,
|
||||
w2_scale[i], block_shape,
|
||||
out.dtype)
|
||||
if b_bias2 is not None:
|
||||
out[mask] = out[mask] + b_bias2[i].view(1, -1).to(
|
||||
tmp1.dtype)
|
||||
else:
|
||||
assert (a_scale is not None and w1_scale is not None
|
||||
and w2_scale is not None)
|
||||
@ -1133,6 +1145,8 @@ def torch_experts(
|
||||
tmp1 = a[mask].to(f32) * scales
|
||||
w1_dq = (w1[i].to(f32) * w1_scale[i]).transpose(0, 1)
|
||||
tmp1 = (tmp1 @ w1_dq).to(out.dtype)
|
||||
if b_bias1 is not None:
|
||||
tmp1 = tmp1 + b_bias1[i].view(1, -1).to(out.dtype)
|
||||
|
||||
tmp2 = SiluAndMul()(tmp1).to(out.dtype)
|
||||
|
||||
@ -1144,6 +1158,9 @@ def torch_experts(
|
||||
tmp2 = tmp2.to(f32) * b_scale
|
||||
w2_dq = (w2[i].to(f32) * w2_scale[i]).transpose(0, 1)
|
||||
out[mask] = (tmp2 @ w2_dq).to(out.dtype)
|
||||
if b_bias2 is not None:
|
||||
out[mask] = out[mask] + b_bias2[i].view(1, -1).to(
|
||||
out.dtype)
|
||||
|
||||
if apply_router_weights_on_input:
|
||||
return out
|
||||
@ -1157,12 +1174,14 @@ def torch_moe(a: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
score: torch.Tensor,
|
||||
topk: int,
|
||||
b_bias1: Optional[torch.Tensor] = None,
|
||||
b_bias2: Optional[torch.Tensor] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
return torch_experts(a, w1, w2, topk_weight, topk_ids, global_num_experts,
|
||||
expert_map)
|
||||
b_bias1, b_bias2, expert_map)
|
||||
|
||||
|
||||
def torch_moe_single(a, w, score, topk):
|
||||
|
||||
Reference in New Issue
Block a user