Compare commits

...

444 Commits

Author SHA1 Message Date
221118dc85 [Bugfix] Use a different prompt for benchmark_serving.py test prompt
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
2025-05-17 18:36:31 +00:00
66e63e86ec [MISC] fix typo (#18305)
Signed-off-by: Andy Xie <andy.xning@gmail.com>
2025-05-17 10:52:09 -07:00
9214e60631 [Model] use AutoWeightsLoader for solar (#18113) 2025-05-17 00:24:17 -07:00
f880d42582 Fixed build on ppc64le due to openssl conflicts (#18262)
Signed-off-by: Nishidha Panpaliya <nishidha.panpaliya@partner.ibm.com>
2025-05-17 00:23:46 -07:00
dcfe95234c Update Dockerfile to build for Blackwell (#18095) 2025-05-17 00:23:25 -07:00
48ac2bed5b [Hardware][TPU] Optionally import for TPU backend (#18269)
Signed-off-by: Siyuan Liu <lsiyuan@google.com>
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
Co-authored-by: Carol Zheng <cazheng@google.com>
Co-authored-by: Jade Zheng <zheng.shoujian@outlook.com>
Co-authored-by: Hongmin Fan <fanhongmin@google.com>
2025-05-17 15:23:12 +08:00
3e0d435027 [P/D][V1] Support dynamic loading of external KV connector implementations (#18142)
Signed-off-by: David Ben-David <davidb@pliops.com>
Co-authored-by: David Ben-David <davidb@pliops.com>
2025-05-17 06:40:39 +00:00
4ee4826ede [BugFix] Correct max_model_len derivation from config.json for Mistral format (#17937)
Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
Co-authored-by: tracelogfb <48808670+tracelogfb@users.noreply.github.com>
Co-authored-by: Stephen Chen <tracelog@meta.com>
2025-05-17 04:20:13 +00:00
60017dc841 [Misc] reformat the collect-env output (#18285)
Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>
2025-05-16 19:46:18 -07:00
55f1a468d9 Move cli args docs to its own page (#18228) (#18264)
Signed-off-by: Trevor Royer <troyer@redhat.com>
2025-05-16 19:43:45 -07:00
fd195b194e [V1][P/D] Local attention optimization for NIXL (#18170)
Signed-off-by: mgoin <mgoin64@gmail.com>
2025-05-16 21:16:33 -04:00
fabe89bbc4 [Spec Decode] Don't fall back to V0 when spec decoding is enabled (#18265) 2025-05-16 16:10:27 -07:00
e73b7dfd69 [Bugfix] fix an illegal memory access was encountered of marlin kernel + act_order (#18245) 2025-05-16 16:02:44 -07:00
7fdfa01530 [Sampler] Adapt to FlashInfer 0.2.3 sampler API (#15777)
Signed-off-by: Bowen Wang <abmfy@icloud.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
2025-05-16 15:14:03 -07:00
aef94c6d07 [CI] Assign reviewer to mergify with changes to Tensorizer files (#18278) 2025-05-16 12:04:14 -07:00
0ceaebf87b [BugFix] Fix ordering of KVConnector finished send/rcv sets (#18211)
Signed-off-by: Nick Hill <nhill@redhat.com>
2025-05-16 09:20:54 -07:00
1db4f47f81 [BugFix] Fix multi async save in MultiConnector (#18246)
Signed-off-by: Nick Hill <nhill@redhat.com>
2025-05-16 08:13:47 -07:00
d3d91b6f71 [Misc][MacOS] fix bfloat16 error (#18249)
Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>
2025-05-16 15:05:59 +00:00
87d871470d [Model] Use autoweightloader for dbrx (#18251)
Signed-off-by: learner0810 <zhongjun.li@daocloud.io>
2025-05-16 07:54:13 -07:00
a5f8c111c2 [Fix] Fix typo in resolve_hf_chat_template (#18259)
Signed-off-by: Felix Marty <felmarty@amd.com>
2025-05-16 14:52:41 +00:00
e23564cb70 use ceil_div in cutlass block scaling shape check (#17918) 2025-05-16 03:02:58 -07:00
390ec88905 [Misc] Consolidate Audio tests into multimodal common generation tests (#18214)
Signed-off-by: Isotr0py <2037008807@qq.com>
2025-05-16 09:18:08 +00:00
541817670c [Misc] Add Ray Prometheus logger to V1 (#17925)
Signed-off-by: Seiji Eicher <seiji@anyscale.com>
2025-05-16 01:02:42 -07:00
67da5720d4 [PERF] Speed up Qwen2.5-VL model by speed up rotary position embedding (#17973)
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@centml.ai>
2025-05-15 23:31:02 -07:00
5c04bb8b86 [doc] fix multimodal example script (#18089)
Signed-off-by: David Xia <david@davidxia.com>
2025-05-16 06:05:34 +00:00
3d2779c29a [Feature] Support Pipeline Parallism in torchrun SPMD offline inference for V1 (#17827)
Signed-off-by: Lucia Fang <fanglu@fb.com>
2025-05-15 22:28:27 -07:00
6b31c84aff Throw better error for when running into k8s service discovery issue (#18209)
Signed-off-by: Will Eaton <weaton@redhat.com>
2025-05-15 21:07:28 -07:00
b18201fe06 Allow users to pass arbitrary JSON keys from CLI (#18208)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-05-15 21:05:34 -07:00
f4937a51c1 [Model] vLLM v1 supports Medusa (#17956)
Signed-off-by: lisiqi23 <lisiqi23@xiaomi.com>
Signed-off-by: skylee-01 <497627264@qq.com>
Co-authored-by: lisiqi23 <lisiqi23@xiaomi.com>
2025-05-15 21:05:31 -07:00
ee659e3b60 [Bugfix][ROCm] Use chunked_prefill_paged_decode as fallback for V1 attention on ROCm (#18093)
Signed-off-by: kf <kuanfu.liu@embeddedllm.com>
2025-05-15 19:30:17 -07:00
4e1c6a0264 [Bugfix] fix rotary embedding test for _get_padded_tensor_shape (#18229)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
2025-05-16 01:32:45 +00:00
c7852a6d9b [Build] Allow shipping PTX on a per-file basis (#18155)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
2025-05-15 16:41:55 -07:00
8795eb9975 [Bugfix] Fix test_eagle test (#18223)
Signed-off-by: Lucia Fang <fanglu@fb.com>
2025-05-15 15:59:42 -07:00
0b34593017 Adding "AMD: Tensorizer Test" to amdproduction. (#18216) 2025-05-15 11:01:25 -07:00
e3f3aee6f4 [Misc] Avoid cuda graph log when sizes still match (#18202)
Signed-off-by: NickLucche <nlucches@redhat.com>
2025-05-15 09:59:38 -07:00
92540529c0 [Bugfix] [ROCm]: Remove assertion logic when using AITER fused moe in unquantizedMethod to reenable LLama4 BF16 (#18205)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
2025-05-15 09:53:18 -07:00
fadb8d5c2d [Bugfix]Change the exception thrown by call_hf_processor from RuntimeError to ValueError (#18181)
Signed-off-by: Abatom <abzhonghua@gmail.com>
2025-05-15 09:01:47 -07:00
2aa5470ac5 [Frontend] Fix chat template content format detection (#18190)
Signed-off-by: Sebastian Schönnenbeck <sebastian.schoennenbeck@comma-soft.com>
2025-05-15 09:00:21 -07:00
51ff154639 Improve examples rendering in docs and GitHub (#18203)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-05-15 15:57:49 +00:00
566ec04c3d Adding "Basic Models Test" and "Multi-Modal Models Test (Extended) 3" in AMD Pipeline (#18106)
Signed-off-by: Alexei V. Ivanov <alexei.ivanov@amd.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
2025-05-15 08:49:23 -07:00
01c22335ba [Kernel] [V1] Fix performance regression for triton unified attention (#18161)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
2025-05-15 06:39:00 -07:00
451da4bcbd add tools into TokenizeChatRequest (#18187)
Signed-off-by: yangxia <yangxiast@gmail.com>
2025-05-15 04:01:49 -07:00
07ad27121f Update deprecated type hinting in model_loader (#18130)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-05-15 04:00:21 -07:00
a9944aabfa fix: typos (#18151)
Signed-off-by: omahs <73983677+omahs@users.noreply.github.com>
2025-05-15 02:16:15 -07:00
a8f5aec20a [V1] Update zmq socket creation in nixl connector (#18148)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
2025-05-14 23:17:57 -07:00
de71fec81b [CI] don't skip fixed test_kv_cache_events() (#18183)
Signed-off-by: David Xia <david@davidxia.com>
2025-05-14 23:17:16 -07:00
70f8b96724 [Bugfix] Fix FusedMoEPrepareAndFinalize for cuda-disalike backends (#18178)
Signed-off-by: Mengqing Cao <cmq0113@163.com>
2025-05-14 23:16:31 -07:00
dd2a94596a [Model] Allow the use of sliding window in Qwen2 (#17772)
Signed-off-by: inkcherry <mingzhi.liu@intel.com>
2025-05-14 22:29:38 -07:00
420caf7557 [UT] Add ut for none hash (#17892)
Signed-off-by: Andy Xie <andy.xning@gmail.com>
2025-05-15 13:28:11 +08:00
4f07a64075 Support custom implementations of VideoLoader backends. (#18091) 2025-05-15 13:26:49 +08:00
e6b8e65d2d [Bugfix] Fix fp8 tests for triton_unified_attention for Triton 3.3 (#18013)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
2025-05-15 13:26:34 +08:00
26d0419309 Update deprecated type hinting in models (#18132)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-05-14 22:06:50 -07:00
83f74c698f [Fix][ROCm] Enforce eager for all encoder-decoder models on ROCm (#18154)
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
2025-05-14 22:04:43 -07:00
2dff093574 [Misc] add lobe-chat support (#18177)
Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>
2025-05-15 05:02:23 +00:00
afe3236e90 [Chore] astral's ty (#18116)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
2025-05-15 05:00:43 +00:00
65334ef3b9 [V1][Metrics] Remove unused code (#18158)
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
2025-05-14 20:13:17 -07:00
e60f550b38 [v1] Support multiple KV cache groups in GPU model runner (#17945)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
2025-05-14 18:54:54 -07:00
f25e0d1125 [Bugfix]: make most of test_openai_schema.py pass (#17664) 2025-05-14 17:04:35 -07:00
09f106a91e Upload vllm index for the rc builds (#18173) 2025-05-14 16:35:56 -07:00
2142035b51 [V1] Support multiple kv connectors (#17564)
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
2025-05-14 16:28:02 -07:00
78aa341d12 [CI] Fix race condition in test_kv_cache_events test (#18169)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
2025-05-14 16:27:48 -07:00
7974736740 Add support for loading torchao models with AOPerModuleConfig (#17826)
Signed-off-by: Jerry Zhang <jerryzh168@gmail.com>
2025-05-14 16:24:59 -07:00
2fc9075b82 [V1] Structured Outputs + Thinking compatibility (#16577)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
Co-authored-by: Russell Bryant <rbryant@redhat.com>
2025-05-14 15:45:24 -07:00
d93c976a0d [Kernel] Have rotary embeddings support tensors (#18046)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
2025-05-14 15:43:55 -07:00
749f792553 [Frontend] decrease import time of vllm.multimodal (#18031)
Co-authored-by: Aaron Pham <Aaronpham0103@gmail.com>
2025-05-14 15:43:32 -07:00
856865008e [CI] Disable Failing Tests (#18165) 2025-05-14 13:49:56 -07:00
f9c069c85e Modularize fused experts and integrate PPLX kernels (#15956) 2025-05-14 13:11:54 -07:00
418d2f8bfb [V1][Spec Decode] Share input embedding of target model with EAGLE draft model to free ~1GB for llama 3 model (#17326)
Co-authored-by: root <root@ekagra-8xh100.us-east5-a.c.serving-efficiency-poc.internal>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-05-14 12:31:46 -07:00
964472b966 [Doc] Update prefix cache metrics to counting tokens (#18138)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
2025-05-14 15:23:30 +00:00
59dd311cf5 [KVConnector] Keep KVTransferParams as a dict (#18033) 2025-05-14 08:05:57 -07:00
d066e52013 [Bugfix] Fix chat utils tests (#18139)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-05-14 05:38:21 -07:00
c8ea982d9b Update deprecated type hinting in platform, plugins, triton_utils, vllm_flash_attn (#18129)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-05-14 05:28:16 -07:00
dc372b9c8a Update deprecated type hinting in vllm/device_allocator and vllm/distributed (#18126)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-05-14 04:07:57 -07:00
9b5b39b650 Update deprecated type hinting in vllm/lora (#18128)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-05-14 03:57:59 -07:00
9ccc6ded42 [doc] add missing import (#18133)
Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>
2025-05-14 10:57:34 +00:00
d62a076e84 [Model] GritLM supports other attention backends (#18109)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-05-14 03:33:19 -07:00
259127f8b8 [Bugfix] Fix LoRA test (#18123)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
2025-05-14 10:25:47 +00:00
612c2edb4f [FEAT] [ROCm]: Add AITER CK 2 Stages MoE support (#17110)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
2025-05-14 03:03:11 -07:00
38fe728d60 [Bugfix] Fix QKVCrossParallelLinear::sync_weight_attrs for PyTorch compile (#17844)
Signed-off-by: Andrzej Kotłowski <akotlowski@habana.ai>
2025-05-14 09:39:51 +00:00
82e7f9bb03 [Misc] replace does not exist model (#18119)
Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>
2025-05-14 02:13:47 -07:00
63dc3426e0 [Model] Add packed_modules_mapping for Qwen3-MOE (#18118)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
2025-05-14 02:13:19 -07:00
8f5dc41481 [Bugfix] Fix entrypoints audio test failure (#18111)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-05-14 09:08:07 +00:00
63ad622233 [New Model]: support GTE NewModel (#17986) 2025-05-14 01:31:31 -07:00
e7ef61c1f0 [Bugfix][Example] make lmcache v0 work. (#18051)
Signed-off-by: Ma, Jianpeng <jianpeng.ma@intel.com>
2025-05-13 23:43:44 -07:00
d4154c35a2 [Bugfix] fix moe marlin topk_weight loading (#18080)
Co-authored-by: mgoin <mgoin64@gmail.com>
2025-05-13 23:31:57 -07:00
6685890d11 [Fix] Move "model_config" as keyword args in chat_utils.py (#18098)
Signed-off-by: Linkun <github@lkchen.net>
2025-05-13 23:27:26 -07:00
33011318c2 Fix broken example: examples/offline_inference/profiling at scheduler_config (#18117) 2025-05-13 23:19:14 -07:00
4f8b373225 [BugFix][AMD] Compatible patch for AITER lib after 04/20 (#17912)
Signed-off-by: Qiang Li <qiang.li2@amd.com>
2025-05-13 23:05:20 -07:00
7b2f28deba [AMD][torch.compile] Enable silu+fp8_quant fusion for rocm (#18082)
Signed-off-by: charlifu <charlifu@amd.com>
2025-05-13 22:13:56 -07:00
2d912fb66f [FEAT] [ROCm] [V1]: Add AITER biased group topk for DeepSeekV3 (#17955)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
2025-05-13 22:03:47 -07:00
12e6c0b41c [Bugfix][V1] Fix FlashInfer V1 backend using the wrong VllmConfig (#18086) 2025-05-13 20:36:17 -07:00
9a2a6357de [Bugfix] Fix FP8 Marlin MoE and enable for compressed-tensors models (#18026)
Signed-off-by: mgoin <mgoin64@gmail.com>
2025-05-13 19:48:33 -07:00
6266c57bae [core][distributed] add ep group and all2all interface (#18077)
Signed-off-by: youkaichao <youkaichao@gmail.com>
2025-05-14 10:46:49 +08:00
754b699cbe [Bug]: Fix S3 model/tokenizer path resolution (#18083)
Signed-off-by: Jon Gill <jon@yurts.ai>
2025-05-13 19:34:17 -07:00
6e27c6d86b [Misc] Remove unused numpy tensor (#18084)
Signed-off-by: Roger Wang <hey@rogerw.me>
2025-05-13 19:33:40 -07:00
d5af47a149 [P/D] Add some more debug logs to NixlConnector (#18102)
Signed-off-by: Nick Hill <nhill@redhat.com>
2025-05-13 19:33:03 -07:00
65f0f74b66 [Hardware/NVIDIA/Modelopt] Fix modelopt forward method for v1 torch.compile (#18101)
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
2025-05-13 19:33:00 -07:00
176a95c670 [Fix] Support CUDAGraph capture for encoder-decoder on ROCm (#18104)
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
2025-05-13 19:31:42 -07:00
f2ae883b67 [v1][KVCacheManager] pass num_new_computed_tokens to kv cache manager (#18001)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
2025-05-13 19:09:39 -07:00
40de1ef455 [FEAT] [ROCm]: Add AITER Block-Scaled GEMM Feature (#14968)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
2025-05-13 19:08:20 -07:00
0189a65a2e [Docs] Expand security doc with firewall info (#18081)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
2025-05-13 19:36:00 +00:00
55aa7af994 [V1] DP scale-out (2/N): Decouple engine process management and comms (#15977)
Signed-off-by: Nick Hill <nhill@redhat.com>
2025-05-13 10:48:21 -07:00
0b217da646 Update deprecated type hinting in vllm/adapter_commons (#18073)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-05-13 08:32:51 -07:00
19324d660c Update deprecated type hinting in vllm/compilation (#18072)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-05-13 08:32:48 -07:00
fc407a1425 Give auto-merge label workflow permission to add labels to issues (#18078)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-05-13 07:53:13 -07:00
009d9e7590 Convert benchmarks to ruff format (#18068)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-05-13 13:43:29 +00:00
b922c2ebd2 [Bugfix] Fix entrypoints metrics tests (#18063)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-05-13 06:42:43 -07:00
00b14e0f16 [CI] set token permissions for pre-commit CI job (#17729)
Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
2025-05-13 13:38:30 +00:00
54e467e6f8 [CI] Add token permissions for add-ready-label CI job (#17730)
Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
2025-05-13 13:38:13 +00:00
79a1d25bbd [CI] Add workflow permissions for helm CI job (#17727)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
2025-05-13 12:49:07 +00:00
9944011b30 [CI] Set token permissions for reminder comment CI job (#17728)
Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
2025-05-13 12:46:58 +00:00
8c946cecca Update deprecated type hinting in vllm/transformers_utils (#18058)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-05-13 04:34:37 -07:00
ff334ca1cd Update deprecated type hinting in vllm/profiler (#18057)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-05-13 04:34:34 -07:00
6223dd8114 Update deprecated type hinting in model_executor/layers (#18056)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-05-13 04:17:23 -07:00
906f0598fc [doc] add download/list/delete HF model CLI usage (#17940)
Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>
2025-05-13 11:15:51 +00:00
cb528d0585 [Fix] check to make sure processor has chat templates (#18047)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
2025-05-13 03:04:10 -07:00
98fcba1575 Convert .buildkite to ruff format (#17656)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-05-13 09:28:31 +00:00
23b3134eb5 [Benchmarks] Refactor run_structured_output_benchmarks.sh (#17722)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
2025-05-13 01:47:29 -07:00
ea6ae8cb45 [Bugfix] Fix marlin moe fallback logic for llama4 (#18042)
Signed-off-by: mgoin <mgoin64@gmail.com>
2025-05-13 07:53:28 +00:00
2ff297dce9 [BugFix] Set default random seed to 0 for V1 (#17929)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-05-13 07:52:19 +00:00
8dd0671bac [Bugfix][V1] Only get input embeddings w/ multi-modal models if first PP (#17916)
Signed-off-by: Jin Huang <jinhun@amazon.com>
Co-authored-by: Jin Huang <jinhun@amazon.com>
2025-05-13 15:10:07 +08:00
f0d610a8ae [v1][KVCacheManager] Avoid full cache hit by controlling max_length (#17999)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-05-13 06:50:38 +00:00
e57e4d6e9e Fix Broken macro for cutlass moe (#18049)
Signed-off-by: drisspg <drisspguessous@gmail.com>
2025-05-12 23:31:06 -07:00
ee5be834e7 [BugFix] Fix 4-GPU RLHF tests (#18007)
Signed-off-by: Nick Hill <nhill@redhat.com>
2025-05-12 23:03:55 -07:00
48545728d8 cleanup invalid prints (#18050)
Signed-off-by: calvin chen <120380290@qq.com>
2025-05-12 23:01:57 -07:00
dc1a821768 [Feature][V1] Support tool_choice: required when using Xgrammar as the StructuredOutputBackend. (#17845)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
2025-05-12 23:01:31 -07:00
61e0a506a3 [Bugfix] Avoid repeatedly creating dummy data during engine startup (#17935)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-05-12 22:40:19 -07:00
1df491c522 [Bugfix] Fixes for new marlin moe usage (#18017)
Signed-off-by: mgoin <mgoin64@gmail.com>
2025-05-13 03:50:04 +00:00
d8487ef557 [ROCm]: Fix build from source failure with gcc14 and ROCm 6.3 (#13779)
Signed-off-by: Arjun Kathuria <arjun.kathuria8@gmail.com>
2025-05-12 20:36:33 -07:00
c06af9a959 [Misc] Slight spelling modification (#18039)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
2025-05-12 20:36:27 -07:00
60f7624334 Implements dual-chunk-flash-attn backend for dual chunk attention with sparse attention support (#11844) 2025-05-12 19:52:47 -07:00
f6518b2b48 [ROCm] Skip tests for quantizations incompatible with ROCm (#17905)
Signed-off-by: Hissu Hyvarinen <hissu.hyvarinen@amd.com>
2025-05-12 18:39:28 -06:00
d67085c2c8 Remove noisy warnings from SchedulerConfig (#17995)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-05-13 00:33:45 +00:00
307939f299 Use NVFP4 Marlin for CompressedTensorsW4A16Fp4 (#18000)
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Dipika <dipikasikka1@gmail.com>
Co-authored-by: Dipika <dipikasikka1@gmail.com>
2025-05-12 18:07:34 -06:00
9d7ea9dbbf Update some more deprecated type hinting (#17998)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-05-12 23:49:33 +00:00
acee8f48aa [Model] Support MiMo-7B inference with MTP (#17433)
Signed-off-by: wp-alpha <wangpeng66@xiaomi.com>
Co-authored-by: wangpeng66 <wangpeng66@xiaomi.com>
2025-05-12 23:25:33 +00:00
f065de4e88 Fix FBGEMM integration (#18002)
Signed-off-by: mgoin <mgoin64@gmail.com>
2025-05-12 23:02:07 +00:00
dc9905368d [V1][Spec Decode] Eagle unit tests (#17350)
Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
2025-05-12 23:01:17 +00:00
ebab1ac37c [CI] Make JSON output tests less likely to fail (#17859)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
2025-05-12 22:31:54 +00:00
2b0db9b0e2 Enable standard language model for torhc nightly (#18004)
Signed-off-by: Yang Wang <elainewy@meta.com>
2025-05-12 14:00:04 -07:00
195adb47c0 [Chore] Remove unused method (#18024)
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
2025-05-12 13:59:47 -07:00
302f3aca7e [v1][KVCacheManager] Change prefix caching metric from counting blocks to counting tokens (#18003)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
2025-05-12 13:46:12 -07:00
e9c730c9bd Enabling "Weight Loading Multiple GPU Test - Large Models" (#18020) 2025-05-12 13:05:33 -07:00
289199feb6 [Core] Use platform-agnostic device control for DP engine core (#17245)
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
2025-05-12 12:09:16 -07:00
b9fd0d7a69 [CI/Build] Fix TPU V1 Test mixed use of & and && across tests (#17968) 2025-05-12 12:06:59 -07:00
72a3f6b898 Construct KVTransferConfig properly from Python instead of using JSON blobs without CLI (#17994)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-05-12 11:25:33 -07:00
98ea35601c [Lora][Frontend]Add default local directory LoRA resolver plugin. (#16855)
Signed-off-by: jberkhahn <jaberkha@us.ibm.com>
2025-05-12 10:39:10 -07:00
d19110204c [P/D] NIXL Integration (#17751)
Signed-off-by: ApostaC <yihua98@uchicago.edu>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
Signed-off-by: Robert Shaw <rshaw@neuralmagic.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: Brent Salisbury <bsalisbu@redhat.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: ApostaC <yihua98@uchicago.edu>
Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Tyler Michael Smith <tysmith@redhat.com>
Co-authored-by: Brent Salisbury <bsalisbu@redhat.com>
2025-05-12 09:46:16 -07:00
05a4324f8e Initialize the delta tool call fields explicitly (#17340)
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Co-authored-by: igmainc <igmainc@icloud.com>
2025-05-12 13:28:58 +00:00
7ea6cb28b2 [Misc] Improve modelscope import error (#17983)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
2025-05-12 10:46:45 +00:00
9fbf2bfbd5 Correcting testcases in builkite job for IBM Power (#17675)
Signed-off-by: Aaruni Aggarwal <aaruniagg@gmail.com>
2025-05-12 08:11:55 +00:00
3a5ea75129 [Feature] Support DeepSeekV3 Function Call (#17784)
Signed-off-by: 许文卿 <xwq391974@alibaba-inc.com>
Signed-off-by: Xu Wenqing <xuwq1993@qq.com>
2025-05-12 00:45:21 -07:00
891b9d33de [Fix] Benchmark "EngineClient" has no attribute "model_config" (#17976)
Signed-off-by: Brayden Zhong <b8zhong@uwaterloo.ca>
2025-05-11 22:55:53 -07:00
430783018c [Bugfix][TPU] Use np array when updating cache slot_mapping (#17971)
Signed-off-by: Siyuan Liu <lsiyuan@google.com>
2025-05-12 12:58:33 +08:00
19a3c78d1f [Bugfix] Fix pydantic.errors.PydanticUserError (#17962)
Signed-off-by: wangli <wangli858794774@gmail.com>
2025-05-12 12:58:23 +08:00
ada50aa295 [bugfix] fix the wrong parser (#17958)
Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>
2025-05-12 04:58:02 +00:00
08bf784078 [Bugfix] validate grammar and throw 400 error instead of crashing the engine when xgrammar validation fails (#17623)
Signed-off-by: Jason Cheng <jasoncky96@gmail.com>
Co-authored-by: Russell Bryant <rbryant@redhat.com>
2025-05-12 09:06:10 +08:00
d45fe333fb [misc] add instructions on how to install nvshmem/pplx/deepep (#17964)
Signed-off-by: youkaichao <youkaichao@gmail.com>
2025-05-11 18:02:39 -07:00
021c16c7ca [Model] Broadcast Ovis2 implementation to fit Ovis1.6 (#17861)
Signed-off-by: Isotr0py <2037008807@qq.com>
2025-05-11 17:56:30 -07:00
7de18d541b [BUG] [ROCm] [MLA] Fix variable name bug due to change in variable name in PR #17483 (#17961)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
2025-05-11 09:14:30 -07:00
a810b5b088 [BugFix] [ROCm]: Bugfix and handle addition case of input for rocm_aiter_rms_norm (#17857)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
2025-05-11 04:17:11 -07:00
009b3d5382 [Misc] not show --model in vllm serve --help (#16691)
Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>
2025-05-11 08:47:58 +00:00
e4b8713380 [New Model]: nomic-embed-text-v2-moe (#17785) 2025-05-11 00:59:43 -07:00
06c0922a69 [FP8][ROCm][Attention] Enable FP8 KV cache on ROCm for V1 (#17870)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
2025-05-11 15:58:45 +08:00
cd3edfc908 [Misc] Add compressed-tensors NVFP4A16 emulation support (#17914)
Signed-off-by: Dipika Sikka <dipikasikka1@gmail.com>
Signed-off-by: Dipika <dipikasikka1@gmail.com>
2025-05-11 15:58:38 +08:00
9cea90eab4 [Frontend] Add /classify endpoint (#17032)
Signed-off-by: Frieda (Jingying) Huang <jingyingfhuang@gmail.com>
2025-05-11 07:57:07 +00:00
d1110f5b5a [doc] update lora doc (#17936)
Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>
2025-05-11 15:56:21 +08:00
8132365b74 [Bugfix]: v1 engine - consider lora adapters in allowed_token_ids (#17855)
Signed-off-by: Ben Browning <bbrownin@redhat.com>
2025-05-11 00:53:58 -07:00
eea22a56ab fix amd triton mla path (#17871) 2025-05-11 07:53:31 +00:00
9112155283 [Perf] Use small max_num_batched_tokens for A100 (#17885)
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
2025-05-11 07:53:23 +00:00
90d0a74b60 [Bugfix] Add revision to transformers.Auto*.from_pretrained processors (#17948)
Signed-off-by: Xin Li <xin@centml.ai>
2025-05-11 07:52:44 +00:00
d74e5f37bc [Kernel] fp4 marlin kernel (#17687)
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
2025-05-10 19:58:49 -07:00
ca66a1674c [v1] Rename specialized_manager.py to single_type_kv_cache_manager.py (#17946)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
2025-05-10 16:14:12 -07:00
950751a987 [v1] Pass BlockTable and KVCacheSpec to AttentionMetadataBuilders (#17483)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
2025-05-10 16:12:04 -07:00
4c31218f80 [Misc] remove --model from vllm serve usage (#17944)
Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>
2025-05-10 13:23:31 +00:00
68311891f5 Don't default construct ModelConfig when default constructing VllmConfig (#17943)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-05-10 13:23:00 +00:00
fc4441a4ee Add missing content type headers to /ping and /health (#17036) (#17786)
Signed-off-by: Ximo Guanter <ximo.guanter@gmail.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-05-10 07:13:32 +01:00
246e3e0a36 fix broken test vllm:test_kernels - test_attention_selector.py::test_flash_attn (#17873)
Co-authored-by: Stephen Chen <tracelog@meta.com>
2025-05-10 10:46:54 +08:00
7042cc96b0 [V1][Spec Decoding] Log accumulated metrics after system goes idle (#17913)
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
2025-05-09 18:23:07 -07:00
0c0fdae84f [Hardware/NVIDIA/Kernel] Enable nvidia/DeepSeek-R1-FP4 Model (#16362) 2025-05-09 16:24:41 -07:00
3b602cdea7 AMD conditional all test execution // new test groups (#17556)
Signed-off-by: Alexei V. Ivanov <alexei.ivanov@amd.com>
Signed-off-by: Yida Wu <yidawu@alumni.cmu.edu>
2025-05-09 15:35:58 -07:00
4b2ed7926a Improve configs - the rest! (#17562)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-05-09 15:18:44 -07:00
7e3571134f [V1][Spec Decoding] Include bonus tokens in mean acceptance length (#17908)
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
2025-05-09 13:32:36 -07:00
ea2236bf95 Add option to use torch._inductor.standalone_compile (#17057)
Signed-off-by: rzou <zou3519@gmail.com>
2025-05-09 12:59:04 -07:00
7d4aedae7c Handle error when str passed to /v1/audio/transcriptions (#17909)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-05-09 19:23:59 +00:00
22481fbfa3 Update CT WNA16MarlinMoE integration (#16666)
Signed-off-by: mgoin <mgoin64@gmail.com>
2025-05-09 13:19:45 -04:00
5c4c08f6f1 [Misc] Auto fallback to float16 for pre-Ampere GPUs when detected bfloat16 config (#17265)
Signed-off-by: Isotr0py <2037008807@qq.com>
2025-05-09 17:16:12 +00:00
c44c384b1c [Misc] Add references in ray_serve_deepseek example (#17907)
Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
2025-05-09 16:59:36 +00:00
85b72cb7b1 Revert "[BugFix][AMD] Compatible patch for latest AITER(05/07/2025)" (#17910) 2025-05-09 08:58:18 -07:00
6e5595ca39 [CI/Build] Automatically retry flaky tests (#17856)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-05-09 09:55:17 -06:00
200da9a517 [v1] Move block management logic from KVCacheManager to SpecializedManager (#17474)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
2025-05-09 15:25:34 +00:00
9f64e93415 [BugFix][AMD] Compatible patch for latest AITER(05/07/2025) (#17864)
Signed-off-by: Qiang Li <qiang.li2@amd.com>
2025-05-09 08:59:36 -06:00
ec61ea20a8 [Misc] add dify integration (#17895)
Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>
2025-05-09 03:42:39 -07:00
c6798baa9c Change top_k to be disabled with 0 (still accept -1 for now) (#17773)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-05-09 10:01:49 +00:00
5b2dcbf0b8 Fix Whisper crash caused by invalid`` max_num_batched_tokens`` config (#17853)
Signed-off-by: inkcherry <mingzhi.liu@intel.com>
2025-05-09 09:16:26 +00:00
6e4a93e3f7 [Bugfix][CPU] Fix broken AVX2 CPU TP support (#17252)
Signed-off-by: Isotr0py <2037008807@qq.com>
2025-05-09 08:55:14 +00:00
217db4baa6 [Bugfix][ROCm] Fix AITER MLA V1 (#17880)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
2025-05-09 08:38:21 +00:00
ff8c400502 [Doc] remove visible token in doc (#17884)
Signed-off-by: yan <yanma1@habana.ai>
2025-05-09 01:21:31 -07:00
89a0315f4c [Doc] Update several links in reasoning_outputs.md (#17846)
Signed-off-by: windsonsea <haifeng.yao@daocloud.io>
2025-05-09 01:20:55 -07:00
3d1e387652 [Docs] Add Slides from NYC Meetup (#17879)
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-05-08 21:46:54 -07:00
d310e6de98 [BUGFIX]: return fast when request requires prompt logprobs (#17251) 2025-05-08 21:25:41 -07:00
5e6f939484 [Attention] MLA move rotary embedding to cuda-graph region (#17668)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
2025-05-09 11:14:42 +08:00
760e3ecc8f [V1][Structured Output] Update llguidance (>= 0.7.11) to avoid AttributeError (no StructTag) (#17839)
Signed-off-by: shen-shanshan <467638484@qq.com>
2025-05-08 20:14:18 -07:00
3c9396a64f [FEAT][ROCm]: Support AITER MLA on V1 Engine (#17523)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: qli88 <qiang.li2@amd.com>
Co-authored-by: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com>
2025-05-09 10:42:05 +08:00
376786fac1 Add cutlass support for blackwell fp8 blockwise gemm (#14383)
Signed-off-by: Shu Wang <shuw@nvidia.com>
2025-05-08 15:09:55 -07:00
4f605a6de5 Fix noisy warning for uncalibrated q_scale/p_scale (#17414)
Signed-off-by: mgoin <mgoin64@gmail.com>
2025-05-08 15:56:59 -04:00
8342e3abd1 [CI] Prune down lm-eval small tests (#17012)
Signed-off-by: mgoin <mgoin64@gmail.com>
2025-05-08 19:00:26 +00:00
a83a0f92b5 [Test] Attempt all TPU V1 tests, even if some of them fail. (#17334)
Signed-off-by: Yarong Mu <ymu@google.com>
2025-05-08 17:20:54 +00:00
226a4272cf [V1] Improve VLLM_ALLOW_INSECURE_SERIALIZATION logging (#17860)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
2025-05-08 16:57:35 +00:00
ec54d73c31 [CI] Fix test_collective_rpc (#17858)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
2025-05-08 16:47:12 +00:00
a944f8ede7 [Misc] Delete LoRA-related redundancy code (#17841)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
2025-05-08 06:02:21 -07:00
015815fe01 [Bugfix] use_fast failing to be propagated to Qwen2-VL image processor (#17838)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-05-08 05:39:21 -07:00
e4ca6e3a99 Fix transient dependency error in docs build (#17848)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-05-08 03:42:03 -07:00
53d0cb7423 [Misc] add chatbox integration (#17828)
Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>
2025-05-08 10:05:26 +00:00
f50dcb7c21 [Easy] Eliminate c10::optional usage in vllm/csrc (#17819) 2025-05-08 03:05:10 -07:00
a1e19b635d [Doc] Fix a typo in the file name (#17836)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-05-08 18:04:18 +08:00
bb239a730f [Bugfix] Fix quark fp8 format loading on AMD GPUs (#12612)
Signed-off-by: Felix Marty <felmarty@amd.com>
Signed-off-by: kewang2 <kewang2@amd.com>
Co-authored-by: kewang2 <kewang2@amd.com>
2025-05-08 02:53:53 -07:00
a463555dee [TPU] Fix the test_sampler (#17820) 2025-05-08 05:51:33 -04:00
ca04b97c93 [Bugfix] Fix tool call template validation for Mistral models (#17644)
Signed-off-by: Rick Yuan <yuan821120@gmail.com>
Signed-off-by: RIck Yuan <yuan821120@gmail.com>
Co-authored-by: Aaron Pham <Aaronpham0103@gmail.com>
2025-05-08 09:47:19 +00:00
0a9bbaa104 [Misc] support model prefix & add deepseek vl2 tiny fused moe config (#17763)
Signed-off-by: 唯勤 <xsank.mz@alibaba-inc.com>
Co-authored-by: 唯勤 <xsank.mz@alibaba-inc.com>
2025-05-08 07:50:22 +00:00
39956efb3f [Bugfix] Fix bad words for Mistral models (#17753)
Signed-off-by: Qiong Zhou Huang <qiong@phonic.co>
2025-05-07 23:32:10 -07:00
597051e56f [Qwen3]add qwen3-235b-bf16 fused moe config on A100 (#17715) 2025-05-07 23:09:32 -07:00
96722aa81d [Frontend] Chat template fallbacks for multimodal models (#17805)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-05-07 23:05:54 -07:00
843b222723 [Hardware][Intel-Gaudi] Support Automatic Prefix Caching on HPU (#17648)
Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai>
2025-05-07 22:37:03 -07:00
e515668edf [Hardware][Power] Enable compressed tensor W8A8 INT8 quantization for POWER (#17153)
Signed-off-by: Akash Kaothalkar <akash.kaothalkar@ibm.com>
Co-authored-by: Akash Kaothalkar <akash.kaothalkar@ibm.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
2025-05-07 22:35:03 -07:00
5a499e70d5 [Kernel][Hardware][AMD] Bf16 mfma opt for ROCm skinny GEMMs (#17071)
Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
Signed-off-by: charlifu <charlifu@amd.com>
Co-authored-by: charlifu <charlifu@amd.com>
2025-05-07 22:34:49 -07:00
6930a41116 [V1] Add VLLM_ALLOW_INSECURE_SERIALIZATION env var (#17490)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
2025-05-08 13:34:02 +08:00
998eea4a0e Only log non-default CLI args for online serving (#17803)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-05-07 22:33:29 -07:00
c747d84576 [Installation] OpenTelemetry version update (#17771)
Signed-off-by: Mikhail Podvitskii <podvitskiymichael@gmail.com>
2025-05-07 22:32:49 -07:00
b2da14a05a Improve exception reporting in MP engine (#17800)
Signed-off-by: Vadim Markovtsev <vadim@poolside.ai>
2025-05-08 05:32:39 +00:00
7ea2adb802 [Core] Support full cuda graph in v1 (#16072)
Signed-off-by: Chanh Nguyen <cnguyen@linkedin.com>
Co-authored-by: Chanh Nguyen <cnguyen@linkedin.com>
2025-05-07 22:30:15 -07:00
3d13ca0e24 [BugFix] Fix --disable-log-stats in V1 server mode (#17600)
Signed-off-by: Nick Hill <nhill@redhat.com>
2025-05-08 04:08:15 +00:00
66ab3b13c9 Don't call the venv vllm (#17810)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-05-08 04:06:39 +00:00
a8238bbdb0 [Chore][Doc] uses model id determined from OpenAI client (#17815)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
2025-05-08 01:48:57 +00:00
d43f914d42 [Core][Feature] Input metadata dump on crash (#13407)
Signed-off-by: Wallas Santos <wallashss@ibm.com>
2025-05-07 22:15:09 +00:00
ed5272cf21 [BugFix] Avoid secondary missing MultiprocExecutor.workers error (#17811)
Signed-off-by: Nick Hill <nhill@redhat.com>
2025-05-07 21:55:04 +00:00
c20ef40fd0 [Hardware][TPU][V1] Multi-LoRA implementation for the V1 TPU backend (#14238)
Signed-off-by: Akshat Tripathi <akshat@krai.ai>
Signed-off-by: Chengji Yao <chengjiyao@google.com>
Co-authored-by: Chengji Yao <chengjiyao@google.com>
2025-05-07 16:28:47 -04:00
db593aa67f [Quantization] Quark MXFP4 format loading (#16943) 2025-05-07 15:05:05 -04:00
f98e307588 [Bugfix] Fix missing lora name mapping for lora without prefix (#17793)
Signed-off-by: Isotr0py <2037008807@qq.com>
2025-05-07 16:17:12 +00:00
646a31e51e Fix and simplify deprecated=True CLI kwarg (#17781)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-05-07 16:51:06 +01:00
be8ff88e66 [Bugfix] Fix Video IO error for short video (#17791)
Signed-off-by: Isotr0py <2037008807@qq.com>
2025-05-07 15:36:06 +00:00
1a6af1453d Only depend on importlib-metadata for Python < 3.10 (#17776)
Signed-off-by: Christian Heimes <christian@python.org>
2025-05-07 07:51:06 -07:00
32aa74c09c [ROCm][FP8][Kernel] FP8 quantization fused into Custom Paged Attention (#17139)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
2025-05-07 07:12:35 -07:00
7377dd0307 [doc] update the issue link (#17782)
Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>
2025-05-07 20:29:05 +08:00
98c89e16ff Make key optional for rotary embedding (#17566)
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
2025-05-07 00:11:46 -07:00
324a3119b0 Fix test_memory_usage_no_spec (#17754)
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
2025-05-07 00:10:33 -07:00
8a15c2603a [Frontend] Add missing chat templates for various MLLMs (#17758)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-05-07 00:10:01 -07:00
043e4c4955 Add NeuronxDistributedInference support, Speculative Decoding, Dynamic on-device sampling (#16357)
Signed-off-by: Satyajith Chilappagari <satchill@amazon.com>
Co-authored-by: Aaron Dou <yzdou@amazon.com>
Co-authored-by: Shashwat Srijan <sssrijan@amazon.com>
Co-authored-by: Chongming Ni <chongmni@amazon.com>
Co-authored-by: Amulya Ballakur <amulyaab@amazon.com>
Co-authored-by: Patrick Lange <patlange@amazon.com>
Co-authored-by: Elaine Zhao <elaineyz@amazon.com>
Co-authored-by: Lin Lin Pan <tailinpa@amazon.com>
Co-authored-by: Navyadhara Gogineni <navyadha@amazon.com>
Co-authored-by: Yishan McNabb <yishanm@amazon.com>
Co-authored-by: Mrinal Shukla <181322398+mrinalks@users.noreply.github.com>
2025-05-07 00:07:30 -07:00
ba7703e659 [Misc] Remove qlora_adapter_name_or_path (#17699)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
2025-05-06 23:10:37 -07:00
f80ae5bdcf [Kernel] Use fused rmsnorm for some models like qwen3 series (#17735)
Signed-off-by: evian <eviantai@u.nus.edu>
Co-authored-by: evian <eviantai@u.nus.edu>
2025-05-06 23:10:02 -07:00
1a45a61387 [Kernel] GGUF MoeVec kernel (#16780)
Signed-off-by: SzymonOzog <szymon.ozog@aleph-alpha.com>
Signed-off-by: SzymonOzog <szymon.ozog@gmail.com>
Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Isotr0py <2037008807@qq.com>
2025-05-06 23:07:23 -07:00
c3e9d5060e [Misc] Use apply_rotary_emb from vllm_flash_attn for Qwen2-VL vision RoPE (#17726)
Signed-off-by: Isotr0py <2037008807@qq.com>
2025-05-07 04:51:33 +00:00
822de7fb94 [Misc] Split model loader (#17712)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
2025-05-07 12:42:26 +08:00
8d84d836d1 [BugFix][Spec Decode] Fix hidden size mismatch between target and eagle head (#17740)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-05-06 19:51:26 -07:00
950b71186f Replace lm-eval bash script with pytest and use enforce_eager for faster CI (#17717)
Signed-off-by: mgoin <mgoin64@gmail.com>
2025-05-06 18:00:10 -07:00
e50a1f1a9c [TPU] Add kernel test for moe_pallas (#17496)
Signed-off-by: Michael Goin <mgoin64@gmail.com>
2025-05-06 17:59:57 -07:00
a17cef70ea Removed unused marlin cuda code (#17684)
Signed-off-by: mgoin <mgoin64@gmail.com>
2025-05-06 17:59:47 -07:00
18dd5e01f2 [Model] Mamba2 causal conv1d Refactor to Split Prefill and Decode Requests for Corresponding Kernels (#17146)
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
2025-05-06 17:59:30 -07:00
6de3e13413 Add logging for torch nightly version (#17669)
Signed-off-by: Yang Wang <elainewy@meta.com>
2025-05-07 00:45:51 +00:00
ed3a1d2106 [ROCm] fix num_stages for default moe config to avoid triton OutOfResource error (#17744)
Signed-off-by: Hongxia Yang <hongxia.yang@amd.com>
2025-05-07 00:39:48 +00:00
022afbeb4e Fix doc build performance (#17748)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-05-07 00:36:41 +00:00
2f925e5777 [Kernel] Unified Triton kernel that doesn't distinguish between prefill + decode (#16828)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
2025-05-06 18:21:48 -04:00
de906b95f9 [Bugfix] Fix for the condition to accept empty encoder inputs for mllama (#17732)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
2025-05-06 19:59:06 +00:00
d456aea71f [Misc] Add Next Edit Prediction (NEP) datasets support in benchmark_serving.py (#16839)
Signed-off-by: dtransposed <damian@damian-ml-machine.europe-west3-b.c.jetbrains-grazie.internal>
Signed-off-by: dtransposed <>
Co-authored-by: dtransposed <damian@damian-ml-machine.europe-west3-b.c.jetbrains-grazie.internal>
2025-05-06 15:38:45 -04:00
621ca2c0ab [TPU] Increase block size and reset block shapes (#16458) 2025-05-06 13:55:04 -04:00
6115b11582 Make right sidebar more readable in "Supported Models" (#17723)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-05-06 16:48:26 +00:00
5b8c390747 [Bugfix] Fix modality limits in vision language example (#17721)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-05-06 16:12:28 +00:00
7525d5f3d5 [doc] Add RAG Integration example (#17692)
Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>
2025-05-06 16:10:23 +00:00
aabcd2cae3 [v1] Introduce KVCacheBlocks as interface between Scheduler and KVCacheManager (#17479)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
2025-05-06 08:50:34 -07:00
0d115460a7 [Docs] Use gh-file to add links to tool_calling.md (#17709)
Signed-off-by: windsonsea <haifeng.yao@daocloud.io>
2025-05-06 15:27:19 +00:00
175bda67a1 [Feat] Add deprecated=True to CLI args (#17426)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
2025-05-06 08:11:27 -07:00
cba31c47c4 [v1] AttentionMetadata for each layer (#17394)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
2025-05-06 07:58:37 -07:00
a6fed02068 [V1][PP] Support PP for MultiprocExecutor (#14219)
Signed-off-by: jiang1.li <jiang1.li@intel.com>
Signed-off-by: jiang.li <jiang1.li@intel.com>
2025-05-06 07:58:05 -07:00
d419aa5dc4 [V1] Enable TPU V1 backend by default (#17673)
Signed-off-by: mgoin <mgoin64@gmail.com>
2025-05-06 06:49:49 -07:00
f9bc5a0693 [Bugfix] Fix triton import with local TritonPlaceholder (#17446)
Signed-off-by: Mengqing Cao <cmq0113@163.com>
2025-05-06 17:53:09 +08:00
05e1f96419 Fix dockerfilegraph pre-commit hook (#17698)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-05-06 08:56:48 +00:00
6eae34533a [Misc] Fix ScalarType float4 naming (#17690)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
2025-05-06 01:07:15 -07:00
63ced7b43f [Doc] Update notes for H2O-VL and Gemma3 (#17219)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-05-06 07:51:02 +00:00
dc47ba32f8 [Bugfix] Fixed prompt length for random dataset (#17408)
Signed-off-by: Mikhail Podvitskii <podvitskiymichael@gmail.com>
2025-05-06 07:00:08 +00:00
edbf2d609e [easy] Fix logspam on PiecewiseBackend errors (#17138)
Signed-off-by: rzou <zou3519@gmail.com>
2025-05-05 23:46:11 -07:00
999328be0d [Model] Add GraniteMoeHybrid 4.0 model (#17497)
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Stanislaw Wozniak <stw@zurich.ibm.com>
Co-authored-by: Thomas Ortner <boh@zurich.ibm.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: Tyler Michael Smith <tysmith@redhat.com>
2025-05-06 12:00:31 +08:00
98834fefaa Update nm to rht in doc links + refine fp8 doc (#17678)
Signed-off-by: mgoin <mgoin64@gmail.com>
2025-05-06 00:41:14 +00:00
90bd2ae172 [Bugfix] LoRA - Retire unused maxnreg LoRA kernel argument (#17677) 2025-05-05 17:34:29 -07:00
5941e0b7ea [TPU][V1] Add support for top-logprobs (#17072)
Signed-off-by: NickLucche <nlucches@redhat.com>
2025-05-05 14:20:15 -07:00
9765940824 [TPU] Enable gemma3-27b with TP>1 on multi-chips. (#17335)
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
2025-05-05 14:19:58 -07:00
5ea5c514da [BugFix] Increase timeout for startup failure test (#17642)
Signed-off-by: Nick Hill <nhill@redhat.com>
2025-05-05 20:53:19 +00:00
d3efde8176 [Benchmarks] Remove invalid option under V1 engine (#17651)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
2025-05-05 16:30:22 -04:00
aea302be6c Use git-path commit in hook (#17616)
Signed-off-by: Thomas J. Fan <thomasjpfan@gmail.com>
2025-05-05 17:55:32 +00:00
cc05b90d86 [Doc] Fix broken cuda installation doc rendering (#17654)
Signed-off-by: Isotr0py <2037008807@qq.com>
2025-05-05 17:52:40 +00:00
1d0c9d6b2d [Kernel] some optimizations for dense marlin and moe marlin (#16850)
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
2025-05-05 09:39:30 -07:00
f62cad6431 [Build/CI] Upgrade CUTLASS to 3.9.2 (#17641)
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
2025-05-04 19:23:17 -07:00
5394ad7387 [Bugfix] fix KeyError on top logprobs are special tokens (#17637)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
2025-05-04 19:22:35 -07:00
68e1ee0072 [Bugfix][Easy] Fix whitespace in shm_broadcast.py logging (#17635)
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
2025-05-04 19:20:19 -07:00
2858830c39 [Bugfix] Prioritize dtype in root config before checking text config (#17629)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-05-04 12:43:05 +00:00
d6484ef3c3 Add full API docs and improve the UX of navigating them (#17485)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-05-03 19:42:43 -07:00
46fae69cf0 [Misc] V0 fallback for --enable-prompt-embeds (#17615)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-05-03 22:59:24 +00:00
f66f1e0fa3 [Bugfix] Fix broken Qwen2.5-omni tests (#17613)
Signed-off-by: Isotr0py <2037008807@qq.com>
2025-05-03 17:08:14 +00:00
887d7af882 [Core] Gate prompt_embeds behind a feature flag (#17607)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-05-04 00:19:20 +08:00
a92842454c [Bugfix][ROCm] Using device_type because on ROCm the API is still torch.cuda (#17601)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
2025-05-02 22:25:47 -07:00
c8386fa61d [Build/CI] Upgrade CUTLASS to 3.9.1 (#17602)
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
2025-05-02 22:25:14 -07:00
87baebebd8 [Frontend][TPU] Add TPU default max-num-batched-tokens based on device name (#17508)
Signed-off-by: Chenyaaang <chenyangli@google.com>
2025-05-02 21:42:44 -07:00
e3d0a1d190 [Quantizaton] [AMD] Add support for running DeepSeek int8 w8a8 MoE on ROCm (#17558)
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
2025-05-02 21:41:10 -07:00
d47b605eca Update test requirements to CUDA 12.8 (#17576)
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
2025-05-02 21:40:15 -07:00
22c6f6397f [Neuron][Build] Require setuptools >= 77.0.3 for PEP 639 (#17603)
Signed-off-by: Liangfu Chen <liangfc@amazon.com>
2025-05-03 02:41:59 +00:00
3ec97e2cc5 [release] Add command to clean up Docker containers/images in TPU release machine (#17606) 2025-05-02 18:54:34 -07:00
9b103a1d76 fix typo in logging (#17605) 2025-05-02 18:04:40 -07:00
b90b0852e9 [easy] Print number of needed GPUs in skip message (#17594)
Signed-off-by: rzou <zou3519@gmail.com>
2025-05-02 15:27:43 -07:00
9352cdb56d [Hardware][AMD] Improve OAM device ID + llama4 Maverick MOE tuning (#16263)
Signed-off-by: Lu Fang <lufang@fb.com>
Co-authored-by: Lu Fang <lufang@fb.com>
2025-05-02 19:44:19 +00:00
182f40ea8b Add NVIDIA TensorRT Model Optimizer in vLLM documentation (#17561) 2025-05-02 11:36:46 -07:00
3e887d2e0c permute/unpermute kernel for moe optimization (#14568)
Signed-off-by: Caleb_Du <Caleb_Du@zju.edu.cn>
2025-05-02 11:31:55 -07:00
0f87d8f7b2 [BugFix][Attention] Fix sliding window attention in V1 giving incorrect results (#17574)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
2025-05-02 11:01:38 -07:00
4c33d67321 [Bugfix] fix tmp_out and exp_sums dimensions (#17438)
Signed-off-by: Hui Liu <96135754+hliuca@users.noreply.github.com>
2025-05-02 16:44:07 +00:00
cb234955df [Misc] Clean up input processing (#17582)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-05-02 08:11:53 -07:00
3a500cd0b6 [doc] miss result (#17589)
Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>
2025-05-02 07:04:49 -07:00
868c546da4 Support W8A8 INT8 MoE for compressed-tensors (#16745)
Signed-off-by: mgoin <mgoin64@gmail.com>
2025-05-02 10:03:32 -04:00
99404f53c7 [Security] Fix image hash collision (#17378)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-05-02 08:36:39 -04:00
785d75a03b Automatically tell users that dict args must be valid JSON in CLI (#17577)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-05-02 05:24:55 -07:00
6d1479ca4b [doc] add the print result (#17584)
Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>
2025-05-02 05:24:45 -07:00
b8b0859b5c add more pytorch related tests for torch nightly (#17422)
Signed-off-by: Yang Wang <elainewy@meta.com>
2025-05-02 03:29:59 -07:00
d7543862bd [Misc] Rename assets for testing (#17575)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-05-02 03:29:25 -07:00
c777df79f7 [BugFix] Fix Memory Leak (#17567)
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
2025-05-02 01:07:03 -07:00
cc2a77d7f1 [Core] [Bugfix] Add Input Embeddings (#15428)
Signed-off-by: Andrew Sansom <andrew@protopia.ai>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: 临景 <linjing.yx@alibaba-inc.com>
Co-authored-by: Bryce1010 <bryceyx@gmail.com>
Co-authored-by: Nan2018 <nan@protopia.ai>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-05-02 01:06:39 -07:00
9e2de9b9e9 [Bugifx] Remove TritonPlaceholder from sys.modules (#17317)
Signed-off-by: Isotr0py <2037008807@qq.com>
2025-05-02 00:45:01 -07:00
109e15a335 Add pt_load_map_location to allow loading to cuda (#16869)
Signed-off-by: Jerry Zhang <jerryzh168@gmail.com>
2025-05-01 23:23:42 -07:00
f192ca90e6 Fix PixtralHF missing spatial_merge_size (#17571)
Signed-off-by: mgoin <mgoin64@gmail.com>
2025-05-01 22:14:09 -07:00
f89d0e11bf [Misc] Continue refactoring model tests (#17573)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-05-01 22:06:08 -07:00
b4003d11fc Check if bitblas is installed during support check (#17572)
Signed-off-by: mgoin <mgoin64@gmail.com>
2025-05-02 04:32:54 +00:00
292fc59d61 [CI] Actually run tests/kv_transfer/test_disagg.py in CI (#17555)
Signed-off-by: mgoin <mgoin64@gmail.com>
2025-05-02 04:05:04 +00:00
afcb3f8863 [Attention] MLA move o_proj q_proj into cuda-graph region (#17484)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
2025-05-02 03:16:26 +00:00
afb12e4294 [Doc] note that not all unit tests pass on CPU platforms (#17554)
Signed-off-by: David Xia <david@davidxia.com>
2025-05-02 02:57:21 +00:00
24aebae177 [Bugfix] Disable gptq_bitblas for <SM80 to fix GPTQ on V100/T4 (#17541)
Signed-off-by: mgoin <mgoin64@gmail.com>
2025-05-01 17:59:35 -07:00
39c0813a7f [V1][Spec Decode] Apply torch.compile & cudagraph to EAGLE3 (#17504)
Signed-off-by: qizixi <qizixi@meta.com>
2025-05-01 16:19:30 -07:00
9b70e2b4c1 [Misc][Tools][Benchmark] Publish script to auto tune server parameters (#17207)
Signed-off-by: Chenyaaang <chenyangli@google.com>
2025-05-01 19:53:03 +00:00
173daac19d [Bug]change the position of cuda_graph_sizes in dataclasses (#17548)
Signed-off-by: CXIAAAAA <cxia0209@gmail.com>
2025-05-01 11:52:37 -07:00
04f2cfc894 Remove duplicate code from dbrx.py (#17550) 2025-05-01 11:51:58 -07:00
811a6c0972 [ROCM] Add gfx950 to the custom attention archs (#16034)
Signed-off-by: jpvillam <Juan.Villamizar@amd.com>
Signed-off-by: seungrokjung <seungrok.jung@amd.com>
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Co-authored-by: seungrokjung <seungrok.jung@amd.com>
Co-authored-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
2025-05-01 11:18:28 -07:00
9b1769dd9a [Bugfix] Fix lint error (#17547)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-05-01 11:12:19 -07:00
61c299f81f [Misc]add configurable cuda graph size (#17201)
Signed-off-by: CXIAAAAA <cxia0209@gmail.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-05-01 11:04:50 -07:00
4acfa3354a [ROCm] update installation guide to include build aiter from source instructions (#17542)
Signed-off-by: Hongxia Yang <hongxia.yang@amd.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
2025-05-01 11:01:28 -07:00
88c8304104 [Model] Refactor Ovis2 to support original tokenizer (#17537)
Signed-off-by: Isotr0py <2037008807@qq.com>
2025-05-01 11:00:53 -07:00
6768ff4a22 Move the last arguments in arg_utils.py to be in their final groups (#17531)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-05-01 10:31:44 -07:00
f2e7af9b86 [CI/Build] Remove awscli dependency (#17532)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-05-01 09:20:54 -07:00
7423cf0a9b [Misc] refactor example - cpu_offload_lmcache (#17460)
Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>
2025-05-01 15:05:24 +00:00
460a2b1100 [torch.compile] Add torch inductor pass for fusing silu_and_mul with subsequent scaled_fp8_quant operations (#10867)
Signed-off-by: Sage Moore <sage@neuralmagic.com>
2025-05-01 07:59:28 -07:00
28566d73b3 [ROCm] remove unsupported archs from rocm triton flash-attention supported list (#17536)
Signed-off-by: Hongxia Yang <hongxia.yang@amd.com>
2025-05-01 07:54:25 -07:00
98060b001d [Feature][Frontend]: Deprecate --enable-reasoning (#17452)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
2025-05-01 06:46:16 -07:00
f5a3c655b2 [FEAT] [ROCm]: Add Qwen/Qwen3-235B-A22B-FP8 TP4 triton fused moe config (#17535)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
2025-05-01 06:37:17 -07:00
7169f87ad0 [doc] add streamlit integration (#17522)
Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>
2025-05-01 13:34:02 +00:00
b74d888c63 Fix more broken speculative decode tests (#17450)
Signed-off-by: Huy Do <huydhn@gmail.com>
2025-05-01 06:05:58 -07:00
2007d4d54f [FEAT] [ROCm]: Add Qwen/Qwen3-30B-A3B-FP8 fused moe config for MI300X (#17530)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
2025-05-01 06:03:13 -07:00
48e925fab5 [Misc] Clean up test docstrings and names (#17521)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-05-01 05:19:32 -07:00
1903c0b8a3 [Frontend] Show progress bar for adding requests (#17525)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-05-01 05:15:32 -07:00
86a1f67a3b [Bugfix][Benchmarks] Allow benchmark of deepspeed-mii backend to select a model (#17285)
Signed-off-by: Teruaki Ishizaki <teruaki.ishizaki@ntt.com>
2025-05-01 11:54:51 +00:00
a257d9bccc Improve configs - ObservabilityConfig (#17453)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-05-01 03:52:05 -07:00
015069b017 [Misc] Optimize the Qwen3_ReasoningParser extract_reasoning_content (#17515)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
2025-05-01 03:29:01 -07:00
fbefc8a78d [Core] Enable IPv6 with vllm.utils.make_zmq_socket() (#16506)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
2025-05-01 09:38:18 +00:00
26bc4bbcd8 Avoid overwriting vllm_compile_cache.py (#17418)
Signed-off-by: Keyun Tong <tongkeyun@gmail.com>
2025-05-01 07:30:57 +00:00
3c3d767201 [BugFix] Fix mla cpu - missing 3 required positional arguments (#17494)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
2025-05-01 14:36:52 +08:00
13cf6b6236 [BugFix] fix speculative decoding memory leak when speculation is disabled (#15506)
Signed-off-by: Noah Yoshida <noahcy117@gmail.com>
2025-04-30 23:28:17 -07:00
90d0a54c4d [ROCm] Effort to reduce the number of environment variables in command line (#17229)
Signed-off-by: Hongxia Yang <hongxia.yang@amd.com>
2025-04-30 23:27:06 -07:00
7a0a146c54 [Build] Require setuptools >= 77.0.3 for PEP 639 (#17389)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
2025-04-30 23:25:36 -07:00
7ab643e425 FIxing the AMD test failures caused by PR#16457 (#17511)
Signed-off-by: Alexei V. Ivanov <alexei.ivanov@amd.com>
2025-04-30 23:23:07 -07:00
afb4429b4f [CI/Build] Reorganize models tests (#17459)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-04-30 23:03:08 -07:00
aa4502e7f3 [CI][Bugfix] Fix failing V1 Test due to missing 'cache_salt' arg (#17500)
Signed-off-by: mgoin <mgoin64@gmail.com>
2025-04-30 21:03:30 -07:00
17b4d85f63 [CI][TPU] Skip structured outputs+spec decode tests on TPU (#17510)
Signed-off-by: mgoin <mgoin64@gmail.com>
2025-04-30 20:36:20 -07:00
1144a8efe7 [Bugfix] Temporarily disable gptq_bitblas on ROCm (#17411)
Signed-off-by: Yan Cangang <nalanzeyu@gmail.com>
2025-04-30 19:51:45 -07:00
08fb5587b4 [Bugfix][ROCm] Fix import error on ROCm (#17495)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
2025-04-30 19:51:42 -07:00
dbc18e7816 [CI][TPU] Skip Multimodal test (#17488)
Signed-off-by: Siyuan Liu <lsiyuan@google.com>
2025-04-30 19:51:39 -07:00
02bd654846 [Misc] Rename Audios -> Audio in Qwen2audio Processing (#17507)
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
2025-04-30 19:51:36 -07:00
200bbf92e8 Bump Compressed Tensors version to 0.9.4 (#17478)
Signed-off-by: Rahul Tuli <rtuli@redhat.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
2025-04-30 15:24:45 -07:00
81ecf425f0 [v1][Spec Decode] Make sliding window compatible with eagle prefix caching (#17398)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
2025-04-30 18:25:53 +00:00
42d9a2c4c7 doc: fix bug report Github template formatting (#17486)
Signed-off-by: David Xia <david@davidxia.com>
2025-04-30 10:03:20 -07:00
2ac74d098e [doc] add install tips (#17373)
Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>
2025-04-30 17:02:41 +00:00
584f5fb4c6 [Bugfix][ROCm] Restrict ray version due to a breaking release (#17480)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
2025-04-30 09:59:06 -07:00
d586ddc691 [BugFix] Fix authorization of openai_transcription_client.py (#17321)
Signed-off-by: zh Wang <rekind133@outlook.com>
2025-04-30 09:51:05 -07:00
0b7e701dd4 [Docs] Update optimization.md doc (#17482)
Signed-off-by: mgoin <mgoin64@gmail.com>
2025-04-30 09:34:02 -07:00
947f2f5375 [V1] Allow turning off pickle fallback in vllm.v1.serial_utils (#17427)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
2025-04-30 16:10:54 +00:00
739e03b344 [Bugfix] Fixed mistral tokenizer path when pointing to file (#17457)
Signed-off-by: Pete Savage <psavage@redhat.com>
2025-04-30 08:08:37 -07:00
da4e7687b5 [Fix] Support passing args to logger (#17425)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
2025-04-30 08:06:58 -07:00
39317cf42b [Docs] Add command for running mypy tests from CI (#17475)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
2025-04-30 08:06:09 -07:00
2990cee95b [Feature] The Qwen3 reasoning parser supports guided decoding (#17466)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
2025-04-30 07:48:21 -07:00
0be6d05b5e [V1][Metrics] add support for kv event publishing (#16750)
Signed-off-by: alec-flowers <aflowers@nvidia.com>
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
Co-authored-by: Mark McLoughlin <markmc@redhat.com>
2025-04-30 07:44:45 -07:00
77073c77bc [Core] Prevent side-channel attacks via cache salting (#17045)
Signed-off-by: Marko Rosenmueller <5467316+dr75@users.noreply.github.com>
2025-04-30 20:27:21 +08:00
a7d5b016bd [TPU][V1][CI] Update regression test baseline for v6 CI (#17064)
Signed-off-by: NickLucche <nlucches@redhat.com>
2025-04-30 04:03:22 -07:00
d803786731 [V1][Bugfix]: vllm v1 verison metric num_gpu_blocks is None (#15755)
Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>
2025-04-30 18:20:39 +08:00
1534d389af [Misc] Remove deprecated files (#17447)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
2025-04-30 01:52:19 -07:00
ece5a8b0b6 Make the _apply_rotary_emb compatible with dynamo (#17435) 2025-04-30 07:52:48 +00:00
54072f315f [MODEL ADDITION] Ovis2 Model Addition (#15826)
Signed-off-by: Marco <121761685+mlinmg@users.noreply.github.com>
Signed-off-by: Isotr0py <2037008807@qq.com>
Signed-off-by: isotr0py <2037008807@qq.com>
Co-authored-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
2025-04-30 07:33:29 +00:00
be633fba0f [Bugfix] Fix AttributeError: 'State' object has no attribute 'engine_client' (#17434)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
2025-04-30 00:11:04 -07:00
ed6cfb90c8 [Hardware][Intel GPU] Upgrade to torch 2.7 (#17444)
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Co-authored-by: Qiming Zhang <qiming1.zhang@intel.com>
2025-04-30 00:03:58 -07:00
6ed9f6047e [Intel GPU] [CI]Fix XPU ci, setuptools >=80.0 have build issue (#17298)
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
2025-04-29 22:54:10 -07:00
a44c4f1d2f Support LoRA for Mistral3 (#17428)
Signed-off-by: mgoin <mgoin64@gmail.com>
2025-04-29 21:10:30 -07:00
88fcf00dda Fix some speculative decode tests with tl.dot (#17371)
Signed-off-by: Huy Do <huydhn@gmail.com>
2025-04-29 19:41:02 -07:00
d1f569b1b9 Fix call to logger.info_once (#17416)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-04-29 19:39:18 -07:00
13698db634 Improve configs - ModelConfig (#17130)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-04-30 10:38:22 +08:00
2c4f59afc3 Update PyTorch to 2.7.0 (#16859) 2025-04-29 19:08:04 -07:00
1c2bc7ead0 Truncation control for embedding models (#14776)
Signed-off-by: Gabriel Marinho <gmarinho@ibm.com>
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Co-authored-by: Max de Bayser <mbayser@br.ibm.com>
2025-04-30 09:24:57 +08:00
4055130a85 [release] Always git fetch all to get latest tag on TPU release (#17322) 2025-04-29 17:52:11 -07:00
34120f5acd [V1][Feature] Enable Speculative Decoding with Structured Outputs (#14702)
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
Signed-off-by: Benjamin Chislett <chislett.ben@gmail.com>
2025-04-30 00:02:10 +00:00
7489ec0bab Remove Bamba 9B from CI (#17407)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-04-29 21:10:31 +00:00
70788bdbdc [V1][Spec Decode] Apply torch.compile & cudagraph to EAGLE (#17211)
Signed-off-by: Bryan Lu <yuzhelu@amazon.com>
2025-04-29 21:10:00 +00:00
c9c1b59e59 Fix: Python package installation for opentelmetry (#17049)
Signed-off-by: Dilip Gowda Bhagavan <dilip.bhagavan@ibm.com>
2025-04-29 20:20:24 +00:00
0350809f3a Remove Falcon3 2x7B from CI (#17404)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-04-29 19:52:25 +00:00
a6977dbd15 Simplify (and fix) passing of guided decoding backend options (#17008)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-04-29 19:02:23 +00:00
2fa2a50bf9 [Bugfix] Fix Minicpm-O-int4 GPTQ model inference (#17397)
Signed-off-by: Isotr0py <2037008807@qq.com>
2025-04-29 18:21:42 +00:00
08e15defa9 [CI/Build] Add retry mechanism for add-apt-repository (#17107)
Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>
2025-04-29 10:40:52 -07:00
b37685afbb [CI] Uses Python 3.11 for TPU (#17359)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
2025-04-29 17:39:16 +00:00
792595b59d [TPU][V1][CI] Replace python3 setup.py develop with standard pip install --e on TPU (#17374)
Signed-off-by: NickLucche <nlucches@redhat.com>
2025-04-29 10:36:48 -07:00
0c1c788312 [Doc][Typo] Fixing label in new model requests link in overview.md (#17400) 2025-04-29 10:29:48 -07:00
56d64fbe30 [Docs] Propose a deprecation policy for the project (#17063)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
2025-04-29 10:29:44 -07:00
608968b7c5 Enabling multi-group kernel tests. (#17115)
Signed-off-by: Alexei V. Ivanov <alexei.ivanov@amd.com>
2025-04-29 10:27:27 -07:00
06ffc7e1d3 [Misc][ROCm] Exclude cutlass_mla_decode for ROCm build (#17289)
Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>
2025-04-29 10:26:42 -07:00
d3cf61b89b fix gemma3 results all zero (#17364)
Signed-off-by: mayuyuace <qiming1.zhang@intel.com>
2025-04-29 09:40:25 -07:00
a39203f99e [Bugfix] add qwen3 reasoning-parser fix content is None when disable … (#17369)
Signed-off-by: mofanke <mofanke@gmail.com>
2025-04-29 16:32:40 +00:00
24e6ad3f16 [V1] Remove num_input_tokens from attn_metadata (#17193)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
2025-04-29 09:28:41 -07:00
2ef5d106bb Improve literal dataclass field conversion to argparse argument (#17391)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-04-29 16:25:08 +00:00
0ed27ef66c Fix: Spelling of inference (#17387) 2025-04-29 09:23:39 -07:00
900edfa8d4 Transformers backend tweaks (#17365)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-04-29 09:08:03 -07:00
88ad9ec6b2 [Frontend] Support chat_template_kwargs in LLM.chat (#17356)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-04-29 22:03:35 +08:00
40896bdf3f pre-commit autoupdate (#17380)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-04-29 06:46:55 -07:00
00ee37efa2 [Bugfix] Clean up MiniMax-VL and fix processing (#17354)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-04-29 20:42:16 +08:00
890f104cdf [Doc] Fix QWen3MOE info (#17381)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
2025-04-29 12:38:32 +00:00
4a5e13149a Update docs requirements (#17379)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-04-29 11:35:47 +00:00
97cc8729f0 [Model] Ignore rotary embed load for Cohere model (#17319) 2025-04-29 00:30:40 -07:00
4464109219 [Build][Bugfix] Restrict setuptools version to <80 (#17320)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
2025-04-29 00:17:23 -07:00
193e78e35d [Fix] Documentation spacing in compilation config help text (#17342)
Signed-off-by: Zerohertz <ohg3417@gmail.com>
2025-04-29 00:16:17 -07:00
bdb2cddafc [Misc]Use a platform independent interface to obtain the device attributes (#17100) 2025-04-29 06:59:13 +00:00
ebb3930d28 [Misc] Move config fields to MultiModalConfig (#17343)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-04-29 06:37:21 +00:00
cde384cd92 [Model] support MiniMax-VL-01 model (#16328)
Signed-off-by: qingjun <qingjun@minimaxi.com>
2025-04-29 12:05:50 +08:00
96e06e3cb7 [Misc] Add a Jinja template to support Mistral3 function calling (#17195)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
2025-04-28 19:53:44 -07:00
17eb306fcc [Bugfix] Add contiguous call inside rope kernel wrapper (#17091)
Signed-off-by: 苏政渊 <suzhengyuan@moonshot.cn>
Co-authored-by: 苏政渊 <suzhengyuan@moonshot.cn>
2025-04-28 19:24:07 -07:00
165cb56329 Ignore '<string>' filepath (#17330)
Signed-off-by: rzou <zou3519@gmail.com>
2025-04-28 19:23:29 -07:00
d6da8a8ff2 [Bugfix] Fix numel() downcast in fused_layernorm_dynamic_per_token_quant.cu (#17316) 2025-04-28 19:23:18 -07:00
b4ac4fa04d [model] make llama4 compatible with pure dense layers (#17315)
Signed-off-by: Lucia Fang <fanglu@fb.com>
2025-04-29 10:22:22 +08:00
e136000595 [V1][Spec Decode] Make Eagle model arch config driven (#17323) 2025-04-29 10:22:02 +08:00
86d9fc29cb implement Structural Tag with Guidance backend (#17333)
Signed-off-by: Michal Moskal <michal@moskal.me>
2025-04-29 02:21:32 +00:00
506475de5f [Optim] Compute multimodal hash only once per item (#17314)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-04-29 09:40:35 +08:00
cfe4532093 [Benchmark] Add single turn MTBench to Serving Bench (#17202) 2025-04-28 16:46:15 -07:00
8fc88d63f1 [Model] Add tuned triton fused_moe configs for Qwen3Moe (#17328)
Signed-off-by: mgoin <mgoin64@gmail.com>
2025-04-28 15:20:24 -07:00
6e74fd4945 Support loading transformers models with named parameters (#16868)
Signed-off-by: Alex <alexwu@character.ai>
2025-04-28 23:15:58 +01:00
dcbac4cb4b [Model] Qwen3 Dense FP8 Compat Fixes (#17318)
Signed-off-by: simon-mo <xmo@berkeley.edu>
2025-04-28 14:12:01 -07:00
ed2462030f [Bugfix] Fix moe weight losing all extra attrs after process_weights_after_loading. (#16854)
Signed-off-by: charlifu <charlifu@amd.com>
2025-04-28 21:05:07 +00:00
cc5befbced [BugFix] Fix cascade attention - RuntimeError: scheduler_metadata must have shape (metadata_size) (#17283)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
2025-04-28 13:55:50 -07:00
2c89cd96a8 [Chore] cleanup license indicators in light of SPDX (#17259)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
Co-authored-by: Russell Bryant <rbryant@redhat.com>
2025-04-28 19:43:52 +00:00
1112 changed files with 60355 additions and 24613 deletions

View File

@ -8,12 +8,12 @@ import zipfile
# Note that we have 400 MiB quota, please use it wisely. # Note that we have 400 MiB quota, please use it wisely.
# See https://github.com/pypi/support/issues/3792 . # See https://github.com/pypi/support/issues/3792 .
# Please also sync the value with the one in Dockerfile. # Please also sync the value with the one in Dockerfile.
VLLM_MAX_SIZE_MB = int(os.environ.get('VLLM_MAX_SIZE_MB', 400)) VLLM_MAX_SIZE_MB = int(os.environ.get("VLLM_MAX_SIZE_MB", 400))
def print_top_10_largest_files(zip_file): def print_top_10_largest_files(zip_file):
"""Print the top 10 largest files in the given zip file.""" """Print the top 10 largest files in the given zip file."""
with zipfile.ZipFile(zip_file, 'r') as z: with zipfile.ZipFile(zip_file, "r") as z:
file_sizes = [(f, z.getinfo(f).file_size) for f in z.namelist()] file_sizes = [(f, z.getinfo(f).file_size) for f in z.namelist()]
file_sizes.sort(key=lambda x: x[1], reverse=True) file_sizes.sort(key=lambda x: x[1], reverse=True)
for f, size in file_sizes[:10]: for f, size in file_sizes[:10]:
@ -28,14 +28,18 @@ def check_wheel_size(directory):
wheel_path = os.path.join(root, file_name) wheel_path = os.path.join(root, file_name)
wheel_size_mb = os.path.getsize(wheel_path) / (1024 * 1024) wheel_size_mb = os.path.getsize(wheel_path) / (1024 * 1024)
if wheel_size_mb > VLLM_MAX_SIZE_MB: if wheel_size_mb > VLLM_MAX_SIZE_MB:
print(f"Not allowed: Wheel {wheel_path} is larger " print(
f"({wheel_size_mb:.2f} MB) than the limit " f"Not allowed: Wheel {wheel_path} is larger "
f"({VLLM_MAX_SIZE_MB} MB).") f"({wheel_size_mb:.2f} MB) than the limit "
f"({VLLM_MAX_SIZE_MB} MB)."
)
print_top_10_largest_files(wheel_path) print_top_10_largest_files(wheel_path)
return 1 return 1
else: else:
print(f"Wheel {wheel_path} is within the allowed size " print(
f"({wheel_size_mb:.2f} MB).") f"Wheel {wheel_path} is within the allowed size "
f"({wheel_size_mb:.2f} MB)."
)
return 0 return 0

View File

@ -22,5 +22,5 @@ with open("index.html", "w") as f:
print(f"Generated index.html for {args.wheel}") print(f"Generated index.html for {args.wheel}")
# cloudfront requires escaping the '+' character # cloudfront requires escaping the '+' character
f.write( f.write(
template.format(wheel=filename, template.format(wheel=filename, wheel_html_escaped=filename.replace("+", "%2B"))
wheel_html_escaped=filename.replace("+", "%2B"))) )

View File

@ -0,0 +1,11 @@
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m RedHatAI/Llama-3.2-1B-Instruct-FP8 -b "auto" -l 1319 -f 5 -t 1
model_name: "RedHatAI/Llama-3.2-1B-Instruct-FP8"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.335
- name: "exact_match,flexible-extract"
value: 0.323
limit: 1319
num_fewshot: 5

View File

@ -0,0 +1,11 @@
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m Qwen/Qwen2.5-1.5B-Instruct -b auto -l 1319 -f 5 -t 1
model_name: "Qwen/Qwen2.5-1.5B-Instruct"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.54
- name: "exact_match,flexible-extract"
value: 0.59
limit: 1319
num_fewshot: 5

View File

@ -0,0 +1,11 @@
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m RedHatAI/Qwen2.5-VL-3B-Instruct-FP8-Dynamic -b auto -l 1319 -f 5 -t 1
model_name: "RedHatAI/Qwen2.5-VL-3B-Instruct-FP8-Dynamic"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.47
- name: "exact_match,flexible-extract"
value: 0.64
limit: 1319
num_fewshot: 5

View File

@ -3,3 +3,4 @@ Meta-Llama-3-70B-Instruct.yaml
Mixtral-8x7B-Instruct-v0.1.yaml Mixtral-8x7B-Instruct-v0.1.yaml
Qwen2-57B-A14-Instruct.yaml Qwen2-57B-A14-Instruct.yaml
DeepSeek-V2-Lite-Chat.yaml DeepSeek-V2-Lite-Chat.yaml
Meta-Llama-3-8B-QQQ.yaml

View File

@ -1,10 +1,6 @@
Meta-Llama-3-8B-Instruct.yaml Qwen2.5-1.5B-Instruct.yaml
Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml
Qwen1.5-MoE-W4A16-compressed-tensors.yaml Qwen1.5-MoE-W4A16-compressed-tensors.yaml
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
Qwen2-1.5B-Instruct-FP8W8.yaml
Meta-Llama-3-8B-QQQ.yaml

View File

@ -0,0 +1,43 @@
# SPDX-License-Identifier: Apache-2.0
from pathlib import Path
import pytest
def pytest_addoption(parser):
parser.addoption(
"--config-list-file",
action="store",
help="Path to the file listing model config YAMLs (one per line)",
)
parser.addoption(
"--tp-size",
action="store",
default="1",
help="Tensor parallel size to use for evaluation",
)
@pytest.fixture(scope="session")
def config_list_file(pytestconfig, config_dir):
rel_path = pytestconfig.getoption("--config-list-file")
return config_dir / rel_path
@pytest.fixture(scope="session")
def tp_size(pytestconfig):
return pytestconfig.getoption("--tp-size")
def pytest_generate_tests(metafunc):
if "config_filename" in metafunc.fixturenames:
rel_path = metafunc.config.getoption("--config-list-file")
config_list_file = Path(rel_path).resolve()
config_dir = config_list_file.parent
with open(config_list_file, encoding="utf-8") as f:
configs = [
config_dir / line.strip()
for line in f
if line.strip() and not line.startswith("#")
]
metafunc.parametrize("config_filename", configs)

View File

@ -1,59 +0,0 @@
#!/bin/bash
usage() {
echo``
echo "Runs lm eval harness on GSM8k using vllm and compares to "
echo "precomputed baseline (measured by HF transformers.)"
echo
echo "usage: ${0} <options>"
echo
echo " -c - path to the test data config (e.g. configs/small-models.txt)"
echo " -t - tensor parallel size"
echo
}
SUCCESS=0
while getopts "c:t:" OPT; do
case ${OPT} in
c )
CONFIG="$OPTARG"
;;
t )
TP_SIZE="$OPTARG"
;;
\? )
usage
exit 1
;;
esac
done
# Parse list of configs.
IFS=$'\n' read -d '' -r -a MODEL_CONFIGS < "$CONFIG"
for MODEL_CONFIG in "${MODEL_CONFIGS[@]}"
do
LOCAL_SUCCESS=0
echo "=== RUNNING MODEL: $MODEL_CONFIG WITH TP SIZE: $TP_SIZE==="
export LM_EVAL_TEST_DATA_FILE=$PWD/configs/${MODEL_CONFIG}
export LM_EVAL_TP_SIZE=$TP_SIZE
pytest -s test_lm_eval_correctness.py || LOCAL_SUCCESS=$?
if [[ $LOCAL_SUCCESS == 0 ]]; then
echo "=== PASSED MODEL: ${MODEL_CONFIG} ==="
else
echo "=== FAILED MODEL: ${MODEL_CONFIG} ==="
fi
SUCCESS=$((SUCCESS + LOCAL_SUCCESS))
done
if [ "${SUCCESS}" -eq "0" ]; then
exit 0
else
exit 1
fi

View File

@ -3,67 +3,52 @@
LM eval harness on model to compare vs HF baseline computed offline. LM eval harness on model to compare vs HF baseline computed offline.
Configs are found in configs/$MODEL.yaml Configs are found in configs/$MODEL.yaml
* export LM_EVAL_TEST_DATA_FILE=configs/Meta-Llama-3-70B-Instruct.yaml pytest -s -v test_lm_eval_correctness.py \
* export LM_EVAL_TP_SIZE=4 --config-list-file=configs/models-small.txt \
* pytest -s test_lm_eval_correctness.py --tp-size=1
""" """
import os
from pathlib import Path
import lm_eval import lm_eval
import numpy import numpy as np
import pytest
import yaml import yaml
RTOL = 0.08 RTOL = 0.08
TEST_DATA_FILE = os.environ.get(
"LM_EVAL_TEST_DATA_FILE",
".buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct.yaml")
TP_SIZE = os.environ.get("LM_EVAL_TP_SIZE", 1)
def launch_lm_eval(eval_config): def launch_lm_eval(eval_config, tp_size):
trust_remote_code = eval_config.get('trust_remote_code', False) trust_remote_code = eval_config.get("trust_remote_code", False)
model_args = (
model_args = f"pretrained={eval_config['model_name']}," \ f"pretrained={eval_config['model_name']},"
f"tensor_parallel_size={TP_SIZE}," \ f"tensor_parallel_size={tp_size},"
f"add_bos_token=true," \ f"enforce_eager=true,"
f"trust_remote_code={trust_remote_code}" f"add_bos_token=true,"
f"trust_remote_code={trust_remote_code}"
)
results = lm_eval.simple_evaluate( results = lm_eval.simple_evaluate(
model="vllm", model="vllm",
model_args=model_args, model_args=model_args,
tasks=[task["name"] for task in eval_config["tasks"]], tasks=[task["name"] for task in eval_config["tasks"]],
num_fewshot=eval_config["num_fewshot"], num_fewshot=eval_config["num_fewshot"],
limit=eval_config["limit"], limit=eval_config["limit"],
batch_size="auto") batch_size="auto",
)
return results return results
def test_lm_eval_correctness(): def test_lm_eval_correctness_param(config_filename, tp_size):
eval_config = yaml.safe_load( eval_config = yaml.safe_load(config_filename.read_text(encoding="utf-8"))
Path(TEST_DATA_FILE).read_text(encoding="utf-8"))
if eval_config[ results = launch_lm_eval(eval_config, tp_size)
"model_name"] == "nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform": #noqa: E501
pytest.skip("FBGEMM is currently failing on main.")
# Launch eval requests.
results = launch_lm_eval(eval_config)
# Confirm scores match ground truth.
success = True success = True
for task in eval_config["tasks"]: for task in eval_config["tasks"]:
for metric in task["metrics"]: for metric in task["metrics"]:
ground_truth = metric["value"] ground_truth = metric["value"]
measured_value = results["results"][task["name"]][metric["name"]] measured_value = results["results"][task["name"]][metric["name"]]
print(f'{task["name"]} | {metric["name"]}: ' print(
f'ground_truth={ground_truth} | measured={measured_value}') f"{task['name']} | {metric['name']}: "
success = success and numpy.isclose( f"ground_truth={ground_truth} | measured={measured_value}"
ground_truth, measured_value, rtol=RTOL) )
success = success and np.isclose(ground_truth, measured_value, rtol=RTOL)
# Assert at the end, print all scores even on failure for debugging.
assert success assert success

View File

@ -65,18 +65,18 @@ def read_markdown(file):
def results_to_json(latency, throughput, serving): def results_to_json(latency, throughput, serving):
return json.dumps({ return json.dumps(
'latency': latency.to_dict(), {
'throughput': throughput.to_dict(), "latency": latency.to_dict(),
'serving': serving.to_dict() "throughput": throughput.to_dict(),
}) "serving": serving.to_dict(),
}
)
if __name__ == "__main__": if __name__ == "__main__":
# collect results # collect results
for test_file in results_folder.glob("*.json"): for test_file in results_folder.glob("*.json"):
with open(test_file) as f: with open(test_file) as f:
raw_result = json.loads(f.read()) raw_result = json.loads(f.read())
@ -120,7 +120,8 @@ if __name__ == "__main__":
for perc in [10, 25, 50, 75, 90, 99]: for perc in [10, 25, 50, 75, 90, 99]:
# Multiply 1000 to convert the time unit from s to ms # Multiply 1000 to convert the time unit from s to ms
raw_result.update( raw_result.update(
{f"P{perc}": 1000 * raw_result["percentiles"][str(perc)]}) {f"P{perc}": 1000 * raw_result["percentiles"][str(perc)]}
)
raw_result["avg_latency"] = raw_result["avg_latency"] * 1000 raw_result["avg_latency"] = raw_result["avg_latency"] * 1000
# add the result to raw_result # add the result to raw_result
@ -153,26 +154,27 @@ if __name__ == "__main__":
serving_results = pd.DataFrame.from_dict(serving_results) serving_results = pd.DataFrame.from_dict(serving_results)
throughput_results = pd.DataFrame.from_dict(throughput_results) throughput_results = pd.DataFrame.from_dict(throughput_results)
raw_results_json = results_to_json(latency_results, throughput_results, raw_results_json = results_to_json(
serving_results) latency_results, throughput_results, serving_results
)
# remapping the key, for visualization purpose # remapping the key, for visualization purpose
if not latency_results.empty: if not latency_results.empty:
latency_results = latency_results[list( latency_results = latency_results[list(latency_column_mapping.keys())].rename(
latency_column_mapping.keys())].rename( columns=latency_column_mapping
columns=latency_column_mapping) )
if not serving_results.empty: if not serving_results.empty:
serving_results = serving_results[list( serving_results = serving_results[list(serving_column_mapping.keys())].rename(
serving_column_mapping.keys())].rename( columns=serving_column_mapping
columns=serving_column_mapping) )
if not throughput_results.empty: if not throughput_results.empty:
throughput_results = throughput_results[list( throughput_results = throughput_results[
throughput_results_column_mapping.keys())].rename( list(throughput_results_column_mapping.keys())
columns=throughput_results_column_mapping) ].rename(columns=throughput_results_column_mapping)
processed_results_json = results_to_json(latency_results, processed_results_json = results_to_json(
throughput_results, latency_results, throughput_results, serving_results
serving_results) )
for df in [latency_results, serving_results, throughput_results]: for df in [latency_results, serving_results, throughput_results]:
if df.empty: if df.empty:
@ -184,38 +186,39 @@ if __name__ == "__main__":
# The GPUs sometimes come in format of "GPUTYPE\nGPUTYPE\n...", # The GPUs sometimes come in format of "GPUTYPE\nGPUTYPE\n...",
# we want to turn it into "8xGPUTYPE" # we want to turn it into "8xGPUTYPE"
df["GPU"] = df["GPU"].apply( df["GPU"] = df["GPU"].apply(
lambda x: f"{len(x.split('\n'))}x{x.split('\n')[0]}") lambda x: f"{len(x.split('\n'))}x{x.split('\n')[0]}"
)
# get markdown tables # get markdown tables
latency_md_table = tabulate(latency_results, latency_md_table = tabulate(
headers='keys', latency_results, headers="keys", tablefmt="pipe", showindex=False
tablefmt='pipe', )
showindex=False) serving_md_table = tabulate(
serving_md_table = tabulate(serving_results, serving_results, headers="keys", tablefmt="pipe", showindex=False
headers='keys', )
tablefmt='pipe', throughput_md_table = tabulate(
showindex=False) throughput_results, headers="keys", tablefmt="pipe", showindex=False
throughput_md_table = tabulate(throughput_results, )
headers='keys',
tablefmt='pipe',
showindex=False)
# document the result # document the result
with open(results_folder / "benchmark_results.md", "w") as f: with open(results_folder / "benchmark_results.md", "w") as f:
results = read_markdown(
results = read_markdown("../.buildkite/nightly-benchmarks/" + "../.buildkite/nightly-benchmarks/"
"performance-benchmarks-descriptions.md") + "performance-benchmarks-descriptions.md"
)
results = results.format( results = results.format(
latency_tests_markdown_table=latency_md_table, latency_tests_markdown_table=latency_md_table,
throughput_tests_markdown_table=throughput_md_table, throughput_tests_markdown_table=throughput_md_table,
serving_tests_markdown_table=serving_md_table, serving_tests_markdown_table=serving_md_table,
benchmarking_results_in_json_string=processed_results_json) benchmarking_results_in_json_string=processed_results_json,
)
f.write(results) f.write(results)
# document benchmarking results in json # document benchmarking results in json
with open(results_folder / "benchmark_results.json", "w") as f: with open(results_folder / "benchmark_results.json", "w") as f:
results = (
results = latency_results.to_dict( latency_results.to_dict(orient="records")
orient='records') + throughput_results.to_dict( + throughput_results.to_dict(orient="records")
orient='records') + serving_results.to_dict(orient='records') + serving_results.to_dict(orient="records")
)
f.write(json.dumps(results)) f.write(json.dumps(results))

View File

@ -14,15 +14,12 @@ def main(model, cachedir):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Download and save Hugging Face tokenizer") description="Download and save Hugging Face tokenizer"
parser.add_argument("--model", )
type=str, parser.add_argument("--model", type=str, required=True, help="Name of the model")
required=True, parser.add_argument(
help="Name of the model") "--cachedir", type=str, required=True, help="Directory to save the tokenizer"
parser.add_argument("--cachedir", )
type=str,
required=True,
help="Directory to save the tokenizer")
args = parser.parse_args() args = parser.parse_args()
main(args.model, args.cachedir) main(args.model, args.cachedir)

View File

@ -11,33 +11,33 @@ from tabulate import tabulate
def parse_arguments(): def parse_arguments():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description= description="Parse command line arguments for summary-nightly-results script."
'Parse command line arguments for summary-nightly-results script.') )
parser.add_argument('--results-folder', parser.add_argument(
type=str, "--results-folder",
required=True, type=str,
help='The folder where the results are stored.') required=True,
parser.add_argument('--description', help="The folder where the results are stored.",
type=str, )
required=True, parser.add_argument(
help='Description of the results.') "--description", type=str, required=True, help="Description of the results."
)
args = parser.parse_args() args = parser.parse_args()
return args return args
def get_perf(df, method, model, metric): def get_perf(df, method, model, metric):
means = [] means = []
for qps in [2, 4, 8, 16, "inf"]: for qps in [2, 4, 8, 16, "inf"]:
target = df['Test name'].str.contains(model) target = df["Test name"].str.contains(model)
target = target & df['Engine'].str.contains(method) target = target & df["Engine"].str.contains(method)
target = target & df['Test name'].str.contains("qps_" + str(qps)) target = target & df["Test name"].str.contains("qps_" + str(qps))
filtered_df = df[target] filtered_df = df[target]
if filtered_df.empty: if filtered_df.empty:
means.append(0.) means.append(0.0)
else: else:
means.append(filtered_df[metric].values[0]) means.append(filtered_df[metric].values[0])
@ -45,7 +45,6 @@ def get_perf(df, method, model, metric):
def get_perf_w_std(df, method, model, metric): def get_perf_w_std(df, method, model, metric):
if metric in ["TTFT", "ITL"]: if metric in ["TTFT", "ITL"]:
mean = get_perf(df, method, model, "Mean " + metric + " (ms)") mean = get_perf(df, method, model, "Mean " + metric + " (ms)")
mean = mean.tolist() mean = mean.tolist()
@ -60,7 +59,8 @@ def get_perf_w_std(df, method, model, metric):
else: else:
assert metric == "Tput" assert metric == "Tput"
mean = get_perf(df, method, model, "Input Tput (tok/s)") + get_perf( mean = get_perf(df, method, model, "Input Tput (tok/s)") + get_perf(
df, method, model, "Output Tput (tok/s)") df, method, model, "Output Tput (tok/s)"
)
mean = mean.tolist() mean = mean.tolist()
std = None std = None
@ -80,18 +80,17 @@ def main(args):
# generate markdown table # generate markdown table
df = pd.DataFrame.from_dict(results) df = pd.DataFrame.from_dict(results)
md_table = tabulate(df, headers='keys', tablefmt='pipe', showindex=False) md_table = tabulate(df, headers="keys", tablefmt="pipe", showindex=False)
with open(args.description) as f: with open(args.description) as f:
description = f.read() description = f.read()
description = description.format( description = description.format(nightly_results_benchmarking_table=md_table)
nightly_results_benchmarking_table=md_table)
with open("nightly_results.md", "w") as f: with open("nightly_results.md", "w") as f:
f.write(description) f.write(description)
if __name__ == '__main__': if __name__ == "__main__":
args = parse_arguments() args = parse_arguments()
main(args) main(args)

View File

@ -34,10 +34,8 @@ serving_column_mapping = {
} }
if __name__ == "__main__": if __name__ == "__main__":
# collect results # collect results
for test_file in results_folder.glob("*.json"): for test_file in results_folder.glob("*.json"):
with open(test_file) as f: with open(test_file) as f:
raw_result = json.loads(f.read()) raw_result = json.loads(f.read())
@ -56,17 +54,16 @@ if __name__ == "__main__":
serving_results = pd.DataFrame.from_dict(serving_results) serving_results = pd.DataFrame.from_dict(serving_results)
if not serving_results.empty: if not serving_results.empty:
serving_results = serving_results[list( serving_results = serving_results[list(serving_column_mapping.keys())].rename(
serving_column_mapping.keys())].rename( columns=serving_column_mapping
columns=serving_column_mapping) )
serving_md_table_with_headers = tabulate(serving_results, serving_md_table_with_headers = tabulate(
headers='keys', serving_results, headers="keys", tablefmt="pipe", showindex=False
tablefmt='pipe', )
showindex=False)
# remove the first line of header # remove the first line of header
serving_md_table_lines = serving_md_table_with_headers.split('\n') serving_md_table_lines = serving_md_table_with_headers.split("\n")
serving_md_table_without_header = '\n'.join(serving_md_table_lines[2:]) serving_md_table_without_header = "\n".join(serving_md_table_lines[2:])
prefix = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") prefix = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
prefix = prefix + "_" + os.environ.get("CURRENT_LLM_SERVING_ENGINE") prefix = prefix + "_" + os.environ.get("CURRENT_LLM_SERVING_ENGINE")
@ -76,10 +73,9 @@ if __name__ == "__main__":
# document results with header. # document results with header.
# for those who wants to reproduce our benchmark. # for those who wants to reproduce our benchmark.
f.write(serving_md_table_with_headers) f.write(serving_md_table_with_headers)
f.write('\n') f.write("\n")
# document benchmarking results in json # document benchmarking results in json
with open(results_folder / f"{prefix}_nightly_results.json", "w") as f: with open(results_folder / f"{prefix}_nightly_results.json", "w") as f:
results = serving_results.to_dict(orient="records")
results = serving_results.to_dict(orient='records')
f.write(json.dumps(results)) f.write(json.dumps(results))

51
.buildkite/pyproject.toml Normal file
View File

@ -0,0 +1,51 @@
# This local pyproject file is part of the migration from yapf to ruff format.
# It uses the same core rules as the main pyproject.toml file, but with the
# following differences:
# - ruff line length is overridden to 88
# - deprecated typing ignores (UP006, UP035) have been removed
[tool.ruff]
line-length = 88
exclude = [
# External file, leaving license intact
"examples/other/fp8/quantizer/quantize.py",
"vllm/vllm_flash_attn/flash_attn_interface.pyi"
]
[tool.ruff.lint.per-file-ignores]
"vllm/third_party/**" = ["ALL"]
"vllm/version.py" = ["F401"]
"vllm/_version.py" = ["ALL"]
[tool.ruff.lint]
select = [
# pycodestyle
"E",
# Pyflakes
"F",
# pyupgrade
"UP",
# flake8-bugbear
"B",
# flake8-simplify
"SIM",
# isort
"I",
# flake8-logging-format
"G",
]
ignore = [
# star imports
"F405", "F403",
# lambda expression assignment
"E731",
# Loop control variable not used within loop body
"B007",
# f-string format
"UP032",
# Can remove once 3.10+ is the minimum Python version
"UP007",
]
[tool.ruff.format]
docstring-code-format = true

View File

@ -1,20 +1,20 @@
steps: steps:
- label: "Build wheel - CUDA 12.4" - label: "Build wheel - CUDA 12.8"
agents: agents:
queue: cpu_queue_postmerge queue: cpu_queue_postmerge
commands: commands:
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.4.0 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
- "mkdir artifacts" - "mkdir artifacts"
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
- "bash .buildkite/scripts/upload-wheels.sh" - "bash .buildkite/scripts/upload-wheels.sh"
env: env:
DOCKER_BUILDKIT: "1" DOCKER_BUILDKIT: "1"
- label: "Build wheel - CUDA 12.1" - label: "Build wheel - CUDA 12.6"
agents: agents:
queue: cpu_queue_postmerge queue: cpu_queue_postmerge
commands: commands:
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.1.0 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.6.3 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
- "mkdir artifacts" - "mkdir artifacts"
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
- "bash .buildkite/scripts/upload-wheels.sh" - "bash .buildkite/scripts/upload-wheels.sh"
@ -48,7 +48,7 @@ steps:
queue: cpu_queue_postmerge queue: cpu_queue_postmerge
commands: commands:
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.4.0 --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ." - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ."
- "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT" - "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT"
- label: "Build and publish TPU release image" - label: "Build and publish TPU release image"
@ -57,6 +57,8 @@ steps:
agents: agents:
queue: tpu_queue_postmerge queue: tpu_queue_postmerge
commands: commands:
- "yes | docker system prune -a"
- "git fetch --all"
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --tag vllm/vllm-tpu:nightly --tag vllm/vllm-tpu:$BUILDKITE_COMMIT --progress plain -f docker/Dockerfile.tpu ." - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --tag vllm/vllm-tpu:nightly --tag vllm/vllm-tpu:$BUILDKITE_COMMIT --progress plain -f docker/Dockerfile.tpu ."
- "docker push vllm/vllm-tpu:nightly" - "docker push vllm/vllm-tpu:nightly"
- "docker push vllm/vllm-tpu:$BUILDKITE_COMMIT" - "docker push vllm/vllm-tpu:$BUILDKITE_COMMIT"

View File

@ -3,6 +3,9 @@
# This script runs test inside the corresponding ROCm docker container. # This script runs test inside the corresponding ROCm docker container.
set -o pipefail set -o pipefail
# Export Python path
export PYTHONPATH=".."
# Print ROCm version # Print ROCm version
echo "--- Confirming Clean Initial State" echo "--- Confirming Clean Initial State"
while true; do while true; do
@ -74,38 +77,69 @@ HF_MOUNT="/root/.cache/huggingface"
commands=$@ commands=$@
echo "Commands:$commands" echo "Commands:$commands"
if [[ $commands == *"pytest -v -s basic_correctness/test_basic_correctness.py"* ]]; then
commands=${commands//"pytest -v -s basic_correctness/test_basic_correctness.py"/"VLLM_USE_TRITON_FLASH_ATTN=0 pytest -v -s basic_correctness/test_basic_correctness.py"}
fi
if [[ $commands == *"pytest -v -s models/test_registry.py"* ]]; then
commands=${commands//"pytest -v -s models/test_registry.py"/"pytest -v -s models/test_registry.py -k 'not BambaForCausalLM and not GritLM and not Mamba2ForCausalLM and not Zamba2ForCausalLM'"}
fi
if [[ $commands == *"VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2'"* ]]; then
commands=${commands//"VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2'"/"VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2 and not BambaForCausalLM and not Gemma2ForCausalLM and not Grok1ModelForCausalLM and not Zamba2ForCausalLM and not Gemma2Model and not GritLM'"}
fi
if [[ $commands == *"pytest -v -s compile/test_basic_correctness.py"* ]]; then
commands=${commands//"pytest -v -s compile/test_basic_correctness.py"/"VLLM_USE_TRITON_FLASH_ATTN=0 pytest -v -s compile/test_basic_correctness.py"}
fi
#ignore certain kernels tests #ignore certain kernels tests
if [[ $commands == *" kernels "* ]]; then if [[ $commands == *" kernels/core"* ]]; then
commands="${commands} \ commands="${commands} \
--ignore=kernels/test_attention_selector.py \ --ignore=kernels/core/test_fused_quant_layernorm.py \
--ignore=kernels/test_blocksparse_attention.py \ --ignore=kernels/core/test_permute_cols.py"
--ignore=kernels/test_causal_conv1d.py \ fi
--ignore=kernels/test_cutlass.py \
--ignore=kernels/test_encoder_decoder_attn.py \ if [[ $commands == *" kernels/attention"* ]]; then
--ignore=kernels/test_flash_attn.py \ commands="${commands} \
--ignore=kernels/test_flashinfer.py \ --ignore=kernels/attention/stest_attention_selector.py \
--ignore=kernels/test_int8_quant.py \ --ignore=kernels/attention/test_blocksparse_attention.py \
--ignore=kernels/test_machete_gemm.py \ --ignore=kernels/attention/test_encoder_decoder_attn.py \
--ignore=kernels/test_mamba_ssm.py \ --ignore=kernels/attention/test_attention_selector.py \
--ignore=kernels/test_marlin_gemm.py \ --ignore=kernels/attention/test_flash_attn.py \
--ignore=kernels/test_moe.py \ --ignore=kernels/attention/test_flashinfer.py \
--ignore=kernels/test_prefix_prefill.py \ --ignore=kernels/attention/test_prefix_prefill.py \
--ignore=kernels/test_rand.py \ --ignore=kernels/attention/test_cascade_flash_attn.py \
--ignore=kernels/test_sampler.py \ --ignore=kernels/attention/test_mha_attn.py \
--ignore=kernels/test_cascade_flash_attn.py \ --ignore=kernels/attention/test_lightning_attn.py \
--ignore=kernels/test_mamba_mixer2.py \ --ignore=kernels/attention/test_attention.py"
--ignore=kernels/test_aqlm.py \ fi
--ignore=kernels/test_machete_mm.py \
--ignore=kernels/test_mha_attn.py \ if [[ $commands == *" kernels/quantization"* ]]; then
--ignore=kernels/test_block_fp8.py \ commands="${commands} \
--ignore=kernels/test_cutlass_moe.py \ --ignore=kernels/quantization/test_int8_quant.py \
--ignore=kernels/test_mamba_ssm_ssd.py \ --ignore=kernels/quantization/test_aqlm.py \
--ignore=kernels/test_attention.py \ --ignore=kernels/quantization/test_machete_mm.py \
--ignore=kernels/test_block_int8.py \ --ignore=kernels/quantization/test_block_fp8.py \
--ignore=kernels/test_fused_quant_layernorm.py \ --ignore=kernels/quantization/test_block_int8.py \
--ignore=kernels/test_int8_kernel.py \ --ignore=kernels/quantization/test_marlin_gemm.py \
--ignore=kernels/test_triton_moe_ptpc_fp8.py \ --ignore=kernels/quantization/test_cutlass_scaled_mm.py \
--ignore=kernels/test_permute_cols.py" --ignore=kernels/quantization/test_int8_kernel.py"
fi
if [[ $commands == *" kernels/mamba"* ]]; then
commands="${commands} \
--ignore=kernels/mamba/test_mamba_mixer2.py \
--ignore=kernels/mamba/test_causal_conv1d.py \
--ignore=kernels/mamba/test_mamba_ssm_ssd.py"
fi
if [[ $commands == *" kernels/moe"* ]]; then
commands="${commands} \
--ignore=kernels/moe/test_moe.py \
--ignore=kernels/moe/test_cutlass_moe.py \
--ignore=kernels/moe/test_triton_moe_ptpc_fp8.py"
fi fi
#ignore certain Entrypoints/openai tests #ignore certain Entrypoints/openai tests
@ -147,6 +181,8 @@ fi
PARALLEL_JOB_COUNT=8 PARALLEL_JOB_COUNT=8
MYPYTHONPATH=".."
# check if the command contains shard flag, we will run all shards in parallel because the host have 8 GPUs. # check if the command contains shard flag, we will run all shards in parallel because the host have 8 GPUs.
if [[ $commands == *"--shard-id="* ]]; then if [[ $commands == *"--shard-id="* ]]; then
# assign job count as the number of shards used # assign job count as the number of shards used
@ -167,6 +203,7 @@ if [[ $commands == *"--shard-id="* ]]; then
-e AWS_SECRET_ACCESS_KEY \ -e AWS_SECRET_ACCESS_KEY \
-v "${HF_CACHE}:${HF_MOUNT}" \ -v "${HF_CACHE}:${HF_MOUNT}" \
-e "HF_HOME=${HF_MOUNT}" \ -e "HF_HOME=${HF_MOUNT}" \
-e "PYTHONPATH=${MYPYTHONPATH}" \
--name "${container_name}_${GPU}" \ --name "${container_name}_${GPU}" \
"${image_name}" \ "${image_name}" \
/bin/bash -c "${commands_gpu}" \ /bin/bash -c "${commands_gpu}" \
@ -197,6 +234,7 @@ else
-e AWS_SECRET_ACCESS_KEY \ -e AWS_SECRET_ACCESS_KEY \
-v "${HF_CACHE}:${HF_MOUNT}" \ -v "${HF_CACHE}:${HF_MOUNT}" \
-e "HF_HOME=${HF_MOUNT}" \ -e "HF_HOME=${HF_MOUNT}" \
-e "PYTHONPATH=${MYPYTHONPATH}" \
--name "${container_name}" \ --name "${container_name}" \
"${image_name}" \ "${image_name}" \
/bin/bash -c "${commands}" /bin/bash -c "${commands}"

View File

@ -32,9 +32,12 @@ function cpu_tests() {
set -e set -e
pip install pytest pytest-asyncio einops peft Pillow soundfile transformers_stream_generator matplotlib pip install pytest pytest-asyncio einops peft Pillow soundfile transformers_stream_generator matplotlib
pip install sentence-transformers datamodel_code_generator pip install sentence-transformers datamodel_code_generator
pytest -v -s tests/models/embedding/language/test_cls_models.py::test_classification_models[float-jason9693/Qwen2.5-1.5B-apeach] pytest -v -s tests/models/language/generation/test_bart.py -m cpu_model
pytest -v -s tests/models/embedding/language/test_embedding.py::test_models[half-BAAI/bge-base-en-v1.5] pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-openai-community/gpt2]
pytest -v -s tests/models/encoder_decoder/language -m cpu_model" pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-facebook/opt-125m]
pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-google/gemma-1.1-2b-it]
pytest -v -s tests/models/language/pooling/test_classification.py::test_models[float-jason9693/Qwen2.5-1.5B-apeach]
pytest -v -s tests/models/language/pooling/test_embedding.py::test_models[half-BAAI/bge-base-en-v1.5]"
} }
# All of CPU tests are expected to be finished less than 40 mins. # All of CPU tests are expected to be finished less than 40 mins.

View File

@ -1,6 +1,6 @@
#!/bin/bash #!/bin/bash
set -xue set -xu
# Build the docker image. # Build the docker image.
docker build -f docker/Dockerfile.tpu -t vllm-tpu . docker build -f docker/Dockerfile.tpu -t vllm-tpu .
@ -24,31 +24,80 @@ docker run --privileged --net host --shm-size=16G -it \
&& export VLLM_XLA_CHECK_RECOMPILATION=1 \ && export VLLM_XLA_CHECK_RECOMPILATION=1 \
&& echo HARDWARE \ && echo HARDWARE \
&& tpu-info \ && tpu-info \
&& echo TEST_0 \ && { \
&& pytest -v -s /workspace/vllm/tests/v1/tpu/test_perf.py \ echo TEST_0: Running test_perf.py; \
&& echo TEST_1 \ python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_perf.py; \
&& pytest -v -s /workspace/vllm/tests/tpu/test_compilation.py \ echo TEST_0_EXIT_CODE: \$?; \
&& echo TEST_2 \ } & \
&& pytest -v -s /workspace/vllm/tests/v1/tpu/test_basic.py \ { \
&& echo TEST_3 \ echo TEST_1: Running test_compilation.py; \
&& pytest -v -s /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine \ python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_compilation.py; \
&& echo TEST_4 \ echo TEST_1_EXIT_CODE: \$?; \
&& pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py \ } & \
&& echo TEST_5 \ { \
&& python3 /workspace/vllm/examples/offline_inference/tpu.py \ echo TEST_2: Running test_basic.py; \
&& echo TEST_6 \ python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_basic.py; \
&& pytest -s -v /workspace/vllm/tests/v1/tpu/worker/test_tpu_model_runner.py \ echo TEST_2_EXIT_CODE: \$?; \
&& echo TEST_7 \ } & \
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py \ { \
&& echo TEST_8 \ echo TEST_3: Running test_accuracy.py::test_lm_eval_accuracy_v1_engine; \
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py \ python3 -m pytest -s -v /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine; \
&& echo TEST_9 \ echo TEST_3_EXIT_CODE: \$?; \
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_multimodal.py \ } & \
&& echo TEST_10 \ { \
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py \ echo TEST_4: Running test_quantization_accuracy.py; \
&& echo TEST_11 \ python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py; \
&& pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py" \ echo TEST_4_EXIT_CODE: \$?; \
} & \
{ \
echo TEST_5: Running examples/offline_inference/tpu.py; \
python3 /workspace/vllm/examples/offline_inference/tpu.py; \
echo TEST_5_EXIT_CODE: \$?; \
} & \
{ \
echo TEST_6: Running test_tpu_model_runner.py; \
python3 -m pytest -s -v /workspace/vllm/tests/tpu/worker/test_tpu_model_runner.py; \
echo TEST_6_EXIT_CODE: \$?; \
} & \
{ \
echo TEST_7: Running test_sampler.py; \
python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py; \
echo TEST_7_EXIT_CODE: \$?; \
} & \
{ \
echo TEST_8: Running test_topk_topp_sampler.py; \
python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py; \
echo TEST_8_EXIT_CODE: \$?; \
} & \
{ \
echo TEST_9: Running test_multimodal.py; \
python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_multimodal.py; \
echo TEST_9_EXIT_CODE: \$?; \
} & \
{ \
echo TEST_10: Running test_pallas.py; \
python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py; \
echo TEST_10_EXIT_CODE: \$?; \
} & \
{ \
echo TEST_11: Running test_struct_output_generate.py; \
python3 -m pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py; \
echo TEST_11_EXIT_CODE: \$?; \
} & \
{ \
echo TEST_12: Running test_moe_pallas.py; \
python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py; \
echo TEST_12_EXIT_CODE: \$?; \
} & \
# Disable the TPU LoRA tests until the feature is activated
# & { \
# echo TEST_13: Running test_moe_pallas.py; \
# python3 -m pytest -s -v /workspace/vllm/tests/tpu/lora/; \
# echo TEST_13_EXIT_CODE: \$?; \
# } & \
wait \
&& echo 'All tests have attempted to run. Check logs for individual test statuses and exit codes.' \
"
# TODO: This test fails because it uses RANDOM_SEED sampling # TODO: This test fails because it uses RANDOM_SEED sampling
# && VLLM_USE_V1=1 pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py \ # && VLLM_USE_V1=1 pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py \

View File

@ -50,11 +50,11 @@ aws s3 cp "$normal_wheel" "s3://vllm-wheels/$BUILDKITE_COMMIT/"
if [[ $normal_wheel == *"cu118"* ]]; then if [[ $normal_wheel == *"cu118"* ]]; then
# if $normal_wheel matches cu118, do not upload the index.html # if $normal_wheel matches cu118, do not upload the index.html
echo "Skipping index files for cu118 wheels" echo "Skipping index files for cu118 wheels"
elif [[ $normal_wheel == *"cu121"* ]]; then elif [[ $normal_wheel == *"cu126"* ]]; then
# if $normal_wheel matches cu121, do not upload the index.html # if $normal_wheel matches cu126, do not upload the index.html
echo "Skipping index files for cu121 wheels" echo "Skipping index files for cu126 wheels"
else else
# only upload index.html for cu124 wheels (default wheels) # only upload index.html for cu128 wheels (default wheels)
aws s3 cp index.html "s3://vllm-wheels/$BUILDKITE_COMMIT/vllm/index.html" aws s3 cp index.html "s3://vllm-wheels/$BUILDKITE_COMMIT/vllm/index.html"
aws s3 cp "s3://vllm-wheels/nightly/index.html" "s3://vllm-wheels/$BUILDKITE_COMMIT/index.html" aws s3 cp "s3://vllm-wheels/nightly/index.html" "s3://vllm-wheels/$BUILDKITE_COMMIT/index.html"
fi fi
@ -66,12 +66,13 @@ aws s3 cp "$normal_wheel" "s3://vllm-wheels/nightly/"
if [[ $normal_wheel == *"cu118"* ]]; then if [[ $normal_wheel == *"cu118"* ]]; then
# if $normal_wheel matches cu118, do not upload the index.html # if $normal_wheel matches cu118, do not upload the index.html
echo "Skipping index files for cu118 wheels" echo "Skipping index files for cu118 wheels"
elif [[ $normal_wheel == *"cu121"* ]]; then elif [[ $normal_wheel == *"cu126"* ]]; then
# if $normal_wheel matches cu121, do not upload the index.html # if $normal_wheel matches cu126, do not upload the index.html
echo "Skipping index files for cu121 wheels" echo "Skipping index files for cu126 wheels"
else else
# only upload index.html for cu124 wheels (default wheels) # only upload index.html for cu128 wheels (default wheels)
aws s3 cp index.html "s3://vllm-wheels/nightly/vllm/index.html" aws s3 cp index.html "s3://vllm-wheels/nightly/vllm/index.html"
fi fi
aws s3 cp "$wheel" "s3://vllm-wheels/$version/" aws s3 cp "$wheel" "s3://vllm-wheels/$version/"
aws s3 cp index.html "s3://vllm-wheels/$version/vllm/index.html"

View File

@ -32,6 +32,7 @@ steps:
##### fast check tests ##### ##### fast check tests #####
- label: Documentation Build # 2min - label: Documentation Build # 2min
mirror_hardwares: [amdexperimental]
working_dir: "/vllm-workspace/test_docs/docs" working_dir: "/vllm-workspace/test_docs/docs"
fast_check: true fast_check: true
no_gpu: True no_gpu: True
@ -39,9 +40,10 @@ steps:
- pip install -r ../../requirements/docs.txt - pip install -r ../../requirements/docs.txt
- SPHINXOPTS=\"-W\" make html - SPHINXOPTS=\"-W\" make html
# Check API reference (if it fails, you may have missing mock imports) # Check API reference (if it fails, you may have missing mock imports)
- grep \"sig sig-object py\" build/html/api/inference_params.html - grep \"sig sig-object py\" build/html/api/vllm/vllm.sampling_params.html
- label: Async Engine, Inputs, Utils, Worker Test # 24min - label: Async Engine, Inputs, Utils, Worker Test # 24min
mirror_hardwares: [amdexperimental]
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/mq_llm_engine - tests/mq_llm_engine
@ -62,6 +64,7 @@ steps:
- pytest -v -s worker # Worker - pytest -v -s worker # Worker
- label: Python-only Installation Test - label: Python-only Installation Test
mirror_hardwares: [amdexperimental]
source_file_dependencies: source_file_dependencies:
- tests/standalone_tests/python_only_compile.sh - tests/standalone_tests/python_only_compile.sh
- setup.py - setup.py
@ -69,7 +72,7 @@ steps:
- bash standalone_tests/python_only_compile.sh - bash standalone_tests/python_only_compile.sh
- label: Basic Correctness Test # 30min - label: Basic Correctness Test # 30min
#mirror_hardwares: [amd] mirror_hardwares: [amdexperimental, amdproduction]
fast_check: true fast_check: true
torch_nightly: true torch_nightly: true
source_file_dependencies: source_file_dependencies:
@ -86,6 +89,7 @@ steps:
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
- label: Chunked Prefill Test - label: Chunked Prefill Test
mirror_hardwares: [amdexperimental]
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/basic_correctness/test_chunked_prefill - tests/basic_correctness/test_chunked_prefill
@ -94,7 +98,7 @@ steps:
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
- label: Core Test # 10min - label: Core Test # 10min
mirror_hardwares: [amd] mirror_hardwares: [amdexperimental, amdproduction]
fast_check: true fast_check: true
source_file_dependencies: source_file_dependencies:
- vllm/core - vllm/core
@ -104,10 +108,10 @@ steps:
- pytest -v -s core - pytest -v -s core
- label: Entrypoints Test # 40min - label: Entrypoints Test # 40min
mirror_hardwares: [amdexperimental]
working_dir: "/vllm-workspace/tests" working_dir: "/vllm-workspace/tests"
fast_check: true fast_check: true
torch_nightly: true torch_nightly: true
#mirror_hardwares: [amd]
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/entrypoints/llm - tests/entrypoints/llm
@ -126,6 +130,7 @@ steps:
- VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests - VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
- label: Distributed Tests (4 GPUs) # 10min - label: Distributed Tests (4 GPUs) # 10min
mirror_hardwares: [amdexperimental]
working_dir: "/vllm-workspace/tests" working_dir: "/vllm-workspace/tests"
num_gpus: 4 num_gpus: 4
source_file_dependencies: source_file_dependencies:
@ -143,6 +148,8 @@ steps:
# test with tp=2 and external_dp=2 # test with tp=2 and external_dp=2
- VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py - VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
- torchrun --nproc-per-node=4 distributed/test_torchrun_example.py - torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
# test with tp=2 and pp=2
- PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
# test with internal dp # test with internal dp
- python3 ../examples/offline_inference/data_parallel.py - python3 ../examples/offline_inference/data_parallel.py
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
@ -153,12 +160,12 @@ steps:
# TODO: create a dedicated test section for multi-GPU example tests # TODO: create a dedicated test section for multi-GPU example tests
# when we have multiple distributed example tests # when we have multiple distributed example tests
- pushd ../examples/offline_inference - pushd ../examples/offline_inference
- python3 rlhf.py - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py
- RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py - VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
- popd - popd
- label: Metrics, Tracing Test # 10min - label: Metrics, Tracing Test # 10min
mirror_hardwares: [amd] mirror_hardwares: [amdexperimental, amdproduction]
num_gpus: 2 num_gpus: 2
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
@ -172,7 +179,7 @@ steps:
##### 1 GPU test ##### ##### 1 GPU test #####
- label: Regression Test # 5min - label: Regression Test # 5min
#mirror_hardwares: [amd] mirror_hardwares: [amdexperimental, amdproduction]
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/test_regression - tests/test_regression
@ -182,7 +189,7 @@ steps:
working_dir: "/vllm-workspace/tests" # optional working_dir: "/vllm-workspace/tests" # optional
- label: Engine Test # 10min - label: Engine Test # 10min
mirror_hardwares: [amd] mirror_hardwares: [amdexperimental, amdproduction]
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/engine - tests/engine
@ -196,7 +203,7 @@ steps:
- pytest -v -s tokenization - pytest -v -s tokenization
- label: V1 Test - label: V1 Test
#mirror_hardwares: [amd] mirror_hardwares: [amdexperimental]
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/v1 - tests/v1
@ -209,8 +216,8 @@ steps:
- pytest -v -s v1/worker - pytest -v -s v1/worker
- pytest -v -s v1/structured_output - pytest -v -s v1/structured_output
- pytest -v -s v1/spec_decode - pytest -v -s v1/spec_decode
- pytest -v -s v1/kv_connector/unit
- pytest -v -s v1/test_serial_utils.py - pytest -v -s v1/test_serial_utils.py
- pytest -v -s v1/test_stats.py
- pytest -v -s v1/test_utils.py - pytest -v -s v1/test_utils.py
- pytest -v -s v1/test_oracle.py - pytest -v -s v1/test_oracle.py
# TODO: accuracy does not match, whether setting # TODO: accuracy does not match, whether setting
@ -221,8 +228,8 @@ steps:
- pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine - pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine
- label: Examples Test # 25min - label: Examples Test # 25min
mirror_hardwares: [amdexperimental]
working_dir: "/vllm-workspace/examples" working_dir: "/vllm-workspace/examples"
#mirror_hardwares: [amd]
source_file_dependencies: source_file_dependencies:
- vllm/entrypoints - vllm/entrypoints
- examples/ - examples/
@ -246,7 +253,7 @@ steps:
- VLLM_USE_V1=0 python3 offline_inference/profiling.py --model facebook/opt-125m run_num_steps --num-steps 2 - VLLM_USE_V1=0 python3 offline_inference/profiling.py --model facebook/opt-125m run_num_steps --num-steps 2
- label: Prefix Caching Test # 9min - label: Prefix Caching Test # 9min
mirror_hardwares: [amd] mirror_hardwares: [amdexperimental, amdproduction]
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/prefix_caching - tests/prefix_caching
@ -254,6 +261,7 @@ steps:
- pytest -v -s prefix_caching - pytest -v -s prefix_caching
- label: Samplers Test # 36min - label: Samplers Test # 36min
mirror_hardwares: [amdexperimental]
source_file_dependencies: source_file_dependencies:
- vllm/model_executor/layers - vllm/model_executor/layers
- vllm/sampling_metadata.py - vllm/sampling_metadata.py
@ -264,7 +272,7 @@ steps:
- VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers - VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers
- label: LogitsProcessor Test # 5min - label: LogitsProcessor Test # 5min
mirror_hardwares: [amd] mirror_hardwares: [amdexperimental, amdproduction]
source_file_dependencies: source_file_dependencies:
- vllm/model_executor/layers - vllm/model_executor/layers
- vllm/model_executor/guided_decoding - vllm/model_executor/guided_decoding
@ -275,6 +283,7 @@ steps:
- pytest -v -s model_executor/test_guided_processors.py - pytest -v -s model_executor/test_guided_processors.py
- label: Speculative decoding tests # 40min - label: Speculative decoding tests # 40min
mirror_hardwares: [amdexperimental]
source_file_dependencies: source_file_dependencies:
- vllm/spec_decode - vllm/spec_decode
- tests/spec_decode - tests/spec_decode
@ -285,7 +294,7 @@ steps:
- pytest -v -s spec_decode/e2e/test_eagle_correctness.py - pytest -v -s spec_decode/e2e/test_eagle_correctness.py
- label: LoRA Test %N # 15min each - label: LoRA Test %N # 15min each
#mirror_hardwares: [amd] mirror_hardwares: [amdexperimental]
source_file_dependencies: source_file_dependencies:
- vllm/lora - vllm/lora
- tests/lora - tests/lora
@ -293,15 +302,20 @@ steps:
parallelism: 4 parallelism: 4
- label: PyTorch Compilation Unit Tests - label: PyTorch Compilation Unit Tests
mirror_hardwares: [amdexperimental, amdproduction]
torch_nightly: true
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/compile - tests/compile
commands: commands:
- pytest -v -s compile/test_pass_manager.py - pytest -v -s compile/test_pass_manager.py
- pytest -v -s compile/test_fusion.py - pytest -v -s compile/test_fusion.py
- pytest -v -s compile/test_silu_mul_quant_fusion.py
- pytest -v -s compile/test_sequence_parallelism.py - pytest -v -s compile/test_sequence_parallelism.py
- label: PyTorch Fullgraph Smoke Test # 9min - label: PyTorch Fullgraph Smoke Test # 9min
mirror_hardwares: [amdexperimental, amdproduction]
torch_nightly: true
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/compile - tests/compile
@ -312,6 +326,8 @@ steps:
- pytest -v -s compile/piecewise/test_toy_llama.py - pytest -v -s compile/piecewise/test_toy_llama.py
- label: PyTorch Fullgraph Test # 18min - label: PyTorch Fullgraph Test # 18min
mirror_hardwares: [amdexperimental, amdproduction]
torch_nightly: true
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/compile - tests/compile
@ -319,6 +335,7 @@ steps:
- pytest -v -s compile/test_full_graph.py - pytest -v -s compile/test_full_graph.py
- label: Kernels Core Operation Test - label: Kernels Core Operation Test
mirror_hardwares: [amdexperimental, amdproduction]
source_file_dependencies: source_file_dependencies:
- csrc/ - csrc/
- tests/kernels/core - tests/kernels/core
@ -326,6 +343,7 @@ steps:
- pytest -v -s kernels/core - pytest -v -s kernels/core
- label: Kernels Attention Test %N - label: Kernels Attention Test %N
mirror_hardwares: [amdexperimental, amdproduction]
source_file_dependencies: source_file_dependencies:
- csrc/attention/ - csrc/attention/
- vllm/attention - vllm/attention
@ -336,6 +354,7 @@ steps:
parallelism: 2 parallelism: 2
- label: Kernels Quantization Test %N - label: Kernels Quantization Test %N
mirror_hardwares: [amdexperimental, amdproduction]
source_file_dependencies: source_file_dependencies:
- csrc/quantization/ - csrc/quantization/
- vllm/model_executor/layers/quantization - vllm/model_executor/layers/quantization
@ -345,6 +364,7 @@ steps:
parallelism: 2 parallelism: 2
- label: Kernels MoE Test - label: Kernels MoE Test
mirror_hardwares: [amdexperimental]
source_file_dependencies: source_file_dependencies:
- csrc/moe/ - csrc/moe/
- tests/kernels/moe - tests/kernels/moe
@ -353,6 +373,7 @@ steps:
- pytest -v -s kernels/moe - pytest -v -s kernels/moe
- label: Kernels Mamba Test - label: Kernels Mamba Test
mirror_hardwares: [amdexperimental]
source_file_dependencies: source_file_dependencies:
- csrc/mamba/ - csrc/mamba/
- tests/kernels/mamba - tests/kernels/mamba
@ -360,7 +381,7 @@ steps:
- pytest -v -s kernels/mamba - pytest -v -s kernels/mamba
- label: Tensorizer Test # 11min - label: Tensorizer Test # 11min
# mirror_hardwares: [amd] mirror_hardwares: [amdexperimental, amdproduction]
soft_fail: true soft_fail: true
source_file_dependencies: source_file_dependencies:
- vllm/model_executor/model_loader - vllm/model_executor/model_loader
@ -371,37 +392,42 @@ steps:
- pytest -v -s tensorizer_loader - pytest -v -s tensorizer_loader
- label: Benchmarks # 9min - label: Benchmarks # 9min
mirror_hardwares: [amdexperimental, amdproduction]
working_dir: "/vllm-workspace/.buildkite" working_dir: "/vllm-workspace/.buildkite"
mirror_hardwares: [amd]
source_file_dependencies: source_file_dependencies:
- benchmarks/ - benchmarks/
commands: commands:
- bash scripts/run-benchmarks.sh - bash scripts/run-benchmarks.sh
- label: Benchmarks CLI Test # 10min - label: Benchmarks CLI Test # 10min
mirror_hardwares: [amdexperimental, amdproduction]
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/benchmarks/ - tests/benchmarks/
commands: commands:
- pytest -v -s benchmarks/ - pytest -v -s benchmarks/
- label: Quantization Test # 33min - label: Quantization Test
mirror_hardwares: [amdexperimental]
source_file_dependencies: source_file_dependencies:
- csrc/ - csrc/
- vllm/model_executor/layers/quantization - vllm/model_executor/layers/quantization
- tests/quantization - tests/quantization
command: VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization commands:
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization
- label: LM Eval Small Models # 53min - label: LM Eval Small Models # 53min
mirror_hardwares: [amdexperimental]
working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
source_file_dependencies: source_file_dependencies:
- csrc/ - csrc/
- vllm/model_executor/layers/quantization - vllm/model_executor/layers/quantization
commands: commands:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn - export VLLM_WORKER_MULTIPROC_METHOD=spawn
- bash ./run-tests.sh -c configs/models-small.txt -t 1 - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-small.txt --tp-size=1
- label: OpenAI API correctness - label: OpenAI API correctness
mirror_hardwares: [amdexperimental]
source_file_dependencies: source_file_dependencies:
- csrc/ - csrc/
- vllm/entrypoints/openai/ - vllm/entrypoints/openai/
@ -410,6 +436,7 @@ steps:
- pytest -s entrypoints/openai/correctness/ - pytest -s entrypoints/openai/correctness/
- label: Encoder Decoder tests # 5min - label: Encoder Decoder tests # 5min
mirror_hardwares: [amdexperimental]
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/encoder_decoder - tests/encoder_decoder
@ -417,8 +444,8 @@ steps:
- pytest -v -s encoder_decoder - pytest -v -s encoder_decoder
- label: OpenAI-Compatible Tool Use # 20 min - label: OpenAI-Compatible Tool Use # 20 min
mirror_hardwares: [amdexperimental]
fast_check: false fast_check: false
#mirror_hardwares: [ amd ]
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/tool_use - tests/tool_use
@ -430,92 +457,98 @@ steps:
##### models test ##### ##### models test #####
- label: Basic Models Test # 24min - label: Basic Models Test # 24min
mirror_hardwares: [amdexperimental, amdproduction]
torch_nightly: true
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/models - tests/models
commands: commands:
- pytest -v -s models/test_transformers.py - pytest -v -s models/test_transformers.py
- pytest -v -s models/test_registry.py - pytest -v -s models/test_registry.py
- pytest -v -s models/test_utils.py
- pytest -v -s models/test_vision.py
# V1 Test: https://github.com/vllm-project/vllm/issues/14531 # V1 Test: https://github.com/vllm-project/vllm/issues/14531
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2' - VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2'
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'llama4' - VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'llama4'
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'plamo2' - VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'plamo2'
- label: Language Models Test (Standard) # 32min - label: Language Models Test (Standard)
#mirror_hardwares: [amd] mirror_hardwares: [amdexperimental]
torch_nightly: true
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/models/decoder_only/language - tests/models/language
- tests/models/embedding/language
- tests/models/encoder_decoder/language
commands: commands:
# Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile. # Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
- pip install causal-conv1d - pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8'
- pytest -v -s models/decoder_only/language -m 'core_model or quant_model' - pip freeze | grep -E 'torch'
- pytest -v -s models/embedding/language -m core_model - pytest -v -s models/language -m core_model
- label: Language Models Test (Extended) # 1h10min - label: Language Models Test (Extended)
mirror_hardwares: [amdexperimental]
optional: true optional: true
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/models/decoder_only/language - tests/models/language
- tests/models/embedding/language
- tests/models/encoder_decoder/language
commands: commands:
# Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile. # Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
- pip install causal-conv1d - pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8'
- pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model' - pytest -v -s models/language -m 'not core_model'
- pytest -v -s models/embedding/language -m 'not core_model'
- label: Multi-Modal Models Test (Standard) # 40min - label: Multi-Modal Models Test (Standard)
#mirror_hardwares: [amd] mirror_hardwares: [amdexperimental]
torch_nightly: true
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/models/decoder_only/audio_language - tests/models/multimodal
- tests/models/decoder_only/vision_language
- tests/models/embedding/vision_language
- tests/models/encoder_decoder/audio_language
- tests/models/encoder_decoder/vision_language
commands: commands:
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
- pytest -v -s models/multimodal - pip freeze | grep -E 'torch'
- pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model' - pytest -v -s models/multimodal/processing
- pytest -v -s models/decoder_only/vision_language -m 'core_model or quant_model' - pytest -v -s --ignore models/multimodal/generation/test_whisper.py models/multimodal -m core_model
- pytest -v -s models/embedding/vision_language -m core_model - cd .. && pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work
- pytest -v -s models/encoder_decoder/audio_language -m core_model
- pytest -v -s models/encoder_decoder/language -m core_model
- pytest -v -s models/encoder_decoder/vision_language -m core_model
- pytest -v -s models/decoder_only/vision_language/test_interleaved.py
- label: Multi-Modal Models Test (Extended) 1 # 48m - label: Multi-Modal Models Test (Extended) 1
mirror_hardwares: [amdexperimental]
optional: true optional: true
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/models/decoder_only/audio_language - tests/models/multimodal
- tests/models/decoder_only/vision_language
- tests/models/embedding/vision_language
- tests/models/encoder_decoder/vision_language
commands: commands:
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
- pytest -v -s models/decoder_only/audio_language -m 'not core_model and not quant_model' - pytest -v -s --ignore models/multimodal/generation/test_common.py --ignore models/multimodal/processing models/multimodal -m 'not core_model'
- pytest -v -s models/decoder_only/vision_language/test_models.py -m 'split(group=0) and not core_model and not quant_model'
- pytest -v -s --ignore models/decoder_only/vision_language/test_models.py models/decoder_only/vision_language -m 'not core_model and not quant_model'
- pytest -v -s models/embedding/vision_language -m 'not core_model'
- pytest -v -s models/encoder_decoder/language -m 'not core_model'
- pytest -v -s models/encoder_decoder/vision_language -m 'not core_model'
- label: Multi-Modal Models Test (Extended) 2 # 38m - label: Multi-Modal Models Test (Extended) 2
mirror_hardwares: [amdexperimental]
optional: true optional: true
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/models/decoder_only/vision_language - tests/models/multimodal
commands: commands:
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
- pytest -v -s models/decoder_only/vision_language/test_models.py -m 'split(group=1) and not core_model and not quant_model' - pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=0) and not core_model'
- label: Multi-Modal Models Test (Extended) 3
mirror_hardwares: [amdexperimental, amdproduction]
optional: true
source_file_dependencies:
- vllm/
- tests/models/multimodal
commands:
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
- pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=1) and not core_model'
- label: Quantized Models Test
mirror_hardwares: [amdexperimental, amdproduction]
source_file_dependencies:
- vllm/model_executor/layers/quantization
- tests/models/quantization
commands:
- pytest -v -s models/quantization
# This test is used only in PR development phase to test individual models and should never run on main # This test is used only in PR development phase to test individual models and should never run on main
- label: Custom Models Test - label: Custom Models Test
mirror_hardwares: [amd] mirror_hardwares: [amdexperimental, amdproduction]
optional: true optional: true
commands: commands:
- echo 'Testing custom models...' - echo 'Testing custom models...'
@ -527,7 +560,7 @@ steps:
##### multi gpus test ##### ##### multi gpus test #####
- label: Distributed Comm Ops Test # 7min - label: Distributed Comm Ops Test # 7min
mirror_hardwares: [amd] mirror_hardwares: [amdexperimental, amdproduction]
working_dir: "/vllm-workspace/tests" working_dir: "/vllm-workspace/tests"
num_gpus: 2 num_gpus: 2
source_file_dependencies: source_file_dependencies:
@ -538,6 +571,7 @@ steps:
- pytest -v -s distributed/test_shm_broadcast.py - pytest -v -s distributed/test_shm_broadcast.py
- label: 2 Node Tests (4 GPUs in total) # 16min - label: 2 Node Tests (4 GPUs in total) # 16min
mirror_hardwares: [amdexperimental]
working_dir: "/vllm-workspace/tests" working_dir: "/vllm-workspace/tests"
num_gpus: 2 num_gpus: 2
num_nodes: 2 num_nodes: 2
@ -556,7 +590,7 @@ steps:
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
- label: Distributed Tests (2 GPUs) # 40min - label: Distributed Tests (2 GPUs) # 40min
#mirror_hardwares: [amd] mirror_hardwares: [amdexperimental]
working_dir: "/vllm-workspace/tests" working_dir: "/vllm-workspace/tests"
num_gpus: 2 num_gpus: 2
source_file_dependencies: source_file_dependencies:
@ -581,9 +615,8 @@ steps:
- TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)' - TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)'
# Avoid importing model tests that cause CUDA reinitialization error # Avoid importing model tests that cause CUDA reinitialization error
- pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)' - pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)'
- pytest models/encoder_decoder/language/test_bart.py -v -s -m 'distributed(num_gpus=2)' - pytest models/language -v -s -m 'distributed(num_gpus=2)'
- pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m 'distributed(num_gpus=2)' - pytest models/multimodal -v -s -m 'distributed(num_gpus=2)'
- pytest models/decoder_only/vision_language/test_models.py -v -s -m 'distributed(num_gpus=2)'
# test sequence parallel # test sequence parallel
- pytest -v -s distributed/test_sequence_parallel.py - pytest -v -s distributed/test_sequence_parallel.py
# this test fails consistently. # this test fails consistently.
@ -594,13 +627,14 @@ steps:
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown
- label: Plugin Tests (2 GPUs) # 40min - label: Plugin Tests (2 GPUs) # 40min
mirror_hardwares: [amdexperimental]
working_dir: "/vllm-workspace/tests" working_dir: "/vllm-workspace/tests"
num_gpus: 2 num_gpus: 2
source_file_dependencies: source_file_dependencies:
- vllm/plugins/ - vllm/plugins/
- tests/plugins/ - tests/plugins/
commands: commands:
# begin platform plugin tests, all the code in-between runs on dummy platform # begin platform plugin and general plugin tests, all the code in-between runs on dummy platform
- pip install -e ./plugins/vllm_add_dummy_platform - pip install -e ./plugins/vllm_add_dummy_platform
- pytest -v -s plugins_tests/test_platform_plugins.py - pytest -v -s plugins_tests/test_platform_plugins.py
- pip uninstall vllm_add_dummy_platform -y - pip uninstall vllm_add_dummy_platform -y
@ -611,8 +645,10 @@ steps:
- pytest -v -s distributed/test_distributed_oot.py - pytest -v -s distributed/test_distributed_oot.py
- pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process - pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process
- pytest -v -s models/test_oot_registration.py # it needs a clean process - pytest -v -s models/test_oot_registration.py # it needs a clean process
- pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins
- label: Multi-step Tests (4 GPUs) # 36min - label: Multi-step Tests (4 GPUs) # 36min
mirror_hardwares: [amdexperimental]
working_dir: "/vllm-workspace/tests" working_dir: "/vllm-workspace/tests"
num_gpus: 4 num_gpus: 4
source_file_dependencies: source_file_dependencies:
@ -633,6 +669,7 @@ steps:
- pytest -v -s multi_step/test_correctness_llm.py - pytest -v -s multi_step/test_correctness_llm.py
- label: Pipeline Parallelism Test # 45min - label: Pipeline Parallelism Test # 45min
mirror_hardwares: [amdexperimental, amdproduction]
working_dir: "/vllm-workspace/tests" working_dir: "/vllm-workspace/tests"
num_gpus: 4 num_gpus: 4
source_file_dependencies: source_file_dependencies:
@ -646,6 +683,7 @@ steps:
- pytest -v -s distributed/test_pipeline_parallel.py - pytest -v -s distributed/test_pipeline_parallel.py
- label: LoRA TP Test (Distributed) - label: LoRA TP Test (Distributed)
mirror_hardwares: [amdexperimental, amdproduction]
num_gpus: 4 num_gpus: 4
source_file_dependencies: source_file_dependencies:
- vllm/lora - vllm/lora
@ -661,6 +699,7 @@ steps:
- label: Weight Loading Multiple GPU Test # 33min - label: Weight Loading Multiple GPU Test # 33min
mirror_hardwares: [amdexperimental]
working_dir: "/vllm-workspace/tests" working_dir: "/vllm-workspace/tests"
num_gpus: 2 num_gpus: 2
source_file_dependencies: source_file_dependencies:
@ -670,6 +709,7 @@ steps:
- bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt
- label: Weight Loading Multiple GPU Test - Large Models # optional - label: Weight Loading Multiple GPU Test - Large Models # optional
mirror_hardwares: [amdexperimental]
working_dir: "/vllm-workspace/tests" working_dir: "/vllm-workspace/tests"
num_gpus: 2 num_gpus: 2
gpu: a100 gpu: a100
@ -708,4 +748,4 @@ steps:
- vllm/model_executor/layers/quantization - vllm/model_executor/layers/quantization
commands: commands:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn - export VLLM_WORKER_MULTIPROC_METHOD=spawn
- bash ./run-tests.sh -c configs/models-large.txt -t 4 - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4

View File

@ -21,7 +21,7 @@ body:
It is suggested to download and execute the latest script, as vllm might frequently update the diagnosis information needed for accurately and quickly responding to issues. It is suggested to download and execute the latest script, as vllm might frequently update the diagnosis information needed for accurately and quickly responding to issues.
value: | value: |
<details> <details>
<summary>The output of `python collect_env.py`</summary> <summary>The output of <code>python collect_env.py</code></summary>
```text ```text
Your output of `python collect_env.py` here Your output of `python collect_env.py` here
@ -75,7 +75,7 @@ body:
``` ```
``` ```
The error message you got, with the full traceback. The error message you got, with the full traceback and the error logs with [dump_input.py:##] if present.
``` ```
validations: validations:
required: true required: true

11
.github/mergify.yml vendored
View File

@ -163,6 +163,17 @@ pull_request_rules:
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
- name: assign reviewer for tensorizer changes
conditions:
- files~=^vllm/model_executor/model_loader/tensorizer.py
- files~=^vllm/model_executor/model_loader/tensorizer_loader.py
- files~=^tests/entrypoints/openai/test_tensorizer_entrypoint.py
- files~=^tests/tensorizer_loader/
actions:
assign:
users:
- "sangstar"
- name: remove 'needs-rebase' label when conflict is resolved - name: remove 'needs-rebase' label when conflict is resolved
conditions: conditions:
- -conflict - -conflict

View File

@ -1,4 +1,6 @@
name: Add label on auto-merge enabled name: Add label on auto-merge enabled
permissions:
pull-requests: write
on: on:
pull_request_target: pull_request_target:
types: types:

View File

@ -2,6 +2,9 @@ name: Lint and Deploy Charts
on: pull_request on: pull_request
permissions:
contents: read
jobs: jobs:
lint-and-deploy: lint-and-deploy:
runs-on: ubuntu-latest runs-on: ubuntu-latest

View File

@ -5,6 +5,9 @@ on:
push: push:
branches: [main] branches: [main]
permissions:
contents: read
jobs: jobs:
pre-commit: pre-commit:
runs-on: ubuntu-latest runs-on: ubuntu-latest

View File

@ -1,4 +1,6 @@
name: PR Reminder Comment Bot name: PR Reminder Comment Bot
permissions:
pull-requests: write
on: on:
pull_request_target: pull_request_target:
types: [opened] types: [opened]

1
.gitignore vendored
View File

@ -80,6 +80,7 @@ instance/
# Sphinx documentation # Sphinx documentation
docs/_build/ docs/_build/
docs/source/getting_started/examples/ docs/source/getting_started/examples/
docs/source/api/vllm
# PyBuilder # PyBuilder
.pybuilder/ .pybuilder/

View File

@ -12,29 +12,31 @@ repos:
- id: yapf - id: yapf
args: [--in-place, --verbose] args: [--in-place, --verbose]
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.3 rev: v0.11.7
hooks: hooks:
- id: ruff - id: ruff
args: [--output-format, github, --fix] args: [--output-format, github, --fix]
- id: ruff-format
files: ^(.buildkite|benchmarks)/.*
- repo: https://github.com/codespell-project/codespell - repo: https://github.com/codespell-project/codespell
rev: v2.4.0 rev: v2.4.1
hooks: hooks:
- id: codespell - id: codespell
additional_dependencies: ['tomli'] additional_dependencies: ['tomli']
args: ['--toml', 'pyproject.toml'] args: ['--toml', 'pyproject.toml']
- repo: https://github.com/PyCQA/isort - repo: https://github.com/PyCQA/isort
rev: 0a0b7a830386ba6a31c2ec8316849ae4d1b8240d # 6.0.0 rev: 6.0.1
hooks: hooks:
- id: isort - id: isort
- repo: https://github.com/pre-commit/mirrors-clang-format - repo: https://github.com/pre-commit/mirrors-clang-format
rev: v19.1.7 rev: v20.1.3
hooks: hooks:
- id: clang-format - id: clang-format
exclude: 'csrc/(moe/topk_softmax_kernels.cu|quantization/gguf/(ggml-common.h|dequantize.cuh|vecdotq.cuh|mmq.cuh|mmvq.cuh))|vllm/third_party/.*' exclude: 'csrc/(moe/topk_softmax_kernels.cu|quantization/gguf/(ggml-common.h|dequantize.cuh|vecdotq.cuh|mmq.cuh|mmvq.cuh))|vllm/third_party/.*'
types_or: [c++, cuda] types_or: [c++, cuda]
args: [--style=file, --verbose] args: [--style=file, --verbose]
- repo: https://github.com/jackdewinter/pymarkdown - repo: https://github.com/jackdewinter/pymarkdown
rev: v0.9.27 rev: v0.9.29
hooks: hooks:
- id: pymarkdown - id: pymarkdown
args: [fix] args: [fix]
@ -43,10 +45,10 @@ repos:
hooks: hooks:
- id: actionlint - id: actionlint
- repo: https://github.com/astral-sh/uv-pre-commit - repo: https://github.com/astral-sh/uv-pre-commit
rev: 0.6.2 rev: 0.6.17
hooks: hooks:
- id: pip-compile - id: pip-compile
args: [requirements/test.in, -o, requirements/test.txt] args: [requirements/test.in, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --torch-backend, cu128]
files: ^requirements/test\.(in|txt)$ files: ^requirements/test\.(in|txt)$
- repo: local - repo: local
hooks: hooks:
@ -101,8 +103,8 @@ repos:
args: args:
- -c - -c
- | - |
if ! grep -q "^Signed-off-by: $(git config user.name) <$(git config user.email)>" .git/COMMIT_EDITMSG; then if ! grep -q "^Signed-off-by: $(git config user.name) <$(git config user.email)>" "$(git rev-parse --git-path COMMIT_EDITMSG)"; then
printf "\nSigned-off-by: $(git config user.name) <$(git config user.email)>\n" >> .git/COMMIT_EDITMSG printf "\nSigned-off-by: $(git config user.name) <$(git config user.email)>\n" >> "$(git rev-parse --git-path COMMIT_EDITMSG)"
fi fi
language: system language: system
verbose: true verbose: true
@ -125,8 +127,6 @@ repos:
name: Update Dockerfile dependency graph name: Update Dockerfile dependency graph
entry: tools/update-dockerfile-graph.sh entry: tools/update-dockerfile-graph.sh
language: script language: script
files: ^docker/Dockerfile$
pass_filenames: false
# Keep `suggestion` last # Keep `suggestion` last
- id: suggestion - id: suggestion
name: Suggestion name: Suggestion

View File

@ -15,7 +15,6 @@ project(vllm_extensions LANGUAGES CXX)
# CUDA by default, can be overridden by using -DVLLM_TARGET_DEVICE=... (used by setup.py) # CUDA by default, can be overridden by using -DVLLM_TARGET_DEVICE=... (used by setup.py)
set(VLLM_TARGET_DEVICE "cuda" CACHE STRING "Target device backend for vLLM") set(VLLM_TARGET_DEVICE "cuda" CACHE STRING "Target device backend for vLLM")
message(STATUS "Build type: ${CMAKE_BUILD_TYPE}") message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
message(STATUS "Target device: ${VLLM_TARGET_DEVICE}") message(STATUS "Target device: ${VLLM_TARGET_DEVICE}")
@ -46,8 +45,8 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1
# requirements.txt files and should be kept consistent. The ROCm torch # requirements.txt files and should be kept consistent. The ROCm torch
# versions are derived from docker/Dockerfile.rocm # versions are derived from docker/Dockerfile.rocm
# #
set(TORCH_SUPPORTED_VERSION_CUDA "2.6.0") set(TORCH_SUPPORTED_VERSION_CUDA "2.7.0")
set(TORCH_SUPPORTED_VERSION_ROCM "2.6.0") set(TORCH_SUPPORTED_VERSION_ROCM "2.7.0")
# #
# Try to find python package with an executable that exactly matches # Try to find python package with an executable that exactly matches
@ -231,6 +230,7 @@ set(VLLM_EXT_SRC
"csrc/attention/paged_attention_v1.cu" "csrc/attention/paged_attention_v1.cu"
"csrc/attention/paged_attention_v2.cu" "csrc/attention/paged_attention_v2.cu"
"csrc/attention/merge_attn_states.cu" "csrc/attention/merge_attn_states.cu"
"csrc/attention/vertical_slash_index.cu"
"csrc/pos_encoding_kernels.cu" "csrc/pos_encoding_kernels.cu"
"csrc/activation_kernels.cu" "csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu" "csrc/layernorm_kernels.cu"
@ -241,6 +241,7 @@ set(VLLM_EXT_SRC
"csrc/quantization/fp8/common.cu" "csrc/quantization/fp8/common.cu"
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu" "csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/quantization/gguf/gguf_kernel.cu" "csrc/quantization/gguf/gguf_kernel.cu"
"csrc/quantization/activation_kernels.cu"
"csrc/cuda_utils_kernels.cu" "csrc/cuda_utils_kernels.cu"
"csrc/prepare_inputs/advance_step.cu" "csrc/prepare_inputs/advance_step.cu"
"csrc/custom_all_reduce.cu" "csrc/custom_all_reduce.cu"
@ -249,9 +250,8 @@ set(VLLM_EXT_SRC
if(VLLM_GPU_LANG STREQUAL "CUDA") if(VLLM_GPU_LANG STREQUAL "CUDA")
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
# Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case. # Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building.
# Please keep this in sync with FetchContent_Declare line below. set(CUTLASS_REVISION "v3.9.2" CACHE STRING "CUTLASS revision to use")
set(CUTLASS_REVISION "v3.9.0" CACHE STRING "CUTLASS revision to use")
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR}) if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
@ -269,7 +269,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
cutlass cutlass
GIT_REPOSITORY https://github.com/nvidia/cutlass.git GIT_REPOSITORY https://github.com/nvidia/cutlass.git
# Please keep this in sync with CUTLASS_REVISION line above. # Please keep this in sync with CUTLASS_REVISION line above.
GIT_TAG v3.9.0 GIT_TAG ${CUTLASS_REVISION}
GIT_PROGRESS TRUE GIT_PROGRESS TRUE
# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
@ -289,6 +289,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
"csrc/quantization/fp4/nvfp4_quant_entry.cu" "csrc/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu" "csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu"
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
"csrc/cutlass_extensions/common.cpp" "csrc/cutlass_extensions/common.cpp"
"csrc/attention/mla/cutlass_mla_entry.cu") "csrc/attention/mla/cutlass_mla_entry.cu")
@ -300,10 +301,55 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# Only build Marlin kernels if we are building for at least some compatible archs. # Only build Marlin kernels if we are building for at least some compatible archs.
# Keep building Marlin for 9.0 as there are some group sizes and shapes that # Keep building Marlin for 9.0 as there are some group sizes and shapes that
# are not supported by Machete yet. # are not supported by Machete yet.
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}") # 9.0 for latest bf16 atomicAdd PTX
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;9.0+PTX" "${CUDA_ARCHS}")
if (MARLIN_ARCHS) if (MARLIN_ARCHS)
#
# For the Marlin kernels we automatically generate sources for various
# preselected input type pairs and schedules.
# Generate sources:
set(MARLIN_GEN_SCRIPT
${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/gptq_marlin/generate_kernels.py)
file(MD5 ${MARLIN_GEN_SCRIPT} MARLIN_GEN_SCRIPT_HASH)
message(STATUS "Marlin generation script hash: ${MARLIN_GEN_SCRIPT_HASH}")
message(STATUS "Last run Marlin generate script hash: $CACHE{MARLIN_GEN_SCRIPT_HASH}")
if (NOT DEFINED CACHE{MARLIN_GEN_SCRIPT_HASH}
OR NOT $CACHE{MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MARLIN_GEN_SCRIPT_HASH})
execute_process(
COMMAND ${CMAKE_COMMAND} -E env
PYTHONPATH=$PYTHONPATH
${Python_EXECUTABLE} ${MARLIN_GEN_SCRIPT}
RESULT_VARIABLE marlin_generation_result
OUTPUT_VARIABLE marlin_generation_result
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log
ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log
)
if (NOT marlin_generation_result EQUAL 0)
message(FATAL_ERROR "Marlin generation failed."
" Result: \"${marlin_generation_result}\""
"\nCheck the log for details: "
"${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log")
else()
set(MARLIN_GEN_SCRIPT_HASH ${MARLIN_GEN_SCRIPT_HASH}
CACHE STRING "Last run Marlin generate script hash" FORCE)
message(STATUS "Marlin generation completed successfully.")
endif()
else()
message(STATUS "Marlin generation script has not changed, skipping generation.")
endif()
file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/kernel_*.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}"
CUDA_ARCHS "${MARLIN_ARCHS}")
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})
set(MARLIN_SRCS set(MARLIN_SRCS
"csrc/quantization/fp8/fp8_marlin.cu"
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu" "csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
"csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu" "csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu"
@ -375,6 +421,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set(SRCS set(SRCS
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu" "csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu" "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu"
) )
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
SRCS "${SRCS}" SRCS "${SRCS}"
@ -399,8 +446,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# #
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x) # For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
# kernels for the remaining archs that are not already built for 3x. # kernels for the remaining archs that are not already built for 3x.
# (Build 8.9 for FP8)
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
"7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}") "7.5;8.0;8.9+PTX" "${CUDA_ARCHS}")
# subtract out the archs that are already built for 3x # subtract out the archs that are already built for 3x
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS}) list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
if (SCALED_MM_2X_ARCHS) if (SCALED_MM_2X_ARCHS)
@ -451,7 +499,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND FP4_ARCHS) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND FP4_ARCHS)
set(SRCS set(SRCS
"csrc/quantization/fp4/nvfp4_quant_kernels.cu" "csrc/quantization/fp4/nvfp4_quant_kernels.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu") "csrc/quantization/fp4/nvfp4_experts_quant.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu"
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
SRCS "${SRCS}" SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}") CUDA_ARCHS "${FP4_ARCHS}")
@ -489,7 +539,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works # The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works
# on Hopper). get_cutlass_moe_mm_data should only be compiled if it's possible # on Hopper). get_cutlass_moe_mm_data should only be compiled if it's possible
# to compile MoE kernels that use its output. # to compile MoE kernels that use its output.
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}") cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu" set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu"
"csrc/quantization/cutlass_w8a8/moe/moe_data.cu") "csrc/quantization/cutlass_w8a8/moe/moe_data.cu")
@ -627,7 +677,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
CUDA_ARCHS "${CUDA_ARCHS}") CUDA_ARCHS "${CUDA_ARCHS}")
list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}") list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}")
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}") # 9.0 for latest bf16 atomicAdd PTX
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;9.0+PTX" "${CUDA_ARCHS}")
if (MARLIN_MOE_ARCHS) if (MARLIN_MOE_ARCHS)
# #
@ -645,7 +696,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH}) OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH})
execute_process( execute_process(
COMMAND ${CMAKE_COMMAND} -E env COMMAND ${CMAKE_COMMAND} -E env
PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH PYTHONPATH=$PYTHONPATH
${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT} ${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT}
RESULT_VARIABLE moe_marlin_generation_result RESULT_VARIABLE moe_marlin_generation_result
OUTPUT_VARIABLE moe_marlin_generation_output OUTPUT_VARIABLE moe_marlin_generation_output
@ -681,6 +732,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif() endif()
endif() endif()
if(VLLM_GPU_LANG STREQUAL "CUDA")
set(MOE_PERMUTE_SRC
"csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu"
"csrc/moe/moe_permute_unpermute_op.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_PERMUTE_SRC}"
CUDA_ARCHS "${MOE_PERMUTE_ARCHS}")
list(APPEND VLLM_MOE_EXT_SRC "${MOE_PERMUTE_SRC}")
endif()
message(STATUS "Enabling moe extension.") message(STATUS "Enabling moe extension.")
define_gpu_extension_target( define_gpu_extension_target(
_moe_C _moe_C
@ -689,6 +751,8 @@ define_gpu_extension_target(
SOURCES ${VLLM_MOE_EXT_SRC} SOURCES ${VLLM_MOE_EXT_SRC}
COMPILE_FLAGS ${VLLM_GPU_FLAGS} COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES} ARCHITECTURES ${VLLM_GPU_ARCHES}
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
USE_SABI 3 USE_SABI 3
WITH_SOABI) WITH_SOABI)

View File

@ -16,18 +16,20 @@ Easy, fast, and cheap LLM serving for everyone
--- ---
*Latest News* 🔥 *Latest News* 🔥
- [2025/05] We hosted [NYC vLLM Meetup](https://lu.ma/c1rqyf1f)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing).
- [2025/05] vLLM is now a hosted project under PyTorch Foundation! Please find the announcement [here](https://pytorch.org/blog/pytorch-foundation-welcomes-vllm/).
- [2025/04] We hosted [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing). - [2025/04] We hosted [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing).
- [2025/03] We hosted [vLLM x Ollama Inference Night](https://lu.ma/vllm-ollama)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/16T2PDD1YwRnZ4Tu8Q5r6n53c5Lr5c73UV9Vd2_eBo4U/edit?usp=sharing).
- [2025/03] We hosted [the first vLLM China Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg)! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1REHvfQMKGnvz6p3Fd23HhSO4c8j5WPGZV0bKYLwnHyQ/edit?usp=sharing).
- [2025/03] We hosted [the East Coast vLLM Meetup](https://lu.ma/7mu4k4xx)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1NHiv8EUFF1NLd3fEYODm56nDmL26lEeXCaDgyDlTsRs/edit#slide=id.g31441846c39_0_0).
- [2025/02] We hosted [the ninth vLLM meetup](https://lu.ma/h7g3kuj9) with Meta! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1jzC_PZVXrVNSFVCW-V4cFXb6pn7zZ2CyP_Flwo05aqg/edit?usp=sharing) and AMD [here](https://drive.google.com/file/d/1Zk5qEJIkTmlQ2eQcXQZlljAx3m9s7nwn/view?usp=sharing). The slides from Meta will not be posted.
- [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html). - [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html).
- [2025/01] We hosted [the eighth vLLM meetup](https://lu.ma/zep56hui) with Google Cloud! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1epVkt4Zu8Jz_S5OhEHPc798emsYh2BwYfRuDDVEF7u4/edit?usp=sharing), and Google Cloud team [here](https://drive.google.com/file/d/1h24pHewANyRL11xy5dXUbvRC9F9Kkjix/view?usp=sharing).
- [2024/12] vLLM joins [pytorch ecosystem](https://pytorch.org/blog/vllm-joins-pytorch)! Easy, Fast, and Cheap LLM Serving for Everyone!
<details> <details>
<summary>Previous News</summary> <summary>Previous News</summary>
- [2025/03] We hosted [vLLM x Ollama Inference Night](https://lu.ma/vllm-ollama)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/16T2PDD1YwRnZ4Tu8Q5r6n53c5Lr5c73UV9Vd2_eBo4U/edit?usp=sharing).
- [2025/03] We hosted [the first vLLM China Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg)! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1REHvfQMKGnvz6p3Fd23HhSO4c8j5WPGZV0bKYLwnHyQ/edit?usp=sharing).
- [2025/03] We hosted [the East Coast vLLM Meetup](https://lu.ma/7mu4k4xx)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1NHiv8EUFF1NLd3fEYODm56nDmL26lEeXCaDgyDlTsRs/edit#slide=id.g31441846c39_0_0).
- [2025/02] We hosted [the ninth vLLM meetup](https://lu.ma/h7g3kuj9) with Meta! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1jzC_PZVXrVNSFVCW-V4cFXb6pn7zZ2CyP_Flwo05aqg/edit?usp=sharing) and AMD [here](https://drive.google.com/file/d/1Zk5qEJIkTmlQ2eQcXQZlljAx3m9s7nwn/view?usp=sharing). The slides from Meta will not be posted.
- [2025/01] We hosted [the eighth vLLM meetup](https://lu.ma/zep56hui) with Google Cloud! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1epVkt4Zu8Jz_S5OhEHPc798emsYh2BwYfRuDDVEF7u4/edit?usp=sharing), and Google Cloud team [here](https://drive.google.com/file/d/1h24pHewANyRL11xy5dXUbvRC9F9Kkjix/view?usp=sharing).
- [2024/12] vLLM joins [pytorch ecosystem](https://pytorch.org/blog/vllm-joins-pytorch)! Easy, Fast, and Cheap LLM Serving for Everyone!
- [2024/11] We hosted [the seventh vLLM meetup](https://lu.ma/h0qvrajz) with Snowflake! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1e3CxQBV3JsfGp30SwyvS3eM_tW-ghOhJ9PAJGK6KR54/edit?usp=sharing), and Snowflake team [here](https://docs.google.com/presentation/d/1qF3RkDAbOULwz9WK5TOltt2fE9t6uIc_hVNLFAaQX6A/edit?usp=sharing). - [2024/11] We hosted [the seventh vLLM meetup](https://lu.ma/h0qvrajz) with Snowflake! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1e3CxQBV3JsfGp30SwyvS3eM_tW-ghOhJ9PAJGK6KR54/edit?usp=sharing), and Snowflake team [here](https://docs.google.com/presentation/d/1qF3RkDAbOULwz9WK5TOltt2fE9t6uIc_hVNLFAaQX6A/edit?usp=sharing).
- [2024/10] We have just created a developer slack ([slack.vllm.ai](https://slack.vllm.ai)) focusing on coordinating contributions and discussing features. Please feel free to join us there! - [2024/10] We have just created a developer slack ([slack.vllm.ai](https://slack.vllm.ai)) focusing on coordinating contributions and discussing features. Please feel free to join us there!
- [2024/10] Ray Summit 2024 held a special track for vLLM! Please find the opening talk slides from the vLLM team [here](https://docs.google.com/presentation/d/1B_KQxpHBTRa_mDF-tR6i8rWdOU5QoTZNcEg2MKZxEHM/edit?usp=sharing). Learn more from the [talks](https://www.youtube.com/playlist?list=PLzTswPQNepXl6AQwifuwUImLPFRVpksjR) from other vLLM contributors and users! - [2024/10] Ray Summit 2024 held a special track for vLLM! Please find the opening talk slides from the vLLM team [here](https://docs.google.com/presentation/d/1B_KQxpHBTRa_mDF-tR6i8rWdOU5QoTZNcEg2MKZxEHM/edit?usp=sharing). Learn more from the [talks](https://www.youtube.com/playlist?list=PLzTswPQNepXl6AQwifuwUImLPFRVpksjR) from other vLLM contributors and users!
@ -72,7 +74,7 @@ vLLM is flexible and easy to use with:
- OpenAI-compatible API server - OpenAI-compatible API server
- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, TPU, and AWS Neuron. - Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, TPU, and AWS Neuron.
- Prefix caching support - Prefix caching support
- Multi-lora support - Multi-LoRA support
vLLM seamlessly supports most popular open-source models on HuggingFace, including: vLLM seamlessly supports most popular open-source models on HuggingFace, including:
- Transformer-like LLMs (e.g., Llama) - Transformer-like LLMs (e.g., Llama)

212
benchmarks/auto_tune.sh Normal file
View File

@ -0,0 +1,212 @@
#!/bin/bash
# This script aims to tune the best server parameter combinations to maximize throughput for given requirement.
# The current server parameter combination is max_num_seqs and max_num_batched_tokens
# It also supports additional requirement: e2e latency and prefix cache.
# Pre-requisite:
# 1. Checkout to your branch, install/ update the correct running env. For TPU, activate conda env and install the corresponding torch, xla version.
# 2. If the model is customized, replace the MODEL's config with the customized config.
# 3. Set variables (ALL REQUIRED)
# BASE: your directory for vllm repo
# MODEL: the model served by vllm
# DOWNLOAD_DIR: directory to download and load model weights.
# INPUT_LEN: request input len
# OUTPUT_LEN: request output len
# MIN_CACHE_HIT_PCT: prefix cache rate
# MAX_LATENCY_ALLOWED_MS: (e2e) latency requirement. If there's no latency requirement, set it to a large number like 1000000000
# 4. Run the script, it might take a long time, you can use tmux to avoid the script stop if disconnection happens.
# 5. The final result will be saved in RESULT file.
# Example use cases
# 1. Given input_len=1800, output_len=20, what's the best max_num_seqs and max_num_batched_tokens to get highest throughput?
# Use INPUT_LEN=1800, OUTPUT_LEN=20, MIN_CACHE_HIT_PCT=0, MAX_LATENCY_ALLOWED_MS=100000000000
# 2. If we have latency requirement to be lower than 500ms, what's the best server parameter?
# Use INPUT_LEN=1800, OUTPUT_LEN=20, MIN_CACHE_HIT_PCT=0, MAX_LATENCY_ALLOWED_MS=500
# 3. If we want to reach 60% prefix cache, what's the best server parameter?
# Use INPUT_LEN=1800, OUTPUT_LEN=20, MIN_CACHE_HIT_PCT=60, MAX_LATENCY_ALLOWED_MS=500
TAG=$(date +"%Y_%m_%d_%H_%M")
BASE=""
MODEL="meta-llama/Llama-3.1-8B-Instruct"
DOWNLOAD_DIR=""
INPUT_LEN=4000
OUTPUT_LEN=16
MIN_CACHE_HIT_PCT_PCT=0
MAX_LATENCY_ALLOWED_MS=100000000000
LOG_FOLDER="$BASE/auto-benchmark/$TAG"
RESULT="$LOG_FOLDER/result.txt"
echo "result file$ $RESULT"
echo "model: $MODEL"
echo
rm -rf $LOG_FOLDER
mkdir -p $LOG_FOLDER
cd "$BASE/vllm"
# create sonnet-4x.txt so that we can sample 2048 tokens for input
echo "" > benchmarks/sonnet_4x.txt
for _ in {1..4}
do
cat benchmarks/sonnet.txt >> benchmarks/sonnet_4x.txt
done
pip install datasets
current_hash=$(git rev-parse HEAD)
echo "hash:$current_hash" >> "$RESULT"
echo "current_hash: $current_hash"
best_throughput=0
best_max_num_seqs=0
best_num_batched_tokens=0
best_goodput=0
run_benchmark() {
local max_num_seqs=$1
local max_num_batched_tokens=$2
echo "max_num_seq: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens"
local vllm_log="$LOG_FOLDER/vllm_log_${max_num_seqs}_${max_num_batched_tokens}.txt"
echo "vllm_log: $vllm_log"
echo
rm -f $vllm_log
# start the server
VLLM_USE_V1=1 VLLM_SERVER_DEV_MODE=1 vllm serve $MODEL \
--disable-log-requests \
--port 8004 \
--gpu-memory-utilization 0.98 \
--max-num-seqs $max_num_seqs \
--max-num-batched-tokens $max_num_batched_tokens \
--tensor-parallel-size 1 \
--enable-prefix-caching \
--load-format dummy \
--download-dir $DOWNLOAD_DIR \
--max-model-len $(( INPUT_LEN+OUTPUT_LEN )) > "$vllm_log" 2>&1 &
echo "wait for 10 minutes.."
echo
# wait for 10 minutes...
server_started=0
for i in {1..60}; do
if grep -Fq "Application startup complete" "$vllm_log"; then
echo "Application started"
server_started=1
break
else
# echo "wait for 10 seconds..."
sleep 10
fi
done
if (( ! server_started )); then
echo "server did not start within 10 minutes, terminate the benchmarking. Please check server log at $vllm_log"
echo "pkill -f vllm"
echo
pkill vllm
sleep 10
return 1
fi
echo "run benchmark test..."
echo
meet_latency_requirement=0
# get a basic qps by using request-rate inf
bm_log="$LOG_FOLDER/bm_log_${max_num_seqs}_${max_num_batched_tokens}_requestrate_inf.txt"
prefix_len=$(( INPUT_LEN * MIN_CACHE_HIT_PCT / 100 ))
python benchmarks/benchmark_serving.py \
--backend vllm \
--model $MODEL \
--dataset-name sonnet \
--dataset-path benchmarks/sonnet_4x.txt \
--sonnet-input-len $INPUT_LEN \
--sonnet-output-len $OUTPUT_LEN \
--ignore-eos \
--disable-tqdm \
--request-rate inf \
--percentile-metrics ttft,tpot,itl,e2el \
--goodput e2el:$MAX_LATENCY_ALLOWED_MS \
--num-prompts 100 \
--sonnet-prefix-len $prefix_len \
--port 8004 > "$bm_log"
through_put=$(grep "Request throughput (req/s):" "$bm_log" | sed 's/[^0-9.]//g')
e2el=$(grep "P99 E2EL (ms):" "$bm_log" | awk '{print $NF}')
goodput=$(grep "Request goodput (req/s):" "$bm_log" | sed 's/[^0-9.]//g')
if (( $(echo "$e2el <= $MAX_LATENCY_ALLOWED_MS" | bc -l) )); then
meet_latency_requirement=1
fi
if (( ! meet_latency_requirement )); then
# start from request-rate as int(through_put) + 1
request_rate=$((${through_put%.*} + 1))
while ((request_rate > 0)); do
# clear prefix cache
curl -X POST http://0.0.0.0:8004/reset_prefix_cache
sleep 5
bm_log="$LOG_FOLDER/bm_log_${max_num_seqs}_${max_num_batched_tokens}_requestrate_${request_rate}.txt"
python benchmarks/benchmark_serving.py \
--backend vllm \
--model $MODEL \
--dataset-name sonnet \
--dataset-path benchmarks/sonnet_4x.txt \
--sonnet-input-len $INPUT_LEN \
--sonnet-output-len $OUTPUT_LEN \
--ignore_eos \
--disable-tqdm \
--request-rate $request_rate \
--percentile-metrics ttft,tpot,itl,e2el \
--goodput e2el:$MAX_LATENCY_ALLOWED_MS \
--num-prompts 100 \
--sonnet-prefix-len $prefix_len \
--port 8004 > "$bm_log"
through_put=$(grep "Request throughput (req/s):" "$bm_log" | sed 's/[^0-9.]//g')
e2el=$(grep "P99 E2EL (ms):" "$bm_log" | awk '{print $NF}')
goodput=$(grep "Request goodput (req/s):" "$bm_log" | sed 's/[^0-9.]//g')
if (( $(echo "$e2el <= $MAX_LATENCY_ALLOWED_MS" | bc -l) )); then
meet_latency_requirement=1
break
fi
request_rate=$((request_rate-1))
done
fi
# write the results and update the best result.
if ((meet_latency_requirement)); then
echo "max_num_seqs: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens, request_rate: $request_rate, e2el: $e2el, through put: $through_put, goodput: $goodput"
echo "max_num_seqs: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens, request_rate: $request_rate, e2el: $e2el, through put: $through_put, goodput: $goodput" >> "$RESULT"
if (( $(echo "$through_put > $best_throughput" | bc -l) )); then
best_throughput=$through_put
best_max_num_seqs=$max_num_seqs
best_num_batched_tokens=$max_num_batched_tokens
best_goodput=$goodput
fi
else
echo "max_num_seqs: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens does not meet latency requirement ${MAX_LATENCY_ALLOWED_MS}"
echo "max_num_seqs: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens does not meet latency requirement ${MAX_LATENCY_ALLOWED_MS}" >> "$RESULT"
fi
echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput"
echo "pkill -f vllm"
echo
pkill vllm
sleep 10
rm -f $vllm_log
printf '=%.0s' $(seq 1 20)
return 0
}
num_seqs_list="128 256"
num_batched_tokens_list="512 1024 2048 4096"
for num_seqs in $num_seqs_list; do
for num_batched_tokens in $num_batched_tokens_list; do
run_benchmark $num_seqs $num_batched_tokens
exit 0
done
done
echo "finish permutations"
echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput"
echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput" >> "$RESULT"

View File

@ -12,8 +12,7 @@ from typing import Optional, Union
import aiohttp import aiohttp
import huggingface_hub.constants import huggingface_hub.constants
from tqdm.asyncio import tqdm from tqdm.asyncio import tqdm
from transformers import (AutoTokenizer, PreTrainedTokenizer, from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
PreTrainedTokenizerFast)
# NOTE(simon): do not import vLLM here so the benchmark script # NOTE(simon): do not import vLLM here so the benchmark script
# can run without vLLM installed. # can run without vLLM installed.
@ -43,8 +42,7 @@ class RequestFuncOutput:
latency: float = 0.0 latency: float = 0.0
output_tokens: int = 0 output_tokens: int = 0
ttft: float = 0.0 # Time to first token ttft: float = 0.0 # Time to first token
itl: list[float] = field( itl: list[float] = field(default_factory=list) # list of inter-token latencies
default_factory=list) # list of inter-token latencies
tpot: float = 0.0 # avg next-token latencies tpot: float = 0.0 # avg next-token latencies
prompt_len: int = 0 prompt_len: int = 0
error: str = "" error: str = ""
@ -57,8 +55,9 @@ async def async_request_tgi(
api_url = request_func_input.api_url api_url = request_func_input.api_url
assert api_url.endswith("generate_stream") assert api_url.endswith("generate_stream")
async with aiohttp.ClientSession(trust_env=True, async with aiohttp.ClientSession(
timeout=AIOHTTP_TIMEOUT) as session: trust_env=True, timeout=AIOHTTP_TIMEOUT
) as session:
params = { params = {
"max_new_tokens": request_func_input.output_len, "max_new_tokens": request_func_input.output_len,
"do_sample": True, "do_sample": True,
@ -105,8 +104,7 @@ async def async_request_tgi(
# Decoding phase # Decoding phase
else: else:
output.itl.append(timestamp - output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp)
most_recent_timestamp = timestamp most_recent_timestamp = timestamp
@ -133,8 +131,9 @@ async def async_request_trt_llm(
api_url = request_func_input.api_url api_url = request_func_input.api_url
assert api_url.endswith("generate_stream") assert api_url.endswith("generate_stream")
async with aiohttp.ClientSession(trust_env=True, async with aiohttp.ClientSession(
timeout=AIOHTTP_TIMEOUT) as session: trust_env=True, timeout=AIOHTTP_TIMEOUT
) as session:
payload = { payload = {
"accumulate_tokens": True, "accumulate_tokens": True,
"text_input": request_func_input.prompt, "text_input": request_func_input.prompt,
@ -159,8 +158,7 @@ async def async_request_trt_llm(
if not chunk_bytes: if not chunk_bytes:
continue continue
chunk = chunk_bytes.decode("utf-8").removeprefix( chunk = chunk_bytes.decode("utf-8").removeprefix("data:")
"data:")
data = json.loads(chunk) data = json.loads(chunk)
output.generated_text += data["text_output"] output.generated_text += data["text_output"]
@ -172,8 +170,7 @@ async def async_request_trt_llm(
# Decoding phase # Decoding phase
else: else:
output.itl.append(timestamp - output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp)
most_recent_timestamp = timestamp most_recent_timestamp = timestamp
@ -197,10 +194,11 @@ async def async_request_deepspeed_mii(
request_func_input: RequestFuncInput, request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None, pbar: Optional[tqdm] = None,
) -> RequestFuncOutput: ) -> RequestFuncOutput:
async with aiohttp.ClientSession(trust_env=True, async with aiohttp.ClientSession(
timeout=AIOHTTP_TIMEOUT) as session: trust_env=True, timeout=AIOHTTP_TIMEOUT
) as session:
payload = { payload = {
"model": request_func_input.model,
"prompt": request_func_input.prompt, "prompt": request_func_input.prompt,
"max_tokens": request_func_input.output_len, "max_tokens": request_func_input.output_len,
"temperature": 0.01, # deepspeed-mii does not accept 0.0 temp. "temperature": 0.01, # deepspeed-mii does not accept 0.0 temp.
@ -216,19 +214,21 @@ async def async_request_deepspeed_mii(
st = time.perf_counter() st = time.perf_counter()
try: try:
async with session.post(url=request_func_input.api_url, async with session.post(
json=payload) as response: url=request_func_input.api_url, json=payload
) as response:
if response.status == 200: if response.status == 200:
parsed_resp = await response.json() parsed_resp = await response.json()
output.latency = time.perf_counter() - st output.latency = time.perf_counter() - st
if "choices" in parsed_resp: if "choices" in parsed_resp:
output.generated_text = parsed_resp["choices"][0][ output.generated_text = parsed_resp["choices"][0]["text"]
"text"]
elif "text" in parsed_resp: elif "text" in parsed_resp:
output.generated_text = parsed_resp["text"][0] output.generated_text = parsed_resp["text"][0]
else: else:
output.error = ("Unexpected response format: " output.error = (
"neither 'choices' nor 'text' found") "Unexpected response format: "
"neither 'choices' nor 'text' found"
)
output.success = False output.success = False
output.success = True output.success = True
else: else:
@ -249,17 +249,20 @@ async def async_request_openai_completions(
pbar: Optional[tqdm] = None, pbar: Optional[tqdm] = None,
) -> RequestFuncOutput: ) -> RequestFuncOutput:
api_url = request_func_input.api_url api_url = request_func_input.api_url
assert api_url.endswith( assert api_url.endswith(("completions", "profile")), (
("completions", "profile") "OpenAI Completions API URL must end with 'completions' or 'profile'."
), "OpenAI Completions API URL must end with 'completions' or 'profile'." )
async with aiohttp.ClientSession(trust_env=True, async with aiohttp.ClientSession(
timeout=AIOHTTP_TIMEOUT) as session: trust_env=True, timeout=AIOHTTP_TIMEOUT
) as session:
payload = { payload = {
"model": request_func_input.model_name \ "model": request_func_input.model_name
if request_func_input.model_name else request_func_input.model, if request_func_input.model_name
else request_func_input.model,
"prompt": request_func_input.prompt, "prompt": request_func_input.prompt,
"temperature": 0.0, "temperature": 0.0,
"repetition_penalty": 1.0,
"max_tokens": request_func_input.output_len, "max_tokens": request_func_input.output_len,
"logprobs": request_func_input.logprobs, "logprobs": request_func_input.logprobs,
"stream": True, "stream": True,
@ -271,9 +274,7 @@ async def async_request_openai_completions(
payload["ignore_eos"] = request_func_input.ignore_eos payload["ignore_eos"] = request_func_input.ignore_eos
if request_func_input.extra_body: if request_func_input.extra_body:
payload.update(request_func_input.extra_body) payload.update(request_func_input.extra_body)
headers = { headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
}
output = RequestFuncOutput() output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
@ -282,8 +283,9 @@ async def async_request_openai_completions(
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
try: try:
async with session.post(url=api_url, json=payload, async with session.post(
headers=headers) as response: url=api_url, json=payload, headers=headers
) as response:
if response.status == 200: if response.status == 200:
first_chunk_received = False first_chunk_received = False
async for chunk_bytes in response.content: async for chunk_bytes in response.content:
@ -291,8 +293,7 @@ async def async_request_openai_completions(
if not chunk_bytes: if not chunk_bytes:
continue continue
chunk = chunk_bytes.decode("utf-8").removeprefix( chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
"data: ")
if chunk != "[DONE]": if chunk != "[DONE]":
data = json.loads(chunk) data = json.loads(chunk)
@ -312,21 +313,20 @@ async def async_request_openai_completions(
# Decoding phase # Decoding phase
else: else:
output.itl.append(timestamp - output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp)
most_recent_timestamp = timestamp most_recent_timestamp = timestamp
generated_text += text or "" generated_text += text or ""
elif usage := data.get("usage"): elif usage := data.get("usage"):
output.output_tokens = usage.get( output.output_tokens = usage.get("completion_tokens")
"completion_tokens")
if first_chunk_received: if first_chunk_received:
output.success = True output.success = True
else: else:
output.success = False output.success = False
output.error = ( output.error = (
"Never received a valid chunk to calculate TTFT." "Never received a valid chunk to calculate TTFT."
"This response will be marked as failed!") "This response will be marked as failed!"
)
output.generated_text = generated_text output.generated_text = generated_text
output.latency = most_recent_timestamp - st output.latency = most_recent_timestamp - st
else: else:
@ -347,23 +347,22 @@ async def async_request_openai_chat_completions(
pbar: Optional[tqdm] = None, pbar: Optional[tqdm] = None,
) -> RequestFuncOutput: ) -> RequestFuncOutput:
api_url = request_func_input.api_url api_url = request_func_input.api_url
assert api_url.endswith( assert api_url.endswith(("chat/completions", "profile")), (
("chat/completions", "profile") "OpenAI Chat Completions API URL must end with 'chat/completions'."
), "OpenAI Chat Completions API URL must end with 'chat/completions'." )
async with aiohttp.ClientSession(trust_env=True, async with aiohttp.ClientSession(
timeout=AIOHTTP_TIMEOUT) as session: trust_env=True, timeout=AIOHTTP_TIMEOUT
) as session:
content = [{"type": "text", "text": request_func_input.prompt}] content = [{"type": "text", "text": request_func_input.prompt}]
if request_func_input.multi_modal_content: if request_func_input.multi_modal_content:
content.append(request_func_input.multi_modal_content) content.append(request_func_input.multi_modal_content)
payload = { payload = {
"model": request_func_input.model_name \ "model": request_func_input.model_name
if request_func_input.model_name else request_func_input.model, if request_func_input.model_name
else request_func_input.model,
"messages": [ "messages": [
{ {"role": "user", "content": content},
"role": "user",
"content": content
},
], ],
"temperature": 0.0, "temperature": 0.0,
"max_completion_tokens": request_func_input.output_len, "max_completion_tokens": request_func_input.output_len,
@ -389,16 +388,16 @@ async def async_request_openai_chat_completions(
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
try: try:
async with session.post(url=api_url, json=payload, async with session.post(
headers=headers) as response: url=api_url, json=payload, headers=headers
) as response:
if response.status == 200: if response.status == 200:
async for chunk_bytes in response.content: async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip() chunk_bytes = chunk_bytes.strip()
if not chunk_bytes: if not chunk_bytes:
continue continue
chunk = chunk_bytes.decode("utf-8").removeprefix( chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
"data: ")
if chunk != "[DONE]": if chunk != "[DONE]":
timestamp = time.perf_counter() timestamp = time.perf_counter()
data = json.loads(chunk) data = json.loads(chunk)
@ -412,13 +411,11 @@ async def async_request_openai_chat_completions(
# Decoding phase # Decoding phase
else: else:
output.itl.append(timestamp - output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp)
generated_text += content or "" generated_text += content or ""
elif usage := data.get("usage"): elif usage := data.get("usage"):
output.output_tokens = usage.get( output.output_tokens = usage.get("completion_tokens")
"completion_tokens")
most_recent_timestamp = timestamp most_recent_timestamp = timestamp
@ -444,25 +441,28 @@ async def async_request_openai_audio(
) -> RequestFuncOutput: ) -> RequestFuncOutput:
# Lazy import without PlaceholderModule to avoid vllm dep. # Lazy import without PlaceholderModule to avoid vllm dep.
import soundfile import soundfile
api_url = request_func_input.api_url api_url = request_func_input.api_url
assert api_url.endswith( assert api_url.endswith(("transcriptions", "translations")), (
("transcriptions", "translations" "OpenAI Chat Completions API URL must end with 'transcriptions' "
)), "OpenAI Chat Completions API URL must end with 'transcriptions' " )
"or `translations`." "or `translations`."
async with aiohttp.ClientSession(trust_env=True, async with aiohttp.ClientSession(
timeout=AIOHTTP_TIMEOUT) as session: trust_env=True, timeout=AIOHTTP_TIMEOUT
) as session:
content = [{"type": "text", "text": request_func_input.prompt}] content = [{"type": "text", "text": request_func_input.prompt}]
payload = { payload = {
"model": request_func_input.model_name \ "model": request_func_input.model_name
if request_func_input.model_name else request_func_input.model, if request_func_input.model_name
else request_func_input.model,
"temperature": 0.0, "temperature": 0.0,
"max_completion_tokens": request_func_input.output_len, "max_completion_tokens": request_func_input.output_len,
"stream": True, "stream": True,
"language": "en", "language": "en",
# Flattened due to multipart/form-data # Flattened due to multipart/form-data
"stream_include_usage": True, "stream_include_usage": True,
"stream_continuous_usage_stats": True "stream_continuous_usage_stats": True,
} }
if request_func_input.extra_body: if request_func_input.extra_body:
payload.update(request_func_input.extra_body) payload.update(request_func_input.extra_body)
@ -477,9 +477,9 @@ async def async_request_openai_audio(
buffer.seek(0) buffer.seek(0)
return buffer return buffer
with to_bytes(*request_func_input.multi_modal_content['audio']) as f: with to_bytes(*request_func_input.multi_modal_content["audio"]) as f:
form = aiohttp.FormData() form = aiohttp.FormData()
form.add_field('file', f, content_type='audio/wav') form.add_field("file", f, content_type="audio/wav")
for key, value in payload.items(): for key, value in payload.items():
form.add_field(key, str(value)) form.add_field(key, str(value))
@ -491,24 +491,22 @@ async def async_request_openai_audio(
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
try: try:
async with session.post(url=api_url, async with session.post(
data=form, url=api_url, data=form, headers=headers
headers=headers) as response: ) as response:
if response.status == 200: if response.status == 200:
async for chunk_bytes in response.content: async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip() chunk_bytes = chunk_bytes.strip()
if not chunk_bytes: if not chunk_bytes:
continue continue
chunk = chunk_bytes.decode("utf-8").removeprefix( chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
"data: ")
if chunk != "[DONE]": if chunk != "[DONE]":
timestamp = time.perf_counter() timestamp = time.perf_counter()
data = json.loads(chunk) data = json.loads(chunk)
if choices := data.get("choices"): if choices := data.get("choices"):
content = choices[0]["delta"].get( content = choices[0]["delta"].get("content")
"content")
# First token # First token
if ttft == 0.0: if ttft == 0.0:
ttft = timestamp - st ttft = timestamp - st
@ -517,12 +515,14 @@ async def async_request_openai_audio(
# Decoding phase # Decoding phase
else: else:
output.itl.append( output.itl.append(
timestamp - most_recent_timestamp) timestamp - most_recent_timestamp
)
generated_text += content or "" generated_text += content or ""
elif usage := data.get("usage"): elif usage := data.get("usage"):
output.output_tokens = usage.get( output.output_tokens = usage.get(
"completion_tokens") "completion_tokens"
)
most_recent_timestamp = timestamp most_recent_timestamp = timestamp
@ -543,7 +543,7 @@ async def async_request_openai_audio(
def get_model(pretrained_model_name_or_path: str) -> str: def get_model(pretrained_model_name_or_path: str) -> str:
if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true': if os.getenv("VLLM_USE_MODELSCOPE", "False").lower() == "true":
from modelscope import snapshot_download from modelscope import snapshot_download
from vllm.model_executor.model_loader.weight_utils import get_lock from vllm.model_executor.model_loader.weight_utils import get_lock
@ -554,7 +554,8 @@ def get_model(pretrained_model_name_or_path: str) -> str:
model_path = snapshot_download( model_path = snapshot_download(
model_id=pretrained_model_name_or_path, model_id=pretrained_model_name_or_path,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"]) ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"],
)
return model_path return model_path
return pretrained_model_name_or_path return pretrained_model_name_or_path
@ -567,23 +568,23 @@ def get_tokenizer(
**kwargs, **kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
if pretrained_model_name_or_path is not None and not os.path.exists( if pretrained_model_name_or_path is not None and not os.path.exists(
pretrained_model_name_or_path): pretrained_model_name_or_path
pretrained_model_name_or_path = get_model( ):
pretrained_model_name_or_path) pretrained_model_name_or_path = get_model(pretrained_model_name_or_path)
if tokenizer_mode == "slow": if tokenizer_mode == "slow":
if kwargs.get("use_fast", False): if kwargs.get("use_fast", False):
raise ValueError( raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
"Cannot use the fast tokenizer in slow tokenizer mode.")
kwargs["use_fast"] = False kwargs["use_fast"] = False
if tokenizer_mode == "mistral": if tokenizer_mode == "mistral":
try: try:
from vllm.transformers_utils.tokenizer import MistralTokenizer from vllm.transformers_utils.tokenizer import MistralTokenizer
except ImportError as e: except ImportError as e:
raise ImportError("MistralTokenizer requires vllm package.\n" raise ImportError(
"Please install it with `pip install vllm` " "MistralTokenizer requires vllm package.\n"
"to use mistral tokenizer mode.") from e "Please install it with `pip install vllm` "
return MistralTokenizer.from_pretrained( "to use mistral tokenizer mode."
str(pretrained_model_name_or_path)) ) from e
return MistralTokenizer.from_pretrained(str(pretrained_model_name_or_path))
else: else:
return AutoTokenizer.from_pretrained( return AutoTokenizer.from_pretrained(
pretrained_model_name_or_path, pretrained_model_name_or_path,
@ -606,7 +607,7 @@ ASYNC_REQUEST_FUNCS = {
} }
OPENAI_COMPATIBLE_BACKENDS = [ OPENAI_COMPATIBLE_BACKENDS = [
k for k, v in ASYNC_REQUEST_FUNCS.items() k
if v in (async_request_openai_completions, for k, v in ASYNC_REQUEST_FUNCS.items()
async_request_openai_chat_completions) if v in (async_request_openai_completions, async_request_openai_chat_completions)
] ]

View File

@ -82,14 +82,12 @@ class BenchmarkDataset(ABC):
self.dataset_path = dataset_path self.dataset_path = dataset_path
# Set the random seed, ensuring that a None value is replaced with the # Set the random seed, ensuring that a None value is replaced with the
# default seed. # default seed.
self.random_seed = (random_seed self.random_seed = random_seed if random_seed is not None else self.DEFAULT_SEED
if random_seed is not None else self.DEFAULT_SEED)
self.data = None self.data = None
def apply_multimodal_chat_transformation( def apply_multimodal_chat_transformation(
self, self, prompt: str, mm_content: Optional[MultiModalDataDict] = None
prompt: str, ) -> list[dict]:
mm_content: Optional[MultiModalDataDict] = None) -> list[dict]:
""" """
Transform a prompt and optional multimodal content into a chat format. Transform a prompt and optional multimodal content into a chat format.
This method is used for chat models that expect a specific conversation This method is used for chat models that expect a specific conversation
@ -111,8 +109,7 @@ class BenchmarkDataset(ABC):
NotImplementedError: If a subclass does not implement this method. NotImplementedError: If a subclass does not implement this method.
""" """
# TODO (jenniferzhao): add support for downloading data # TODO (jenniferzhao): add support for downloading data
raise NotImplementedError( raise NotImplementedError("load_data must be implemented in subclasses.")
"load_data must be implemented in subclasses.")
def get_random_lora_request( def get_random_lora_request(
self, self,
@ -158,8 +155,9 @@ class BenchmarkDataset(ABC):
return lora_request, lora_tokenizer_cache[lora_id] or tokenizer return lora_request, lora_tokenizer_cache[lora_id] or tokenizer
@abstractmethod @abstractmethod
def sample(self, tokenizer: PreTrainedTokenizerBase, def sample(
num_requests: int) -> list[SampleRequest]: self, tokenizer: PreTrainedTokenizerBase, num_requests: int
) -> list[SampleRequest]:
""" """
Abstract method to generate sample requests from the dataset. Abstract method to generate sample requests from the dataset.
@ -177,8 +175,9 @@ class BenchmarkDataset(ABC):
""" """
raise NotImplementedError("sample must be implemented in subclasses.") raise NotImplementedError("sample must be implemented in subclasses.")
def maybe_oversample_requests(self, requests: list[SampleRequest], def maybe_oversample_requests(
num_requests: int) -> None: self, requests: list[SampleRequest], num_requests: int
) -> None:
""" """
Oversamples the list of requests if its size is less than the desired Oversamples the list of requests if its size is less than the desired
number. number.
@ -189,11 +188,9 @@ class BenchmarkDataset(ABC):
""" """
if len(requests) < num_requests: if len(requests) < num_requests:
random.seed(self.random_seed) random.seed(self.random_seed)
additional = random.choices(requests, additional = random.choices(requests, k=num_requests - len(requests))
k=num_requests - len(requests))
requests.extend(additional) requests.extend(additional)
logger.info("Oversampled requests to reach %d total samples.", logger.info("Oversampled requests to reach %d total samples.", num_requests)
num_requests)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ -218,14 +215,14 @@ def is_valid_sequence(
""" """
# Check for invalid conditions # Check for invalid conditions
prompt_too_short = prompt_len < min_len prompt_too_short = prompt_len < min_len
output_too_short = (not skip_min_output_len_check) and (output_len output_too_short = (not skip_min_output_len_check) and (output_len < min_len)
< min_len)
prompt_too_long = prompt_len > max_prompt_len prompt_too_long = prompt_len > max_prompt_len
combined_too_long = (prompt_len + output_len) > max_total_len combined_too_long = (prompt_len + output_len) > max_total_len
# Return True if none of the invalid conditions are met # Return True if none of the invalid conditions are met
return not (prompt_too_short or output_too_short or prompt_too_long return not (
or combined_too_long) prompt_too_short or output_too_short or prompt_too_long or combined_too_long
)
@cache @cache
@ -257,28 +254,28 @@ def process_image(image: Any) -> Mapping[str, Any]:
Raises: Raises:
ValueError: If the input is not a supported type. ValueError: If the input is not a supported type.
""" """
if isinstance(image, dict) and 'bytes' in image: if isinstance(image, dict) and "bytes" in image:
image = Image.open(BytesIO(image['bytes'])) image = Image.open(BytesIO(image["bytes"]))
if isinstance(image, Image.Image): if isinstance(image, Image.Image):
image = image.convert("RGB") image = image.convert("RGB")
with io.BytesIO() as image_data: with io.BytesIO() as image_data:
image.save(image_data, format="JPEG") image.save(image_data, format="JPEG")
image_base64 = base64.b64encode( image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8")
image_data.getvalue()).decode("utf-8")
return { return {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"},
"url": f"data:image/jpeg;base64,{image_base64}"
},
} }
if isinstance(image, str): if isinstance(image, str):
image_url = (image if image.startswith( image_url = (
("http://", "file://")) else f"file://{image}") image if image.startswith(("http://", "file://")) else f"file://{image}"
)
return {"type": "image_url", "image_url": {"url": image_url}} return {"type": "image_url", "image_url": {"url": image_url}}
raise ValueError(f"Invalid image input {image}. Must be a PIL.Image.Image" raise ValueError(
" or str or dictionary with raw image bytes.") f"Invalid image input {image}. Must be a PIL.Image.Image"
" or str or dictionary with raw image bytes."
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ -315,42 +312,56 @@ class RandomDataset(BenchmarkDataset):
) )
vocab_size = tokenizer.vocab_size vocab_size = tokenizer.vocab_size
num_special_tokens = tokenizer.num_special_tokens_to_add()
real_input_len = input_len - num_special_tokens
prefix_token_ids = (np.random.randint( prefix_token_ids = (
0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else []) np.random.randint(0, vocab_size, size=prefix_len).tolist()
if prefix_len > 0
else []
)
# New sampling logic: [X * (1 - b), X * (1 + b)] # New sampling logic: [X * (1 - b), X * (1 + b)]
input_low = int(input_len * (1 - range_ratio)) input_low = int(real_input_len * (1 - range_ratio))
input_high = int(input_len * (1 + range_ratio)) input_high = int(real_input_len * (1 + range_ratio))
output_low = int(output_len * (1 - range_ratio)) output_low = int(output_len * (1 - range_ratio))
output_high = int(output_len * (1 + range_ratio)) output_high = int(output_len * (1 + range_ratio))
# Add logging for debugging # Add logging for debugging
logger.info("Sampling input_len from [%s, %s]", input_low, input_high) logger.info("Sampling input_len from [%s, %s]", input_low, input_high)
logger.info("Sampling output_len from [%s, %s]", output_low, logger.info("Sampling output_len from [%s, %s]", output_low, output_high)
output_high)
input_lens = np.random.randint(input_low, input_lens = np.random.randint(input_low, input_high + 1, size=num_requests)
input_high + 1, output_lens = np.random.randint(output_low, output_high + 1, size=num_requests)
size=num_requests)
output_lens = np.random.randint(output_low,
output_high + 1,
size=num_requests)
offsets = np.random.randint(0, vocab_size, size=num_requests) offsets = np.random.randint(0, vocab_size, size=num_requests)
requests = [] requests = []
for i in range(num_requests): for i in range(num_requests):
inner_seq = ((offsets[i] + i + np.arange(input_lens[i])) % inner_seq = (
vocab_size).tolist() (offsets[i] + i + np.arange(input_lens[i])) % vocab_size
).tolist()
token_sequence = prefix_token_ids + inner_seq token_sequence = prefix_token_ids + inner_seq
prompt = tokenizer.decode(token_sequence) prompt = tokenizer.decode(token_sequence)
# After decoding the prompt we have to encode and decode it again.
# This is done because in some cases N consecutive tokens
# give a string tokenized into != N number of tokens.
# For example for GPT2Tokenizer:
# [6880, 6881] -> ['Ġcalls', 'here'] ->
# [1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
# To avoid uncontrolled change of the prompt length,
# the encoded sequence is truncated before being decode again.
re_encoded_sequence = tokenizer.encode(prompt, add_special_tokens=False)[
: input_lens[i]
]
prompt = tokenizer.decode(re_encoded_sequence)
total_input_len = prefix_len + int(input_lens[i]) total_input_len = prefix_len + int(input_lens[i])
requests.append( requests.append(
SampleRequest( SampleRequest(
prompt=prompt, prompt=prompt,
prompt_len=total_input_len, prompt_len=total_input_len,
expected_output_len=int(output_lens[i]), expected_output_len=int(output_lens[i]),
)) )
)
return requests return requests
@ -377,7 +388,8 @@ class ShareGPTDataset(BenchmarkDataset):
self.data = json.load(f) self.data = json.load(f)
# Filter entries with at least two conversation turns. # Filter entries with at least two conversation turns.
self.data = [ self.data = [
entry for entry in self.data entry
for entry in self.data
if "conversations" in entry and len(entry["conversations"]) >= 2 if "conversations" in entry and len(entry["conversations"]) >= 2
] ]
random.seed(self.random_seed) random.seed(self.random_seed)
@ -403,27 +415,28 @@ class ShareGPTDataset(BenchmarkDataset):
) )
lora_request, tokenizer = self.get_random_lora_request( lora_request, tokenizer = self.get_random_lora_request(
tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path) tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path
)
prompt_ids = tokenizer(prompt).input_ids prompt_ids = tokenizer(prompt).input_ids
completion_ids = tokenizer(completion).input_ids completion_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_ids) prompt_len = len(prompt_ids)
new_output_len = (len(completion_ids) new_output_len = len(completion_ids) if output_len is None else output_len
if output_len is None else output_len) if not is_valid_sequence(
if not is_valid_sequence(prompt_len, prompt_len,
new_output_len, new_output_len,
skip_min_output_len_check=output_len skip_min_output_len_check=output_len is not None,
is not None): ):
continue continue
if enable_multimodal_chat: if enable_multimodal_chat:
prompt = self.apply_multimodal_chat_transformation( prompt = self.apply_multimodal_chat_transformation(prompt, None)
prompt, None)
samples.append( samples.append(
SampleRequest( SampleRequest(
prompt=prompt, prompt=prompt,
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=new_output_len, expected_output_len=new_output_len,
lora_request=lora_request, lora_request=lora_request,
)) )
)
self.maybe_oversample_requests(samples, num_requests) self.maybe_oversample_requests(samples, num_requests)
return samples return samples
@ -469,20 +482,20 @@ class SonnetDataset(BenchmarkDataset):
) -> list: ) -> list:
# Calculate average token length for a poem line. # Calculate average token length for a poem line.
tokenized_lines = [tokenizer(line).input_ids for line in self.data] tokenized_lines = [tokenizer(line).input_ids for line in self.data]
avg_len = sum(len(tokens) avg_len = sum(len(tokens) for tokens in tokenized_lines) / len(tokenized_lines)
for tokens in tokenized_lines) / len(tokenized_lines)
# Build the base prompt. # Build the base prompt.
base_prompt = "Pick as many lines as you can from these poem lines:\n" base_prompt = "Pick as many lines as you can from these poem lines:\n"
base_msg = [{"role": "user", "content": base_prompt}] base_msg = [{"role": "user", "content": base_prompt}]
base_fmt = tokenizer.apply_chat_template(base_msg, base_fmt = tokenizer.apply_chat_template(
add_generation_prompt=True, base_msg, add_generation_prompt=True, tokenize=False
tokenize=False) )
base_offset = len(tokenizer(base_fmt).input_ids) base_offset = len(tokenizer(base_fmt).input_ids)
if input_len <= base_offset: if input_len <= base_offset:
raise ValueError( raise ValueError(
f"'input_len' must be higher than the base prompt length " f"'input_len' must be higher than the base prompt length "
f"({base_offset}).") f"({base_offset})."
)
# Determine how many poem lines to use. # Determine how many poem lines to use.
num_input_lines = round((input_len - base_offset) / avg_len) num_input_lines = round((input_len - base_offset) / avg_len)
@ -491,21 +504,23 @@ class SonnetDataset(BenchmarkDataset):
samples = [] samples = []
while len(samples) < num_requests: while len(samples) < num_requests:
extra_lines = random.choices(self.data, extra_lines = random.choices(
k=num_input_lines - num_prefix_lines) self.data, k=num_input_lines - num_prefix_lines
)
prompt = f"{base_prompt}{''.join(prefix_lines + extra_lines)}" prompt = f"{base_prompt}{''.join(prefix_lines + extra_lines)}"
msg = [{"role": "user", "content": prompt}] msg = [{"role": "user", "content": prompt}]
prompt_formatted = tokenizer.apply_chat_template( prompt_formatted = tokenizer.apply_chat_template(
msg, add_generation_prompt=True, tokenize=False) msg, add_generation_prompt=True, tokenize=False
)
prompt_len = len(tokenizer(prompt_formatted).input_ids) prompt_len = len(tokenizer(prompt_formatted).input_ids)
if prompt_len <= input_len: if prompt_len <= input_len:
samples.append( samples.append(
SampleRequest( SampleRequest(
prompt=prompt_formatted prompt=prompt_formatted if return_prompt_formatted else prompt,
if return_prompt_formatted else prompt,
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
)) )
)
return samples return samples
@ -525,7 +540,9 @@ class BurstGPTDataset(BenchmarkDataset):
super().__init__(**kwargs) super().__init__(**kwargs)
self.load_data() self.load_data()
def load_data(self, ): def load_data(
self,
):
if self.dataset_path is None: if self.dataset_path is None:
raise ValueError("dataset_path must be provided for loading data.") raise ValueError("dataset_path must be provided for loading data.")
@ -539,8 +556,7 @@ class BurstGPTDataset(BenchmarkDataset):
def _sample_loaded_data(self, num_requests: int) -> list: def _sample_loaded_data(self, num_requests: int) -> list:
if num_requests <= len(self.data): if num_requests <= len(self.data):
data = self.data.sample(n=num_requests, data = self.data.sample(n=num_requests, random_state=self.random_seed)
random_state=self.random_seed)
else: else:
data = self.data.sample( data = self.data.sample(
n=num_requests, n=num_requests,
@ -564,7 +580,8 @@ class BurstGPTDataset(BenchmarkDataset):
input_len = int(data[i][2]) input_len = int(data[i][2])
output_len = int(data[i][3]) output_len = int(data[i][3])
lora_req, tokenizer = self.get_random_lora_request( lora_req, tokenizer = self.get_random_lora_request(
tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path) tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path
)
vocab_size = tokenizer.vocab_size vocab_size = tokenizer.vocab_size
# Generate a synthetic prompt: a list of token IDs computed as (i + # Generate a synthetic prompt: a list of token IDs computed as (i +
# j) modulo vocab_size. # j) modulo vocab_size.
@ -576,7 +593,8 @@ class BurstGPTDataset(BenchmarkDataset):
prompt_len=input_len, prompt_len=input_len,
expected_output_len=output_len, expected_output_len=output_len,
lora_request=lora_req, lora_request=lora_req,
)) )
)
return samples return samples
@ -619,20 +637,23 @@ class HuggingFaceDataset(BenchmarkDataset):
class ConversationDataset(HuggingFaceDataset): class ConversationDataset(HuggingFaceDataset):
"""Dataset for conversation data with multimodal support.""" """Dataset for conversation data with multimodal support."""
SUPPORTED_DATASET_PATHS = { SUPPORTED_DATASET_PATHS = {
'lmms-lab/LLaVA-OneVision-Data', 'Aeala/ShareGPT_Vicuna_unfiltered' "lmms-lab/LLaVA-OneVision-Data",
"Aeala/ShareGPT_Vicuna_unfiltered",
} }
IS_MULTIMODAL = True IS_MULTIMODAL = True
def sample(self, def sample(
tokenizer: PreTrainedTokenizerBase, self,
num_requests: int, tokenizer: PreTrainedTokenizerBase,
output_len: Optional[int] = None, num_requests: int,
enable_multimodal_chat: bool = False, output_len: Optional[int] = None,
**kwargs) -> list: enable_multimodal_chat: bool = False,
**kwargs,
) -> list:
# Filter examples with at least 2 conversations # Filter examples with at least 2 conversations
filtered_data = self.data.filter( filtered_data = self.data.filter(lambda x: len(x["conversations"]) >= 2)
lambda x: len(x["conversations"]) >= 2)
sampled_requests = [] sampled_requests = []
dynamic_output = output_len is None dynamic_output = output_len is None
@ -648,24 +669,22 @@ class ConversationDataset(HuggingFaceDataset):
completion_len = len(completion_ids) completion_len = len(completion_ids)
output_len = completion_len if dynamic_output else output_len output_len = completion_len if dynamic_output else output_len
assert isinstance(output_len, int) and output_len > 0 assert isinstance(output_len, int) and output_len > 0
if dynamic_output and not is_valid_sequence( if dynamic_output and not is_valid_sequence(prompt_len, completion_len):
prompt_len, completion_len):
continue continue
mm_content = process_image( mm_content = process_image(item["image"]) if "image" in item else None
item["image"]) if "image" in item else None
if enable_multimodal_chat: if enable_multimodal_chat:
# Note: when chat is enabled the request prompt_len is no longer # Note: when chat is enabled the request prompt_len is no longer
# accurate and we will be using request output to count the # accurate and we will be using request output to count the
# actual prompt len and output len # actual prompt len and output len
prompt = self.apply_multimodal_chat_transformation( prompt = self.apply_multimodal_chat_transformation(prompt, mm_content)
prompt, mm_content)
sampled_requests.append( sampled_requests.append(
SampleRequest( SampleRequest(
prompt=prompt, prompt=prompt,
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
multi_modal_data=mm_content, multi_modal_data=mm_content,
)) )
)
self.maybe_oversample_requests(sampled_requests, num_requests) self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests return sampled_requests
@ -682,10 +701,8 @@ class VisionArenaDataset(HuggingFaceDataset):
DEFAULT_OUTPUT_LEN = 128 DEFAULT_OUTPUT_LEN = 128
SUPPORTED_DATASET_PATHS = { SUPPORTED_DATASET_PATHS = {
"lmarena-ai/VisionArena-Chat": "lmarena-ai/VisionArena-Chat": lambda x: x["conversation"][0][0]["content"],
lambda x: x["conversation"][0][0]["content"], "lmarena-ai/vision-arena-bench-v0.1": lambda x: x["turns"][0][0]["content"],
"lmarena-ai/vision-arena-bench-v0.1":
lambda x: x["turns"][0][0]["content"]
} }
IS_MULTIMODAL = True IS_MULTIMODAL = True
@ -697,16 +714,14 @@ class VisionArenaDataset(HuggingFaceDataset):
enable_multimodal_chat: bool = False, enable_multimodal_chat: bool = False,
**kwargs, **kwargs,
) -> list: ) -> list:
output_len = (output_len output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
sampled_requests = [] sampled_requests = []
for item in self.data: for item in self.data:
if len(sampled_requests) >= num_requests: if len(sampled_requests) >= num_requests:
break break
parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path) parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path)
if parser_fn is None: if parser_fn is None:
raise ValueError( raise ValueError(f"Unsupported dataset path: {self.dataset_path}")
f"Unsupported dataset path: {self.dataset_path}")
prompt = parser_fn(item) prompt = parser_fn(item)
mm_content = process_image(item["images"][0]) mm_content = process_image(item["images"][0])
prompt_len = len(tokenizer(prompt).input_ids) prompt_len = len(tokenizer(prompt).input_ids)
@ -714,15 +729,15 @@ class VisionArenaDataset(HuggingFaceDataset):
# Note: when chat is enabled the request prompt_len is no longer # Note: when chat is enabled the request prompt_len is no longer
# accurate and we will be using request output to count the # accurate and we will be using request output to count the
# actual prompt len # actual prompt len
prompt = self.apply_multimodal_chat_transformation( prompt = self.apply_multimodal_chat_transformation(prompt, mm_content)
prompt, mm_content)
sampled_requests.append( sampled_requests.append(
SampleRequest( SampleRequest(
prompt=prompt, prompt=prompt,
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
multi_modal_data=mm_content, multi_modal_data=mm_content,
)) )
)
self.maybe_oversample_requests(sampled_requests, num_requests) self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests return sampled_requests
@ -747,14 +762,15 @@ class InstructCoderDataset(HuggingFaceDataset):
"likaixin/InstructCoder", "likaixin/InstructCoder",
} }
def sample(self, def sample(
tokenizer: PreTrainedTokenizerBase, self,
num_requests: int, tokenizer: PreTrainedTokenizerBase,
output_len: Optional[int] = None, num_requests: int,
enable_multimodal_chat: bool = False, output_len: Optional[int] = None,
**kwargs) -> list: enable_multimodal_chat: bool = False,
output_len = (output_len **kwargs,
if output_len is not None else self.DEFAULT_OUTPUT_LEN) ) -> list:
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
sampled_requests = [] sampled_requests = []
for item in self.data: for item in self.data:
if len(sampled_requests) >= num_requests: if len(sampled_requests) >= num_requests:
@ -766,7 +782,63 @@ class InstructCoderDataset(HuggingFaceDataset):
prompt=prompt, prompt=prompt,
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
)) )
)
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests
# -----------------------------------------------------------------------------
# MT-Bench Dataset Implementation
# -----------------------------------------------------------------------------
class MTBenchDataset(HuggingFaceDataset):
"""
MT-Bench Dataset.
https://huggingface.co/datasets/philschmid/mt-bench
We create a single turn dataset for MT-Bench.
This is similar to Spec decoding benchmark setup in vLLM
https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18
""" # noqa: E501
DEFAULT_OUTPUT_LEN = 256 # avg len used in SD bench in vLLM
SUPPORTED_DATASET_PATHS = {
"philschmid/mt-bench",
}
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
**kwargs,
) -> list:
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
sampled_requests = []
for item in self.data:
if len(sampled_requests) >= num_requests:
break
prompt = item["turns"][0]
# apply template
prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt=True,
tokenize=False,
)
prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
)
)
self.maybe_oversample_requests(sampled_requests, num_requests) self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests return sampled_requests
@ -780,23 +852,27 @@ class AIMODataset(HuggingFaceDataset):
""" """
Dataset class for processing a AIMO dataset with reasoning questions. Dataset class for processing a AIMO dataset with reasoning questions.
""" """
SUPPORTED_DATASET_PATHS = { SUPPORTED_DATASET_PATHS = {
"AI-MO/aimo-validation-aime", "AI-MO/NuminaMath-1.5", "AI-MO/aimo-validation-aime",
"AI-MO/NuminaMath-CoT" "AI-MO/NuminaMath-1.5",
"AI-MO/NuminaMath-CoT",
} }
def sample(self, def sample(
tokenizer: PreTrainedTokenizerBase, self,
num_requests: int, tokenizer: PreTrainedTokenizerBase,
output_len: Optional[int] = None, num_requests: int,
**kwargs) -> list: output_len: Optional[int] = None,
**kwargs,
) -> list:
sampled_requests = [] sampled_requests = []
dynamic_output = output_len is None dynamic_output = output_len is None
for item in self.data: for item in self.data:
if len(sampled_requests) >= num_requests: if len(sampled_requests) >= num_requests:
break break
prompt, completion = item['problem'], item["solution"] prompt, completion = item["problem"], item["solution"]
prompt_ids = tokenizer(prompt).input_ids prompt_ids = tokenizer(prompt).input_ids
completion_ids = tokenizer(completion).input_ids completion_ids = tokenizer(completion).input_ids
@ -804,10 +880,9 @@ class AIMODataset(HuggingFaceDataset):
completion_len = len(completion_ids) completion_len = len(completion_ids)
output_len = completion_len if dynamic_output else output_len output_len = completion_len if dynamic_output else output_len
assert isinstance(output_len, int) and output_len > 0 assert isinstance(output_len, int) and output_len > 0
if dynamic_output and not is_valid_sequence(prompt_len, if dynamic_output and not is_valid_sequence(
completion_len, prompt_len, completion_len, max_prompt_len=2048, max_total_len=32000
max_prompt_len=2048, ):
max_total_len=32000):
continue continue
sampled_requests.append( sampled_requests.append(
SampleRequest( SampleRequest(
@ -815,11 +890,100 @@ class AIMODataset(HuggingFaceDataset):
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
multi_modal_data=None, multi_modal_data=None,
)) )
)
self.maybe_oversample_requests(sampled_requests, num_requests) self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests return sampled_requests
# -----------------------------------------------------------------------------
# Next Edit Prediction Dataset Implementation
# -----------------------------------------------------------------------------
zeta_prompt = """### Instruction:
You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location.
### User Edits:
{}
### User Excerpt:
{}
### Response:
""" # noqa: E501
def _format_zeta_prompt(
sample: dict, original_start_marker: str = "<|editable_region_start|>"
) -> dict:
"""Format the zeta prompt for the Next Edit Prediction (NEP) dataset.
This function formats examples from the NEP dataset
into prompts and expected outputs. It could be
further extended to support more NEP datasets.
Args:
sample: The dataset sample containing events,
inputs, and outputs.
original_start_marker: The marker indicating the
start of the editable region. Defaults to
"<|editable_region_start|>".
Returns:
A dictionary with the formatted prompts and expected outputs.
"""
events = sample["events"]
input = sample["input"]
output = sample["output"]
prompt = zeta_prompt.format(events, input)
# following the original implementation, extract the focused region
# from the raw output
output_start_index = output.find(original_start_marker)
output_focused_region = output[output_start_index:]
expected_output = output_focused_region
return {"prompt": prompt, "expected_output": expected_output}
class NextEditPredictionDataset(HuggingFaceDataset):
"""
Dataset class for processing a Next Edit Prediction dataset.
"""
SUPPORTED_DATASET_PATHS = {
"zed-industries/zeta",
}
MAPPING_PROMPT_FUNCS = {
"zed-industries/zeta": _format_zeta_prompt,
}
def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int, **kwargs):
formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(self.dataset_path)
if formatting_prompt_func is None:
raise ValueError(f"Unsupported dataset path: {self.dataset_path}")
samples = []
for sample in self.data:
sample = formatting_prompt_func(sample)
samples.append(
SampleRequest(
prompt=sample["prompt"],
prompt_len=len(tokenizer(sample["prompt"]).input_ids),
expected_output_len=len(
tokenizer(sample["expected_output"]).input_ids
),
)
)
if len(samples) >= num_requests:
break
self.maybe_oversample_requests(samples, num_requests)
return samples
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# ASR Dataset Implementation # ASR Dataset Implementation
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ -842,18 +1006,22 @@ class ASRDataset(HuggingFaceDataset):
| AMI | Meetings | Spontaneous | ihm, sdm | | AMI | Meetings | Spontaneous | ihm, sdm |
+----------------+----------------------------------------+--------------------------+-----------------------------+ +----------------+----------------------------------------+--------------------------+-----------------------------+
""" # noqa: E501 """ # noqa: E501
SUPPORTED_DATASET_PATHS = { SUPPORTED_DATASET_PATHS = {
"openslr/librispeech_asr", "facebook/voxpopuli", "LIUM/tedlium", "openslr/librispeech_asr",
"edinburghcstr/ami", "speechcolab/gigaspeech", "kensho/spgispeech" "facebook/voxpopuli",
"LIUM/tedlium",
"edinburghcstr/ami",
"speechcolab/gigaspeech",
"kensho/spgispeech",
} }
DEFAULT_OUTPUT_LEN = 128 DEFAULT_OUTPUT_LEN = 128
IS_MULTIMODAL = True IS_MULTIMODAL = True
# TODO Whisper-specific. Abstract interface when more models are supported. # TODO Whisper-specific. Abstract interface when more models are supported.
TRANSCRIPTION_PREAMBLE = "<|startoftranscript|><|en|><|transcribe|>"\ TRANSCRIPTION_PREAMBLE = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>"
"<|notimestamps|>"
skip_long_audios: bool = True skip_long_audios: bool = True
def sample( def sample(
@ -864,8 +1032,8 @@ class ASRDataset(HuggingFaceDataset):
**kwargs, **kwargs,
) -> list: ) -> list:
import librosa import librosa
output_len = (output_len
if output_len is not None else self.DEFAULT_OUTPUT_LEN) output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
prompt = ASRDataset.TRANSCRIPTION_PREAMBLE prompt = ASRDataset.TRANSCRIPTION_PREAMBLE
prompt_len = len(tokenizer(prompt).input_ids) prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests = [] sampled_requests = []
@ -888,10 +1056,14 @@ class ASRDataset(HuggingFaceDataset):
prompt_len=prompt_len, prompt_len=prompt_len,
expected_output_len=output_len, expected_output_len=output_len,
multi_modal_data=mm_content, multi_modal_data=mm_content,
)) )
)
if skipped: if skipped:
logger.warning("%d samples discarded from dataset due to" \ logger.warning(
" their length being greater than" \ "%d samples discarded from dataset due to"
" what Whisper supports.", skipped) " their length being greater than"
" what Whisper supports.",
skipped,
)
self.maybe_oversample_requests(sampled_requests, num_requests) self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests return sampled_requests

View File

@ -11,9 +11,9 @@ from typing import Any, Optional
import numpy as np import numpy as np
import torch import torch
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
from tqdm import tqdm from tqdm import tqdm
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptType from vllm.inputs import PromptType
@ -21,13 +21,14 @@ from vllm.sampling_params import BeamSearchParams
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
def save_to_pytorch_benchmark_format(args: argparse.Namespace, def save_to_pytorch_benchmark_format(
results: dict[str, Any]) -> None: args: argparse.Namespace, results: dict[str, Any]
) -> None:
pt_records = convert_to_pytorch_benchmark_format( pt_records = convert_to_pytorch_benchmark_format(
args=args, args=args,
metrics={"latency": results["latencies"]}, metrics={"latency": results["latencies"]},
extra_info={k: results[k] extra_info={k: results[k] for k in ["avg_latency", "percentiles"]},
for k in ["avg_latency", "percentiles"]}) )
if pt_records: if pt_records:
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json" pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
write_to_json(pt_file, pt_records) write_to_json(pt_file, pt_records)
@ -42,9 +43,11 @@ def main(args: argparse.Namespace):
# the engine will automatically process the request in multiple batches. # the engine will automatically process the request in multiple batches.
llm = LLM(**dataclasses.asdict(engine_args)) llm = LLM(**dataclasses.asdict(engine_args))
assert llm.llm_engine.model_config.max_model_len >= ( assert llm.llm_engine.model_config.max_model_len >= (
args.input_len + args.input_len + args.output_len
args.output_len), ("Please ensure that max_model_len is greater than" ), (
" the sum of input_len and output_len.") "Please ensure that max_model_len is greater than"
" the sum of input_len and output_len."
)
sampling_params = SamplingParams( sampling_params = SamplingParams(
n=args.n, n=args.n,
@ -55,18 +58,16 @@ def main(args: argparse.Namespace):
detokenize=not args.disable_detokenize, detokenize=not args.disable_detokenize,
) )
print(sampling_params) print(sampling_params)
dummy_prompt_token_ids = np.random.randint(10000, dummy_prompt_token_ids = np.random.randint(
size=(args.batch_size, 10000, size=(args.batch_size, args.input_len)
args.input_len)) )
dummy_prompts: list[PromptType] = [{ dummy_prompts: list[PromptType] = [
"prompt_token_ids": batch {"prompt_token_ids": batch} for batch in dummy_prompt_token_ids.tolist()
} for batch in dummy_prompt_token_ids.tolist()] ]
def llm_generate(): def llm_generate():
if not args.use_beam_search: if not args.use_beam_search:
llm.generate(dummy_prompts, llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False)
sampling_params=sampling_params,
use_tqdm=False)
else: else:
llm.beam_search( llm.beam_search(
dummy_prompts, dummy_prompts,
@ -80,12 +81,13 @@ def main(args: argparse.Namespace):
def run_to_completion(profile_dir: Optional[str] = None): def run_to_completion(profile_dir: Optional[str] = None):
if profile_dir: if profile_dir:
with torch.profiler.profile( with torch.profiler.profile(
activities=[ activities=[
torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA, torch.profiler.ProfilerActivity.CUDA,
], ],
on_trace_ready=torch.profiler.tensorboard_trace_handler( on_trace_ready=torch.profiler.tensorboard_trace_handler(
str(profile_dir)), str(profile_dir)
),
) as p: ) as p:
llm_generate() llm_generate()
print(p.key_averages().table(sort_by="self_cuda_time_total")) print(p.key_averages().table(sort_by="self_cuda_time_total"))
@ -103,8 +105,9 @@ def main(args: argparse.Namespace):
if args.profile: if args.profile:
profile_dir = args.profile_result_dir profile_dir = args.profile_result_dir
if not profile_dir: if not profile_dir:
profile_dir = (Path(".") / "vllm_benchmark_result" / profile_dir = (
f"latency_result_{time.time()}") Path(".") / "vllm_benchmark_result" / f"latency_result_{time.time()}"
)
print(f"Profiling (results will be saved to '{profile_dir}')...") print(f"Profiling (results will be saved to '{profile_dir}')...")
run_to_completion(profile_dir=profile_dir) run_to_completion(profile_dir=profile_dir)
return return
@ -135,7 +138,8 @@ def main(args: argparse.Namespace):
if __name__ == "__main__": if __name__ == "__main__":
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description="Benchmark the latency of processing a single batch of " description="Benchmark the latency of processing a single batch of "
"requests till completion.") "requests till completion."
)
parser.add_argument("--input-len", type=int, default=32) parser.add_argument("--input-len", type=int, default=32)
parser.add_argument("--output-len", type=int, default=128) parser.add_argument("--output-len", type=int, default=128)
parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--batch-size", type=int, default=8)
@ -152,10 +156,9 @@ if __name__ == "__main__":
default=10, default=10,
help="Number of iterations to run for warmup.", help="Number of iterations to run for warmup.",
) )
parser.add_argument("--num-iters", parser.add_argument(
type=int, "--num-iters", type=int, default=30, help="Number of iterations to run."
default=30, )
help="Number of iterations to run.")
parser.add_argument( parser.add_argument(
"--profile", "--profile",
action="store_true", action="store_true",
@ -165,8 +168,10 @@ if __name__ == "__main__":
"--profile-result-dir", "--profile-result-dir",
type=str, type=str,
default=None, default=None,
help=("path to save the pytorch profiler output. Can be visualized " help=(
"with ui.perfetto.dev or Tensorboard."), "path to save the pytorch profiler output. Can be visualized "
"with ui.perfetto.dev or Tensorboard."
),
) )
parser.add_argument( parser.add_argument(
"--output-json", "--output-json",
@ -177,8 +182,10 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--disable-detokenize", "--disable-detokenize",
action="store_true", action="store_true",
help=("Do not detokenize responses (i.e. do not include " help=(
"detokenization time in the latency measurement)"), "Do not detokenize responses (i.e. do not include "
"detokenization time in the latency measurement)"
),
) )
parser = EngineArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)

View File

@ -86,20 +86,21 @@ def repeat_prompts(prompts, repeat_count, mode: str):
ValueError: If an invalid mode is provided. ValueError: If an invalid mode is provided.
""" """
print("Repeat mode: ", mode) print("Repeat mode: ", mode)
if mode == 'random': if mode == "random":
repeated_prompts = prompts * repeat_count repeated_prompts = prompts * repeat_count
random.shuffle(repeated_prompts) random.shuffle(repeated_prompts)
return repeated_prompts return repeated_prompts
elif mode == 'tile': elif mode == "tile":
return prompts * repeat_count return prompts * repeat_count
elif mode == 'interleave': elif mode == "interleave":
repeated_prompts = [] repeated_prompts = []
for prompt in prompts: for prompt in prompts:
repeated_prompts.extend([prompt] * repeat_count) repeated_prompts.extend([prompt] * repeat_count)
return repeated_prompts return repeated_prompts
else: else:
raise ValueError(f"Invalid mode: {mode}, only support " raise ValueError(
"'random', 'tile', 'interleave'") f"Invalid mode: {mode}, only support 'random', 'tile', 'interleave'"
)
def main(args): def main(args):
@ -109,16 +110,16 @@ def main(args):
# we append the document id at the beginning to avoid any of the document # we append the document id at the beginning to avoid any of the document
# being the prefix of other documents # being the prefix of other documents
prompts = [ prompts = [
str(i) + ' '.join(['hi'] * args.document_length) str(i) + " ".join(["hi"] * args.document_length)
for i in range(args.num_documents) for i in range(args.num_documents)
] ]
prompts = repeat_prompts(prompts, args.repeat_count, mode=args.repeat_mode) prompts = repeat_prompts(prompts, args.repeat_count, mode=args.repeat_mode)
warmup_prompts = [ warmup_prompts = [
"This is warm up request " + str(i) + \ "This is warm up request " + str(i) + " ".join(["hi"] * args.document_length)
' '.join(['hi'] * args.document_length) for i in range(args.num_documents)
for i in range(args.num_documents)] ]
# Create the LLM engine # Create the LLM engine
engine_args = EngineArgs.from_cli_args(args) engine_args = EngineArgs.from_cli_args(args)
@ -142,42 +143,52 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description= description="Benchmark the performance with or "
'Benchmark the performance with or without automatic prefix caching.') "without automatic prefix caching."
)
parser.add_argument( parser.add_argument(
'--document-length', "--document-length",
type=int, type=int,
# Roughly the number of tokens for a system paper, # Roughly the number of tokens for a system paper,
# excluding images # excluding images
default=20000, default=20000,
help='Range of input lengths for sampling prompts,' help="Range of input lengths for sampling prompts, "
'specified as "min:max" (e.g., "128:256").') 'specified as "min:max" (e.g., "128:256").',
)
parser.add_argument('--num-documents', parser.add_argument(
type=int, "--num-documents",
default=8, type=int,
help='Range of input lengths for sampling prompts,' default=8,
'specified as "min:max" (e.g., "128:256").') help="Range of input lengths for sampling prompts, "
'specified as "min:max" (e.g., "128:256").',
)
parser.add_argument('--output-len', type=int, default=10) parser.add_argument("--output-len", type=int, default=10)
parser.add_argument('--repeat-count', parser.add_argument(
type=int, "--repeat-count",
default=2, type=int,
help='Number of times to repeat each prompt') default=2,
help="Number of times to repeat each prompt",
)
parser.add_argument("--repeat-mode", parser.add_argument(
type=str, "--repeat-mode",
default='random', type=str,
help='The mode to repeat prompts. The supported ' default="random",
'modes are "random", "tile", and "interleave". ' help="The mode to repeat prompts. The supported "
'See repeat_prompts() in the source code for details.') 'modes are "random", "tile", and "interleave". '
"See repeat_prompts() in the source code for details.",
)
parser.add_argument("--shuffle-seed", parser.add_argument(
type=int, "--shuffle-seed",
default=0, type=int,
help='Random seed when the repeat mode is "random"') default=0,
help='Random seed when the repeat mode is "random"',
)
parser = EngineArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()

View File

@ -63,8 +63,7 @@ class Request:
output_len: int output_len: int
def sample_tokens(tokenizer: PreTrainedTokenizerBase, def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> list[int]:
length: int) -> list[int]:
vocab = tokenizer.get_vocab() vocab = tokenizer.get_vocab()
all_special_ids = set(tokenizer.all_special_ids) all_special_ids = set(tokenizer.all_special_ids)
@ -91,8 +90,10 @@ def sample_requests_from_dataset(
# Filter out the conversations with less than 2 turns. # Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2] dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Only keep the first two turns of each conversation. # Only keep the first two turns of each conversation.
dataset = [(data["conversations"][0]["value"], dataset = [
data["conversations"][1]["value"]) for data in dataset] (data["conversations"][0]["value"], data["conversations"][1]["value"])
for data in dataset
]
# Shuffle the dataset. # Shuffle the dataset.
random.shuffle(dataset) random.shuffle(dataset)
@ -113,8 +114,9 @@ def sample_requests_from_dataset(
completion = dataset[i][1] completion = dataset[i][1]
completion_token_ids = tokenizer(completion).input_ids completion_token_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids) prompt_len = len(prompt_token_ids)
output_len = (len(completion_token_ids) output_len = (
if fixed_output_len is None else fixed_output_len) len(completion_token_ids) if fixed_output_len is None else fixed_output_len
)
if min_len <= prompt_len <= max_len: if min_len <= prompt_len <= max_len:
filtered_requests.append(Request(prompt, prompt_len, output_len)) filtered_requests.append(Request(prompt, prompt_len, output_len))
@ -128,27 +130,27 @@ def sample_requests_from_random(
fixed_output_len: Optional[int], fixed_output_len: Optional[int],
prefix_len: int, prefix_len: int,
) -> list[Request]: ) -> list[Request]:
requests = [] requests = []
prefix_token_ids = sample_tokens(tokenizer, prefix_len) prefix_token_ids = sample_tokens(tokenizer, prefix_len)
min_len, max_len = input_length_range min_len, max_len = input_length_range
for i in range(num_requests): for i in range(num_requests):
unique_part_token_ids = sample_tokens( unique_part_token_ids = sample_tokens(
tokenizer, tokenizer, random.randint(min_len - prefix_len, max_len - prefix_len)
random.randint(min_len - prefix_len, max_len - prefix_len)) )
prompt_token_ids = prefix_token_ids + unique_part_token_ids prompt_token_ids = prefix_token_ids + unique_part_token_ids
prompt = tokenizer.decode(prompt_token_ids) prompt = tokenizer.decode(prompt_token_ids)
prompt_len = len(prompt_token_ids) prompt_len = len(prompt_token_ids)
assert (min_len <= prompt_len <= max_len assert min_len <= prompt_len <= max_len, (
), f"prompt_len {prompt_len} out of range {min_len}:{max_len}" f"prompt_len {prompt_len} out of range {min_len}:{max_len}"
)
requests.append(Request(prompt, prompt_len, fixed_output_len)) requests.append(Request(prompt, prompt_len, fixed_output_len))
return requests return requests
def repeat_and_sort_requests(requests: list[Request], def repeat_and_sort_requests(
repeat_count: int, requests: list[Request], repeat_count: int, sort: bool = False
sort: bool = False) -> list[str]: ) -> list[str]:
repeated_requests = requests * repeat_count repeated_requests = requests * repeat_count
if sort: if sort:
repeated_requests.sort(key=lambda x: x[1]) repeated_requests.sort(key=lambda x: x[1])
@ -159,14 +161,14 @@ def repeat_and_sort_requests(requests: list[Request],
def main(args): def main(args):
tokenizer = get_tokenizer(args.model, trust_remote_code=True) tokenizer = get_tokenizer(args.model, trust_remote_code=True)
input_length_range = tuple(map(int, args.input_length_range.split(':'))) input_length_range = tuple(map(int, args.input_length_range.split(":")))
random.seed(args.seed) random.seed(args.seed)
if args.dataset_path is not None: if args.dataset_path is not None:
if args.prefix_len > 0: if args.prefix_len > 0:
raise ValueError("prefix-len is not supported when " raise ValueError(
"dataset-path is provided.") "prefix-len is not supported when dataset-path is provided."
print(f"Start to sample {args.num_prompts} prompts " )
f"from {args.dataset_path}") print(f"Start to sample {args.num_prompts} prompts from {args.dataset_path}")
filtered_requests = sample_requests_from_dataset( filtered_requests = sample_requests_from_dataset(
dataset_path=args.dataset_path, dataset_path=args.dataset_path,
num_requests=args.num_prompts, num_requests=args.num_prompts,
@ -196,14 +198,16 @@ def main(args):
llm = LLM(**dataclasses.asdict(engine_args)) llm = LLM(**dataclasses.asdict(engine_args))
sampling_params = SamplingParams(temperature=0, sampling_params = SamplingParams(
max_tokens=args.output_len, temperature=0,
detokenize=not args.disable_detokenize) max_tokens=args.output_len,
detokenize=not args.disable_detokenize,
)
print("Testing filtered requests") print("Testing filtered requests")
prompts = repeat_and_sort_requests(filtered_requests, prompts = repeat_and_sort_requests(
repeat_count=args.repeat_count, filtered_requests, repeat_count=args.repeat_count, sort=args.sort
sort=args.sort) )
print("------start generating------") print("------start generating------")
test_prefix( test_prefix(
@ -215,29 +219,35 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description= description="Benchmark the performance with or without "
'Benchmark the performance with or without automatic prefix caching.') "automatic prefix caching."
parser.add_argument("--dataset-path", )
type=str, parser.add_argument(
default=None, "--dataset-path", type=str, default=None, help="Path to the dataset."
help="Path to the dataset.") )
parser.add_argument('--output-len', type=int, default=10) parser.add_argument("--output-len", type=int, default=10)
parser.add_argument('--num-prompts', parser.add_argument(
type=int, "--num-prompts",
required=True, type=int,
help="Number of the prompts sampled from dataset") required=True,
parser.add_argument('--repeat-count', help="Number of the prompts sampled from dataset",
type=int, )
default=1, parser.add_argument(
help='Number of times to repeat each prompt') "--repeat-count",
parser.add_argument('--sort', type=int,
action='store_true', default=1,
help='Sort prompts by input length') help="Number of times to repeat each prompt",
parser.add_argument('--input-length-range', )
type=str, parser.add_argument(
required=True, "--sort", action="store_true", help="Sort prompts by input length"
help='Range of input lengths for sampling prompts,' )
'specified as "min:max" (e.g., "128:256").') parser.add_argument(
"--input-length-range",
type=str,
required=True,
help="Range of input lengths for sampling prompts,"
'specified as "min:max" (e.g., "128:256").',
)
parser.add_argument( parser.add_argument(
"--prefix-len", "--prefix-len",
type=int, type=int,
@ -248,10 +258,12 @@ if __name__ == "__main__":
"when dataset-path is not provided.", "when dataset-path is not provided.",
) )
parser.add_argument( parser.add_argument(
'--disable-detokenize', "--disable-detokenize",
action='store_true', action="store_true",
help=("Do not detokenize responses (i.e. do not include " help=(
"detokenization time in the latency measurement)"), "Do not detokenize responses (i.e. do not include "
"detokenization time in the latency measurement)"
),
) )
parser = EngineArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Benchmark offline prioritization.""" """Benchmark offline prioritization."""
import argparse import argparse
import dataclasses import dataclasses
import json import json
@ -13,7 +14,7 @@ from vllm.engine.arg_utils import EngineArgs
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
#Select a equi-probable random priority # Select a equi-probable random priority
def get_random_flag(): def get_random_flag():
return 0 if random.random() < 0.5 else 1 return 0 if random.random() < 0.5 else 1
@ -33,8 +34,10 @@ def sample_requests(
# Filter out the conversations with less than 2 turns. # Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2] dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Only keep the first two turns of each conversation. # Only keep the first two turns of each conversation.
dataset = [(data["conversations"][0]["value"], dataset = [
data["conversations"][1]["value"]) for data in dataset] (data["conversations"][0]["value"], data["conversations"][1]["value"])
for data in dataset
]
# Shuffle the dataset. # Shuffle the dataset.
random.shuffle(dataset) random.shuffle(dataset)
@ -51,8 +54,9 @@ def sample_requests(
completion = dataset[i][1] completion = dataset[i][1]
completion_token_ids = tokenizer(completion).input_ids completion_token_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids) prompt_len = len(prompt_token_ids)
output_len = len(completion_token_ids output_len = (
) if fixed_output_len is None else fixed_output_len len(completion_token_ids) if fixed_output_len is None else fixed_output_len
)
if prompt_len < 4 or output_len < 4: if prompt_len < 4 or output_len < 4:
# Prune too short sequences. # Prune too short sequences.
continue continue
@ -74,13 +78,16 @@ def run_vllm(
disable_detokenize: bool = False, disable_detokenize: bool = False,
) -> float: ) -> float:
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
llm = LLM(**dataclasses.asdict(engine_args)) llm = LLM(**dataclasses.asdict(engine_args))
assert all( assert all(
llm.llm_engine.model_config.max_model_len >= (request[1] + request[2]) llm.llm_engine.model_config.max_model_len >= (request[1] + request[2])
for request in requests), ( for request in requests
"Please ensure that max_model_len is greater than the sum of" ), (
" input_len and output_len for all requests.") "Please ensure that max_model_len is greater than the sum of"
" input_len and output_len for all requests."
)
# Add the requests to the engine. # Add the requests to the engine.
prompts = [] prompts = []
@ -97,7 +104,8 @@ def run_vllm(
ignore_eos=True, ignore_eos=True,
max_tokens=output_len, max_tokens=output_len,
detokenize=not disable_detokenize, detokenize=not disable_detokenize,
)) )
)
start = time.perf_counter() start = time.perf_counter()
llm.generate(prompts, sampling_params, priority=priority, use_tqdm=True) llm.generate(prompts, sampling_params, priority=priority, use_tqdm=True)
@ -111,26 +119,33 @@ def main(args: argparse.Namespace):
# Sample the requests. # Sample the requests.
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer, trust_remote_code=args.trust_remote_code) args.tokenizer, trust_remote_code=args.trust_remote_code
)
if args.dataset is None: if args.dataset is None:
# Synthesize a prompt with the given input length. # Synthesize a prompt with the given input length.
prompt = "hi" * (args.input_len - 1) prompt = "hi" * (args.input_len - 1)
requests = [(prompt, args.input_len, args.output_len, requests = [
get_random_flag()) for _ in range(args.num_prompts)] (prompt, args.input_len, args.output_len, get_random_flag())
for _ in range(args.num_prompts)
]
else: else:
requests = sample_requests(args.dataset, args.num_prompts, tokenizer, requests = sample_requests(
args.output_len) args.dataset, args.num_prompts, tokenizer, args.output_len
)
if args.backend == "vllm": if args.backend == "vllm":
elapsed_time = run_vllm(requests, args.n, elapsed_time = run_vllm(
EngineArgs.from_cli_args(args), requests, args.n, EngineArgs.from_cli_args(args), args.disable_detokenize
args.disable_detokenize) )
else: else:
raise ValueError(f"Unknown backend: {args.backend}") raise ValueError(f"Unknown backend: {args.backend}")
total_num_tokens = sum(prompt_len + output_len total_num_tokens = sum(
for _, prompt_len, output_len, priority in requests) prompt_len + output_len for _, prompt_len, output_len, priority in requests
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " )
f"{total_num_tokens / elapsed_time:.2f} tokens/s") print(
f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} tokens/s"
)
# Output JSON results if specified # Output JSON results if specified
if args.output_json: if args.output_json:
@ -147,41 +162,44 @@ def main(args: argparse.Namespace):
if __name__ == "__main__": if __name__ == "__main__":
parser = FlexibleArgumentParser(description="Benchmark the throughput.") parser = FlexibleArgumentParser(description="Benchmark the throughput.")
parser.add_argument("--backend",
type=str,
choices=["vllm", "hf", "mii"],
default="vllm")
parser.add_argument("--dataset",
type=str,
default=None,
help="Path to the dataset.")
parser.add_argument("--input-len",
type=int,
default=None,
help="Input prompt length for each request")
parser.add_argument("--output-len",
type=int,
default=None,
help="Output length for each request. Overrides the "
"output length from the dataset.")
parser.add_argument("--n",
type=int,
default=1,
help="Number of generated sequences per prompt.")
parser.add_argument("--num-prompts",
type=int,
default=200,
help="Number of prompts to process.")
parser.add_argument( parser.add_argument(
'--output-json', "--backend", type=str, choices=["vllm", "hf", "mii"], default="vllm"
)
parser.add_argument(
"--dataset", type=str, default=None, help="Path to the dataset."
)
parser.add_argument(
"--input-len",
type=int,
default=None,
help="Input prompt length for each request",
)
parser.add_argument(
"--output-len",
type=int,
default=None,
help="Output length for each request. Overrides the "
"output length from the dataset.",
)
parser.add_argument(
"--n", type=int, default=1, help="Number of generated sequences per prompt."
)
parser.add_argument(
"--num-prompts", type=int, default=200, help="Number of prompts to process."
)
parser.add_argument(
"--output-json",
type=str, type=str,
default=None, default=None,
help='Path to save the throughput results in JSON format.') help="Path to save the throughput results in JSON format.",
)
parser.add_argument( parser.add_argument(
'--disable-detokenize', "--disable-detokenize",
action='store_true', action="store_true",
help=("Do not detokenize responses (i.e. do not include " help=(
"detokenization time in the latency measurement)"), "Do not detokenize responses (i.e. do not include "
"detokenization time in the latency measurement)"
),
) )
parser = EngineArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)

File diff suppressed because it is too large Load Diff

View File

@ -19,6 +19,7 @@ On the client side, run:
--endpoint /generate_stream --endpoint /generate_stream
to the end of the command above. to the end of the command above.
""" """
import argparse import argparse
import asyncio import asyncio
import copy import copy
@ -36,11 +37,15 @@ from typing import Optional
import datasets import datasets
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
RequestFuncOutput)
from tqdm.asyncio import tqdm from tqdm.asyncio import tqdm
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from backend_request_func import (
ASYNC_REQUEST_FUNCS,
RequestFuncInput,
RequestFuncOutput,
)
try: try:
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
except ImportError: except ImportError:
@ -52,7 +57,8 @@ except ImportError:
from argparse import ArgumentParser as FlexibleArgumentParser from argparse import ArgumentParser as FlexibleArgumentParser
from vllm.v1.structured_output.backend_xgrammar import ( from vllm.v1.structured_output.backend_xgrammar import (
has_xgrammar_unsupported_json_features) has_xgrammar_unsupported_json_features,
)
MILLISECONDS_TO_SECONDS_CONVERSION = 1000 MILLISECONDS_TO_SECONDS_CONVERSION = 1000
@ -98,6 +104,7 @@ class SampleRequest:
prompt_len: The length of the prompt in tokens. prompt_len: The length of the prompt in tokens.
expected_output_len: The expected length of the output in tokens. expected_output_len: The expected length of the output in tokens.
""" """
prompt: str prompt: str
prompt_len: int prompt_len: int
expected_output_len: int expected_output_len: int
@ -106,45 +113,45 @@ class SampleRequest:
completion: str = None completion: str = None
def sample_requests(tokenizer: PreTrainedTokenizerBase, def sample_requests(
args: argparse.Namespace) -> list[SampleRequest]: tokenizer: PreTrainedTokenizerBase, args: argparse.Namespace
if args.dataset == 'json' or args.dataset == 'json-unique': ) -> list[SampleRequest]:
if args.dataset == "json" or args.dataset == "json-unique":
if args.json_schema_path is None: if args.json_schema_path is None:
dir_path = os.path.dirname(os.path.realpath(__file__)) dir_path = os.path.dirname(os.path.realpath(__file__))
args.json_schema_path = os.path.join(dir_path, args.json_schema_path = os.path.join(
"structured_schemas", dir_path, "structured_schemas", "structured_schema_1.json"
"structured_schema_1.json") )
json_schemas = [] json_schemas = []
with open(args.json_schema_path) as f: with open(args.json_schema_path) as f:
schema = json.load(f) schema = json.load(f)
if args.dataset == 'json-unique': if args.dataset == "json-unique":
json_schemas = [ json_schemas = [copy.deepcopy(schema) for _ in range(args.num_prompts)]
copy.deepcopy(schema) for _ in range(args.num_prompts)
]
for i in range(len(json_schemas)): for i in range(len(json_schemas)):
json_schemas[i]["properties"][ if "properties" not in json_schemas[i]:
f"__optional_field_{uuid.uuid4()}"] = { json_schemas[i]["properties"] = {}
"type": json_schemas[i]["properties"][f"__optional_field_{uuid.uuid4()}"] = {
"string", "type": "string",
"description": "description": "An unique optional field to avoid cached schemas",
"An unique optional field to avoid cached schemas" }
}
else: else:
json_schemas = [schema] * args.num_prompts json_schemas = [schema] * args.num_prompts
def gen_prompt(index: int): def gen_prompt(index: int):
return f"Generate an example of a user profile given the following schema: {json.dumps(get_schema(index))}" # noqa: E501 return f"Generate an example of a brief user profile given the following schema: {json.dumps(get_schema(index))}" # noqa: E501
def get_schema(index: int): def get_schema(index: int):
return json_schemas[index % len(json_schemas)] return json_schemas[index % len(json_schemas)]
requests = [ requests = [
SampleRequest(prompt=gen_prompt(i), SampleRequest(
prompt_len=len(tokenizer(gen_prompt(i)).input_ids), prompt=gen_prompt(i),
expected_output_len=args.output_len, prompt_len=len(tokenizer(gen_prompt(i)).input_ids),
schema=get_schema(i), expected_output_len=args.output_len,
structure_type=args.structure_type) schema=get_schema(i),
structure_type=args.structure_type,
)
for i in range(args.num_prompts) for i in range(args.num_prompts)
] ]
@ -168,11 +175,13 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
input_len = len(tokenizer(prompt).input_ids) input_len = len(tokenizer(prompt).input_ids)
print(f"Input length of the prompt: {input_len} tokens") print(f"Input length of the prompt: {input_len} tokens")
requests = [ requests = [
SampleRequest(prompt=prompt, SampleRequest(
prompt_len=input_len, prompt=prompt,
expected_output_len=args.output_len, prompt_len=input_len,
schema=schema, expected_output_len=args.output_len,
structure_type=args.structure_type) schema=schema,
structure_type=args.structure_type,
)
for _ in range(args.num_prompts) for _ in range(args.num_prompts)
] ]
@ -186,11 +195,13 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
input_len = len(tokenizer(prompt).input_ids) input_len = len(tokenizer(prompt).input_ids)
print(f"Input length of the prompt: {input_len} tokens") print(f"Input length of the prompt: {input_len} tokens")
requests = [ requests = [
SampleRequest(prompt=prompt, SampleRequest(
prompt_len=input_len, prompt=prompt,
expected_output_len=args.output_len, prompt_len=input_len,
schema=regex, expected_output_len=args.output_len,
structure_type=args.structure_type) schema=regex,
structure_type=args.structure_type,
)
for _ in range(args.num_prompts) for _ in range(args.num_prompts)
] ]
@ -201,47 +212,55 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
input_len = len(tokenizer(prompt).input_ids) input_len = len(tokenizer(prompt).input_ids)
print(f"Input length of the prompt: {input_len} tokens") print(f"Input length of the prompt: {input_len} tokens")
requests = [ requests = [
SampleRequest(prompt=prompt, SampleRequest(
prompt_len=input_len, prompt=prompt,
expected_output_len=args.output_len, prompt_len=input_len,
schema=choice, expected_output_len=args.output_len,
structure_type=args.structure_type) schema=choice,
structure_type=args.structure_type,
)
for _ in range(args.num_prompts) for _ in range(args.num_prompts)
] ]
elif args.dataset == "xgrammar_bench": elif args.dataset == "xgrammar_bench":
requests: list[SampleRequest] = [] requests: list[SampleRequest] = []
dataset = datasets.load_dataset("NousResearch/json-mode-eval", dataset = datasets.load_dataset("NousResearch/json-mode-eval", split="train")
split="train")
full_dataset_len = len(dataset) full_dataset_len = len(dataset)
def _filter_func(item): def _filter_func(item):
import json import json
schema = json.loads(item["schema"]) schema = json.loads(item["schema"])
return not has_xgrammar_unsupported_json_features(schema) return not has_xgrammar_unsupported_json_features(schema)
dataset = dataset.filter(_filter_func) dataset = dataset.filter(_filter_func)
num_filtered_out = full_dataset_len - len(dataset) num_filtered_out = full_dataset_len - len(dataset)
print(f"dataset has {len(dataset)} entries after filtering " print(
f"out {num_filtered_out} entries with unsupported features") f"dataset has {len(dataset)} entries after filtering "
f"out {num_filtered_out} entries with unsupported features"
)
len_dataset = len(dataset) len_dataset = len(dataset)
for data_point_idx in range(args.num_prompts): for data_point_idx in range(args.num_prompts):
idx = data_point_idx idx = data_point_idx
while idx >= len_dataset: while idx >= len_dataset:
idx -= len_dataset idx -= len_dataset
schema = dataset["schema"][idx] schema = dataset["schema"][idx]
prompt = tokenizer.apply_chat_template(dataset["prompt"][idx], prompt = tokenizer.apply_chat_template(
tokenize=False) dataset["prompt"][idx], tokenize=False, add_generation_prompt=True
)
input_len = len(tokenizer(prompt).input_ids) input_len = len(tokenizer(prompt).input_ids)
completion = dataset["completion"][idx] completion = dataset["completion"][idx]
requests.append( requests.append(
SampleRequest(prompt=prompt, SampleRequest(
prompt_len=input_len, prompt=prompt,
expected_output_len=args.output_len, prompt_len=input_len,
schema=schema, expected_output_len=args.output_len,
structure_type=args.structure_type, schema=schema,
completion=completion)) structure_type=args.structure_type,
completion=completion,
)
)
return requests return requests
@ -273,7 +292,8 @@ async def get_request(
# Calculate scale parameter theta to maintain the desired request_rate. # Calculate scale parameter theta to maintain the desired request_rate.
assert burstiness > 0, ( assert burstiness > 0, (
f"A positive burstiness factor is expected, but given {burstiness}.") f"A positive burstiness factor is expected, but given {burstiness}."
)
theta = 1.0 / (request_rate * burstiness) theta = 1.0 / (request_rate * burstiness)
for i, request in enumerate(input_requests): for i, request in enumerate(input_requests):
@ -315,8 +335,8 @@ def calculate_metrics(
# multiple output tokens may be bundled together # multiple output tokens may be bundled together
# Note : this may inflate the output token count slightly # Note : this may inflate the output token count slightly
output_len = len( output_len = len(
tokenizer(outputs[i].generated_text, tokenizer(outputs[i].generated_text, add_special_tokens=False).input_ids
add_special_tokens=False).input_ids) )
actual_output_lens.append(output_len) actual_output_lens.append(output_len)
total_input += input_requests[i].prompt_len total_input += input_requests[i].prompt_len
tpot = 0 tpot = 0
@ -340,16 +360,19 @@ def calculate_metrics(
if "ttft" in goodput_config_dict: if "ttft" in goodput_config_dict:
valid_metrics.append(ttfts) valid_metrics.append(ttfts)
slo_values.append(goodput_config_dict["ttft"] / slo_values.append(
MILLISECONDS_TO_SECONDS_CONVERSION) goodput_config_dict["ttft"] / MILLISECONDS_TO_SECONDS_CONVERSION
)
if "tpot" in goodput_config_dict: if "tpot" in goodput_config_dict:
valid_metrics.append(all_tpots) valid_metrics.append(all_tpots)
slo_values.append(goodput_config_dict["tpot"] / slo_values.append(
MILLISECONDS_TO_SECONDS_CONVERSION) goodput_config_dict["tpot"] / MILLISECONDS_TO_SECONDS_CONVERSION
)
if "e2el" in goodput_config_dict: if "e2el" in goodput_config_dict:
valid_metrics.append(e2els) valid_metrics.append(e2els)
slo_values.append(goodput_config_dict["e2el"] / slo_values.append(
MILLISECONDS_TO_SECONDS_CONVERSION) goodput_config_dict["e2el"] / MILLISECONDS_TO_SECONDS_CONVERSION
)
for req_metric in zip(*valid_metrics): for req_metric in zip(*valid_metrics):
is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)]) is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)])
@ -360,7 +383,8 @@ def calculate_metrics(
warnings.warn( warnings.warn(
"All requests failed. This is likely due to a misconfiguration " "All requests failed. This is likely due to a misconfiguration "
"on the benchmark arguments.", "on the benchmark arguments.",
stacklevel=2) stacklevel=2,
)
metrics = BenchmarkMetrics( metrics = BenchmarkMetrics(
completed=completed, completed=completed,
total_input=total_input, total_input=total_input,
@ -369,27 +393,31 @@ def calculate_metrics(
request_goodput=good_completed / dur_s, request_goodput=good_completed / dur_s,
output_throughput=sum(actual_output_lens) / dur_s, output_throughput=sum(actual_output_lens) / dur_s,
total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s, total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s,
mean_ttft_ms=np.mean(ttfts or 0) * mean_ttft_ms=np.mean(ttfts or 0)
1000, # ttfts is empty if streaming is not supported by backend * 1000, # ttfts is empty if streaming is not supported by backend
std_ttft_ms=np.std(ttfts or 0) * 1000, std_ttft_ms=np.std(ttfts or 0) * 1000,
median_ttft_ms=np.median(ttfts or 0) * 1000, median_ttft_ms=np.median(ttfts or 0) * 1000,
percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000) percentiles_ttft_ms=[
for p in selected_percentiles], (p, np.percentile(ttfts or 0, p) * 1000) for p in selected_percentiles
],
mean_tpot_ms=np.mean(tpots or 0) * 1000, mean_tpot_ms=np.mean(tpots or 0) * 1000,
std_tpot_ms=np.std(tpots or 0) * 1000, std_tpot_ms=np.std(tpots or 0) * 1000,
median_tpot_ms=np.median(tpots or 0) * 1000, median_tpot_ms=np.median(tpots or 0) * 1000,
percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000) percentiles_tpot_ms=[
for p in selected_percentiles], (p, np.percentile(tpots or 0, p) * 1000) for p in selected_percentiles
],
mean_itl_ms=np.mean(itls or 0) * 1000, mean_itl_ms=np.mean(itls or 0) * 1000,
std_itl_ms=np.std(itls or 0) * 1000, std_itl_ms=np.std(itls or 0) * 1000,
median_itl_ms=np.median(itls or 0) * 1000, median_itl_ms=np.median(itls or 0) * 1000,
percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000) percentiles_itl_ms=[
for p in selected_percentiles], (p, np.percentile(itls or 0, p) * 1000) for p in selected_percentiles
],
mean_e2el_ms=np.mean(e2els or 0) * 1000, mean_e2el_ms=np.mean(e2els or 0) * 1000,
std_e2el_ms=np.std(e2els or 0) * 1000, std_e2el_ms=np.std(e2els or 0) * 1000,
median_e2el_ms=np.median(e2els or 0) * 1000, median_e2el_ms=np.median(e2els or 0) * 1000,
percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) percentiles_e2el_ms=[
for p in selected_percentiles], (p, np.percentile(e2els or 0, p) * 1000) for p in selected_percentiles
],
) )
return metrics, actual_output_lens return metrics, actual_output_lens
@ -411,7 +439,6 @@ async def benchmark(
ignore_eos: bool, ignore_eos: bool,
max_concurrency: Optional[int], max_concurrency: Optional[int],
structured_output_ratio: float, structured_output_ratio: float,
structured_output_backend: str,
goodput_config_dict: Optional[dict[str, float]] = None, goodput_config_dict: Optional[dict[str, float]] = None,
): ):
if backend in ASYNC_REQUEST_FUNCS: if backend in ASYNC_REQUEST_FUNCS:
@ -423,18 +450,17 @@ async def benchmark(
extra_body = {} extra_body = {}
# Add the schema to the extra_body # Add the schema to the extra_body
extra_body[request.structure_type] = request.schema extra_body[request.structure_type] = request.schema
# Add the specific structured_output_backend
extra_body["guided_decoding_backend"] = structured_output_backend
return extra_body return extra_body
print("Starting initial single prompt test run...") print("Starting initial single prompt test run...")
structured_output_req_idx = random.sample( structured_output_req_idx = random.sample(
range(len(input_requests)), range(len(input_requests)), int(len(input_requests) * structured_output_ratio)
int(len(input_requests) * structured_output_ratio)) )
test_request = input_requests[0] test_request = input_requests[0]
test_req_extra_body = (prepare_extra_body(test_request) test_req_extra_body = (
if 0 in structured_output_req_idx else None) prepare_extra_body(test_request) if 0 in structured_output_req_idx else None
)
test_input = RequestFuncInput( test_input = RequestFuncInput(
model=model_id, model=model_id,
prompt=test_request.prompt, prompt=test_request.prompt,
@ -448,7 +474,8 @@ async def benchmark(
if not test_output.success: if not test_output.success:
raise ValueError( raise ValueError(
"Initial test run failed - Please make sure benchmark arguments " "Initial test run failed - Please make sure benchmark arguments "
f"are correctly specified. Error: {test_output.error}") f"are correctly specified. Error: {test_output.error}"
)
else: else:
print("Initial test run completed. Starting main benchmark run...") print("Initial test run completed. Starting main benchmark run...")
@ -467,10 +494,7 @@ async def benchmark(
if profile_output.success: if profile_output.success:
print("Profiler started") print("Profiler started")
if burstiness == 1.0: distribution = "Poisson process" if burstiness == 1.0 else "Gamma distribution"
distribution = "Poisson process"
else:
distribution = "Gamma distribution"
print(f"Traffic request rate: {request_rate}") print(f"Traffic request rate: {request_rate}")
print(f"Burstiness factor: {burstiness} ({distribution})") print(f"Burstiness factor: {burstiness} ({distribution})")
@ -482,24 +506,21 @@ async def benchmark(
# and it will simplify the code in limited_request_func. # and it will simplify the code in limited_request_func.
# semaphore = (asyncio.Semaphore(max_concurrency) # semaphore = (asyncio.Semaphore(max_concurrency)
# if max_concurrency else contextlib.nullcontext()) # if max_concurrency else contextlib.nullcontext())
semaphore = (asyncio.Semaphore(max_concurrency) semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
if max_concurrency else None)
async def limited_request_func(request_func_input, pbar): async def limited_request_func(request_func_input, pbar):
if semaphore is None: if semaphore is None:
return await request_func(request_func_input=request_func_input, return await request_func(request_func_input=request_func_input, pbar=pbar)
pbar=pbar)
async with semaphore: async with semaphore:
return await request_func(request_func_input=request_func_input, return await request_func(request_func_input=request_func_input, pbar=pbar)
pbar=pbar)
benchmark_start_time = time.perf_counter() benchmark_start_time = time.perf_counter()
tasks: list[asyncio.Task] = [] tasks: list[asyncio.Task] = []
expected: list[str] = [] expected: list[str] = []
async for i, request in get_request(input_requests, request_rate, async for i, request in get_request(input_requests, request_rate, burstiness):
burstiness): extra_body = (
extra_body = prepare_extra_body( prepare_extra_body(request) if i in structured_output_req_idx else None
request) if i in structured_output_req_idx else None )
request_func_input = RequestFuncInput( request_func_input = RequestFuncInput(
model=model_id, model=model_id,
prompt=request.prompt, prompt=request.prompt,
@ -512,8 +533,9 @@ async def benchmark(
expected.append(request.completion) expected.append(request.completion)
tasks.append( tasks.append(
asyncio.create_task( asyncio.create_task(
limited_request_func(request_func_input=request_func_input, limited_request_func(request_func_input=request_func_input, pbar=pbar)
pbar=pbar))) )
)
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
if profile: if profile:
@ -545,54 +567,58 @@ async def benchmark(
goodput_config_dict=goodput_config_dict, goodput_config_dict=goodput_config_dict,
) )
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) print("{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
benchmark_duration))
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
print("{:<40} {:<10}".format("Total generated tokens:", print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
metrics.total_output)) print(
print("{:<40} {:<10.2f}".format("Request throughput (req/s):", "{:<40} {:<10.2f}".format(
metrics.request_throughput)) "Request throughput (req/s):", metrics.request_throughput
)
)
if goodput_config_dict: if goodput_config_dict:
print("{:<40} {:<10.2f}".format("Request goodput (req/s):", print(
metrics.request_goodput)) "{:<40} {:<10.2f}".format(
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", "Request goodput (req/s):", metrics.request_goodput
metrics.output_throughput)) )
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", )
metrics.total_token_throughput)) print(
"{:<40} {:<10.2f}".format(
"Output token throughput (tok/s):", metrics.output_throughput
)
)
print(
"{:<40} {:<10.2f}".format(
"Total Token throughput (tok/s):", metrics.total_token_throughput
)
)
result = { result = {
"duration": "duration": benchmark_duration,
benchmark_duration, "completed": metrics.completed,
"completed": "total_input_tokens": metrics.total_input,
metrics.completed, "total_output_tokens": metrics.total_output,
"total_input_tokens": "request_throughput": metrics.request_throughput,
metrics.total_input, "output_throughput": metrics.output_throughput,
"total_output_tokens": "total_token_throughput": metrics.total_token_throughput,
metrics.total_output, "ttft_description": pd.Series([output.ttft for output in outputs])
"request_throughput": .describe()
metrics.request_throughput, .to_dict(),
"output_throughput": "tpot_description": pd.Series([output.tpot for output in outputs])
metrics.output_throughput, .describe()
"total_token_throughput": .to_dict(),
metrics.total_token_throughput,
"ttft_description":
pd.Series([output.ttft for output in outputs]).describe().to_dict(),
"tpot_description":
pd.Series([output.tpot for output in outputs]).describe().to_dict(),
"input_lens": [output.prompt_len for output in outputs], "input_lens": [output.prompt_len for output in outputs],
"output_lens": "output_lens": actual_output_lens,
actual_output_lens,
"ttfts": [output.ttft for output in outputs], "ttfts": [output.ttft for output in outputs],
"itls": [output.itl for output in outputs], "itls": [output.itl for output in outputs],
"errors": [output.error for output in outputs], "errors": [output.error for output in outputs],
} }
ret = [{ ret = [
'generated': output.generated_text, {"generated": output.generated_text, "expected": gt}
'expected': gt for output, gt in zip(outputs, expected)
} for output, gt in zip(outputs, expected)] ]
def process_one_metric( def process_one_metric(
# E.g., "ttft" # E.g., "ttft"
@ -606,29 +632,35 @@ async def benchmark(
# metric. # metric.
if metric_attribute_name not in selected_percentile_metrics: if metric_attribute_name not in selected_percentile_metrics:
return return
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-')) print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
print("{:<40} {:<10.2f}".format( print(
f"Mean {metric_name} (ms):", "{:<40} {:<10.2f}".format(
getattr(metrics, f"mean_{metric_attribute_name}_ms"))) f"Mean {metric_name} (ms):",
print("{:<40} {:<10.2f}".format( getattr(metrics, f"mean_{metric_attribute_name}_ms"),
f"Median {metric_name} (ms):", )
getattr(metrics, f"median_{metric_attribute_name}_ms"))) )
print(
"{:<40} {:<10.2f}".format(
f"Median {metric_name} (ms):",
getattr(metrics, f"median_{metric_attribute_name}_ms"),
)
)
result[f"mean_{metric_attribute_name}_ms"] = getattr( result[f"mean_{metric_attribute_name}_ms"] = getattr(
metrics, f"mean_{metric_attribute_name}_ms") metrics, f"mean_{metric_attribute_name}_ms"
)
result[f"median_{metric_attribute_name}_ms"] = getattr( result[f"median_{metric_attribute_name}_ms"] = getattr(
metrics, f"median_{metric_attribute_name}_ms") metrics, f"median_{metric_attribute_name}_ms"
)
result[f"std_{metric_attribute_name}_ms"] = getattr( result[f"std_{metric_attribute_name}_ms"] = getattr(
metrics, f"std_{metric_attribute_name}_ms") metrics, f"std_{metric_attribute_name}_ms"
for p, value in getattr(metrics, )
f"percentiles_{metric_attribute_name}_ms"): for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}_ms"):
p_word = str(int(p)) if int(p) == p else str(p) p_word = str(int(p)) if int(p) == p else str(p)
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value))
value))
result[f"p{p_word}_{metric_attribute_name}_ms"] = value result[f"p{p_word}_{metric_attribute_name}_ms"] = value
process_one_metric("ttft", "TTFT", "Time to First Token") process_one_metric("ttft", "TTFT", "Time to First Token")
process_one_metric("tpot", "TPOT", process_one_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)")
"Time per Output Token (excl. 1st token)")
process_one_metric("itl", "ITL", "Inter-token Latency") process_one_metric("itl", "ITL", "Inter-token Latency")
process_one_metric("e2el", "E2EL", "End-to-end Latency") process_one_metric("e2el", "E2EL", "End-to-end Latency")
@ -638,13 +670,13 @@ async def benchmark(
def evaluate(ret, args): def evaluate(ret, args):
def _eval_correctness_json(expected, actual): def _eval_correctness_json(expected, actual):
# extract json string from string using regex # extract json string from string using regex
import re import re
actual = actual.replace('\n', '').replace(' ', '').strip()
actual = actual.replace("\n", "").replace(" ", "").strip()
try: try:
actual = re.search(r'\{.*\}', actual).group() actual = re.search(r"\{.*\}", actual).group()
actual = json.loads(actual) actual = json.loads(actual)
except Exception: except Exception:
return False return False
@ -656,28 +688,32 @@ def evaluate(ret, args):
def _eval_correctness_regex(expected, actual): def _eval_correctness_regex(expected, actual):
import re import re
return re.match(args.regex, actual) is not None return re.match(args.regex, actual) is not None
def _eval_correctness(expected, actual): def _eval_correctness(expected, actual):
if args.structure_type == 'guided_json': if args.structure_type == "guided_json":
return _eval_correctness_json(expected, actual) return _eval_correctness_json(expected, actual)
elif args.structure_type == 'guided_regex': elif args.structure_type == "guided_regex":
return _eval_correctness_regex(expected, actual) return _eval_correctness_regex(expected, actual)
elif args.structure_type == 'guided_choice': elif args.structure_type == "guided_choice":
return _eval_correctness_choice(expected, actual) return _eval_correctness_choice(expected, actual)
else: else:
return None return None
scores = [] scores = []
for res in ret: for res in ret:
score = _eval_correctness(res['expected'], res['generated']) score = _eval_correctness(res["expected"], res["generated"])
res['correctness'] = score res["correctness"] = score
scores.append(score) scores.append(score)
not_none_scores = [score for score in scores if score is not None] not_none_scores = [score for score in scores if score is not None]
return (sum(not_none_scores) / len(not_none_scores) * return (
100) if len(not_none_scores) > 0 else None (sum(not_none_scores) / len(not_none_scores) * 100)
if len(not_none_scores) > 0
else None
)
def parse_goodput(slo_pairs): def parse_goodput(slo_pairs):
@ -689,9 +725,10 @@ def parse_goodput(slo_pairs):
except ValueError as err: except ValueError as err:
raise argparse.ArgumentTypeError( raise argparse.ArgumentTypeError(
"Invalid format found for service level objectives. " "Invalid format found for service level objectives. "
"Specify service level objectives for goodput as \"KEY:VALUE\" " 'Specify service level objectives for goodput as "KEY:VALUE" '
"pairs, where the key is a metric name, and the value is a " "pairs, where the key is a metric name, and the value is a "
"number in milliseconds.") from err "number in milliseconds."
) from err
return goodput_config_dict return goodput_config_dict
@ -705,12 +742,14 @@ def check_goodput_args(args):
raise ValueError( raise ValueError(
f"Invalid metric name found, {slo_name}: {slo_val}. " f"Invalid metric name found, {slo_name}: {slo_val}. "
"The service level objective name should be one of " "The service level objective name should be one of "
f"{str(VALID_NAMES)}. ") f"{str(VALID_NAMES)}. "
)
if slo_val < 0: if slo_val < 0:
raise ValueError( raise ValueError(
f"Invalid value found, {slo_name}: {slo_val}. " f"Invalid value found, {slo_name}: {slo_val}. "
"The service level objective value should be " "The service level objective value should be "
"non-negative.") "non-negative."
)
return goodput_config_dict return goodput_config_dict
@ -736,19 +775,19 @@ def main(args: argparse.Namespace):
tokenizer_mode=args.tokenizer_mode, tokenizer_mode=args.tokenizer_mode,
) )
if args.dataset == 'grammar': if args.dataset == "grammar":
args.structure_type = 'guided_grammar' args.structure_type = "guided_grammar"
elif args.dataset == 'regex': elif args.dataset == "regex":
args.structure_type = 'guided_regex' args.structure_type = "guided_regex"
elif args.dataset == 'choice': elif args.dataset == "choice":
args.structure_type = 'guided_choice' args.structure_type = "guided_choice"
else: else:
args.structure_type = 'guided_json' args.structure_type = "guided_json"
if args.no_structured_output: if args.no_structured_output:
args.structured_output_ratio = 0 args.structured_output_ratio = 0
if args.save_results: if args.save_results:
result_file_name = f'{args.structured_output_ratio}guided' result_file_name = f"{args.structured_output_ratio}guided"
result_file_name += f"_{backend}" result_file_name += f"_{backend}"
result_file_name += f"_{args.request_rate}qps" result_file_name += f"_{args.request_rate}qps"
result_file_name += f"_{args.model.split('/')[-1]}" result_file_name += f"_{args.model.split('/')[-1]}"
@ -776,37 +815,29 @@ def main(args: argparse.Namespace):
disable_tqdm=args.disable_tqdm, disable_tqdm=args.disable_tqdm,
profile=args.profile, profile=args.profile,
selected_percentile_metrics=args.percentile_metrics.split(","), selected_percentile_metrics=args.percentile_metrics.split(","),
selected_percentiles=[ selected_percentiles=[float(p) for p in args.metric_percentiles.split(",")],
float(p) for p in args.metric_percentiles.split(",")
],
ignore_eos=args.ignore_eos, ignore_eos=args.ignore_eos,
max_concurrency=args.max_concurrency, max_concurrency=args.max_concurrency,
structured_output_ratio=args.structured_output_ratio, structured_output_ratio=args.structured_output_ratio,
structured_output_backend=args.structured_output_backend,
goodput_config_dict=goodput_config_dict, goodput_config_dict=goodput_config_dict,
)) )
)
# Save config and results to json # Save config and results to json
score = evaluate(ret, args) score = evaluate(ret, args)
print("correct_rate(%)", score, '\n') print("correct_rate(%)", score, "\n")
if args.save_results: if args.save_results:
results = { results = {
"backend": "backend": backend,
backend, "model_id": model_id,
"model_id": "tokenizer_id": tokenizer_id,
model_id, "num_prompts": args.num_prompts,
"tokenizer_id": "request_rate": args.request_rate
tokenizer_id, if args.request_rate < float("inf")
"num_prompts": else "inf",
args.num_prompts, "burstiness": args.burstiness,
"request_rate": "max_concurrency": args.max_concurrency,
args.request_rate if args.request_rate < float("inf") else "inf", "correct_rate(%)": score,
"burstiness":
args.burstiness,
"max_concurrency":
args.max_concurrency,
"correct_rate(%)":
score
} }
results = {"outputs": ret, **results, **benchmark_result} results = {"outputs": ret, **results, **benchmark_result}
@ -815,13 +846,14 @@ def main(args: argparse.Namespace):
result_file_name = args.result_filename result_file_name = args.result_filename
if args.result_dir: if args.result_dir:
result_file_name = os.path.join(args.result_dir, result_file_name) result_file_name = os.path.join(args.result_dir, result_file_name)
with open(result_file_name, "w", encoding='utf-8') as outfile: with open(result_file_name, "w", encoding="utf-8") as outfile:
json.dump(results, outfile, indent=4) json.dump(results, outfile, indent=4)
if __name__ == "__main__": if __name__ == "__main__":
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description="Benchmark the online serving throughput.") description="Benchmark the online serving throughput."
)
parser.add_argument( parser.add_argument(
"--backend", "--backend",
type=str, type=str,
@ -843,16 +875,14 @@ if __name__ == "__main__":
default="/v1/completions", default="/v1/completions",
help="API endpoint.", help="API endpoint.",
) )
parser.add_argument("--dataset", parser.add_argument(
default='json', "--dataset",
choices=[ default="json",
'json', 'json-unique', 'grammar', 'regex', choices=["json", "json-unique", "grammar", "regex", "choice", "xgrammar_bench"],
'choice', 'xgrammar_bench' )
]) parser.add_argument(
parser.add_argument("--json_schema_path", "--json-schema-path", type=str, default=None, help="Path to json schema."
type=str, )
default=None,
help="Path to json schema.")
parser.add_argument( parser.add_argument(
"--max-concurrency", "--max-concurrency",
type=int, type=int,
@ -864,7 +894,8 @@ if __name__ == "__main__":
"initiated, this argument will control how many are actually allowed " "initiated, this argument will control how many are actually allowed "
"to execute at a time. This means that when used in combination, the " "to execute at a time. This means that when used in combination, the "
"actual request rate may be lower than specified with --request-rate, " "actual request rate may be lower than specified with --request-rate, "
"if the server is not processing requests fast enough to keep up.") "if the server is not processing requests fast enough to keep up.",
)
parser.add_argument( parser.add_argument(
"--model", "--model",
type=str, type=str,
@ -874,15 +905,13 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--tokenizer", "--tokenizer",
type=str, type=str,
help= help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
) )
parser.add_argument( parser.add_argument(
"--tokenizer-mode", "--tokenizer-mode",
type=str, type=str,
default="auto", default="auto",
help= help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
) )
parser.add_argument( parser.add_argument(
"--num-prompts", "--num-prompts",
@ -959,52 +988,51 @@ if __name__ == "__main__":
"--ignore-eos", "--ignore-eos",
action="store_true", action="store_true",
help="Set ignore_eos flag when sending the benchmark request." help="Set ignore_eos flag when sending the benchmark request."
"Warning: ignore_eos is not supported in deepspeed_mii and tgi.") "Warning: ignore_eos is not supported in deepspeed_mii and tgi.",
)
parser.add_argument( parser.add_argument(
"--percentile-metrics", "--percentile-metrics",
type=str, type=str,
default="ttft,tpot,itl", default="ttft,tpot,itl",
help="Comma-separated list of selected metrics to report percentils. " help="Comma-separated list of selected metrics to report percentils. "
"This argument specifies the metrics to report percentiles. " "This argument specifies the metrics to report percentiles. "
"Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". " 'Allowed metric names are "ttft", "tpot", "itl", "e2el". '
"Default value is \"ttft,tpot,itl\".") 'Default value is "ttft,tpot,itl".',
)
parser.add_argument( parser.add_argument(
"--metric-percentiles", "--metric-percentiles",
type=str, type=str,
default="99", default="99",
help="Comma-separated list of percentiles for selected metrics. " help="Comma-separated list of percentiles for selected metrics. "
"To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". " 'To report 25-th, 50-th, and 75-th percentiles, use "25,50,75". '
"Default value is \"99\". " 'Default value is "99". '
"Use \"--percentile-metrics\" to select metrics.", 'Use "--percentile-metrics" to select metrics.',
) )
parser.add_argument( parser.add_argument(
"--goodput", "--goodput",
nargs="+", nargs="+",
required=False, required=False,
help="Specify service level objectives for goodput as \"KEY:VALUE\" " help='Specify service level objectives for goodput as "KEY:VALUE" '
"pairs, where the key is a metric name, and the value is in " "pairs, where the key is a metric name, and the value is in "
"milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, " 'milliseconds. Multiple "KEY:VALUE" pairs can be provided, '
"separated by spaces. Allowed request level metric names are " "separated by spaces. Allowed request level metric names are "
"\"ttft\", \"tpot\", \"e2el\". For more context on the definition of " '"ttft", "tpot", "e2el". For more context on the definition of '
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
"and the blog: https://hao-ai-lab.github.io/blogs/distserve") "and the blog: https://hao-ai-lab.github.io/blogs/distserve",
)
parser.add_argument("--no-structured-output", parser.add_argument(
action='store_true', "--no-structured-output",
default=False, action="store_true",
help="Whether to disable JSON decoding or not.") default=False,
parser.add_argument("--structured-output-ratio", help="Whether to disable JSON decoding or not.",
type=float, )
default=1.0, parser.add_argument(
help="Ratio of Structured Outputs requests") "--structured-output-ratio",
parser.add_argument("--structured-output-backend", type=float,
type=str, default=1.0,
choices=[ help="Ratio of Structured Outputs requests",
"outlines", "lm-format-enforcer", "xgrammar", )
"guidance", "auto"
],
default="auto",
help="Backend to use for structured outputs")
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Benchmark offline inference throughput.""" """Benchmark offline inference throughput."""
import argparse import argparse
import dataclasses import dataclasses
import json import json
@ -11,18 +12,25 @@ from typing import Any, Optional, Union
import torch import torch
import uvloop import uvloop
from benchmark_dataset import (AIMODataset, BurstGPTDataset,
ConversationDataset, InstructCoderDataset,
RandomDataset, SampleRequest, ShareGPTDataset,
SonnetDataset, VisionArenaDataset)
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
from tqdm import tqdm from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer, from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase
PreTrainedTokenizerBase)
from benchmark_dataset import (
AIMODataset,
BurstGPTDataset,
ConversationDataset,
InstructCoderDataset,
RandomDataset,
SampleRequest,
ShareGPTDataset,
SonnetDataset,
VisionArenaDataset,
)
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.entrypoints.openai.api_server import ( from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args) build_async_engine_client_from_engine_args,
)
from vllm.inputs import TextPrompt, TokensPrompt from vllm.inputs import TextPrompt, TokensPrompt
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
@ -37,23 +45,30 @@ def run_vllm(
disable_detokenize: bool = False, disable_detokenize: bool = False,
) -> tuple[float, Optional[list[RequestOutput]]]: ) -> tuple[float, Optional[list[RequestOutput]]]:
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
llm = LLM(**dataclasses.asdict(engine_args)) llm = LLM(**dataclasses.asdict(engine_args))
assert all( assert all(
llm.llm_engine.model_config.max_model_len >= ( llm.llm_engine.model_config.max_model_len
request.prompt_len + request.expected_output_len) >= (request.prompt_len + request.expected_output_len)
for request in requests), ( for request in requests
"Please ensure that max_model_len is greater than the sum of" ), (
" prompt_len and expected_output_len for all requests.") "Please ensure that max_model_len is greater than the sum of"
" prompt_len and expected_output_len for all requests."
)
# Add the requests to the engine. # Add the requests to the engine.
prompts: list[Union[TextPrompt, TokensPrompt]] = [] prompts: list[Union[TextPrompt, TokensPrompt]] = []
sampling_params: list[SamplingParams] = [] sampling_params: list[SamplingParams] = []
for request in requests: for request in requests:
prompts.append( prompts.append(
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"], TokensPrompt(
multi_modal_data=request.multi_modal_data) prompt_token_ids=request.prompt["prompt_token_ids"],
if "prompt_token_ids" in request.prompt else \ multi_modal_data=request.multi_modal_data,
TextPrompt(prompt=request.prompt, )
multi_modal_data=request.multi_modal_data)) if "prompt_token_ids" in request.prompt
else TextPrompt(
prompt=request.prompt, multi_modal_data=request.multi_modal_data
)
)
sampling_params.append( sampling_params.append(
SamplingParams( SamplingParams(
n=n, n=n,
@ -62,7 +77,8 @@ def run_vllm(
ignore_eos=True, ignore_eos=True,
max_tokens=request.expected_output_len, max_tokens=request.expected_output_len,
detokenize=not disable_detokenize, detokenize=not disable_detokenize,
)) )
)
lora_requests: Optional[list[LoRARequest]] = None lora_requests: Optional[list[LoRARequest]] = None
if engine_args.enable_lora: if engine_args.enable_lora:
lora_requests = [request.lora_request for request in requests] lora_requests = [request.lora_request for request in requests]
@ -72,10 +88,9 @@ def run_vllm(
outputs = None outputs = None
if not use_beam_search: if not use_beam_search:
start = time.perf_counter() start = time.perf_counter()
outputs = llm.generate(prompts, outputs = llm.generate(
sampling_params, prompts, sampling_params, lora_request=lora_requests, use_tqdm=True
lora_request=lora_requests, )
use_tqdm=True)
end = time.perf_counter() end = time.perf_counter()
else: else:
assert lora_requests is None, "BeamSearch API does not support LoRA" assert lora_requests is None, "BeamSearch API does not support LoRA"
@ -91,30 +106,35 @@ def run_vllm(
beam_width=n, beam_width=n,
max_tokens=output_len, max_tokens=output_len,
ignore_eos=True, ignore_eos=True,
)) ),
)
end = time.perf_counter() end = time.perf_counter()
return end - start, outputs return end - start, outputs
def run_vllm_chat( def run_vllm_chat(
requests: list[SampleRequest], requests: list[SampleRequest],
n: int, n: int,
engine_args: EngineArgs, engine_args: EngineArgs,
disable_detokenize: bool = False) -> tuple[float, list[RequestOutput]]: disable_detokenize: bool = False,
) -> tuple[float, list[RequestOutput]]:
""" """
Run vLLM chat benchmark. This function is recommended ONLY for benchmarking Run vLLM chat benchmark. This function is recommended ONLY for benchmarking
multimodal models as it properly handles multimodal inputs and chat multimodal models as it properly handles multimodal inputs and chat
formatting. For non-multimodal models, use run_vllm() instead. formatting. For non-multimodal models, use run_vllm() instead.
""" """
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
llm = LLM(**dataclasses.asdict(engine_args)) llm = LLM(**dataclasses.asdict(engine_args))
assert all( assert all(
llm.llm_engine.model_config.max_model_len >= ( llm.llm_engine.model_config.max_model_len
request.prompt_len + request.expected_output_len) >= (request.prompt_len + request.expected_output_len)
for request in requests), ( for request in requests
"Please ensure that max_model_len is greater than the sum of " ), (
"prompt_len and expected_output_len for all requests.") "Please ensure that max_model_len is greater than the sum of "
"prompt_len and expected_output_len for all requests."
)
prompts = [] prompts = []
sampling_params: list[SamplingParams] = [] sampling_params: list[SamplingParams] = []
@ -128,7 +148,8 @@ def run_vllm_chat(
ignore_eos=True, ignore_eos=True,
max_tokens=request.expected_output_len, max_tokens=request.expected_output_len,
detokenize=not disable_detokenize, detokenize=not disable_detokenize,
)) )
)
start = time.perf_counter() start = time.perf_counter()
outputs = llm.chat(prompts, sampling_params, use_tqdm=True) outputs = llm.chat(prompts, sampling_params, use_tqdm=True)
end = time.perf_counter() end = time.perf_counter()
@ -145,13 +166,17 @@ async def run_vllm_async(
from vllm import SamplingParams from vllm import SamplingParams
async with build_async_engine_client_from_engine_args( async with build_async_engine_client_from_engine_args(
engine_args, disable_frontend_multiprocessing) as llm: engine_args, disable_frontend_multiprocessing
) as llm:
model_config = await llm.get_model_config()
assert all( assert all(
llm.model_config.max_model_len >= (request.prompt_len + model_config.max_model_len
request.expected_output_len) >= (request.prompt_len + request.expected_output_len)
for request in requests), ( for request in requests
"Please ensure that max_model_len is greater than the sum of" ), (
" prompt_len and expected_output_len for all requests.") "Please ensure that max_model_len is greater than the sum of"
" prompt_len and expected_output_len for all requests."
)
# Add the requests to the engine. # Add the requests to the engine.
prompts: list[Union[TextPrompt, TokensPrompt]] = [] prompts: list[Union[TextPrompt, TokensPrompt]] = []
@ -159,11 +184,15 @@ async def run_vllm_async(
lora_requests: list[Optional[LoRARequest]] = [] lora_requests: list[Optional[LoRARequest]] = []
for request in requests: for request in requests:
prompts.append( prompts.append(
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"], TokensPrompt(
multi_modal_data=request.multi_modal_data) prompt_token_ids=request.prompt["prompt_token_ids"],
if "prompt_token_ids" in request.prompt else \ multi_modal_data=request.multi_modal_data,
TextPrompt(prompt=request.prompt, )
multi_modal_data=request.multi_modal_data)) if "prompt_token_ids" in request.prompt
else TextPrompt(
prompt=request.prompt, multi_modal_data=request.multi_modal_data
)
)
sampling_params.append( sampling_params.append(
SamplingParams( SamplingParams(
n=n, n=n,
@ -172,17 +201,16 @@ async def run_vllm_async(
ignore_eos=True, ignore_eos=True,
max_tokens=request.expected_output_len, max_tokens=request.expected_output_len,
detokenize=not disable_detokenize, detokenize=not disable_detokenize,
)) )
)
lora_requests.append(request.lora_request) lora_requests.append(request.lora_request)
generators = [] generators = []
start = time.perf_counter() start = time.perf_counter()
for i, (prompt, sp, for i, (prompt, sp, lr) in enumerate(
lr) in enumerate(zip(prompts, sampling_params, lora_requests)): zip(prompts, sampling_params, lora_requests)
generator = llm.generate(prompt, ):
sp, generator = llm.generate(prompt, sp, lora_request=lr, request_id=f"test{i}")
lora_request=lr,
request_id=f"test{i}")
generators.append(generator) generators.append(generator)
all_gens = merge_async_iterators(*generators) all_gens = merge_async_iterators(*generators)
async for i, res in all_gens: async for i, res in all_gens:
@ -201,7 +229,8 @@ def run_hf(
disable_detokenize: bool = False, disable_detokenize: bool = False,
) -> float: ) -> float:
llm = AutoModelForCausalLM.from_pretrained( llm = AutoModelForCausalLM.from_pretrained(
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code) model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code
)
if llm.config.model_type == "llama": if llm.config.model_type == "llama":
# To enable padding in the HF backend. # To enable padding in the HF backend.
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
@ -224,14 +253,15 @@ def run_hf(
# Check if we can add more requests to the batch. # Check if we can add more requests to the batch.
next_prompt_len = requests[i + 1].prompt_len next_prompt_len = requests[i + 1].prompt_len
next_output_len = requests[i + 1].expected_output_len next_output_len = requests[i + 1].expected_output_len
if (max(max_prompt_len, next_prompt_len) + if (
max(max_output_len, next_output_len)) <= 2048: max(max_prompt_len, next_prompt_len)
+ max(max_output_len, next_output_len)
) <= 2048:
# We can add more requests to the batch. # We can add more requests to the batch.
continue continue
# Generate the sequences. # Generate the sequences.
input_ids = tokenizer(batch, return_tensors="pt", input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids
padding=True).input_ids
llm_outputs = llm.generate( llm_outputs = llm.generate(
input_ids=input_ids.cuda(), input_ids=input_ids.cuda(),
do_sample=True, do_sample=True,
@ -261,6 +291,7 @@ def run_mii(
output_len: int, output_len: int,
) -> float: ) -> float:
from mii import client, serve from mii import client, serve
llm = serve(model, tensor_parallel=tensor_parallel_size) llm = serve(model, tensor_parallel=tensor_parallel_size)
prompts = [request.prompt for request in requests] prompts = [request.prompt for request in requests]
@ -272,8 +303,9 @@ def run_mii(
return end - start return end - start
def save_to_pytorch_benchmark_format(args: argparse.Namespace, def save_to_pytorch_benchmark_format(
results: dict[str, Any]) -> None: args: argparse.Namespace, results: dict[str, Any]
) -> None:
pt_records = convert_to_pytorch_benchmark_format( pt_records = convert_to_pytorch_benchmark_format(
args=args, args=args,
metrics={ metrics={
@ -281,9 +313,9 @@ def save_to_pytorch_benchmark_format(args: argparse.Namespace,
"tokens_per_second": [results["tokens_per_second"]], "tokens_per_second": [results["tokens_per_second"]],
}, },
extra_info={ extra_info={
k: results[k] k: results[k] for k in ["elapsed_time", "num_requests", "total_num_tokens"]
for k in ["elapsed_time", "num_requests", "total_num_tokens"] },
}) )
if pt_records: if pt_records:
# Don't use json suffix here as we don't want CI to pick it up # Don't use json suffix here as we don't want CI to pick it up
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json" pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
@ -315,7 +347,8 @@ def get_requests(args, tokenizer):
sample_kwargs["enable_multimodal_chat"] = True sample_kwargs["enable_multimodal_chat"] = True
elif args.dataset_name == "sonnet": elif args.dataset_name == "sonnet":
assert tokenizer.chat_template or tokenizer.default_chat_template, ( assert tokenizer.chat_template or tokenizer.default_chat_template, (
"Tokenizer/model must have chat template for sonnet dataset.") "Tokenizer/model must have chat template for sonnet dataset."
)
dataset_cls = SonnetDataset dataset_cls = SonnetDataset
sample_kwargs["prefix_len"] = args.prefix_len sample_kwargs["prefix_len"] = args.prefix_len
sample_kwargs["return_prompt_formatted"] = True sample_kwargs["return_prompt_formatted"] = True
@ -324,21 +357,21 @@ def get_requests(args, tokenizer):
elif args.dataset_name == "hf": elif args.dataset_name == "hf":
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS: if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
dataset_cls = VisionArenaDataset dataset_cls = VisionArenaDataset
common_kwargs['dataset_subset'] = None common_kwargs["dataset_subset"] = None
common_kwargs['dataset_split'] = "train" common_kwargs["dataset_split"] = "train"
sample_kwargs["enable_multimodal_chat"] = True sample_kwargs["enable_multimodal_chat"] = True
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS: elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
dataset_cls = InstructCoderDataset dataset_cls = InstructCoderDataset
common_kwargs['dataset_split'] = "train" common_kwargs["dataset_split"] = "train"
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS: elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
dataset_cls = ConversationDataset dataset_cls = ConversationDataset
common_kwargs['dataset_subset'] = args.hf_subset common_kwargs["dataset_subset"] = args.hf_subset
common_kwargs['dataset_split'] = args.hf_split common_kwargs["dataset_split"] = args.hf_split
sample_kwargs["enable_multimodal_chat"] = True sample_kwargs["enable_multimodal_chat"] = True
elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS: elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
dataset_cls = AIMODataset dataset_cls = AIMODataset
common_kwargs['dataset_subset'] = None common_kwargs["dataset_subset"] = None
common_kwargs['dataset_split'] = "train" common_kwargs["dataset_split"] = "train"
else: else:
raise ValueError(f"Unknown dataset name: {args.dataset_name}") raise ValueError(f"Unknown dataset name: {args.dataset_name}")
# Remove None values # Remove None values
@ -353,10 +386,10 @@ def main(args: argparse.Namespace):
random.seed(args.seed) random.seed(args.seed)
# Sample the requests. # Sample the requests.
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer, trust_remote_code=args.trust_remote_code) args.tokenizer, trust_remote_code=args.trust_remote_code
)
requests = get_requests(args, tokenizer) requests = get_requests(args, tokenizer)
is_multi_modal = any(request.multi_modal_data is not None is_multi_modal = any(request.multi_modal_data is not None for request in requests)
for request in requests)
request_outputs: Optional[list[RequestOutput]] = None request_outputs: Optional[list[RequestOutput]] = None
if args.backend == "vllm": if args.backend == "vllm":
if args.async_engine: if args.async_engine:
@ -367,23 +400,34 @@ def main(args: argparse.Namespace):
AsyncEngineArgs.from_cli_args(args), AsyncEngineArgs.from_cli_args(args),
args.disable_frontend_multiprocessing, args.disable_frontend_multiprocessing,
args.disable_detokenize, args.disable_detokenize,
)) )
)
else: else:
elapsed_time, request_outputs = run_vllm( elapsed_time, request_outputs = run_vllm(
requests, args.n, EngineArgs.from_cli_args(args), requests,
args.disable_detokenize) args.n,
EngineArgs.from_cli_args(args),
args.disable_detokenize,
)
elif args.backend == "hf": elif args.backend == "hf":
assert args.tensor_parallel_size == 1 assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n, elapsed_time = run_hf(
args.hf_max_batch_size, args.trust_remote_code, requests,
args.disable_detokenize) args.model,
tokenizer,
args.n,
args.hf_max_batch_size,
args.trust_remote_code,
args.disable_detokenize,
)
elif args.backend == "mii": elif args.backend == "mii":
elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size, elapsed_time = run_mii(
args.output_len) requests, args.model, args.tensor_parallel_size, args.output_len
)
elif args.backend == "vllm-chat": elif args.backend == "vllm-chat":
elapsed_time, request_outputs = run_vllm_chat( elapsed_time, request_outputs = run_vllm_chat(
requests, args.n, EngineArgs.from_cli_args(args), requests, args.n, EngineArgs.from_cli_args(args), args.disable_detokenize
args.disable_detokenize) )
else: else:
raise ValueError(f"Unknown backend: {args.backend}") raise ValueError(f"Unknown backend: {args.backend}")
@ -395,28 +439,31 @@ def main(args: argparse.Namespace):
for ro in request_outputs: for ro in request_outputs:
if not isinstance(ro, RequestOutput): if not isinstance(ro, RequestOutput):
continue continue
total_prompt_tokens += len( total_prompt_tokens += (
ro.prompt_token_ids) if ro.prompt_token_ids else 0 len(ro.prompt_token_ids) if ro.prompt_token_ids else 0
total_output_tokens += sum( )
len(o.token_ids) for o in ro.outputs if o) total_output_tokens += sum(len(o.token_ids) for o in ro.outputs if o)
total_num_tokens = total_prompt_tokens + total_output_tokens total_num_tokens = total_prompt_tokens + total_output_tokens
else: else:
total_num_tokens = sum(r.prompt_len + r.expected_output_len total_num_tokens = sum(r.prompt_len + r.expected_output_len for r in requests)
for r in requests)
total_output_tokens = sum(r.expected_output_len for r in requests) total_output_tokens = sum(r.expected_output_len for r in requests)
total_prompt_tokens = total_num_tokens - total_output_tokens total_prompt_tokens = total_num_tokens - total_output_tokens
if is_multi_modal and args.backend != "vllm-chat": if is_multi_modal and args.backend != "vllm-chat":
print("\033[91mWARNING\033[0m: Multi-modal request with " print(
f"{args.backend} backend detected. The " "\033[91mWARNING\033[0m: Multi-modal request with "
"following metrics are not accurate because image tokens are not" f"{args.backend} backend detected. The "
" counted. See vllm-project/vllm/issues/9778 for details.") "following metrics are not accurate because image tokens are not"
" counted. See vllm-project/vllm/issues/9778 for details."
)
# TODO(vllm-project/vllm/issues/9778): Count multi-modal token length. # TODO(vllm-project/vllm/issues/9778): Count multi-modal token length.
# vllm-chat backend counts the image tokens now # vllm-chat backend counts the image tokens now
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " print(
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, " f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_output_tokens / elapsed_time:.2f} output tokens/s") f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
f"{total_output_tokens / elapsed_time:.2f} output tokens/s"
)
print(f"Total num prompt tokens: {total_prompt_tokens}") print(f"Total num prompt tokens: {total_prompt_tokens}")
print(f"Total num output tokens: {total_output_tokens}") print(f"Total num output tokens: {total_output_tokens}")
@ -444,7 +491,8 @@ def validate_args(args):
warnings.warn( warnings.warn(
"The '--dataset' argument will be deprecated in the next release. " "The '--dataset' argument will be deprecated in the next release. "
"Please use '--dataset-name' and '--dataset-path' instead.", "Please use '--dataset-name' and '--dataset-path' instead.",
stacklevel=2) stacklevel=2,
)
args.dataset_path = args.dataset args.dataset_path = args.dataset
if not getattr(args, "tokenizer", None): if not getattr(args, "tokenizer", None):
@ -457,9 +505,8 @@ def validate_args(args):
# === Dataset Configuration === # === Dataset Configuration ===
if not args.dataset and not args.dataset_path: if not args.dataset and not args.dataset_path:
print( print("When dataset path is not set, it will default to random dataset")
"When dataset path is not set, it will default to random dataset") args.dataset_name = "random"
args.dataset_name = 'random'
if args.input_len is None: if args.input_len is None:
raise ValueError("input_len must be provided for a random dataset") raise ValueError("input_len must be provided for a random dataset")
@ -467,41 +514,55 @@ def validate_args(args):
# --hf-subset and --hf-split: only used # --hf-subset and --hf-split: only used
# when dataset_name is 'hf' # when dataset_name is 'hf'
if args.dataset_name != "hf" and ( if args.dataset_name != "hf" and (
getattr(args, "hf_subset", None) is not None getattr(args, "hf_subset", None) is not None
or getattr(args, "hf_split", None) is not None): or getattr(args, "hf_split", None) is not None
warnings.warn("--hf-subset and --hf-split will be ignored \ ):
warnings.warn(
"--hf-subset and --hf-split will be ignored \
since --dataset-name is not 'hf'.", since --dataset-name is not 'hf'.",
stacklevel=2) stacklevel=2,
)
elif args.dataset_name == "hf": elif args.dataset_name == "hf":
if args.dataset_path in ( if args.dataset_path in (
VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys() VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys()
| ConversationDataset.SUPPORTED_DATASET_PATHS): | ConversationDataset.SUPPORTED_DATASET_PATHS
assert args.backend == "vllm-chat", f"{args.dataset_path} needs to use vllm-chat as the backend." #noqa: E501 ):
elif args.dataset_path in (InstructCoderDataset.SUPPORTED_DATASET_PATHS assert args.backend == "vllm-chat", (
| AIMODataset.SUPPORTED_DATASET_PATHS): f"{args.dataset_path} needs to use vllm-chat as the backend."
assert args.backend == "vllm", f"{args.dataset_path} needs to use vllm as the backend." #noqa: E501 ) # noqa: E501
elif args.dataset_path in (
InstructCoderDataset.SUPPORTED_DATASET_PATHS
| AIMODataset.SUPPORTED_DATASET_PATHS
):
assert args.backend == "vllm", (
f"{args.dataset_path} needs to use vllm as the backend."
) # noqa: E501
else: else:
raise ValueError( raise ValueError(f"{args.dataset_path} is not supported by hf dataset.")
f"{args.dataset_path} is not supported by hf dataset.")
# --random-range-ratio: only used when dataset_name is 'random' # --random-range-ratio: only used when dataset_name is 'random'
if args.dataset_name != 'random' and args.random_range_ratio is not None: if args.dataset_name != "random" and args.random_range_ratio is not None:
warnings.warn("--random-range-ratio will be ignored since \ warnings.warn(
"--random-range-ratio will be ignored since \
--dataset-name is not 'random'.", --dataset-name is not 'random'.",
stacklevel=2) stacklevel=2,
)
# --prefix-len: only used when dataset_name is 'random', 'sonnet', or not # --prefix-len: only used when dataset_name is 'random', 'sonnet', or not
# set. # set.
if args.dataset_name not in {"random", "sonnet", None if (
} and args.prefix_len is not None: args.dataset_name not in {"random", "sonnet", None}
warnings.warn("--prefix-len will be ignored since --dataset-name\ and args.prefix_len is not None
):
warnings.warn(
"--prefix-len will be ignored since --dataset-name\
is not 'random', 'sonnet', or not set.", is not 'random', 'sonnet', or not set.",
stacklevel=2) stacklevel=2,
)
# === LoRA Settings === # === LoRA Settings ===
if getattr(args, "enable_lora", False) and args.backend != "vllm": if getattr(args, "enable_lora", False) and args.backend != "vllm":
raise ValueError( raise ValueError("LoRA benchmarking is only supported for vLLM backend")
"LoRA benchmarking is only supported for vLLM backend")
if getattr(args, "enable_lora", False) and args.lora_path is None: if getattr(args, "enable_lora", False) and args.lora_path is None:
raise ValueError("LoRA path must be provided when enable_lora is True") raise ValueError("LoRA path must be provided when enable_lora is True")
@ -511,8 +572,10 @@ def validate_args(args):
if args.backend != "hf" and args.hf_max_batch_size is not None: if args.backend != "hf" and args.hf_max_batch_size is not None:
raise ValueError("HF max batch size is only for HF backend.") raise ValueError("HF max batch size is only for HF backend.")
if args.backend in {"hf", "mii"} and getattr(args, "quantization", if (
None) is not None: args.backend in {"hf", "mii"}
and getattr(args, "quantization", None) is not None
):
raise ValueError("Quantization is only for vLLM backend.") raise ValueError("Quantization is only for vLLM backend.")
if args.backend == "mii" and args.dtype != "auto": if args.backend == "mii" and args.dtype != "auto":
@ -520,29 +583,32 @@ def validate_args(args):
if args.backend == "mii" and args.n != 1: if args.backend == "mii" and args.n != 1:
raise ValueError("n must be 1 for MII backend.") raise ValueError("n must be 1 for MII backend.")
if args.backend == "mii" and args.tokenizer != args.model: if args.backend == "mii" and args.tokenizer != args.model:
raise ValueError( raise ValueError("Tokenizer must be the same as the model for MII backend.")
"Tokenizer must be the same as the model for MII backend.")
# --data-parallel is not supported currently. # --data-parallel is not supported currently.
# https://github.com/vllm-project/vllm/issues/16222 # https://github.com/vllm-project/vllm/issues/16222
if args.data_parallel_size > 1: if args.data_parallel_size > 1:
raise ValueError( raise ValueError(
"Data parallel is not supported in offline benchmark, \ "Data parallel is not supported in offline benchmark, \
please use benchmark serving instead") please use benchmark serving instead"
)
if __name__ == "__main__": if __name__ == "__main__":
parser = FlexibleArgumentParser(description="Benchmark the throughput.") parser = FlexibleArgumentParser(description="Benchmark the throughput.")
parser.add_argument("--backend", parser.add_argument(
type=str, "--backend",
choices=["vllm", "hf", "mii", "vllm-chat"], type=str,
default="vllm") choices=["vllm", "hf", "mii", "vllm-chat"],
default="vllm",
)
parser.add_argument( parser.add_argument(
"--dataset-name", "--dataset-name",
type=str, type=str,
choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"], choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"],
help="Name of the dataset to benchmark on.", help="Name of the dataset to benchmark on.",
default="sharegpt") default="sharegpt",
)
parser.add_argument( parser.add_argument(
"--dataset", "--dataset",
type=str, type=str,
@ -550,57 +616,70 @@ if __name__ == "__main__":
help="Path to the ShareGPT dataset, will be deprecated in\ help="Path to the ShareGPT dataset, will be deprecated in\
the next release. The dataset is expected to " the next release. The dataset is expected to "
"be a json in form of list[dict[..., conversations: " "be a json in form of list[dict[..., conversations: "
"list[dict[..., value: <prompt_or_response>]]]]") "list[dict[..., value: <prompt_or_response>]]]]",
parser.add_argument("--dataset-path", )
type=str,
default=None,
help="Path to the dataset")
parser.add_argument("--input-len",
type=int,
default=None,
help="Input prompt length for each request")
parser.add_argument("--output-len",
type=int,
default=None,
help="Output length for each request. Overrides the "
"output length from the dataset.")
parser.add_argument("--n",
type=int,
default=1,
help="Number of generated sequences per prompt.")
parser.add_argument("--num-prompts",
type=int,
default=1000,
help="Number of prompts to process.")
parser.add_argument("--hf-max-batch-size",
type=int,
default=None,
help="Maximum batch size for HF backend.")
parser.add_argument( parser.add_argument(
'--output-json', "--dataset-path", type=str, default=None, help="Path to the dataset"
)
parser.add_argument(
"--input-len",
type=int,
default=None,
help="Input prompt length for each request",
)
parser.add_argument(
"--output-len",
type=int,
default=None,
help="Output length for each request. Overrides the "
"output length from the dataset.",
)
parser.add_argument(
"--n", type=int, default=1, help="Number of generated sequences per prompt."
)
parser.add_argument(
"--num-prompts", type=int, default=1000, help="Number of prompts to process."
)
parser.add_argument(
"--hf-max-batch-size",
type=int,
default=None,
help="Maximum batch size for HF backend.",
)
parser.add_argument(
"--output-json",
type=str, type=str,
default=None, default=None,
help='Path to save the throughput results in JSON format.') help="Path to save the throughput results in JSON format.",
parser.add_argument("--async-engine", )
action='store_true', parser.add_argument(
default=False, "--async-engine",
help="Use vLLM async engine rather than LLM class.") action="store_true",
parser.add_argument("--disable-frontend-multiprocessing", default=False,
action='store_true', help="Use vLLM async engine rather than LLM class.",
default=False, )
help="Disable decoupled async engine frontend.") parser.add_argument(
"--disable-frontend-multiprocessing",
action="store_true",
default=False,
help="Disable decoupled async engine frontend.",
)
parser.add_argument( parser.add_argument(
"--disable-detokenize", "--disable-detokenize",
action="store_true", action="store_true",
help=("Do not detokenize the response (i.e. do not include " help=(
"detokenization time in the measurement)")) "Do not detokenize the response (i.e. do not include "
"detokenization time in the measurement)"
),
)
# LoRA # LoRA
parser.add_argument( parser.add_argument(
"--lora-path", "--lora-path",
type=str, type=str,
default=None, default=None,
help="Path to the lora adapters to use. This can be an absolute path, " help="Path to the LoRA adapters to use. This can be an absolute path, "
"a relative path, or a Hugging Face model identifier.") "a relative path, or a Hugging Face model identifier.",
)
parser.add_argument( parser.add_argument(
"--prefix-len", "--prefix-len",
type=int, type=int,
@ -614,7 +693,8 @@ if __name__ == "__main__":
f"prefix_len (default: {SonnetDataset.DEFAULT_PREFIX_LEN}) " f"prefix_len (default: {SonnetDataset.DEFAULT_PREFIX_LEN}) "
"controls how much of the input is fixed lines versus " "controls how much of the input is fixed lines versus "
"random lines, but the total input length remains approximately " "random lines, but the total input length remains approximately "
"input_len tokens.") "input_len tokens.",
)
# random dataset # random dataset
parser.add_argument( parser.add_argument(
"--random-range-ratio", "--random-range-ratio",
@ -628,14 +708,12 @@ if __name__ == "__main__":
) )
# hf dtaset # hf dtaset
parser.add_argument("--hf-subset", parser.add_argument(
type=str, "--hf-subset", type=str, default=None, help="Subset of the HF dataset."
default=None, )
help="Subset of the HF dataset.") parser.add_argument(
parser.add_argument("--hf-split", "--hf-split", type=str, default=None, help="Split of the HF dataset."
type=str, )
default=None,
help="Split of the HF dataset.")
parser = AsyncEngineArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()

View File

@ -7,9 +7,9 @@ import os
from typing import Any from typing import Any
def convert_to_pytorch_benchmark_format(args: argparse.Namespace, def convert_to_pytorch_benchmark_format(
metrics: dict[str, list], args: argparse.Namespace, metrics: dict[str, list], extra_info: dict[str, Any]
extra_info: dict[str, Any]) -> list: ) -> list:
""" """
Save the benchmark results in the format used by PyTorch OSS benchmark with Save the benchmark results in the format used by PyTorch OSS benchmark with
on metric per record on metric per record
@ -37,12 +37,12 @@ def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
}, },
} }
tp = record["benchmark"]["extra_info"]["args"].get( tp = record["benchmark"]["extra_info"]["args"].get("tensor_parallel_size")
"tensor_parallel_size")
# Save tensor_parallel_size parameter if it's part of the metadata # Save tensor_parallel_size parameter if it's part of the metadata
if not tp and "tensor_parallel_size" in extra_info: if not tp and "tensor_parallel_size" in extra_info:
record["benchmark"]["extra_info"]["args"][ record["benchmark"]["extra_info"]["args"]["tensor_parallel_size"] = (
"tensor_parallel_size"] = extra_info["tensor_parallel_size"] extra_info["tensor_parallel_size"]
)
records.append(record) records.append(record)
@ -50,7 +50,6 @@ def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
class InfEncoder(json.JSONEncoder): class InfEncoder(json.JSONEncoder):
def clear_inf(self, o: Any): def clear_inf(self, o: Any):
if isinstance(o, dict): if isinstance(o, dict):
return {k: self.clear_inf(v) for k, v in o.items()} return {k: self.clear_inf(v) for k, v in o.items()}

View File

@ -23,8 +23,9 @@ DEFAULT_TP_SIZES = [1]
# bench # bench
def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, def bench_fn(
**kwargs) -> TMeasurement: label: str, sub_label: str, description: str, fn: Callable, *args, **kwargs
) -> TMeasurement:
min_run_time = 1 min_run_time = 1
globals = { globals = {
@ -41,16 +42,18 @@ def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
).blocked_autorange(min_run_time=min_run_time) ).blocked_autorange(min_run_time=min_run_time)
def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, def bench_int8(
sub_label: str) -> Iterable[TMeasurement]: dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str
) -> Iterable[TMeasurement]:
assert dtype == torch.int8 assert dtype == torch.int8
b_compressed, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k) b_compressed, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k)
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b, out = ops.cutlass_scaled_sparse_mm(
torch.bfloat16) a, b_compressed, e, scale_a, scale_b, torch.bfloat16
)
out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16) out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)
if not torch.allclose(out, out_ref): if not torch.allclose(out, out_ref):
@ -63,54 +66,107 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
timers = [] timers = []
# pytorch impl - bfloat16 # pytorch impl - bfloat16
timers.append( timers.append(
bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", bench_fn(
torch.mm, a.to(dtype=torch.bfloat16), label,
b.to(dtype=torch.bfloat16))) sub_label,
"pytorch_bf16_bf16_bf16_matmul-no-scales",
torch.mm,
a.to(dtype=torch.bfloat16),
b.to(dtype=torch.bfloat16),
)
)
# pytorch impl - float16 # pytorch impl - float16
timers.append( timers.append(
bench_fn(label, sub_label, bench_fn(
"pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm, label,
a.to(dtype=torch.float16), b.to(dtype=torch.float16))) sub_label,
"pytorch_fp16_fp16_fp16_matmul-no-scales",
torch.mm,
a.to(dtype=torch.float16),
b.to(dtype=torch.float16),
)
)
# cutlass impl # cutlass impl
timers.append( timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm", bench_fn(
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, label,
torch.bfloat16)) sub_label,
"cutlass_i8_i8_bf16_scaled_mm",
ops.cutlass_scaled_mm,
a,
b,
scale_a,
scale_b,
torch.bfloat16,
)
)
# cutlass with bias # cutlass with bias
timers.append( timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias", bench_fn(
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16, label,
bias)) sub_label,
"cutlass_i8_i8_bf16_scaled_mm_bias",
ops.cutlass_scaled_mm,
a,
b,
scale_a,
scale_b,
torch.bfloat16,
bias,
)
)
# cutlass sparse impl # cutlass sparse impl
timers.append( timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm", bench_fn(
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, label,
scale_b, torch.bfloat16)) sub_label,
"cutlass_i8_i8_bf16_scaled_sparse_mm",
ops.cutlass_scaled_sparse_mm,
a,
b_compressed,
e,
scale_a,
scale_b,
torch.bfloat16,
)
)
# cutlass sparse with bias # cutlass sparse with bias
timers.append( timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm_bias", bench_fn(
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, label,
scale_b, torch.bfloat16, bias)) sub_label,
"cutlass_i8_i8_bf16_scaled_sparse_mm_bias",
ops.cutlass_scaled_sparse_mm,
a,
b_compressed,
e,
scale_a,
scale_b,
torch.bfloat16,
bias,
)
)
return timers return timers
def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, def bench_fp8(
sub_label: str) -> Iterable[TMeasurement]: dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str
) -> Iterable[TMeasurement]:
assert dtype == torch.float8_e4m3fn assert dtype == torch.float8_e4m3fn
b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k)
k)
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b, out = ops.cutlass_scaled_sparse_mm(
torch.bfloat16) a, b_compressed, e, scale_a, scale_b, torch.bfloat16
)
out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16) out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)
if not torch.allclose(out, out_ref): if not torch.allclose(out, out_ref):
@ -124,97 +180,165 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
# pytorch impl w. bf16 # pytorch impl w. bf16
timers.append( timers.append(
bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", bench_fn(
torch.mm, a.to(dtype=torch.bfloat16, device="cuda"), label,
b.to(dtype=torch.bfloat16, device="cuda"))) sub_label,
"pytorch_bf16_bf16_bf16_matmul-no-scales",
torch.mm,
a.to(dtype=torch.bfloat16, device="cuda"),
b.to(dtype=torch.bfloat16, device="cuda"),
)
)
# pytorch impl: bf16 output, without fp8 fast accum # pytorch impl: bf16 output, without fp8 fast accum
timers.append( timers.append(
bench_fn(label, bench_fn(
sub_label, label,
"pytorch_fp8_fp8_bf16_scaled_mm", sub_label,
torch._scaled_mm, "pytorch_fp8_fp8_bf16_scaled_mm",
a, torch._scaled_mm,
b, a,
scale_a=scale_a, b,
scale_b=scale_b, scale_a=scale_a,
out_dtype=torch.bfloat16)) scale_b=scale_b,
out_dtype=torch.bfloat16,
)
)
# pytorch impl: bf16 output, with fp8 fast accum # pytorch impl: bf16 output, with fp8 fast accum
timers.append( timers.append(
bench_fn(label, bench_fn(
sub_label, label,
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum", sub_label,
torch._scaled_mm, "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum",
a, torch._scaled_mm,
b, a,
scale_a=scale_a, b,
scale_b=scale_b, scale_a=scale_a,
out_dtype=torch.bfloat16, scale_b=scale_b,
use_fast_accum=True)) out_dtype=torch.bfloat16,
use_fast_accum=True,
)
)
# pytorch impl: fp16 output, without fp8 fast accum # pytorch impl: fp16 output, without fp8 fast accum
timers.append( timers.append(
bench_fn(label, bench_fn(
sub_label, label,
"pytorch_fp8_fp8_fp16_scaled_mm", sub_label,
torch._scaled_mm, "pytorch_fp8_fp8_fp16_scaled_mm",
a, torch._scaled_mm,
b, a,
scale_a=scale_a, b,
scale_b=scale_b, scale_a=scale_a,
out_dtype=torch.float16)) scale_b=scale_b,
out_dtype=torch.float16,
)
)
# pytorch impl: fp16 output, with fp8 fast accum # pytorch impl: fp16 output, with fp8 fast accum
timers.append( timers.append(
bench_fn(label, bench_fn(
sub_label, label,
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum", sub_label,
torch._scaled_mm, "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum",
a, torch._scaled_mm,
b, a,
scale_a=scale_a, b,
scale_b=scale_b, scale_a=scale_a,
out_dtype=torch.float16, scale_b=scale_b,
use_fast_accum=True)) out_dtype=torch.float16,
use_fast_accum=True,
)
)
# cutlass impl: bf16 output # cutlass impl: bf16 output
timers.append( timers.append(
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm", bench_fn(
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, label,
torch.bfloat16)) sub_label,
"cutlass_fp8_fp8_bf16_scaled_mm",
ops.cutlass_scaled_mm,
a,
b,
scale_a,
scale_b,
torch.bfloat16,
)
)
# cutlass impl: bf16 output # cutlass impl: bf16 output
timers.append( timers.append(
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_sparse_mm", bench_fn(
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, label,
scale_b, torch.bfloat16)) sub_label,
"cutlass_fp8_fp8_bf16_scaled_sparse_mm",
ops.cutlass_scaled_sparse_mm,
a,
b_compressed,
e,
scale_a,
scale_b,
torch.bfloat16,
)
)
# cutlass impl: fp16 output # cutlass impl: fp16 output
timers.append( timers.append(
bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_sparse_mm", bench_fn(
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, label,
scale_b, torch.float16)) sub_label,
"cutlass_fp8_fp8_fp16_scaled_sparse_mm",
ops.cutlass_scaled_sparse_mm,
a,
b_compressed,
e,
scale_a,
scale_b,
torch.float16,
)
)
# cutlass impl: bf16 output, with bias # cutlass impl: bf16 output, with bias
timers.append( timers.append(
bench_fn(label, sub_label, bench_fn(
"cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias", label,
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, sub_label,
scale_b, torch.bfloat16, bias)) "cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias",
ops.cutlass_scaled_sparse_mm,
a,
b_compressed,
e,
scale_a,
scale_b,
torch.bfloat16,
bias,
)
)
# cutlass impl: fp16 output, with bias # cutlass impl: fp16 output, with bias
timers.append( timers.append(
bench_fn(label, sub_label, bench_fn(
"cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias", label,
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, sub_label,
scale_b, torch.float16, bias.to(dtype=torch.float16))) "cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias",
ops.cutlass_scaled_sparse_mm,
a,
b_compressed,
e,
scale_a,
scale_b,
torch.float16,
bias.to(dtype=torch.float16),
)
)
return timers return timers
def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str, def bench(
sub_label: str) -> Iterable[TMeasurement]: dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str
) -> Iterable[TMeasurement]:
if dtype == torch.int8: if dtype == torch.int8:
return bench_int8(dtype, m, k, n, label, sub_label) return bench_int8(dtype, m, k, n, label, sub_label)
if dtype == torch.float8_e4m3fn: if dtype == torch.float8_e4m3fn:
@ -228,12 +352,12 @@ def print_timers(timers: Iterable[TMeasurement]):
compare.print() compare.print()
def run(dtype: torch.dtype, def run(
MKNs: Iterable[tuple[int, int, int]]) -> Iterable[TMeasurement]: dtype: torch.dtype, MKNs: Iterable[tuple[int, int, int]]
) -> Iterable[TMeasurement]:
results = [] results = []
for m, k, n in MKNs: for m, k, n in MKNs:
timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", f"MKN=({m}x{k}x{n})")
f"MKN=({m}x{k}x{n})")
print_timers(timers) print_timers(timers)
results.extend(timers) results.extend(timers)
@ -241,10 +365,12 @@ def run(dtype: torch.dtype,
# output makers # output makers
def make_output(data: Iterable[TMeasurement], def make_output(
MKNs: Iterable[tuple[int, int, int]], data: Iterable[TMeasurement],
base_description: str, MKNs: Iterable[tuple[int, int, int]],
timestamp=None): base_description: str,
timestamp=None,
):
print(f"== All Results {base_description} ====") print(f"== All Results {base_description} ====")
print_timers(data) print_timers(data)
@ -258,8 +384,7 @@ def make_output(data: Iterable[TMeasurement],
def run_square_bench(args): def run_square_bench(args):
dim_sizes = list( dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment))
range(args.dim_start, args.dim_end + 1, args.dim_increment))
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
data = run(args.dtype, MKNs) data = run(args.dtype, MKNs)
@ -319,7 +444,7 @@ def run_model_bench(args):
pkl.dump(all_data, f) pkl.dump(all_data, f)
if __name__ == '__main__': if __name__ == "__main__":
def to_torch_dtype(dt): def to_torch_dtype(dt):
if dt == "int8": if dt == "int8":
@ -344,12 +469,15 @@ Benchmark Cutlass GEMM.
Output: Output:
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
""", # noqa: E501 """, # noqa: E501
formatter_class=argparse.RawTextHelpFormatter) formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument("--dtype", parser.add_argument(
type=to_torch_dtype, "--dtype",
required=True, type=to_torch_dtype,
help="Available options are ['int8', 'fp8']") required=True,
help="Available options are ['int8', 'fp8']",
)
subparsers = parser.add_subparsers(dest="cmd") subparsers = parser.add_subparsers(dest="cmd")
square_parser = subparsers.add_parser("square_bench") square_parser = subparsers.add_parser("square_bench")
@ -368,19 +496,19 @@ Benchmark Cutlass GEMM.
range_parser.set_defaults(func=run_range_bench) range_parser.set_defaults(func=run_range_bench)
model_parser = subparsers.add_parser("model_bench") model_parser = subparsers.add_parser("model_bench")
model_parser.add_argument("--models", model_parser.add_argument(
nargs="+", "--models",
type=str, nargs="+",
default=DEFAULT_MODELS, type=str,
choices=WEIGHT_SHAPES.keys()) default=DEFAULT_MODELS,
model_parser.add_argument("--tp-sizes", choices=WEIGHT_SHAPES.keys(),
nargs="+", )
type=int, model_parser.add_argument(
default=DEFAULT_TP_SIZES) "--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES
model_parser.add_argument("--batch-sizes", )
nargs="+", model_parser.add_argument(
type=int, "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
default=DEFAULT_BATCH_SIZES) )
model_parser.set_defaults(func=run_model_bench) model_parser.set_defaults(func=run_model_bench)
args = parser.parse_args() args = parser.parse_args()

View File

@ -10,8 +10,9 @@ import vllm._custom_ops as ops
def to_fp8(tensor: torch.Tensor) -> torch.Tensor: def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
finfo = torch.finfo(torch.float8_e4m3fn) finfo = torch.finfo(torch.float8_e4m3fn)
return torch.round(tensor.clamp( return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to(
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) dtype=torch.float8_e4m3fn
)
def to_int8(tensor: torch.Tensor) -> torch.Tensor: def to_int8(tensor: torch.Tensor) -> torch.Tensor:
@ -26,10 +27,11 @@ def to_fp16(tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(dtype=torch.float16) return tensor.to(dtype=torch.float16)
def make_rand_tensors(dtype: torch.dtype, m: int, n: int, def make_rand_tensors(
k: int) -> tuple[torch.Tensor, torch.Tensor]: dtype: torch.dtype, m: int, n: int, k: int
a = torch.randn((m, k), device='cuda') * 5 ) -> tuple[torch.Tensor, torch.Tensor]:
b = torch.randn((n, k), device='cuda').t() * 5 a = torch.randn((m, k), device="cuda") * 5
b = torch.randn((n, k), device="cuda").t() * 5
if dtype == torch.int8: if dtype == torch.int8:
return to_int8(a), to_int8(b) return to_int8(a), to_int8(b)
@ -49,9 +51,7 @@ def prune_to_2_4(tensor):
# Create binary mask # Create binary mask
mask = torch.zeros_like(reshaped) mask = torch.zeros_like(reshaped)
mask.scatter_(dim=1, mask.scatter_(dim=1, index=indices, src=torch.ones_like(indices, dtype=mask.dtype))
index=indices,
src=torch.ones_like(indices, dtype=mask.dtype))
# Apply mask and reshape back # Apply mask and reshape back
pruned = reshaped * mask pruned = reshaped * mask
@ -62,10 +62,11 @@ def prune_to_2_4(tensor):
return pruned.reshape(original_shape) return pruned.reshape(original_shape)
def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int, def make_rand_sparse_tensors(
k: int) -> tuple[torch.Tensor, torch.Tensor]: dtype: torch.dtype, m: int, n: int, k: int
a = torch.randn((m, k), device='cuda') * 5 ) -> tuple[torch.Tensor, torch.Tensor]:
b = torch.randn((n, k), device='cuda').t() * 5 a = torch.randn((m, k), device="cuda") * 5
b = torch.randn((n, k), device="cuda").t() * 5
b = prune_to_2_4(b.t()).t() b = prune_to_2_4(b.t()).t()
@ -86,9 +87,9 @@ def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int,
return b_compressed, e, a, b return b_compressed, e, a, b
def make_n_rand_sparse_tensors(num_tensors: int, dtype: torch.dtype, def make_n_rand_sparse_tensors(
m: int, n: int, k: int) -> \ num_tensors: int, dtype: torch.dtype, m: int, n: int, k: int
tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]: ) -> tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]:
ABs = [] ABs = []
for _ in range(num_tensors): for _ in range(num_tensors):
b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k) b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k)

View File

@ -16,7 +16,8 @@ from weight_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
w8a8_block_fp8_matmul) w8a8_block_fp8_matmul,
)
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
@ -25,8 +26,9 @@ DEFAULT_TP_SIZES = [1]
# bench # bench
def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, def bench_fn(
**kwargs) -> TMeasurement: label: str, sub_label: str, description: str, fn: Callable, *args, **kwargs
) -> TMeasurement:
min_run_time = 1 min_run_time = 1
globals = { globals = {
@ -44,45 +46,48 @@ def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
def bench_int8( def bench_int8(
dtype: torch.dtype, dtype: torch.dtype,
m: int, m: int,
k: int, k: int,
n: int, n: int,
label: str, label: str,
sub_label: str, sub_label: str,
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]: bench_kernels: Optional[list[str]] = None,
) -> Iterable[TMeasurement]:
"""Benchmark INT8-based kernels.""" """Benchmark INT8-based kernels."""
assert dtype == torch.int8 assert dtype == torch.int8
a, b = make_rand_tensors(torch.int8, m, n, k) a, b = make_rand_tensors(torch.int8, m, n, k)
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
azp = torch.zeros((m, ), device="cuda", dtype=torch.int32) azp = torch.zeros((m,), device="cuda", dtype=torch.int32)
azp_adj = torch.zeros((n, ), device="cuda", dtype=torch.int32) azp_adj = torch.zeros((n,), device="cuda", dtype=torch.int32)
bench_fns = { bench_fns = {
"pytorch_bf16_bf16_bf16_matmul-no-scales": "pytorch_bf16_bf16_bf16_matmul-no-scales": lambda: torch.mm(
lambda: torch.mm(a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16) a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
), ),
"pytorch_fp16_fp16_fp16_matmul-no-scales": "pytorch_fp16_fp16_fp16_matmul-no-scales": lambda: torch.mm(
lambda: torch.mm(a.to(dtype=torch.float16), b.to(dtype=torch.float16)), a.to(dtype=torch.float16), b.to(dtype=torch.float16)
"cutlass_i8_i8_bf16_scaled_mm": ),
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16), "cutlass_i8_i8_bf16_scaled_mm": lambda: ops.cutlass_scaled_mm(
"cutlass_i8_i8_bf16_scaled_mm_bias": a, b, scale_a, scale_b, torch.bfloat16
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16, ),
bias), "cutlass_i8_i8_bf16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
"cutlass_i8_i8_bf16_scaled_mm_azp": a, b, scale_a, scale_b, torch.bfloat16, bias
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch. ),
bfloat16, azp_adj), "cutlass_i8_i8_bf16_scaled_mm_azp": lambda: ops.cutlass_scaled_mm_azp(
"cutlass_i8_i8_bf16_scaled_mm_azp_bias": a, b, scale_a, scale_b, torch.bfloat16, azp_adj
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch. ),
bfloat16, azp_adj, None, bias), "cutlass_i8_i8_bf16_scaled_mm_azp_bias": lambda: ops.cutlass_scaled_mm_azp(
"cutlass_i8_i8_bf16_scaled_mm_azp_pt": a, b, scale_a, scale_b, torch.bfloat16, azp_adj, None, bias
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch. ),
bfloat16, azp_adj, azp), "cutlass_i8_i8_bf16_scaled_mm_azp_pt": lambda: ops.cutlass_scaled_mm_azp(
"cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias": a, b, scale_a, scale_b, torch.bfloat16, azp_adj, azp
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch. ),
bfloat16, azp_adj, azp, bias), "cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias": lambda: ops.cutlass_scaled_mm_azp(
a, b, scale_a, scale_b, torch.bfloat16, azp_adj, azp, bias
),
} }
timers = [] timers = []
@ -96,73 +101,73 @@ def bench_int8(
def bench_fp8( def bench_fp8(
dtype: torch.dtype, dtype: torch.dtype,
m: int, m: int,
k: int, k: int,
n: int, n: int,
label: str, label: str,
sub_label: str, sub_label: str,
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]: bench_kernels: Optional[list[str]] = None,
) -> Iterable[TMeasurement]:
"""Benchmark FP8-based kernels.""" """Benchmark FP8-based kernels."""
assert dtype == torch.float8_e4m3fn assert dtype == torch.float8_e4m3fn
a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k) a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k)
a_cont = a.contiguous() a_cont = a.contiguous()
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
block_scale_a = torch.rand((m, k // 128),
device="cuda", def ceil_div(x: int, y: int) -> int:
dtype=torch.float32) return (x + y - 1) // y
block_scale_b = torch.rand((k // 128, n // 128),
device="cuda", block_scale_a = torch.rand(
dtype=torch.float32) (m, ceil_div(k, 128)), device="cuda", dtype=torch.float32
)
block_scale_b = torch.rand(
ceil_div(k, 128), ceil_div(n, 128), device="cuda", dtype=torch.float32
)
block_scale_a_M_major = block_scale_a.t().contiguous().t() block_scale_a_M_major = block_scale_a.t().contiguous().t()
block_scale_b_K_major = block_scale_b.t().contiguous().t() block_scale_b_K_major = block_scale_b.t().contiguous().t()
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
print(m, k, n) print(m, k, n)
bench_fns = { bench_fns = {
"pytorch_bf16_bf16_bf16_matmul-no-scales": "pytorch_bf16_bf16_bf16_matmul-no-scales": lambda: torch.mm(
lambda: torch.mm(a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16) a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
), ),
"pytorch_fp16_fp16_fp16_matmul-no-scales": "pytorch_fp16_fp16_fp16_matmul-no-scales": lambda: torch.mm(
lambda: torch.mm(a.to(dtype=torch.float16), b.to(dtype=torch.float16)), a.to(dtype=torch.float16), b.to(dtype=torch.float16)
"pytorch_fp8_fp8_fp16_scaled_mm": ),
lambda: torch._scaled_mm( "pytorch_fp8_fp8_fp16_scaled_mm": lambda: torch._scaled_mm(
a, b, scale_a, scale_b, out_dtype=torch.float16), a, b, scale_a, scale_b, out_dtype=torch.float16
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum": ),
lambda: torch._scaled_mm(a, "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum": lambda: torch._scaled_mm(
b, a, b, scale_a, scale_b, out_dtype=torch.float16, use_fast_accum=True
scale_a, ),
scale_b, "pytorch_fp8_fp8_bf16_scaled_mm": lambda: torch._scaled_mm(
out_dtype=torch.float16, a, b, scale_a, scale_b, out_dtype=torch.bfloat16
use_fast_accum=True), ),
"pytorch_fp8_fp8_bf16_scaled_mm": "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum": lambda: torch._scaled_mm(
lambda: torch._scaled_mm( a, b, scale_a, scale_b, out_dtype=torch.bfloat16, use_fast_accum=True
a, b, scale_a, scale_b, out_dtype=torch.bfloat16), ),
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum": "cutlass_fp8_fp8_bf16_scaled_mm": lambda: ops.cutlass_scaled_mm(
lambda: torch._scaled_mm(a, a, b, scale_a, scale_b, torch.bfloat16
b, ),
scale_a, "cutlass_fp8_fp8_fp16_scaled_mm": lambda: ops.cutlass_scaled_mm(
scale_b, a, b, scale_a, scale_b, torch.float16
out_dtype=torch.bfloat16, ),
use_fast_accum=True), "cutlass_fp8_fp8_bf16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
"cutlass_fp8_fp8_bf16_scaled_mm": a, b, scale_a, scale_b, torch.bfloat16, bias
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16), ),
"cutlass_fp8_fp8_fp16_scaled_mm": "cutlass_fp8_fp8_fp16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.float16), a, b, scale_a, scale_b, torch.float16, bias.to(dtype=torch.float16)
"cutlass_fp8_fp8_bf16_scaled_mm_bias": ),
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16, "triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_block_fp8_matmul(
bias), a_cont, b.t(), block_scale_a, block_scale_b.t(), (128, 128)
"cutlass_fp8_fp8_fp16_scaled_mm_bias": ),
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.float16, "cutlass_fp8_fp8_fp16_scaled_mm_blockwise": lambda: ops.cutlass_scaled_mm(
bias.to(dtype=torch.float16)), a, b, block_scale_a_M_major, block_scale_b_K_major, torch.float16
"triton_fp8_fp8_fp16_scaled_mm_blockwise": ),
lambda: w8a8_block_fp8_matmul(a_cont, b.t(), block_scale_a,
block_scale_b.t(), (128, 128)),
"cutlass_fp8_fp8_fp16_scaled_mm_blockwise":
lambda: ops.cutlass_scaled_mm(a, b, block_scale_a_M_major,
block_scale_b_K_major, torch.float16),
} }
timers = [] timers = []
@ -175,13 +180,15 @@ def bench_fp8(
return timers return timers
def bench(dtype: torch.dtype, def bench(
m: int, dtype: torch.dtype,
k: int, m: int,
n: int, k: int,
label: str, n: int,
sub_label: str, label: str,
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]: sub_label: str,
bench_kernels: Optional[list[str]] = None,
) -> Iterable[TMeasurement]:
if dtype == torch.int8: if dtype == torch.int8:
return bench_int8(dtype, m, k, n, label, sub_label, bench_kernels) return bench_int8(dtype, m, k, n, label, sub_label, bench_kernels)
if dtype == torch.float8_e4m3fn: if dtype == torch.float8_e4m3fn:
@ -195,27 +202,33 @@ def print_timers(timers: Iterable[TMeasurement]):
compare.print() compare.print()
def run(dtype: torch.dtype, def run(
MKNs: Iterable[tuple[int, int, int]], dtype: torch.dtype,
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]: MKNs: Iterable[tuple[int, int, int]],
bench_kernels: Optional[list[str]] = None,
) -> Iterable[TMeasurement]:
results = [] results = []
for m, k, n in MKNs: for m, k, n in MKNs:
timers = bench(dtype, timers = bench(
m, dtype,
k, m,
n, k,
f"scaled-{dtype}-gemm", n,
f"MKN=({m}x{k}x{n})", f"scaled-{dtype}-gemm",
bench_kernels=bench_kernels) f"MKN=({m}x{k}x{n})",
bench_kernels=bench_kernels,
)
print_timers(timers) print_timers(timers)
results.extend(timers) results.extend(timers)
return results return results
def make_output(data: Iterable[TMeasurement], def make_output(
MKNs: Iterable[tuple[int, int, int]], data: Iterable[TMeasurement],
base_description: str, MKNs: Iterable[tuple[int, int, int]],
timestamp=None): base_description: str,
timestamp=None,
):
print(f"== All Results {base_description} ====") print(f"== All Results {base_description} ====")
print_timers(data) print_timers(data)
@ -226,8 +239,7 @@ def make_output(data: Iterable[TMeasurement],
def run_square_bench(args): def run_square_bench(args):
dim_sizes = list( dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment))
range(args.dim_start, args.dim_end + 1, args.dim_increment))
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
data = run(args.dtype, MKNs, bench_kernels=args.kernels) data = run(args.dtype, MKNs, bench_kernels=args.kernels)
make_output(data, MKNs, f"square_bench-{args.dtype}") make_output(data, MKNs, f"square_bench-{args.dtype}")
@ -285,7 +297,7 @@ def run_model_bench(args):
pkl.dump(all_data, f) pkl.dump(all_data, f)
if __name__ == '__main__': if __name__ == "__main__":
def to_torch_dtype(dt): def to_torch_dtype(dt):
if dt == "int8": if dt == "int8":
@ -310,19 +322,21 @@ Benchmark Cutlass GEMM.
Output: Output:
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
""", # noqa: E501 """, # noqa: E501
formatter_class=argparse.RawTextHelpFormatter) formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument("--dtype", parser.add_argument(
type=to_torch_dtype, "--dtype",
required=True, type=to_torch_dtype,
help="Available options are ['int8', 'fp8']") required=True,
help="Available options are ['int8', 'fp8']",
)
parser.add_argument( parser.add_argument(
"--kernels", "--kernels",
nargs="+", nargs="+",
type=str, type=str,
default=None, default=None,
help= help="Exact names of the kernels to benchmark. If not set, runs all kernels.",
"Exact names of the kernels to benchmark. If not set, runs all kernels."
) )
subparsers = parser.add_subparsers(dest="cmd") subparsers = parser.add_subparsers(dest="cmd")
@ -343,19 +357,19 @@ Benchmark Cutlass GEMM.
range_parser.set_defaults(func=run_range_bench) range_parser.set_defaults(func=run_range_bench)
model_parser = subparsers.add_parser("model_bench") model_parser = subparsers.add_parser("model_bench")
model_parser.add_argument("--models", model_parser.add_argument(
nargs="+", "--models",
type=str, nargs="+",
default=DEFAULT_MODELS, type=str,
choices=WEIGHT_SHAPES.keys()) default=DEFAULT_MODELS,
model_parser.add_argument("--tp-sizes", choices=WEIGHT_SHAPES.keys(),
nargs="+", )
type=int, model_parser.add_argument(
default=DEFAULT_TP_SIZES) "--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES
model_parser.add_argument("--batch-sizes", )
nargs="+", model_parser.add_argument(
type=int, "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
default=DEFAULT_BATCH_SIZES) )
model_parser.set_defaults(func=run_model_bench) model_parser.set_defaults(func=run_model_bench)
args = parser.parse_args() args = parser.parse_args()

View File

@ -12,39 +12,37 @@ app = Quart(__name__)
async def forward_request(url, data): async def forward_request(url, data):
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
headers = { headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" async with session.post(url=url, json=data, headers=headers) as response:
}
async with session.post(url=url, json=data,
headers=headers) as response:
if response.status == 200: if response.status == 200:
# if response.headers.get('Transfer-Encoding') == 'chunked': # if response.headers.get('Transfer-Encoding') == 'chunked':
if True: if True:
async for chunk_bytes in response.content.iter_chunked( async for chunk_bytes in response.content.iter_chunked(1024):
1024):
yield chunk_bytes yield chunk_bytes
else: else:
content = await response.read() content = await response.read()
yield content yield content
@app.route('/v1/completions', methods=['POST']) @app.route("/v1/completions", methods=["POST"])
async def handle_request(): async def handle_request():
try: try:
original_request_data = await request.get_json() original_request_data = await request.get_json()
prefill_request = original_request_data.copy() prefill_request = original_request_data.copy()
# change max_tokens = 1 to let it only do prefill # change max_tokens = 1 to let it only do prefill
prefill_request['max_tokens'] = 1 prefill_request["max_tokens"] = 1
# finish prefill # finish prefill
async for _ in forward_request('http://localhost:8100/v1/completions', async for _ in forward_request(
prefill_request): "http://localhost:8100/v1/completions", prefill_request
):
continue continue
# return decode # return decode
generator = forward_request('http://localhost:8200/v1/completions', generator = forward_request(
original_request_data) "http://localhost:8200/v1/completions", original_request_data
)
response = await make_response(generator) response = await make_response(generator)
response.timeout = None response.timeout = None
@ -53,11 +51,12 @@ async def handle_request():
except Exception as e: except Exception as e:
import sys import sys
import traceback import traceback
exc_info = sys.exc_info() exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server") print("Error occurred in disagg prefill proxy server")
print(e) print(e)
print("".join(traceback.format_exception(*exc_info))) print("".join(traceback.format_exception(*exc_info)))
if __name__ == '__main__': if __name__ == "__main__":
app.run(port=8000) app.run(port=8000)

View File

@ -8,7 +8,6 @@ from aiohttp import web
class RoundRobinProxy: class RoundRobinProxy:
def __init__(self, target_ports): def __init__(self, target_ports):
self.target_ports = target_ports self.target_ports = target_ports
self.port_cycle = itertools.cycle(self.target_ports) self.port_cycle = itertools.cycle(self.target_ports)
@ -21,14 +20,15 @@ class RoundRobinProxy:
try: try:
# Forward the request # Forward the request
async with session.request( async with session.request(
method=request.method, method=request.method,
url=target_url, url=target_url,
headers=request.headers, headers=request.headers,
data=request.content, data=request.content,
) as response: ) as response:
# Start sending the response # Start sending the response
resp = web.StreamResponse(status=response.status, resp = web.StreamResponse(
headers=response.headers) status=response.status, headers=response.headers
)
await resp.prepare(request) await resp.prepare(request)
# Stream the response content # Stream the response content
@ -45,11 +45,11 @@ class RoundRobinProxy:
async def main(): async def main():
proxy = RoundRobinProxy([8100, 8200]) proxy = RoundRobinProxy([8100, 8200])
app = web.Application() app = web.Application()
app.router.add_route('*', '/{path:.*}', proxy.handle_request) app.router.add_route("*", "/{path:.*}", proxy.handle_request)
runner = web.AppRunner(app) runner = web.AppRunner(app)
await runner.setup() await runner.setup()
site = web.TCPSite(runner, 'localhost', 8000) site = web.TCPSite(runner, "localhost", 8000)
await site.start() await site.start()
print("Proxy server started on http://localhost:8000") print("Proxy server started on http://localhost:8000")
@ -58,5 +58,5 @@ async def main():
await asyncio.Event().wait() await asyncio.Event().wait()
if __name__ == '__main__': if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@ -6,43 +6,41 @@ import matplotlib.pyplot as plt
import pandas as pd import pandas as pd
if __name__ == "__main__": if __name__ == "__main__":
data = [] data = []
for name in ['disagg_prefill', 'chunked_prefill']: for name in ["disagg_prefill", "chunked_prefill"]:
for qps in [2, 4, 6, 8]: for qps in [2, 4, 6, 8]:
with open(f"results/{name}-qps-{qps}.json") as f: with open(f"results/{name}-qps-{qps}.json") as f:
x = json.load(f) x = json.load(f)
x['name'] = name x["name"] = name
x['qps'] = qps x["qps"] = qps
data.append(x) data.append(x)
df = pd.DataFrame.from_dict(data) df = pd.DataFrame.from_dict(data)
dis_df = df[df['name'] == 'disagg_prefill'] dis_df = df[df["name"] == "disagg_prefill"]
chu_df = df[df['name'] == 'chunked_prefill'] chu_df = df[df["name"] == "chunked_prefill"]
plt.style.use('bmh') plt.style.use("bmh")
plt.rcParams['font.size'] = 20 plt.rcParams["font.size"] = 20
for key in [ for key in [
'mean_ttft_ms', 'median_ttft_ms', 'p99_ttft_ms', 'mean_itl_ms', "mean_ttft_ms",
'median_itl_ms', 'p99_itl_ms' "median_ttft_ms",
"p99_ttft_ms",
"mean_itl_ms",
"median_itl_ms",
"p99_itl_ms",
]: ]:
fig, ax = plt.subplots(figsize=(11, 7)) fig, ax = plt.subplots(figsize=(11, 7))
plt.plot(dis_df['qps'], plt.plot(
dis_df[key], dis_df["qps"], dis_df[key], label="disagg_prefill", marker="o", linewidth=4
label='disagg_prefill', )
marker='o', plt.plot(
linewidth=4) chu_df["qps"], chu_df[key], label="chunked_prefill", marker="o", linewidth=4
plt.plot(chu_df['qps'], )
chu_df[key],
label='chunked_prefill',
marker='o',
linewidth=4)
ax.legend() ax.legend()
ax.set_xlabel('QPS') ax.set_xlabel("QPS")
ax.set_ylabel(key) ax.set_ylabel(key)
ax.set_ylim(bottom=0) ax.set_ylim(bottom=0)
fig.savefig(f'results/{key}.png') fig.savefig(f"results/{key}.png")
plt.close(fig) plt.close(fig)

View File

@ -24,10 +24,12 @@ class bench_params_t:
dtype: torch.dtype dtype: torch.dtype
def description(self): def description(self):
return (f'N {self.num_tokens} ' return (
f'x D {self.hidden_size} ' f"N {self.num_tokens} "
f'x R {self.add_residual} ' f"x D {self.hidden_size} "
f'x DT {self.dtype}') f"x R {self.add_residual} "
f"x DT {self.dtype}"
)
def get_bench_params() -> list[bench_params_t]: def get_bench_params() -> list[bench_params_t]:
@ -38,15 +40,19 @@ def get_bench_params() -> list[bench_params_t]:
DTYPES = [torch.bfloat16, torch.float] DTYPES = [torch.bfloat16, torch.float]
combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES) combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES)
bench_params = list(map(lambda x: \ bench_params = list(
bench_params_t(x[0], x[1], x[2], x[3]), combinations)) map(lambda x: bench_params_t(x[0], x[1], x[2], x[3]), combinations)
)
return bench_params return bench_params
# Reference impls # Reference impls
def unfused_int8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor, def unfused_int8_impl(
residual: Optional[torch.Tensor], rms_norm_layer: RMSNorm,
quant_dtype: torch.dtype): x: torch.Tensor,
residual: Optional[torch.Tensor],
quant_dtype: torch.dtype,
):
# Norm # Norm
torch_out = None torch_out = None
if residual is None: if residual is None:
@ -58,9 +64,12 @@ def unfused_int8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor,
torch_out, _, _ = ops.scaled_int8_quant(torch_out) torch_out, _, _ = ops.scaled_int8_quant(torch_out)
def unfused_fp8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor, def unfused_fp8_impl(
residual: Optional[torch.Tensor], rms_norm_layer: RMSNorm,
quant_dtype: torch.dtype): x: torch.Tensor,
residual: Optional[torch.Tensor],
quant_dtype: torch.dtype,
):
# Norm # Norm
torch_out = None torch_out = None
if residual is None: if residual is None:
@ -73,22 +82,27 @@ def unfused_fp8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor,
def fused_impl( def fused_impl(
rms_norm_layer: RMSNorm, # this stores the weights rms_norm_layer: RMSNorm, # this stores the weights
x: torch.Tensor, x: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
quant_dtype: torch.dtype): quant_dtype: torch.dtype,
out, _ = ops.rms_norm_dynamic_per_token_quant(x, ):
rms_norm_layer.weight, out, _ = ops.rms_norm_dynamic_per_token_quant(
1e-6, x, rms_norm_layer.weight, 1e-6, quant_dtype, residual=residual
quant_dtype, )
residual=residual)
# Bench functions # Bench functions
def bench_fn(rms_norm_layer: RMSNorm, x: torch.Tensor, residual: torch.Tensor, def bench_fn(
quant_dtype: torch.dtype, label: str, sub_label: str, rms_norm_layer: RMSNorm,
fn: Callable, description: str) -> TMeasurement: x: torch.Tensor,
residual: torch.Tensor,
quant_dtype: torch.dtype,
label: str,
sub_label: str,
fn: Callable,
description: str,
) -> TMeasurement:
min_run_time = 1 min_run_time = 1
globals = { globals = {
@ -106,43 +120,81 @@ def bench_fn(rms_norm_layer: RMSNorm, x: torch.Tensor, residual: torch.Tensor,
description=description, description=description,
).blocked_autorange(min_run_time=min_run_time) ).blocked_autorange(min_run_time=min_run_time)
def bench(params: bench_params_t, label: str, sub_label: str) \
-> Iterable[TMeasurement]:
def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasurement]:
# Make inputs # Make inputs
layer = RMSNorm(params.hidden_size, 1e-6).to(dtype=params.dtype) layer = RMSNorm(params.hidden_size, 1e-6).to(dtype=params.dtype)
# Make weights # Make weights
layer.weight.data.normal_(mean=1.0, std=0.1) layer.weight.data.normal_(mean=1.0, std=0.1)
# Make inputs # Make inputs
scale = 1 / params.hidden_size scale = 1 / params.hidden_size
x = torch.randn(params.num_tokens, x = (
params.hidden_size, torch.randn(
dtype=params.dtype, params.num_tokens, params.hidden_size, dtype=params.dtype, device="cuda"
device='cuda') * scale )
residual = (torch.randn_like(x) * scale).to(device='cuda') \ * scale
if params.add_residual else None )
residual = (
(torch.randn_like(x) * scale).to(device="cuda") if params.add_residual else None
)
timers = [] timers = []
# unfused int8 impl. # unfused int8 impl.
timers.append( timers.append(
bench_fn(layer, x, residual, torch.int8, label, sub_label, bench_fn(
unfused_int8_impl, "unfused_int8_impl")) layer,
x,
residual,
torch.int8,
label,
sub_label,
unfused_int8_impl,
"unfused_int8_impl",
)
)
# unfused fp8 impl. # unfused fp8 impl.
timers.append( timers.append(
bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label, bench_fn(
unfused_fp8_impl, "unfused_fp8_impl")) layer,
x,
residual,
torch.float8_e4m3fn,
label,
sub_label,
unfused_fp8_impl,
"unfused_fp8_impl",
)
)
# fused int8 impl. # fused int8 impl.
timers.append( timers.append(
bench_fn(layer, x, residual, torch.int8, label, sub_label, fused_impl, bench_fn(
"fused_int8_impl")) layer,
x,
residual,
torch.int8,
label,
sub_label,
fused_impl,
"fused_int8_impl",
)
)
# fused fp8 impl. # fused fp8 impl.
timers.append( timers.append(
bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label, bench_fn(
fused_impl, "fused_fp8_impl")) layer,
x,
residual,
torch.float8_e4m3fn,
label,
sub_label,
fused_impl,
"fused_fp8_impl",
)
)
print_timers(timers) print_timers(timers)
@ -157,13 +209,12 @@ def print_timers(timers: Iterable[TMeasurement]):
def main(): def main():
torch.set_default_device('cuda') torch.set_default_device("cuda")
bench_params = get_bench_params() bench_params = get_bench_params()
timers = [] timers = []
for bp in tqdm(bench_params): for bp in tqdm(bench_params):
timers.extend( timers.extend(bench(bp, "rms-norm-dynamic-per-token-quant", bp.description()))
bench(bp, "rms-norm-dynamic-per-token-quant", bp.description()))
print_timers(timers) print_timers(timers)
# pickle all the results # pickle all the results
@ -172,5 +223,5 @@ def main():
pkl.dump(timers, f) pkl.dump(timers, f)
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View File

@ -9,32 +9,39 @@ import torch.nn.functional as F
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.aqlm import ( from vllm.model_executor.layers.quantization.aqlm import (
dequantize_weight, generic_dequantize_gemm, get_int_dtype, dequantize_weight,
optimized_dequantize_gemm) generic_dequantize_gemm,
get_int_dtype,
optimized_dequantize_gemm,
)
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
os.environ['CUDA_VISIBLE_DEVICES'] = '0' os.environ["CUDA_VISIBLE_DEVICES"] = "0"
def torch_mult( def torch_mult(
input: torch.Tensor, # [..., in_features] # [..., in_features]
weights: torch.Tensor, input: torch.Tensor,
scales: torch.Tensor, # [num_out_groups, 1, 1, 1] weights: torch.Tensor,
# [num_out_groups, 1, 1, 1]
scales: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
output = F.linear(input, weights) output = F.linear(input, weights)
return output return output
def dequant_out_scale( def dequant_out_scale(
input: torch.Tensor, # [..., in_features] # [..., in_features]
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] input: torch.Tensor,
codebooks: torch. # [num_out_groups, num_in_groups, num_codebooks]
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] codes: torch.IntTensor,
scales: torch.Tensor, # [num_out_groups, 1, 1, 1] # [num_codebooks, codebook_size, out_group_size, in_group_size]
codebooks: torch.Tensor,
# [num_out_groups, 1, 1, 1]
scales: torch.Tensor,
output_partition_sizes: torch.IntTensor, output_partition_sizes: torch.IntTensor,
bias: Optional[torch.Tensor], bias: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
if bias is None: if bias is None:
@ -46,40 +53,42 @@ def dequant_out_scale(
flattened_output *= b_scales flattened_output *= b_scales
return flattened_output.view(orig_shape) return flattened_output.view(orig_shape)
else: else:
b_scales = scales.view(scales.shape[:-3] + (-1, )).expand( b_scales = scales.view(scales.shape[:-3] + (-1,)).expand(-1, weights.shape[1])
-1, weights.shape[1])
weights *= b_scales weights *= b_scales
return F.linear(input, weights, bias) return F.linear(input, weights, bias)
def dequant_weight_scale( def dequant_weight_scale(
input: torch.Tensor, # [..., in_features] # [..., in_features]
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] input: torch.Tensor,
codebooks: torch. # [num_out_groups, num_in_groups, num_codebooks]
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] codes: torch.IntTensor,
scales: torch.Tensor, # [num_out_groups, 1, 1, 1] # [num_codebooks, codebook_size, out_group_size, in_group_size]
codebooks: torch.Tensor,
# [num_out_groups, 1, 1, 1]
scales: torch.Tensor,
output_partition_sizes: torch.IntTensor, output_partition_sizes: torch.IntTensor,
bias: Optional[torch.Tensor], bias: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
b_scales = scales.view(scales.shape[:-3] + (-1, )).expand( b_scales = scales.view(scales.shape[:-3] + (-1,)).expand(-1, weights.shape[1])
-1, weights.shape[1])
weights *= b_scales weights *= b_scales
return F.linear(input, weights, bias) return F.linear(input, weights, bias)
def dequant_no_scale( def dequant_no_scale(
input: torch.Tensor, # [..., in_features] # [..., in_features]
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] input: torch.Tensor,
codebooks: torch. # [num_out_groups, num_in_groups, num_codebooks]
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] codes: torch.IntTensor,
scales: torch.Tensor, # [num_out_groups, 1, 1, 1] # [num_codebooks, codebook_size, out_group_size, in_group_size]
codebooks: torch.Tensor,
# [num_out_groups, 1, 1, 1]
scales: torch.Tensor,
output_partition_sizes: torch.IntTensor, output_partition_sizes: torch.IntTensor,
bias: Optional[torch.Tensor], bias: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
return F.linear(input, weights, bias) return F.linear(input, weights, bias)
@ -89,23 +98,26 @@ def dequant_no_scale(
# the generic pytorch version. # the generic pytorch version.
# Just visual comparison. # Just visual comparison.
def dequant_test(k: int, parts: torch.Tensor, nbooks: int, bits: int) -> None: def dequant_test(k: int, parts: torch.Tensor, nbooks: int, bits: int) -> None:
n = int(parts.sum().item()) n = int(parts.sum().item())
device = torch.device('cuda:0') device = torch.device("cuda:0")
code_range = (1 << bits) // 2 code_range = (1 << bits) // 2
ingroups = 8 ingroups = 8
codes = torch.randint(-code_range, codes = torch.randint(
code_range, -code_range,
size=(n, k // ingroups, nbooks), code_range,
dtype=get_int_dtype(bits), size=(n, k // ingroups, nbooks),
device=device) dtype=get_int_dtype(bits),
device=device,
)
codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8), codebooks = torch.randn(
dtype=torch.float16, size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
device=device) dtype=torch.float16,
device=device,
)
count = 0 count = 0
for index in range(16): for index in range(16):
@ -138,24 +150,25 @@ def dequant_test(k: int, parts: torch.Tensor, nbooks: int, bits: int) -> None:
def main(): def main():
parser = FlexibleArgumentParser(description="Benchmark aqlm performance.") parser = FlexibleArgumentParser(description="Benchmark aqlm performance.")
# Add arguments # Add arguments
parser.add_argument("--nbooks", parser.add_argument(
type=int, "--nbooks", type=int, default=1, help="Number of codebooks (default: 1)"
default=1, )
help="Number of codebooks (default: 1)") parser.add_argument(
parser.add_argument("--bits", "--bits",
type=int, type=int,
default=16, default=16,
help="Number of bits per code element (default: 16)") help="Number of bits per code element (default: 16)",
)
parser.add_argument( parser.add_argument(
"--test", "--test",
type=bool, type=bool,
default=False, default=False,
help="Run the decompression/dequant tester rather than benchmarking " help="Run the decompression/dequant tester rather than benchmarking "
"(default: False)") "(default: False)",
)
# Parse the arguments # Parse the arguments
args = parser.parse_args() args = parser.parse_args()
@ -165,7 +178,7 @@ def main():
bits = args.bits bits = args.bits
if args.test: if args.test:
dequant_test(4096, torch.tensor((4096, )), nbooks, bits) dequant_test(4096, torch.tensor((4096,)), nbooks, bits)
return return
# Otherwise, benchmark. # Otherwise, benchmark.
@ -184,31 +197,54 @@ def main():
with open(filename, "w") as f: with open(filename, "w") as f:
sys.stdout = f sys.stdout = f
print('m | k | n | n parts', end='') print("m | k | n | n parts", end="")
for method in methods: for method in methods:
print(f" | {method.__name__.replace('_', ' ')} (µs)", end='') print(f" | {method.__name__.replace('_', ' ')} (µs)", end="")
print('') print("")
# These are reasonable prefill sizes. # These are reasonable prefill sizes.
ksandpartions = ((4096, (4096, 4096, 4096)), (4096, (4096, )), ksandpartions = (
(4096, (11008, 11008)), (11008, (4096, ))) (4096, (4096, 4096, 4096)),
(4096, (4096,)),
(4096, (11008, 11008)),
(11008, (4096,)),
)
# reasonable ranges for m. # reasonable ranges for m.
for m in [ for m in [
1, 2, 4, 8, 10, 12, 14, 16, 24, 32, 48, 52, 56, 64, 96, 112, 1,
128, 256, 512, 1024, 1536, 2048, 3072, 4096 2,
4,
8,
10,
12,
14,
16,
24,
32,
48,
52,
56,
64,
96,
112,
128,
256,
512,
1024,
1536,
2048,
3072,
4096,
]: ]:
print(f'{m}', file=sys.__stdout__) print(f"{m}", file=sys.__stdout__)
for ksp in ksandpartions: for ksp in ksandpartions:
run_grid(m, ksp[0], torch.tensor(ksp[1]), nbooks, bits, run_grid(m, ksp[0], torch.tensor(ksp[1]), nbooks, bits, methods)
methods)
sys.stdout = sys.__stdout__ sys.stdout = sys.__stdout__
def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, methods):
methods):
# I didn't see visible improvements from increasing these, but feel free :) # I didn't see visible improvements from increasing these, but feel free :)
num_warmup_trials = 1 num_warmup_trials = 1
num_trials = 1 num_trials = 1
@ -229,7 +265,7 @@ def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int,
) )
n = parts.sum().item() n = parts.sum().item()
print(f'{m} | {k} | {n} | {parts.tolist()}', end='') print(f"{m} | {k} | {n} | {parts.tolist()}", end="")
for method in methods: for method in methods:
best_time_us = 1e20 best_time_us = 1e20
@ -249,32 +285,36 @@ def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int,
if kernel_dur_us < best_time_us: if kernel_dur_us < best_time_us:
best_time_us = kernel_dur_us best_time_us = kernel_dur_us
print(f' | {kernel_dur_us:.0f}', end='') print(f" | {kernel_dur_us:.0f}", end="")
print('') print("")
def run_timing(num_calls: int, m: int, k: int, parts: torch.Tensor, def run_timing(
nbooks: int, bits: int, method) -> float: num_calls: int, m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, method
) -> float:
n = int(parts.sum().item()) n = int(parts.sum().item())
device = torch.device('cuda:0') device = torch.device("cuda:0")
input = torch.randn((1, m, k), dtype=torch.float16, device=device) input = torch.randn((1, m, k), dtype=torch.float16, device=device)
code_range = (1 << bits) // 2 code_range = (1 << bits) // 2
ingroups = 8 ingroups = 8
codes = torch.randint(-code_range, codes = torch.randint(
code_range, -code_range,
size=(n, k // ingroups, nbooks), code_range,
dtype=get_int_dtype(bits), size=(n, k // ingroups, nbooks),
device=device) dtype=get_int_dtype(bits),
device=device,
)
codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8), codebooks = torch.randn(
dtype=torch.float16, size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
device=device) dtype=torch.float16,
device=device,
)
scales = torch.randn(size=(n, 1, 1, 1), dtype=torch.float16, device=device) scales = torch.randn(size=(n, 1, 1, 1), dtype=torch.float16, device=device)

View File

@ -3,27 +3,33 @@
# Licensed under the MIT License. # Licensed under the MIT License.
from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
MINIMUM_BITBLAS_VERSION) MINIMUM_BITBLAS_VERSION,
)
try: try:
import bitblas import bitblas
if bitblas.__version__ < MINIMUM_BITBLAS_VERSION: if bitblas.__version__ < MINIMUM_BITBLAS_VERSION:
raise ImportError("bitblas version is wrong. Please " raise ImportError(
f"install bitblas>={MINIMUM_BITBLAS_VERSION}") "bitblas version is wrong. Please "
f"install bitblas>={MINIMUM_BITBLAS_VERSION}"
)
except ImportError as e: except ImportError as e:
bitblas_import_exception = e bitblas_import_exception = e
raise ValueError("Trying to use the bitblas backend, but could not import" raise ValueError(
f"with the following error: {bitblas_import_exception}. " "Trying to use the bitblas backend, but could not import"
"Please install bitblas through the following command: " f"with the following error: {bitblas_import_exception}. "
f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`" "Please install bitblas through the following command: "
) from bitblas_import_exception f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`"
) from bitblas_import_exception
from bitblas import Matmul, MatmulConfig, auto_detect_nvidia_target from bitblas import Matmul, MatmulConfig, auto_detect_nvidia_target
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description="Benchmark BitBLAS int4 on a specific target.") description="Benchmark BitBLAS int4 on a specific target."
)
# Add arguments to the parser # Add arguments to the parser
parser.add_argument( parser.add_argument(
@ -32,10 +38,9 @@ parser.add_argument(
default=auto_detect_nvidia_target(), default=auto_detect_nvidia_target(),
help="Specify the target device for benchmarking.", help="Specify the target device for benchmarking.",
) )
parser.add_argument("--group_size", parser.add_argument(
type=int, "--group_size", type=int, default=None, help="Group size for grouped quantization."
default=None, )
help="Group size for grouped quantization.")
parser.add_argument( parser.add_argument(
"--A_dtype", "--A_dtype",
type=str, type=str,
@ -82,17 +87,17 @@ parser.add_argument(
choices=["nt", "nn"], choices=["nt", "nn"],
help="Matrix layout, 'nt' for non-transpose A and transpose W.", help="Matrix layout, 'nt' for non-transpose A and transpose W.",
) )
parser.add_argument("--with_bias", parser.add_argument(
action="store_true", "--with_bias", action="store_true", help="Include bias in the benchmark."
help="Include bias in the benchmark.") )
parser.add_argument( parser.add_argument(
"--with_scaling", "--with_scaling",
action="store_true", action="store_true",
help="Include scaling factor in the quantization.", help="Include scaling factor in the quantization.",
) )
parser.add_argument("--with_zeros", parser.add_argument(
action="store_true", "--with_zeros", action="store_true", help="Include zeros in the quantization."
help="Include zeros in the quantization.") )
parser.add_argument( parser.add_argument(
"--zeros_mode", "--zeros_mode",
type=str, type=str,
@ -170,8 +175,7 @@ shapes = [
] ]
# Build test shapes with all the shared arguments # Build test shapes with all the shared arguments
test_shapes = [(MatmulConfig, Matmul, (*shape, *shared_args)) test_shapes = [(MatmulConfig, Matmul, (*shape, *shared_args)) for shape in shapes]
for shape in shapes]
benchmark_sets = [] benchmark_sets = []
benchmark_sets.extend(test_shapes) benchmark_sets.extend(test_shapes)
@ -206,12 +210,12 @@ for config_key, values in benchmark_results.items():
func_name = args_split[0] func_name = args_split[0]
input_args_str = "-".join(args_split[1:]) input_args_str = "-".join(args_split[1:])
col_widths[0] = max(col_widths[0], len(func_name) + 2, len(headers[0]) + 2) col_widths[0] = max(col_widths[0], len(func_name) + 2, len(headers[0]) + 2)
col_widths[1] = max(col_widths[1], col_widths[1] = max(col_widths[1], len(input_args_str) + 2, len(headers[1]) + 2)
len(input_args_str) + 2, col_widths[2] = max(
len(headers[1]) + 2) col_widths[2],
col_widths[2] = max(col_widths[2], len(f"{values['BitBLAS_top20_latency']:.3f} ms") + 2,
len(f"{values['BitBLAS_top20_latency']:.3f} ms") + 2, len(headers[2]) + 2,
len(headers[2]) + 2) )
# break only if you want to measure widths from a single example; # break only if you want to measure widths from a single example;
# otherwise, let it loop over all items. # otherwise, let it loop over all items.
@ -232,5 +236,6 @@ for config_key, values in benchmark_results.items():
f"{values['BitBLAS_top20_latency']:.3f} ms", f"{values['BitBLAS_top20_latency']:.3f} ms",
] ]
row_str = "".join( row_str = "".join(
[str(cell).ljust(col_widths[idx]) for idx, cell in enumerate(row)]) [str(cell).ljust(col_widths[idx]) for idx, cell in enumerate(row)]
)
print(row_str) print(row_str)

View File

@ -0,0 +1,489 @@
# SPDX-License-Identifier: Apache-2.0
"""
Benchmark the performance of the cutlass_moe_fp4 kernel vs the triton_moe
kernel. The cutlass_moe_fp4 kernel takes in fp4 quantized weights and 16-bit
activations. The triton_moe kernel takes in fp8 weights(tensor scaled to fp8)
and 16-bit activations.
"""
import nvtx
import torch
import torch.utils.benchmark as benchmark
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
from vllm.scalar_type import scalar_types
from vllm.utils import FlexibleArgumentParser
WEIGHT_SHAPES_MOE = {
"nvidia/DeepSeek-R1-FP4": [
[256, 8, 2048, 7168],
],
}
DEFAULT_MODELS = [
"nvidia/DeepSeek-R1-FP4",
]
DEFAULT_BATCH_SIZES = [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
DEFAULT_TP_SIZES = [1]
PER_ACT_TOKEN_OPTS = [False]
PER_OUT_CH_OPTS = [False]
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
def to_fp8(tensor: torch.Tensor):
finfo = torch.finfo(torch.float8_e4m3fn)
return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to(
dtype=torch.float8_e4m3fn
)
def bench_run(
results: list[benchmark.Measurement],
model: str,
num_experts: int,
topk: int,
per_act_token: bool,
per_out_ch: bool,
mkn: tuple[int, int, int],
):
label = "NVFP4 Blockscaled CUTLASS MOE vs FP8 Tensor Scaled Triton"
sub_label = (
"{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, MKN=({})".format(
model, num_experts, topk, per_act_token, per_out_ch, mkn
)
)
print(f"Testing: {sub_label}")
(m, k, n) = mkn
dtype = torch.half
device = "cuda"
a = torch.randn((m, k), device=device, dtype=dtype) / 10
w1 = torch.randn((num_experts, 2 * n, k), device=device, dtype=dtype) / 10
w2 = torch.randn((num_experts, k, n), device=device, dtype=dtype) / 10
_, a_fp8_scale = ops.scaled_fp8_quant(a)
w1_fp8q = torch.empty(
(num_experts, 2 * n, k), device=device, dtype=torch.float8_e4m3fn
)
w2_fp8q = torch.empty((num_experts, k, n), device=device, dtype=torch.float8_e4m3fn)
w1_fp8scale = torch.empty((num_experts, 1, 1), device=device, dtype=torch.float32)
w2_fp8scale = torch.empty((num_experts, 1, 1), device=device, dtype=torch.float32)
for expert in range(num_experts):
w1_fp8q[expert], w1_fp8scale[expert] = ops.scaled_fp8_quant(w1[expert])
w2_fp8q[expert], w2_fp8scale[expert] = ops.scaled_fp8_quant(w2[expert])
w1_fp8q_notransp = w1_fp8q.clone()
w2_fp8q_notransp = w2_fp8q.clone()
w1_fp8q = w1_fp8q.transpose(1, 2)
w2_fp8q = w2_fp8q.transpose(1, 2)
score = torch.randn((m, num_experts), device=device, dtype=dtype)
topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False)
quant_blocksize = 16
w1_blockscale = torch.empty(
(num_experts, 2 * n, k // quant_blocksize),
device=device,
dtype=torch.float8_e4m3fn,
)
w2_blockscale = torch.empty(
(num_experts, k, n // quant_blocksize), device=device, dtype=torch.float8_e4m3fn
)
# n_b_scales = 2 * n if per_out_ch else 1
# k_b_scales = k if per_out_ch else 1
w1_fp4 = torch.empty((num_experts, 2 * n, k // 2), device=device, dtype=torch.uint8)
w2_fp4 = torch.empty((num_experts, k, n // 2), device=device, dtype=torch.uint8)
w1_gs = torch.empty((num_experts,), device=device, dtype=torch.float32)
w2_gs = torch.empty((num_experts,), device=device, dtype=torch.float32)
a1_gs = torch.ones((num_experts,), device=device, dtype=torch.float32)
a2_gs = torch.ones((num_experts,), device=device, dtype=torch.float32)
for expert in range(num_experts):
w1_e = w1[expert]
w2_e = w2[expert]
w1_amax = torch.abs(w1_e).max().to(torch.float32)
w2_amax = torch.abs(w2_e).max().to(torch.float32)
w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax
w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax
w1_fp4[expert], w1_blockscale[expert] = ops.scaled_fp4_quant(
w1_e, w1_gs[expert]
)
w2_fp4[expert], w2_blockscale[expert] = ops.scaled_fp4_quant(
w2_e, w2_gs[expert]
)
def run_triton_moe(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
a_fp8_scale: torch.Tensor,
num_repeats: int,
):
for _ in range(num_repeats):
fused_experts(
a,
w1,
w2,
topk_weights,
topk_ids,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_fp8_scale,
)
def run_cutlass_moe_fp4(
a: torch.Tensor,
w1_fp4: torch.Tensor,
w2_fp4: torch.Tensor,
w1_blockscale: torch.Tensor,
w2_blockscale: torch.Tensor,
w1_gs: torch.Tensor,
w2_gs: torch.Tensor,
a1_gs: torch.Tensor,
a2_gs: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
m: int,
n: int,
k: int,
e: int,
device: torch.device,
num_repeats: int,
):
for _ in range(num_repeats):
with nvtx.annotate("cutlass_moe_fp4", color="green"):
cutlass_moe_fp4(
a=a,
a1_gscale=a1_gs,
a2_gscale=a2_gs,
w1_fp4=w1_fp4,
w1_blockscale=w1_blockscale,
w1_alphas=w1_gs,
w2_fp4=w2_fp4,
w2_blockscale=w2_blockscale,
w2_alphas=w2_gs,
topk_weights=topk_weights,
topk_ids=topk_ids,
m=m,
n=n,
k=k,
e=num_experts,
device=device,
)
def run_cutlass_from_graph(
a: torch.Tensor,
a1_gscale: torch.Tensor,
w1_fp4: torch.Tensor,
w1_blockscale: torch.Tensor,
w1_alphas: torch.Tensor,
a2_gscale: torch.Tensor,
w2_fp4: torch.Tensor,
w2_blockscale: torch.Tensor,
w2_alphas: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
m: int,
n: int,
k: int,
e: int,
device: torch.device,
):
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
):
return cutlass_moe_fp4(
a=a,
a1_gscale=a1_gs,
w1_fp4=w1_fp4,
w1_blockscale=w1_blockscale,
w1_alphas=w1_alphas,
a2_gscale=a2_gs,
w2_fp4=w2_fp4,
w2_blockscale=w2_blockscale,
w2_alphas=w2_alphas,
topk_weights=topk_weights,
topk_ids=topk_ids,
m=m,
n=n,
k=k,
e=num_experts,
device=device,
)
def run_triton_from_graph(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
a_fp8_scale: torch.Tensor,
):
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
):
return fused_experts(
a,
w1,
w2,
topk_weights,
topk_ids,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_fp8_scale,
)
def replay_graph(graph, num_repeats):
for _ in range(num_repeats):
graph.replay()
torch.cuda.synchronize()
cutlass_stream = torch.cuda.Stream()
cutlass_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(cutlass_graph, stream=cutlass_stream):
run_cutlass_from_graph(
a=a,
a1_gscale=a1_gs,
w1_fp4=w1_fp4,
w1_blockscale=w1_blockscale,
w1_alphas=w1_gs,
a2_gscale=a2_gs,
w2_fp4=w2_fp4,
w2_blockscale=w2_blockscale,
w2_alphas=w2_gs,
topk_weights=topk_weights,
topk_ids=topk_ids,
m=m,
n=n,
k=k,
e=num_experts,
device=device,
)
torch.cuda.synchronize()
triton_stream = torch.cuda.Stream()
triton_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(triton_graph, stream=triton_stream):
run_triton_from_graph(
a,
w1_fp8q_notransp,
w2_fp8q_notransp,
topk_weights,
topk_ids,
w1_fp8scale,
w2_fp8scale,
a_fp8_scale,
)
torch.cuda.synchronize()
min_run_time = 5
num_warmup = 5
num_runs = 25
globals = {
# Baseline params
"w1": w1,
"w2": w2,
"score": score,
"topk": topk,
"w1_fp8q_notransp": w1_fp8q_notransp,
"w2_fp8q_notransp": w2_fp8q_notransp,
"w1_fp8scale": w1_fp8scale,
"w2_fp8scale": w2_fp8scale,
"a_fp8_scale": a_fp8_scale,
# Cutlass params
"a": a,
"a1_gscale": a1_gs,
"w1_fp4": w1_fp4,
"w1_blockscale": w1_blockscale,
"w1_alphas": w1_gs,
"a2_gscale": a2_gs,
"w2_fp4": w2_fp4,
"w2_blockscale": w2_blockscale,
"w2_alphas": w2_gs,
"topk_weights": topk_weights,
"topk_ids": topk_ids,
"m": m,
"n": n,
"k": k,
"e": num_experts,
"device": device,
# cuda graph params
"cutlass_graph": cutlass_graph,
"triton_graph": triton_graph,
# Gen params
"num_runs": num_runs,
# Kernels
"run_triton_moe": run_triton_moe,
"run_cutlass_moe_fp4": run_cutlass_moe_fp4,
"replay_graph": replay_graph,
}
# Warmup
run_triton_moe(
a,
w1_fp8q_notransp,
w2_fp8q_notransp,
topk_weights,
topk_ids,
w1_fp8scale,
w2_fp8scale,
a_fp8_scale,
num_warmup,
)
results.append(
benchmark.Timer(
stmt="run_triton_moe(a, w1_fp8q_notransp, w2_fp8q_notransp, topk_weights, topk_ids, w1_fp8scale, w2_fp8scale, a_fp8_scale, num_runs)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="triton_moe",
).blocked_autorange(min_run_time=min_run_time)
)
# Warmup
replay_graph(triton_graph, num_warmup)
results.append(
benchmark.Timer(
stmt="replay_graph(triton_graph, num_runs)",
globals=globals,
label=label,
sub_label=sub_label,
description="triton_moe_cuda_graphs",
).blocked_autorange(min_run_time=min_run_time)
)
# Warmup
run_cutlass_moe_fp4(
a,
w1_fp4,
w2_fp4,
w1_blockscale,
w2_blockscale,
w1_gs,
w2_gs,
a1_gs,
a2_gs,
topk_weights,
topk_ids,
m,
n,
k,
num_experts,
device,
num_warmup,
)
results.append(
benchmark.Timer(
stmt="run_cutlass_moe_fp4(a, w1_fp4, w2_fp4, w1_blockscale, w2_blockscale, w1_alphas, w2_alphas, a1_gscale, a2_gscale, topk_weights, topk_ids, m, n, k, e, device, num_runs)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="cutlass_moe_fp4",
).blocked_autorange(min_run_time=min_run_time)
)
# Warmup
replay_graph(cutlass_graph, num_warmup)
results.append(
benchmark.Timer(
stmt="replay_graph(cutlass_graph, num_runs)",
globals=globals,
label=label,
sub_label=sub_label,
description="cutlass_moe_fp4_cuda_graphs",
).blocked_autorange(min_run_time=min_run_time)
)
def main(args):
print("Benchmarking models:")
for i, model in enumerate(args.models):
print(f"[{i}] {model}")
results: list[benchmark.Measurement] = []
for model in args.models:
for tp in args.tp_sizes:
for layer in WEIGHT_SHAPES_MOE[model]:
num_experts = layer[0]
topk = layer[1]
size_k = layer[2]
size_n = layer[3] // tp
if len(args.limit_k) > 0 and size_k not in args.limit_k:
continue
if len(args.limit_n) > 0 and size_n not in args.limit_n:
continue
for per_act_token in PER_ACT_TOKEN_OPTS:
for per_out_ch in PER_OUT_CH_OPTS:
for size_m in args.batch_sizes:
mkn = (size_m, size_k, size_n)
bench_run(
results,
model,
num_experts,
topk,
per_act_token,
per_out_ch,
mkn,
)
compare = benchmark.Compare(results)
compare.print()
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description="Benchmark NVFP4 CUTLASS MOE across specified models/shapes/batches"
)
parser.add_argument(
"--models",
nargs="+",
type=str,
default=DEFAULT_MODELS,
choices=WEIGHT_SHAPES_MOE.keys(),
)
parser.add_argument("--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES)
parser.add_argument(
"--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
)
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[])
parser.add_argument("--limit-per-act-token", nargs="+", type=int, default=[])
parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[])
args = parser.parse_args()
main(args)

View File

@ -6,14 +6,18 @@ from benchmark_shapes import WEIGHT_SHAPES_MOE
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe_fp8, from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts, cutlass_moe_fp8,
fused_topk) fused_experts,
fused_topk,
)
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
DEFAULT_MODELS = [ DEFAULT_MODELS = [
"nm-testing/Mixtral-8x7B-Instruct-v0.1", "nm-testing/deepseekv2-lite", "nm-testing/Mixtral-8x7B-Instruct-v0.1",
"ibm-granite/granite-3.0-1b-a400m", "ibm-granite/granite-3.0-3b-a800m" "nm-testing/deepseekv2-lite",
"ibm-granite/granite-3.0-1b-a400m",
"ibm-granite/granite-3.0-3b-a800m",
] ]
DEFAULT_BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128, 256, 512] DEFAULT_BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128, 256, 512]
DEFAULT_TP_SIZES = [1] DEFAULT_TP_SIZES = [1]
@ -24,19 +28,27 @@ PER_OUT_CH_OPTS = [False]
def to_fp8(tensor: torch.Tensor): def to_fp8(tensor: torch.Tensor):
finfo = torch.finfo(torch.float8_e4m3fn) finfo = torch.finfo(torch.float8_e4m3fn)
return torch.round(tensor.clamp( return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to(
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) dtype=torch.float8_e4m3fn
)
def bench_run(results: list[benchmark.Measurement], model: str, def bench_run(
num_experts: int, topk: int, per_act_token: bool, results: list[benchmark.Measurement],
per_out_ch: bool, mkn: tuple[int, int, int]): model: str,
num_experts: int,
topk: int,
per_act_token: bool,
per_out_ch: bool,
mkn: tuple[int, int, int],
):
label = "Quant Matmul" label = "Quant Matmul"
sub_label = ( sub_label = (
"{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, " "{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, MKN=({})".format(
"MKN=({})".format(model, num_experts, topk, per_act_token, per_out_ch, model, num_experts, topk, per_act_token, per_out_ch, mkn
mkn)) )
)
print(f"Testing: {sub_label}") print(f"Testing: {sub_label}")
@ -50,35 +62,17 @@ def bench_run(results: list[benchmark.Measurement], model: str,
_, a_scale = ops.scaled_fp8_quant(a) _, a_scale = ops.scaled_fp8_quant(a)
w1_q = torch.empty((num_experts, 2 * n, k), w1_q = torch.empty(
device="cuda", (num_experts, 2 * n, k), device="cuda", dtype=torch.float8_e4m3fn
dtype=torch.float8_e4m3fn) )
w2_q = torch.empty((num_experts, k, n), w2_q = torch.empty((num_experts, k, n), device="cuda", dtype=torch.float8_e4m3fn)
device="cuda", w1_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32)
dtype=torch.float8_e4m3fn) w2_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32)
w1_scale = torch.empty((num_experts, 1, 1),
device="cuda",
dtype=torch.float32)
w2_scale = torch.empty((num_experts, 1, 1),
device="cuda",
dtype=torch.float32)
ab_strides1 = torch.full((num_experts, ), ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
k, c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
device="cuda", ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
dtype=torch.int64) c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((num_experts, ),
2 * n,
device="cuda",
dtype=torch.int64)
ab_strides2 = torch.full((num_experts, ),
n,
device="cuda",
dtype=torch.int64)
c_strides2 = torch.full((num_experts, ),
k,
device="cuda",
dtype=torch.int64)
for expert in range(num_experts): for expert in range(num_experts):
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert]) w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert])
@ -90,82 +84,121 @@ def bench_run(results: list[benchmark.Measurement], model: str,
score = torch.randn((m, num_experts), device="cuda", dtype=dtype) score = torch.randn((m, num_experts), device="cuda", dtype=dtype)
topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) topk_weights, topk_ids, token_expert_indices = fused_topk(
a, score, topk, renormalize=False
)
def run_triton_moe(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, def run_triton_moe(
topk_weights: torch.Tensor, topk_ids: torch.Tensor, a: torch.Tensor,
w1_scale: torch.Tensor, w2_scale: torch.Tensor, w1: torch.Tensor,
a_scale: torch.Tensor, num_repeats: int): w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
a_scale: torch.Tensor,
num_repeats: int,
):
for _ in range(num_repeats): for _ in range(num_repeats):
fused_experts(a, fused_experts(
w1, a,
w2, w1,
topk_weights, w2,
topk_ids, topk_weights,
use_fp8_w8a8=True, topk_ids,
w1_scale=w1_scale, use_fp8_w8a8=True,
w2_scale=w2_scale, w1_scale=w1_scale,
a1_scale=a_scale) w2_scale=w2_scale,
a1_scale=a_scale,
)
def run_cutlass_moe(a: torch.Tensor, a_scale: torch.Tensor, def run_cutlass_moe(
w1: torch.Tensor, w2: torch.Tensor, a: torch.Tensor,
w1_scale: torch.Tensor, w2_scale: torch.Tensor, a_scale: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor, w1: torch.Tensor,
ab_strides1: torch.Tensor, c_strides1: torch.Tensor, w2: torch.Tensor,
ab_strides2: torch.Tensor, c_strides2: torch.Tensor, w1_scale: torch.Tensor,
num_repeats: int): w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
ab_strides1: torch.Tensor,
c_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides2: torch.Tensor,
num_repeats: int,
):
for _ in range(num_repeats): for _ in range(num_repeats):
cutlass_moe_fp8(a, cutlass_moe_fp8(
w1, a,
w2, w1,
w1_scale, w2,
w2_scale, w1_scale,
topk_weights, w2_scale,
topk_ids, topk_weights,
ab_strides1, topk_ids,
c_strides1, ab_strides1,
ab_strides2, c_strides1,
c_strides2, ab_strides2,
a1_scale=a_scale) c_strides2,
a1_scale=a_scale,
)
def run_cutlass_from_graph( def run_cutlass_from_graph(
a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, a: torch.Tensor,
w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, a_scale: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor, w1_q: torch.Tensor,
ab_strides1: torch.Tensor, c_strides1: torch.Tensor, w2_q: torch.Tensor,
ab_strides2: torch.Tensor, c_strides2: torch.Tensor): w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
ab_strides1: torch.Tensor,
c_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides2: torch.Tensor,
):
with set_current_vllm_config( with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig( VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
pipeline_parallel_size=1))): ):
return cutlass_moe_fp8(a, return cutlass_moe_fp8(
w1_q, a,
w2_q, w1_q,
w1_scale, w2_q,
w2_scale, w1_scale,
topk_weights, w2_scale,
topk_ids, topk_weights,
ab_strides1, topk_ids,
c_strides1, ab_strides1,
ab_strides2, c_strides1,
c_strides2, ab_strides2,
a1_scale=a_scale) c_strides2,
a1_scale=a_scale,
)
def run_triton_from_graph(a: torch.Tensor, w1: torch.Tensor, def run_triton_from_graph(
w2: torch.Tensor, topk_weights: torch.Tensor, a: torch.Tensor,
topk_ids: torch.Tensor, w1_scale: torch.Tensor, w1: torch.Tensor,
w2_scale: torch.Tensor, a_scale: torch.Tensor): w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
a_scale: torch.Tensor,
):
with set_current_vllm_config( with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig( VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
pipeline_parallel_size=1))): ):
return fused_experts(a, return fused_experts(
w1, a,
w2, w1,
topk_weights, w2,
topk_ids, topk_weights,
use_fp8_w8a8=True, topk_ids,
w1_scale=w1_scale, use_fp8_w8a8=True,
w2_scale=w2_scale, w1_scale=w1_scale,
a1_scale=a_scale) w2_scale=w2_scale,
a1_scale=a_scale,
)
def replay_graph(graph, num_repeats): def replay_graph(graph, num_repeats):
for _ in range(num_repeats): for _ in range(num_repeats):
@ -175,16 +208,35 @@ def bench_run(results: list[benchmark.Measurement], model: str,
cutlass_stream = torch.cuda.Stream() cutlass_stream = torch.cuda.Stream()
cutlass_graph = torch.cuda.CUDAGraph() cutlass_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(cutlass_graph, stream=cutlass_stream): with torch.cuda.graph(cutlass_graph, stream=cutlass_stream):
run_cutlass_from_graph(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, run_cutlass_from_graph(
topk_weights, topk_ids, ab_strides1, c_strides1, a,
ab_strides2, c_strides2) a_scale,
w1_q,
w2_q,
w1_scale,
w2_scale,
topk_weights,
topk_ids,
ab_strides1,
c_strides1,
ab_strides2,
c_strides2,
)
torch.cuda.synchronize() torch.cuda.synchronize()
triton_stream = torch.cuda.Stream() triton_stream = torch.cuda.Stream()
triton_graph = torch.cuda.CUDAGraph() triton_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(triton_graph, stream=triton_stream): with torch.cuda.graph(triton_graph, stream=triton_stream):
run_triton_from_graph(a, w1_q_notransp, w2_q_notransp, topk_weights, run_triton_from_graph(
topk_ids, w1_scale, w2_scale, a_scale) a,
w1_q_notransp,
w2_q_notransp,
topk_weights,
topk_ids,
w1_scale,
w2_scale,
a_scale,
)
torch.cuda.synchronize() torch.cuda.synchronize()
min_run_time = 5 min_run_time = 5
@ -224,18 +276,27 @@ def bench_run(results: list[benchmark.Measurement], model: str,
} }
# Warmup # Warmup
run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, run_triton_moe(
w1_scale, w2_scale, a_scale, num_warmup) a,
w1_q_notransp,
w2_q_notransp,
topk_weights,
topk_ids,
w1_scale,
w2_scale,
a_scale,
num_warmup,
)
results.append( results.append(
benchmark.Timer( benchmark.Timer(
stmt= stmt="run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501
"run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501
globals=globals, globals=globals,
label=label, label=label,
sub_label=sub_label, sub_label=sub_label,
description="triton_moe", description="triton_moe",
).blocked_autorange(min_run_time=min_run_time)) ).blocked_autorange(min_run_time=min_run_time)
)
# Warmup # Warmup
replay_graph(triton_graph, num_warmup) replay_graph(triton_graph, num_warmup)
@ -247,22 +308,35 @@ def bench_run(results: list[benchmark.Measurement], model: str,
label=label, label=label,
sub_label=sub_label, sub_label=sub_label,
description="triton_moe_cuda_graphs", description="triton_moe_cuda_graphs",
).blocked_autorange(min_run_time=min_run_time)) ).blocked_autorange(min_run_time=min_run_time)
)
# Warmup # Warmup
run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, run_cutlass_moe(
topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, a,
num_warmup) a_scale,
w1_q,
w2_q,
w1_scale,
w2_scale,
topk_weights,
topk_ids,
ab_strides1,
c_strides1,
ab_strides2,
c_strides2,
num_warmup,
)
results.append( results.append(
benchmark.Timer( benchmark.Timer(
stmt= stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501
"run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501
globals=globals, globals=globals,
label=label, label=label,
sub_label=sub_label, sub_label=sub_label,
description="grouped_gemm_moe", description="grouped_gemm_moe",
).blocked_autorange(min_run_time=min_run_time)) ).blocked_autorange(min_run_time=min_run_time)
)
# Warmup # Warmup
replay_graph(cutlass_graph, num_warmup) replay_graph(cutlass_graph, num_warmup)
@ -274,7 +348,8 @@ def bench_run(results: list[benchmark.Measurement], model: str,
label=label, label=label,
sub_label=sub_label, sub_label=sub_label,
description="grouped_gemm_moe_cuda_graphs", description="grouped_gemm_moe_cuda_graphs",
).blocked_autorange(min_run_time=min_run_time)) ).blocked_autorange(min_run_time=min_run_time)
)
def main(args): def main(args):
@ -302,8 +377,15 @@ def main(args):
for per_out_ch in PER_OUT_CH_OPTS: for per_out_ch in PER_OUT_CH_OPTS:
for size_m in DEFAULT_BATCH_SIZES: for size_m in DEFAULT_BATCH_SIZES:
mkn = (size_m, size_k, size_n) mkn = (size_m, size_k, size_n)
bench_run(results, model, num_experts, topk, bench_run(
per_act_token, per_out_ch, mkn) results,
model,
num_experts,
topk,
per_act_token,
per_out_ch,
mkn,
)
compare = benchmark.Compare(results) compare = benchmark.Compare(results)
compare.print() compare.print()
@ -311,7 +393,8 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description="Benchmark Marlin across specified models/shapes/batches") description="Benchmark Marlin across specified models/shapes/batches"
)
parser.add_argument( parser.add_argument(
"--models", "--models",
nargs="+", nargs="+",
@ -319,21 +402,14 @@ if __name__ == "__main__":
default=DEFAULT_MODELS, default=DEFAULT_MODELS,
choices=WEIGHT_SHAPES_MOE.keys(), choices=WEIGHT_SHAPES_MOE.keys(),
) )
parser.add_argument("--tp-sizes", parser.add_argument("--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES)
nargs="+", parser.add_argument(
type=int, "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
default=DEFAULT_TP_SIZES) )
parser.add_argument("--batch-sizes",
nargs="+",
type=int,
default=DEFAULT_BATCH_SIZES)
parser.add_argument("--limit-k", nargs="+", type=int, default=[]) parser.add_argument("--limit-k", nargs="+", type=int, default=[])
parser.add_argument("--limit-n", nargs="+", type=int, default=[]) parser.add_argument("--limit-n", nargs="+", type=int, default=[])
parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[]) parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[])
parser.add_argument("--limit-per-act-token", parser.add_argument("--limit-per-act-token", nargs="+", type=int, default=[])
nargs="+",
type=int,
default=[])
parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[]) parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[])
args = parser.parse_args() args = parser.parse_args()

View File

@ -10,14 +10,16 @@ from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
@torch.inference_mode() @torch.inference_mode()
def main(num_tokens: int, def main(
hidden_size: int, num_tokens: int,
add_residual: bool, hidden_size: int,
dtype: torch.dtype, add_residual: bool,
seed: int = 0, dtype: torch.dtype,
do_profile: bool = False, seed: int = 0,
num_warmup_iters: int = 5, do_profile: bool = False,
num_iters: int = 100) -> None: num_warmup_iters: int = 5,
num_iters: int = 100,
) -> None:
current_platform.seed_everything(seed) current_platform.seed_everything(seed)
torch.set_default_device("cuda") torch.set_default_device("cuda")
@ -56,33 +58,35 @@ def main(num_tokens: int,
print(f"Kernel running time: {latency * 1000000:.3f} us") print(f"Kernel running time: {latency * 1000000:.3f} us")
if __name__ == '__main__': if __name__ == "__main__":
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(description="Benchmark the layernorm kernel.")
description="Benchmark the layernorm kernel.")
parser.add_argument("--num-tokens", type=int, default=4096) parser.add_argument("--num-tokens", type=int, default=4096)
parser.add_argument("--hidden-size", type=int, default=8192) parser.add_argument("--hidden-size", type=int, default=8192)
parser.add_argument("--add-residual", action="store_true") parser.add_argument("--add-residual", action="store_true")
parser.add_argument("--dtype", parser.add_argument(
type=str, "--dtype", type=str, choices=["half", "bfloat16", "float"], default="half"
choices=["half", "bfloat16", "float"], )
default="half")
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--profile", action="store_true") parser.add_argument("--profile", action="store_true")
parser.add_argument("--num-warmup-iters", type=int, default=5) parser.add_argument("--num-warmup-iters", type=int, default=5)
parser.add_argument("--num-iters", parser.add_argument(
type=int, "--num-iters",
default=100, type=int,
help="Number of benchmark iterations. " default=100,
"If --profile is set, this number is ignored") help="Number of benchmark iterations. "
"If --profile is set, this number is ignored",
)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
main(num_tokens=args.num_tokens, main(
hidden_size=args.hidden_size, num_tokens=args.num_tokens,
add_residual=args.add_residual, hidden_size=args.hidden_size,
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], add_residual=args.add_residual,
seed=args.seed, dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
do_profile=args.profile, seed=args.seed,
num_warmup_iters=args.num_warmup_iters, do_profile=args.profile,
num_iters=args.num_iters) num_warmup_iters=args.num_warmup_iters,
num_iters=args.num_iters,
)

File diff suppressed because it is too large Load Diff

View File

@ -20,12 +20,18 @@ from weight_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, marlin_permute_scales, GPTQ_MARLIN_MAX_PARALLEL,
marlin_zero_points) GPTQ_MARLIN_MIN_THREAD_N,
marlin_permute_scales,
marlin_zero_points,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
MarlinWorkspace) MarlinWorkspace,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
pack_rows, quantize_weights) pack_rows,
quantize_weights,
)
from vllm.scalar_type import ScalarType, scalar_types from vllm.scalar_type import ScalarType, scalar_types
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
@ -82,12 +88,14 @@ def rand_data(shape, dtype=torch.float16, scale=1):
return torch.randint(-15, 15, shape, dtype=dtype, device="cuda") return torch.randint(-15, 15, shape, dtype=dtype, device="cuda")
def quantize_and_pack(atype: torch.dtype, def quantize_and_pack(
w: torch.Tensor, atype: torch.dtype,
wtype: ScalarType, w: torch.Tensor,
stype: Optional[torch.dtype], wtype: ScalarType,
group_size: Optional[int], stype: Optional[torch.dtype],
zero_points: bool = False): group_size: Optional[int],
zero_points: bool = False,
):
assert wtype.is_integer(), "TODO: support floating point weights" assert wtype.is_integer(), "TODO: support floating point weights"
w_ref, w_q, w_s, w_zp = quantize_weights( w_ref, w_q, w_s, w_zp = quantize_weights(
@ -96,21 +104,24 @@ def quantize_and_pack(atype: torch.dtype,
group_size=group_size, group_size=group_size,
zero_points=zero_points, zero_points=zero_points,
# to match how the kernel applies zps # to match how the kernel applies zps
ref_zero_points_after_scales=True) ref_zero_points_after_scales=True,
)
w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape) w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape)
return w_ref, w_q, w_s, w_zp return w_ref, w_q, w_s, w_zp
def create_bench_tensors(shape: tuple[int, int, int], types: TypeConfig, def create_bench_tensors(
group_size: Optional[int]) -> list[BenchmarkTensors]: shape: tuple[int, int, int], types: TypeConfig, group_size: Optional[int]
) -> list[BenchmarkTensors]:
m, n, k = shape m, n, k = shape
# we want to make sure that weights don't fit into L2 cache between runs so # we want to make sure that weights don't fit into L2 cache between runs so
# we construct enough weights to exceed L2 cache, which is 50mb on a H100 # we construct enough weights to exceed L2 cache, which is 50mb on a H100
# so we target total weight size > 2*50mb # so we target total weight size > 2*50mb
num_weights = math.ceil(2 * 50 * 1024**2 * 8 / num_weights = math.ceil(
(k * n * types.weight_type.size_bits)) 2 * 50 * 1024**2 * 8 / (k * n * types.weight_type.size_bits)
)
a = rand_data((m, k), types.act_type, scale=5) a = rand_data((m, k), types.act_type, scale=5)
@ -124,8 +135,13 @@ def create_bench_tensors(shape: tuple[int, int, int], types: TypeConfig,
w = w.to(torch.float16) w = w.to(torch.float16)
w_ref, w_q_packed, w_s, w_zp = quantize_and_pack( w_ref, w_q_packed, w_s, w_zp = quantize_and_pack(
a.dtype, w, types.weight_type, types.group_scale_type, group_size, a.dtype,
types.group_zero_type is not None) w,
types.weight_type,
types.group_scale_type,
group_size,
types.group_zero_type is not None,
)
if not a.dtype.is_floating_point: if not a.dtype.is_floating_point:
aiinfo = torch.iinfo(a.dtype) aiinfo = torch.iinfo(a.dtype)
@ -133,21 +149,30 @@ def create_bench_tensors(shape: tuple[int, int, int], types: TypeConfig,
w_ref = w_ref.to(torch.float32) w_ref = w_ref.to(torch.float32)
w_ch_s = None if types.channel_scale_type is None else\ w_ch_s = (
rand_data((n,), types.channel_scale_type) None
w_tok_s = None if types.token_scale_type is None else\ if types.channel_scale_type is None
rand_data((m,), types.token_scale_type) else rand_data((n,), types.channel_scale_type)
)
w_tok_s = (
None
if types.token_scale_type is None
else rand_data((m,), types.token_scale_type)
)
benchmark_tensors.append( benchmark_tensors.append(
BenchmarkTensors(w_ref=w_ref, BenchmarkTensors(
a=a, w_ref=w_ref,
w_q=w_q_packed, a=a,
wtype=types.weight_type, w_q=w_q_packed,
w_g_s=w_s, wtype=types.weight_type,
w_g_zp=w_zp, w_g_s=w_s,
group_size=group_size, w_g_zp=w_zp,
w_ch_s=w_ch_s, group_size=group_size,
w_tok_s=w_tok_s)) w_ch_s=w_ch_s,
w_tok_s=w_tok_s,
)
)
return benchmark_tensors return benchmark_tensors
@ -170,50 +195,57 @@ def cutlass_scaled_mm_create_bench_fn(bt: BenchmarkTensors) -> Callable:
scale_b = torch.tensor(1.0, dtype=torch.float32, device=bt.a.device) scale_b = torch.tensor(1.0, dtype=torch.float32, device=bt.a.device)
w_col_major = bt.w_ref.to(bt.a.dtype).t().contiguous().t() w_col_major = bt.w_ref.to(bt.a.dtype).t().contiguous().t()
return lambda: ops.cutlass_scaled_mm( return lambda: ops.cutlass_scaled_mm(
bt.a, w_col_major, scale_a, scale_b, out_dtype=torch.float16) bt.a, w_col_major, scale_a, scale_b, out_dtype=torch.float16
)
def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable: def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
device = bt.a.device device = bt.a.device
workspace = MarlinWorkspace(bt.w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N, workspace = MarlinWorkspace(
GPTQ_MARLIN_MAX_PARALLEL) bt.w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL
)
if bt.w_g_zp is None: if bt.w_g_zp is None:
w_zp = torch.empty(0, dtype=torch.int, device=device) w_zp = torch.empty(0, dtype=torch.int, device=device)
else: else:
w_zp = marlin_zero_points(bt.w_g_zp, bt.w_ref.shape[0], w_zp = marlin_zero_points(
bt.w_ref.shape[1], bt.wtype.size_bits) bt.w_g_zp, bt.w_ref.shape[0], bt.w_ref.shape[1], bt.wtype.size_bits
)
if bt.group_size is None: if bt.group_size is None:
w_s = torch.tensor([], device="cuda", dtype=torch.half) w_s = torch.tensor([], device="cuda", dtype=torch.half)
else: else:
w_s = marlin_permute_scales(bt.w_g_s, bt.w_ref.shape[0], w_s = marlin_permute_scales(
bt.w_ref.shape[1], bt.group_size) bt.w_g_s, bt.w_ref.shape[0], bt.w_ref.shape[1], bt.group_size
)
sort_indices = torch.empty(0, dtype=torch.int, device=device) sort_indices = torch.empty(0, dtype=torch.int, device=device)
g_idx = torch.empty(0, dtype=torch.int, device=device) g_idx = torch.empty(0, dtype=torch.int, device=device)
w_q = ops.gptq_marlin_repack(bt.w_q, sort_indices, bt.w_ref.shape[0], w_q = ops.gptq_marlin_repack(
bt.w_ref.shape[1], bt.wtype.size_bits) bt.w_q, sort_indices, bt.w_ref.shape[0], bt.w_ref.shape[1], bt.wtype.size_bits
)
if bt.a.dtype.is_floating_point: if bt.a.dtype.is_floating_point:
assert bt.w_ch_s is None assert bt.w_ch_s is None
assert bt.w_tok_s is None assert bt.w_tok_s is None
assert bt.group_size is not None assert bt.group_size is not None
fn = lambda: ops.gptq_marlin_gemm(a=bt.a, fn = lambda: ops.gptq_marlin_gemm(
b_q_weight=w_q, a=bt.a,
b_scales=w_s, b_q_weight=w_q,
b_zeros=w_zp, b_scales=w_s,
g_idx=g_idx, b_zeros=w_zp,
perm=sort_indices, g_idx=g_idx,
workspace=workspace.scratch, perm=sort_indices,
b_q_type=bt.wtype, workspace=workspace.scratch,
size_m=bt.a.shape[0], b_q_type=bt.wtype,
size_n=bt.w_ref.shape[1], size_m=bt.a.shape[0],
size_k=bt.w_ref.shape[0], size_n=bt.w_ref.shape[1],
is_k_full=True, size_k=bt.w_ref.shape[0],
is_zp_float=False) is_k_full=True,
is_zp_float=False,
)
else: else:
assert bt.a.dtype == torch.int8 assert bt.a.dtype == torch.int8
assert bt.wtype == scalar_types.uint4b8 assert bt.wtype == scalar_types.uint4b8
@ -221,36 +253,35 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
if bt.w_ch_s is not None: if bt.w_ch_s is not None:
s_ch = bt.w_ch_s.to(torch.float32) s_ch = bt.w_ch_s.to(torch.float32)
else: else:
s_ch = torch.ones(bt.w_ref.shape[1], s_ch = torch.ones(bt.w_ref.shape[1], dtype=torch.float32, device=device)
dtype=torch.float32,
device=device)
if bt.w_tok_s is not None: if bt.w_tok_s is not None:
s_tok = bt.w_tok_s.to(torch.float32) s_tok = bt.w_tok_s.to(torch.float32)
else: else:
s_tok = torch.ones(bt.a.shape[0], s_tok = torch.ones(bt.a.shape[0], dtype=torch.float32, device=device)
dtype=torch.float32,
device=device)
fn = lambda: ops.marlin_qqq_gemm(a=bt.a, fn = lambda: ops.marlin_qqq_gemm(
b_q_weight=w_q, a=bt.a,
s_group=w_s, b_q_weight=w_q,
s_tok=s_tok, s_group=w_s,
s_ch=s_ch, s_tok=s_tok,
workspace=workspace.scratch, s_ch=s_ch,
size_m=bt.a.shape[0], workspace=workspace.scratch,
size_n=bt.w_ref.shape[1], size_m=bt.a.shape[0],
size_k=bt.w_ref.shape[0]) size_n=bt.w_ref.shape[1],
size_k=bt.w_ref.shape[0],
)
return fn return fn
def machete_create_bench_fn(bt: BenchmarkTensors, def machete_create_bench_fn(
out_type=torch.dtype, bt: BenchmarkTensors, out_type=torch.dtype, schedule=None
schedule=None) -> Callable: ) -> Callable:
w_q = bt.w_q.t().contiguous().t() # make col major w_q = bt.w_q.t().contiguous().t() # make col major
w_q = ops.machete_prepack_B(w_q, bt.a.dtype, bt.wtype, w_q = ops.machete_prepack_B(
None if bt.w_g_s is None else bt.w_g_s.dtype) w_q, bt.a.dtype, bt.wtype, None if bt.w_g_s is None else bt.w_g_s.dtype
)
w_g_zp = bt.w_g_zp w_g_zp = bt.w_g_zp
if w_g_zp is not None: if w_g_zp is not None:
@ -275,26 +306,24 @@ def machete_create_bench_fn(bt: BenchmarkTensors,
# bench # bench
def bench_fns(label: str, sub_label: str, description: str, def bench_fns(label: str, sub_label: str, description: str, fns: list[Callable]):
fns: list[Callable]):
min_run_time = 1 if not NVTX_PROFILE else 0.1 min_run_time = 1 if not NVTX_PROFILE else 0.1
res = TBenchmark.Timer( res = TBenchmark.Timer(
stmt=""" stmt="""
for fn in fns: for fn in fns:
fn() fn()
""", """,
globals={ globals={"fns": fns},
"fns": fns
},
label=label, label=label,
sub_label=sub_label, sub_label=sub_label,
description=description, description=description,
).blocked_autorange(min_run_time=min_run_time) ).blocked_autorange(min_run_time=min_run_time)
if NVTX_PROFILE: if NVTX_PROFILE:
with nvtx.annotate("mm-bench"), nvtx.annotate( with (
f"{label}|{sub_label}|{description}"): nvtx.annotate("mm-bench"),
nvtx.annotate(f"{label}|{sub_label}|{description}"),
):
fns[0]() fns[0]()
return res return res
@ -304,19 +333,20 @@ _SWEEP_SCHEDULES_RESULTS: Optional[pd.DataFrame] = None
_SWEEP_SCHEDULES_RESULTS_CSV: Optional[str] = None _SWEEP_SCHEDULES_RESULTS_CSV: Optional[str] = None
def bench(types: TypeConfig, def bench(
group_size: int, types: TypeConfig,
m: int, group_size: int,
k: int, m: int,
n: int, k: int,
label: str, n: int,
sub_label: str, label: str,
sweep_schedules: bool = True) -> list[TMeasurement]: sub_label: str,
sweep_schedules: bool = True,
) -> list[TMeasurement]:
benchmark_tensors = create_bench_tensors((m, n, k), types, group_size) benchmark_tensors = create_bench_tensors((m, n, k), types, group_size)
sub_label += f", L={len(benchmark_tensors)}" sub_label += f", L={len(benchmark_tensors)}"
name_type_string = f"W{types.weight_type}"+\ name_type_string = f"W{types.weight_type}" + f"-A{terse_type_name(types.act_type)}"
f"-A{terse_type_name(types.act_type)}"
if types.group_scale_type is not None: if types.group_scale_type is not None:
name_type_string += f"-GS{terse_type_name(types.group_scale_type)}" name_type_string += f"-GS{terse_type_name(types.group_scale_type)}"
if types.group_zero_type is not None: if types.group_zero_type is not None:
@ -332,31 +362,45 @@ def bench(types: TypeConfig,
# pytorch impl # pytorch impl
timers.append( timers.append(
bench_fns( bench_fns(
label, sub_label, "torch.matmul (fp16)", label,
[torch_matmul_f16_create_bench_fn(bt) sub_label,
for bt in benchmark_tensors])) "torch.matmul (fp16)",
[torch_matmul_f16_create_bench_fn(bt) for bt in benchmark_tensors],
)
)
if types.act_type == torch.int8 or types.act_type == torch.float8_e4m3fn: if types.act_type == torch.int8 or types.act_type == torch.float8_e4m3fn:
timers.append( timers.append(
bench_fns( bench_fns(
label, sub_label, label,
f"cutlass_scaled_mm ({terse_type_name(types.act_type)})", [ sub_label,
cutlass_scaled_mm_create_bench_fn(bt) f"cutlass_scaled_mm ({terse_type_name(types.act_type)})",
for bt in benchmark_tensors [cutlass_scaled_mm_create_bench_fn(bt) for bt in benchmark_tensors],
])) )
)
if types.act_type != torch.float8_e4m3fn: if types.act_type != torch.float8_e4m3fn:
timers.append( timers.append(
bench_fns(label, sub_label, f"marlin ({name_type_string})", bench_fns(
[marlin_create_bench_fn(bt) label,
for bt in benchmark_tensors])) sub_label,
f"marlin ({name_type_string})",
[marlin_create_bench_fn(bt) for bt in benchmark_tensors],
)
)
# machete # machete
timers.append( timers.append(
bench_fns(label, sub_label, f"machete ({name_type_string})", [ bench_fns(
machete_create_bench_fn(bt, out_type=types.output_type) label,
for bt in benchmark_tensors sub_label,
])) f"machete ({name_type_string})",
[
machete_create_bench_fn(bt, out_type=types.output_type)
for bt in benchmark_tensors
],
)
)
if sweep_schedules: if sweep_schedules:
global _SWEEP_SCHEDULES_RESULTS global _SWEEP_SCHEDULES_RESULTS
@ -371,7 +415,8 @@ def bench(types: TypeConfig,
group_zeros_type=types.group_zero_type, group_zeros_type=types.group_zero_type,
token_scales_type=types.token_scale_type, token_scales_type=types.token_scale_type,
channel_scales_type=types.channel_scale_type, channel_scales_type=types.channel_scale_type,
out_type=types.output_type) out_type=types.output_type,
)
if schedules is None or len(schedules) == 0: if schedules is None or len(schedules) == 0:
raise ValueError("No schedules found to sweep") raise ValueError("No schedules found to sweep")
@ -383,11 +428,17 @@ def bench(types: TypeConfig,
if schedule_M >= 2 * max(m, 16) or schedule_M < m // 4: if schedule_M >= 2 * max(m, 16) or schedule_M < m // 4:
continue continue
res = bench_fns(label, sub_label, "machete_best", [ res = bench_fns(
machete_create_bench_fn( label,
bt, out_type=types.output_type, schedule=schedule) sub_label,
for bt in benchmark_tensors "machete_best",
]) [
machete_create_bench_fn(
bt, out_type=types.output_type, schedule=schedule
)
for bt in benchmark_tensors
],
)
results_row = { results_row = {
"M": m, "M": m,
@ -398,10 +449,8 @@ def bench(types: TypeConfig,
"median": res.median, "median": res.median,
} }
if _SWEEP_SCHEDULES_RESULTS is None: if _SWEEP_SCHEDULES_RESULTS is None:
_SWEEP_SCHEDULES_RESULTS = pd.DataFrame( _SWEEP_SCHEDULES_RESULTS = pd.DataFrame(columns=results_row.keys())
columns=results_row.keys()) _SWEEP_SCHEDULES_RESULTS.loc[len(_SWEEP_SCHEDULES_RESULTS)] = results_row
_SWEEP_SCHEDULES_RESULTS.\
loc[len(_SWEEP_SCHEDULES_RESULTS)] = results_row
print(f" {res.median:5.5} ", schedule) print(f" {res.median:5.5} ", schedule)
if not best or res.median < best.median: if not best or res.median < best.median:
@ -422,8 +471,9 @@ def print_timers(timers: list[TMeasurement]):
def run(args, MKNs: Iterable[tuple[int, int, int]]) -> Iterable[TMeasurement]: def run(args, MKNs: Iterable[tuple[int, int, int]]) -> Iterable[TMeasurement]:
types = TypeConfig( types = TypeConfig(
act_type=args.act_type, act_type=args.act_type,
weight_type=scalar_types.uint4b8 if args.group_zero_type is None \ weight_type=scalar_types.uint4b8
else scalar_types.uint4, if args.group_zero_type is None
else scalar_types.uint4,
output_type=args.out_type, output_type=args.out_type,
group_scale_type=args.group_scale_type, group_scale_type=args.group_scale_type,
group_zero_type=args.group_zero_type, group_zero_type=args.group_zero_type,
@ -433,14 +483,16 @@ def run(args, MKNs: Iterable[tuple[int, int, int]]) -> Iterable[TMeasurement]:
results: list[TMeasurement] = [] results: list[TMeasurement] = []
for m, k, n in MKNs: for m, k, n in MKNs:
timers = bench(types, timers = bench(
args.group_size, types,
m, args.group_size,
k, m,
n, k,
f"{args.act_type}-gemm", n,
f"MKN=({m}x{k}x{n})", f"{args.act_type}-gemm",
sweep_schedules=args.sweep_schedules) f"MKN=({m}x{k}x{n})",
sweep_schedules=args.sweep_schedules,
)
print_timers(timers) print_timers(timers)
results.extend(timers) results.extend(timers)
@ -454,7 +506,6 @@ def make_output(
base_description: str, base_description: str,
timestamp=None, timestamp=None,
): ):
print(f"== All Results {base_description} ====") print(f"== All Results {base_description} ====")
print_timers(data) print_timers(data)
@ -468,8 +519,7 @@ def make_output(
def run_square_bench(args): def run_square_bench(args):
dim_sizes = list( dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment))
range(args.dim_start, args.dim_end + 1, args.dim_increment))
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
data = run(args.dtype, args.sweep_schedules, MKNs) data = run(args.dtype, args.sweep_schedules, MKNs)
@ -479,8 +529,9 @@ def run_square_bench(args):
def run_range_bench(args): def run_range_bench(args):
m_start, k_start, n_start = (int(x) for x in args.dim_start.split(",")) m_start, k_start, n_start = (int(x) for x in args.dim_start.split(","))
m_end, k_end, n_end = (int(x) for x in args.dim_end.split(",")) m_end, k_end, n_end = (int(x) for x in args.dim_end.split(","))
m_increment, k_increment, n_increment = \ m_increment, k_increment, n_increment = (
(int(x) for x in args.dim_increment.split(",")) int(x) for x in args.dim_increment.split(",")
)
Ms = list(range(m_start, m_end + 1, m_increment)) Ms = list(range(m_start, m_end + 1, m_increment))
Ks = list(range(k_start, k_end + 1, k_increment)) Ks = list(range(k_start, k_end + 1, k_increment))
Ns = list(range(n_start, n_end + 1, n_increment)) Ns = list(range(n_start, n_end + 1, n_increment))
@ -492,7 +543,6 @@ def run_range_bench(args):
def run_model_bench(args): def run_model_bench(args):
print("Benchmarking models:") print("Benchmarking models:")
for i, model in enumerate(args.models): for i, model in enumerate(args.models):
print(f"[{i}] {model}") print(f"[{i}] {model}")
@ -535,10 +585,13 @@ def run_model_bench(args):
with open(f"model_bench-{type_string}-{timestr}.pkl", "wb") as f: with open(f"model_bench-{type_string}-{timestr}.pkl", "wb") as f:
args_dict = vars(args) args_dict = vars(args)
args_dict.pop("func") args_dict.pop("func")
pkl.dump({ pkl.dump(
"args": args_dict, {
"results": all_results, "args": args_dict,
}, f) "results": all_results,
},
f,
)
if __name__ == "__main__": if __name__ == "__main__":
@ -554,7 +607,6 @@ if __name__ == "__main__":
}[dt] }[dt]
class ToTorchDtype(argparse.Action): class ToTorchDtype(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None): def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, to_torch_dtype(values)) setattr(namespace, self.dest, to_torch_dtype(values))
@ -580,32 +632,32 @@ Benchmark Machete GEMM.
"--act-type", "--act-type",
action=ToTorchDtype, action=ToTorchDtype,
required=True, required=True,
choices=['bfloat16', 'float16', 'int8', 'float8_e4m3fn'], choices=["bfloat16", "float16", "int8", "float8_e4m3fn"],
) )
parser.add_argument( parser.add_argument(
"--group-scale-type", "--group-scale-type",
action=ToTorchDtype, action=ToTorchDtype,
choices=['bfloat16', 'float16'], choices=["bfloat16", "float16"],
) )
parser.add_argument( parser.add_argument(
"--group-zero-type", "--group-zero-type",
type=to_torch_dtype, type=to_torch_dtype,
choices=['bfloat16', 'float16'], choices=["bfloat16", "float16"],
) )
parser.add_argument( parser.add_argument(
"--channel-scale-type", "--channel-scale-type",
action=ToTorchDtype, action=ToTorchDtype,
choices=['float'], choices=["float"],
) )
parser.add_argument( parser.add_argument(
"--token-scale-type", "--token-scale-type",
action=ToTorchDtype, action=ToTorchDtype,
choices=['float'], choices=["float"],
) )
parser.add_argument( parser.add_argument(
"--out-type", "--out-type",
action=ToTorchDtype, action=ToTorchDtype,
choices=['bfloat16', 'float16'], choices=["bfloat16", "float16"],
) )
parser.add_argument( parser.add_argument(
"--group-size", "--group-size",
@ -618,9 +670,11 @@ Benchmark Machete GEMM.
action="store_true", action="store_true",
help="Run a sweep over all supported schedules", help="Run a sweep over all supported schedules",
) )
parser.add_argument("--sweep-csv-out", parser.add_argument(
help="CSV to store sweep results", "--sweep-csv-out",
default="sch_sweep_results.csv") help="CSV to store sweep results",
default="sch_sweep_results.csv",
)
subparsers = parser.add_subparsers(dest="cmd", required=True) subparsers = parser.add_subparsers(dest="cmd", required=True)
square_parser = subparsers.add_parser("square_bench") square_parser = subparsers.add_parser("square_bench")
@ -634,17 +688,20 @@ Benchmark Machete GEMM.
"--dim-start", "--dim-start",
type=str, type=str,
required=True, required=True,
help="Start value for M,K,N as common separated list") help="Start value for M,K,N as common separated list",
)
range_parser.add_argument( range_parser.add_argument(
"--dim-end", "--dim-end",
type=str, type=str,
required=True, required=True,
help="End value (inclusive) for M,K,N as common separated list") help="End value (inclusive) for M,K,N as common separated list",
)
range_parser.add_argument( range_parser.add_argument(
"--dim-increment", "--dim-increment",
type=str, type=str,
required=True, required=True,
help="Increment value for M,K,N as common separated list") help="Increment value for M,K,N as common separated list",
)
range_parser.set_defaults(func=run_range_bench) range_parser.set_defaults(func=run_range_bench)
model_parser = subparsers.add_parser("model_bench") model_parser = subparsers.add_parser("model_bench")
@ -655,14 +712,12 @@ Benchmark Machete GEMM.
default=DEFAULT_MODELS, default=DEFAULT_MODELS,
choices=WEIGHT_SHAPES.keys(), choices=WEIGHT_SHAPES.keys(),
) )
model_parser.add_argument("--tp-sizes", model_parser.add_argument(
nargs="+", "--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES
type=int, )
default=DEFAULT_TP_SIZES) model_parser.add_argument(
model_parser.add_argument("--batch-sizes", "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
nargs="+", )
type=int,
default=DEFAULT_BATCH_SIZES)
model_parser.set_defaults(func=run_model_bench) model_parser.set_defaults(func=run_model_bench)
args = parser.parse_args() args = parser.parse_args()

View File

@ -6,19 +6,34 @@ from benchmark_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES) GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES,
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES,
)
from vllm.model_executor.layers.quantization.utils.allspark_utils import ( from vllm.model_executor.layers.quantization.utils.allspark_utils import (
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, ALLSPARK_SUPPORTED_QUANT_TYPES) ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
ALLSPARK_SUPPORTED_QUANT_TYPES,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL,
MARLIN_SUPPORTED_GROUP_SIZES, query_marlin_supported_quant_types) GPTQ_MARLIN_MIN_THREAD_N,
MARLIN_SUPPORTED_GROUP_SIZES,
query_marlin_supported_quant_types,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
MarlinWorkspace, marlin_quantize) MarlinWorkspace,
marlin_quantize,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
marlin_24_quantize) marlin_24_quantize,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights) gptq_pack,
gptq_quantize_weights,
quantize_weights,
sort_weights,
)
from vllm.scalar_type import ScalarType from vllm.scalar_type import ScalarType
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
@ -29,22 +44,29 @@ ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True] K_FULL_OPTS = [False, True]
def bench_run(results: list[benchmark.Measurement], model: str, def bench_run(
act_order: bool, is_k_full: bool, quant_type: ScalarType, results: list[benchmark.Measurement],
group_size: int, size_m: int, size_k: int, size_n: int): model: str,
act_order: bool,
is_k_full: bool,
quant_type: ScalarType,
group_size: int,
size_m: int,
size_k: int,
size_n: int,
):
label = "Quant Matmul" label = "Quant Matmul"
sub_label = ("{}, act={} k_full={}, q={}, g={}, " sub_label = "{}, act={} k_full={}, q={}, g={}, MKN=({}x{}x{})".format(
"MKN=({}x{}x{})".format(model, act_order, is_k_full, model, act_order, is_k_full, str(quant_type), group_size, size_m, size_k, size_n
str(quant_type), group_size, size_m, )
size_k, size_n))
print(f"Testing: {sub_label}") print(f"Testing: {sub_label}")
a = torch.randn(size_m, size_k).to(torch.half).cuda() a = torch.randn(size_m, size_k).to(torch.half).cuda()
b = torch.rand(size_k, size_n).to(torch.half).cuda() b = torch.rand(size_k, size_n).to(torch.half).cuda()
a_tmp = (torch.zeros(size_m, size_k).to(torch.half).cuda()) a_tmp = torch.zeros(size_m, size_k).to(torch.half).cuda()
# Marlin quant # Marlin quant
( (
@ -57,14 +79,16 @@ def bench_run(results: list[benchmark.Measurement], model: str,
) = marlin_quantize(b, quant_type, group_size, act_order) ) = marlin_quantize(b, quant_type, group_size, act_order)
# Marlin_24 quant # Marlin_24 quant
(marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = (
marlin_24_s) = marlin_24_quantize(b, quant_type, group_size) marlin_24_quantize(b, quant_type, group_size)
)
marlin_zp = torch.empty(0, dtype=torch.int, device=b.device) marlin_zp = torch.empty(0, dtype=torch.int, device=b.device)
# GPTQ quant # GPTQ quant
(w_ref, q_w, s, g_idx, (w_ref, q_w, s, g_idx, rand_perm) = gptq_quantize_weights(
rand_perm) = gptq_quantize_weights(b, quant_type, group_size, act_order) b, quant_type, group_size, act_order
)
q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n) q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
# For act_order, sort the "weights" and "g_idx" # For act_order, sort the "weights" and "g_idx"
@ -74,32 +98,37 @@ def bench_run(results: list[benchmark.Measurement], model: str,
(q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx) (q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx)
# Prepare # Prepare
marlin_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, marlin_workspace = MarlinWorkspace(
GPTQ_MARLIN_MAX_PARALLEL) size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL
)
marlin_24_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N, marlin_24_workspace = MarlinWorkspace(
GPTQ_MARLIN_24_MAX_PARALLEL) size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL
)
marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int) marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int)
# AllSpark W8A16 quant # AllSpark W8A16 quant
as_supported_case = (quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES as_supported_case = (
and group_size == -1 and not act_order and is_k_full) quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES
and group_size == -1
and not act_order
and is_k_full
)
if as_supported_case: if as_supported_case:
properties = torch.cuda.get_device_properties(b.device.index) properties = torch.cuda.get_device_properties(b.device.index)
sm_count = properties.multi_processor_count sm_count = properties.multi_processor_count
sm_version = properties.major * 10 + properties.minor sm_version = properties.major * 10 + properties.minor
supported_arch = (sm_version >= 80 and sm_version < 90) supported_arch = sm_version >= 80 and sm_version < 90
as_supported_case = as_supported_case and supported_arch as_supported_case = as_supported_case and supported_arch
if supported_arch: if supported_arch:
has_zp = False has_zp = False
w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size, w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size, has_zp)
has_zp)
qw = qw.to(torch.uint8) qw = qw.to(torch.uint8)
qw_reorder, s_reorder, zp_reorder = \ qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight(
ops.allspark_repack_weight( qw, s, zp, has_zp
qw, s, zp, has_zp) )
CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
globals = { globals = {
@ -136,8 +165,7 @@ def bench_run(results: list[benchmark.Measurement], model: str,
"zp_reorder": zp_reorder if as_supported_case else None, "zp_reorder": zp_reorder if as_supported_case else None,
"sm_count": sm_count if as_supported_case else None, "sm_count": sm_count if as_supported_case else None,
"sm_version": sm_version if as_supported_case else None, "sm_version": sm_version if as_supported_case else None,
"CUBLAS_M_THRESHOLD": "CUBLAS_M_THRESHOLD": CUBLAS_M_THRESHOLD if as_supported_case else None,
CUBLAS_M_THRESHOLD if as_supported_case else None,
# Kernels # Kernels
"gptq_marlin_gemm": ops.gptq_marlin_gemm, "gptq_marlin_gemm": ops.gptq_marlin_gemm,
"gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm, "gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm,
@ -158,60 +186,63 @@ def bench_run(results: list[benchmark.Measurement], model: str,
label=label, label=label,
sub_label=sub_label, sub_label=sub_label,
description="pytorch_gemm", description="pytorch_gemm",
).blocked_autorange(min_run_time=min_run_time)) ).blocked_autorange(min_run_time=min_run_time)
)
results.append( results.append(
benchmark.Timer( benchmark.Timer(
stmt= stmt="output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501
globals=globals, globals=globals,
label=label, label=label,
sub_label=sub_label, sub_label=sub_label,
description="gptq_marlin_gemm_fp16", description="gptq_marlin_gemm_fp16",
).blocked_autorange(min_run_time=min_run_time)) ).blocked_autorange(min_run_time=min_run_time)
)
results.append( results.append(
benchmark.Timer( benchmark.Timer(
stmt= stmt="output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501
globals=globals, globals=globals,
label=label, label=label,
sub_label=sub_label, sub_label=sub_label,
description="gptq_marlin_gemm_fp32", description="gptq_marlin_gemm_fp32",
).blocked_autorange(min_run_time=min_run_time)) ).blocked_autorange(min_run_time=min_run_time)
)
if (quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES if (
and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES): quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
):
results.append( results.append(
benchmark.Timer( benchmark.Timer(
stmt= stmt="output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)", # noqa: E501
"output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)", # noqa: E501
globals=globals, globals=globals,
label=label, label=label,
sub_label=sub_label, sub_label=sub_label,
description="gptq_marlin_24_gemm", description="gptq_marlin_24_gemm",
).blocked_autorange(min_run_time=min_run_time)) ).blocked_autorange(min_run_time=min_run_time)
)
results.append( results.append(
benchmark.Timer( benchmark.Timer(
stmt= stmt="q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)", # noqa: E501
"q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)", # noqa: E501
globals=globals, globals=globals,
label=label, label=label,
sub_label=sub_label, sub_label=sub_label,
description="gptq_marlin_repack", description="gptq_marlin_repack",
).blocked_autorange(min_run_time=min_run_time)) ).blocked_autorange(min_run_time=min_run_time)
)
if as_supported_case: if as_supported_case:
results.append( results.append(
benchmark.Timer( benchmark.Timer(
stmt= stmt="output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)", # noqa: E501
"output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)", # noqa: E501
globals=globals, globals=globals,
label=label, label=label,
sub_label=sub_label, sub_label=sub_label,
description="allspark_w8a16_gemm_fp32", description="allspark_w8a16_gemm_fp32",
).blocked_autorange(min_run_time=min_run_time)) ).blocked_autorange(min_run_time=min_run_time)
)
def main(args): def main(args):
@ -233,37 +264,50 @@ def main(args):
continue continue
for act_order in ACT_ORDER_OPTS: for act_order in ACT_ORDER_OPTS:
if len(args.limit_act_order if (
) > 0 and act_order not in args.limit_act_order: len(args.limit_act_order) > 0
and act_order not in args.limit_act_order
):
continue continue
for is_k_full in K_FULL_OPTS: for is_k_full in K_FULL_OPTS:
if len(args.limit_k_full if (
) > 0 and is_k_full not in args.limit_k_full: len(args.limit_k_full) > 0
and is_k_full not in args.limit_k_full
):
continue continue
for quant_type in query_marlin_supported_quant_types( for quant_type in query_marlin_supported_quant_types(False):
False): if (
if len(args.limit_num_bits) > 0 and \ len(args.limit_num_bits) > 0
quant_type.size_bits not in args.limit_num_bits: and quant_type.size_bits not in args.limit_num_bits
):
continue continue
for group_size in MARLIN_SUPPORTED_GROUP_SIZES: for group_size in MARLIN_SUPPORTED_GROUP_SIZES:
if len( if (
args.limit_group_size len(args.limit_group_size) > 0
) > 0 and group_size not in args.limit_group_size: and group_size not in args.limit_group_size
):
continue continue
# For act_order, the group_size must be less than # For act_order, the group_size must be less than
# size_k # size_k
if act_order and (group_size == size_k if act_order and (group_size == size_k or group_size == -1):
or group_size == -1):
continue continue
for size_m in args.batch_sizes: for size_m in args.batch_sizes:
bench_run(results, model, act_order, is_k_full, bench_run(
quant_type, group_size, size_m, results,
size_k, size_n) model,
act_order,
is_k_full,
quant_type,
group_size,
size_m,
size_k,
size_n,
)
compare = benchmark.Compare(results) compare = benchmark.Compare(results)
compare.print() compare.print()
@ -274,7 +318,8 @@ def main(args):
# #
if __name__ == "__main__": if __name__ == "__main__":
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description="Benchmark Marlin across specified models/shapes/batches") description="Benchmark Marlin across specified models/shapes/batches"
)
parser.add_argument( parser.add_argument(
"--models", "--models",
nargs="+", nargs="+",
@ -282,10 +327,9 @@ if __name__ == "__main__":
default=DEFAULT_MODELS, default=DEFAULT_MODELS,
choices=WEIGHT_SHAPES.keys(), choices=WEIGHT_SHAPES.keys(),
) )
parser.add_argument("--batch-sizes", parser.add_argument(
nargs="+", "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
type=int, )
default=DEFAULT_BATCH_SIZES)
parser.add_argument("--limit-k", nargs="+", type=int, default=[]) parser.add_argument("--limit-k", nargs="+", type=int, default=[])
parser.add_argument("--limit-n", nargs="+", type=int, default=[]) parser.add_argument("--limit-n", nargs="+", type=int, default=[])
parser.add_argument("--limit-group-size", nargs="+", type=int, default=[]) parser.add_argument("--limit-group-size", nargs="+", type=int, default=[])

View File

@ -6,16 +6,17 @@ import time
from contextlib import nullcontext from contextlib import nullcontext
from datetime import datetime from datetime import datetime
from itertools import product from itertools import product
from types import SimpleNamespace
from typing import Any, TypedDict from typing import Any, TypedDict
import ray import ray
import torch import torch
import triton
from ray.experimental.tqdm_ray import tqdm from ray.experimental.tqdm_ray import tqdm
from transformers import AutoConfig
from vllm.model_executor.layers.fused_moe.fused_moe import * from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_config
from vllm.triton_utils import triton
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
@ -30,56 +31,60 @@ class BenchmarkConfig(TypedDict):
num_stages: int num_stages: int
def benchmark_config(config: BenchmarkConfig, def benchmark_config(
num_tokens: int, config: BenchmarkConfig,
num_experts: int, num_tokens: int,
shard_intermediate_size: int, num_experts: int,
hidden_size: int, shard_intermediate_size: int,
topk: int, hidden_size: int,
dtype: torch.dtype, topk: int,
use_fp8_w8a8: bool, dtype: torch.dtype,
use_int8_w8a16: bool, use_fp8_w8a8: bool,
num_iters: int = 100, use_int8_w8a16: bool,
block_quant_shape: List[int] = None, num_iters: int = 100,
use_deep_gemm: bool = False) -> float: block_quant_shape: List[int] = None,
use_deep_gemm: bool = False,
) -> float:
init_dtype = torch.float16 if use_fp8_w8a8 else dtype init_dtype = torch.float16 if use_fp8_w8a8 else dtype
x = torch.randn(num_tokens, hidden_size, dtype=dtype) x = torch.randn(num_tokens, hidden_size, dtype=dtype)
if use_int8_w8a16: if use_int8_w8a16:
w1 = torch.randint(-127, w1 = torch.randint(
127, ( -127,
num_experts, 127,
shard_intermediate_size, (
hidden_size, num_experts,
), shard_intermediate_size,
dtype=torch.int8) hidden_size,
w2 = torch.randint(-127, ),
127, ( dtype=torch.int8,
num_experts, )
hidden_size, w2 = torch.randint(
shard_intermediate_size // 2, -127,
), 127,
dtype=torch.int8) (
num_experts,
hidden_size,
shard_intermediate_size // 2,
),
dtype=torch.int8,
)
else: else:
w1 = torch.randn(num_experts, w1 = torch.randn(
shard_intermediate_size, num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
hidden_size, )
dtype=init_dtype) w2 = torch.randn(
w2 = torch.randn(num_experts, num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
hidden_size, )
shard_intermediate_size // 2, gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
dtype=init_dtype)
gating_output = torch.randn(num_iters,
num_tokens,
num_experts,
dtype=torch.float32)
w1_scale = None w1_scale = None
w2_scale = None w2_scale = None
a1_scale = None a1_scale = None
a2_scale = None a2_scale = None
if use_int8_w8a16: if use_int8_w8a16:
w1_scale = torch.randn((num_experts, 2 * shard_intermediate_size), w1_scale = torch.randn(
dtype=torch.float32) (num_experts, 2 * shard_intermediate_size), dtype=torch.float32
)
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32) w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
if use_fp8_w8a8: if use_fp8_w8a8:
if block_quant_shape: if block_quant_shape:
@ -92,10 +97,14 @@ def benchmark_config(config: BenchmarkConfig,
n_tiles_w2 = (K + block_n - 1) // block_n n_tiles_w2 = (K + block_n - 1) // block_n
k_tiles_w1 = (K + block_k - 1) // block_k k_tiles_w1 = (K + block_k - 1) // block_k
k_tiles_w2 = (N + block_k - 1) // block_k k_tiles_w2 = (N + block_k - 1) // block_k
w1_scale = torch.rand((E, n_tiles_w1, k_tiles_w1), w1_scale = (
dtype=torch.float32) * factor_for_scale torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
w2_scale = torch.rand((E, n_tiles_w2, k_tiles_w2), * factor_for_scale
dtype=torch.float32) * factor_for_scale )
w2_scale = (
torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
* factor_for_scale
)
else: else:
w1_scale = torch.randn(num_experts, dtype=torch.float32) w1_scale = torch.randn(num_experts, dtype=torch.float32)
w2_scale = torch.randn(num_experts, dtype=torch.float32) w2_scale = torch.randn(num_experts, dtype=torch.float32)
@ -113,10 +122,12 @@ def benchmark_config(config: BenchmarkConfig,
def run(): def run():
from vllm.model_executor.layers.fused_moe import override_config from vllm.model_executor.layers.fused_moe import override_config
with override_config(config): with override_config(config):
if use_deep_gemm: if use_deep_gemm:
topk_weights, topk_ids = fused_topk(x, input_gating, topk, topk_weights, topk_ids, token_expert_indices = fused_topk(
False) x, input_gating, topk, False
)
return fused_experts( return fused_experts(
x, x,
w1, w1,
@ -212,8 +223,7 @@ def get_rocm_tuning_space(use_fp16):
return param_ranges return param_ranges
def get_configs_compute_bound(use_fp16, def get_configs_compute_bound(use_fp16, block_quant_shape) -> list[dict[str, int]]:
block_quant_shape) -> list[dict[str, int]]:
configs: list[BenchmarkConfig] = [] configs: list[BenchmarkConfig] = []
if current_platform.is_rocm(): if current_platform.is_rocm():
@ -249,20 +259,25 @@ def get_configs_compute_bound(use_fp16,
if block_quant_shape is not None and not use_fp16: if block_quant_shape is not None and not use_fp16:
block_n, block_k = block_quant_shape[0], block_quant_shape[1] block_n, block_k = block_quant_shape[0], block_quant_shape[1]
for config in configs[:]: for config in configs[:]:
if config["BLOCK_SIZE_K"] % block_k != 0 or config[ if (
"BLOCK_SIZE_N"] % block_n != 0: config["BLOCK_SIZE_K"] % block_k != 0
or config["BLOCK_SIZE_N"] % block_n != 0
):
configs.remove(config) configs.remove(config)
return configs return configs
def prune_rocm_search_space(num_tokens, shard_intermediate_size, hidden_size, def prune_rocm_search_space(
search_space, is_fp16, topk): num_tokens, shard_intermediate_size, hidden_size, search_space, is_fp16, topk
):
N1, K1 = shard_intermediate_size, hidden_size N1, K1 = shard_intermediate_size, hidden_size
N2, K2 = hidden_size, shard_intermediate_size // 2 N2, K2 = hidden_size, shard_intermediate_size // 2
pruned_space_1 = prune_rocm_configs(num_tokens * topk, N1, K1, pruned_space_1 = prune_rocm_configs(
search_space, is_fp16) num_tokens * topk, N1, K1, search_space, is_fp16
pruned_space_2 = prune_rocm_configs(num_tokens * topk, N2, K2, )
search_space, is_fp16) pruned_space_2 = prune_rocm_configs(
num_tokens * topk, N2, K2, search_space, is_fp16
)
search_space = merge_unique_dicts(pruned_space_1, pruned_space_2) search_space = merge_unique_dicts(pruned_space_1, pruned_space_2)
return search_space return search_space
@ -300,14 +315,14 @@ def prune_rocm_configs(M, N, K, configs, is_fp16=True):
SPLIT_K = config.get("SPLIT_K", 1) SPLIT_K = config.get("SPLIT_K", 1)
GROUP_M = config.get("GROUP_SIZE_M") GROUP_M = config.get("GROUP_SIZE_M")
if is_fp16: if is_fp16:
if (matrix_instr_nonkdim > BLOCK_SIZE_M if (
or matrix_instr_nonkdim > BLOCK_SIZE_N): matrix_instr_nonkdim > BLOCK_SIZE_M
or matrix_instr_nonkdim > BLOCK_SIZE_N
):
continue continue
if (matrix_instr_nonkdim >= M if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M:
and matrix_instr_nonkdim != BLOCK_SIZE_M):
continue continue
if (matrix_instr_nonkdim >= N if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N:
and matrix_instr_nonkdim != BLOCK_SIZE_N):
continue continue
# Skip BLOCK_SIZE that is too large compare to M/N # Skip BLOCK_SIZE that is too large compare to M/N
# unless BLOCK_SIZE is already small enough # unless BLOCK_SIZE is already small enough
@ -328,8 +343,10 @@ def prune_rocm_configs(M, N, K, configs, is_fp16=True):
continue continue
# out of shared memory resource # out of shared memory resource
# TODO (zhanglx): This does not consider the LDS usage in the epilogue # TODO (zhanglx): This does not consider the LDS usage in the epilogue
LDS = (BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a + LDS = (
BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b) BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a
+ BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b
)
if LDS > 65536: if LDS > 65536:
continue continue
# Skip small block sizes and num_warps for large gemm # Skip small block sizes and num_warps for large gemm
@ -363,7 +380,6 @@ def merge_unique_dicts(list1, list2):
@ray.remote(num_gpus=1) @ray.remote(num_gpus=1)
class BenchmarkWorker: class BenchmarkWorker:
def __init__(self, seed: int) -> None: def __init__(self, seed: int) -> None:
torch.set_default_device("cuda") torch.set_default_device("cuda")
current_platform.seed_everything(seed) current_platform.seed_everything(seed)
@ -387,36 +403,40 @@ class BenchmarkWorker:
use_deep_gemm: bool = False, use_deep_gemm: bool = False,
) -> tuple[dict[str, int], float]: ) -> tuple[dict[str, int], float]:
current_platform.seed_everything(self.seed) current_platform.seed_everything(self.seed)
dtype_str = get_config_dtype_str(dtype, dtype_str = get_config_dtype_str(
use_int8_w8a16=use_int8_w8a16, dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
use_fp8_w8a8=use_fp8_w8a8) )
# NOTE(woosuk): The current naming convention uses w2.shape[2], which # NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul. # is the intermediate size after silu_and_mul.
op_config = get_moe_configs(num_experts, shard_intermediate_size // 2, op_config = get_moe_configs(
dtype_str) num_experts, shard_intermediate_size // 2, dtype_str
)
if op_config is None: if op_config is None:
config = get_default_config(num_tokens, config = get_default_config(
num_experts, num_tokens,
shard_intermediate_size, num_experts,
hidden_size, shard_intermediate_size,
topk, hidden_size,
dtype_str, topk,
is_marlin=False) dtype_str,
is_marlin=False,
)
else: else:
config = op_config[min(op_config.keys(), config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))]
key=lambda x: abs(x - num_tokens))] kernel_time = benchmark_config(
kernel_time = benchmark_config(config, config,
num_tokens, num_tokens,
num_experts, num_experts,
shard_intermediate_size, shard_intermediate_size,
hidden_size, hidden_size,
topk, topk,
dtype, dtype,
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a16, use_int8_w8a16,
num_iters=100, num_iters=100,
block_quant_shape=block_quant_shape, block_quant_shape=block_quant_shape,
use_deep_gemm=use_deep_gemm) use_deep_gemm=use_deep_gemm,
)
return config, kernel_time return config, kernel_time
def tune( def tune(
@ -437,13 +457,22 @@ class BenchmarkWorker:
best_time = float("inf") best_time = float("inf")
if current_platform.is_rocm(): if current_platform.is_rocm():
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16) is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
search_space = prune_rocm_search_space(num_tokens, search_space = prune_rocm_search_space(
shard_intermediate_size, num_tokens,
hidden_size, search_space, shard_intermediate_size,
is_fp16, topk) hidden_size,
search_space,
is_fp16,
topk,
)
with torch.cuda.device(self.device_id) if current_platform.is_rocm( need_device_guard = False
) else nullcontext(): if current_platform.is_rocm():
visible_device = os.environ.get("ROCR_VISIBLE_DEVICES", None)
if visible_device != f"{self.device_id}":
need_device_guard = True
with torch.cuda.device(self.device_id) if need_device_guard else nullcontext():
for config in tqdm(search_space): for config in tqdm(search_space):
try: try:
kernel_time = benchmark_config( kernel_time = benchmark_config(
@ -458,7 +487,8 @@ class BenchmarkWorker:
use_int8_w8a16, use_int8_w8a16,
num_iters=20, num_iters=20,
block_quant_shape=block_quant_shape, block_quant_shape=block_quant_shape,
use_deep_gemm=use_deep_gemm) use_deep_gemm=use_deep_gemm,
)
except triton.runtime.autotuner.OutOfResources: except triton.runtime.autotuner.OutOfResources:
# Some configurations may be invalid and fail to compile. # Some configurations may be invalid and fail to compile.
continue continue
@ -474,42 +504,44 @@ class BenchmarkWorker:
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig: def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
return { return {
"BLOCK_SIZE_M": "BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
config["BLOCK_SIZE_M"], "BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
"BLOCK_SIZE_N": "BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
config["BLOCK_SIZE_N"], "GROUP_SIZE_M": config["GROUP_SIZE_M"],
"BLOCK_SIZE_K": "num_warps": config["num_warps"],
config["BLOCK_SIZE_K"], "num_stages": config["num_stages"],
"GROUP_SIZE_M": **(
config["GROUP_SIZE_M"], {"waves_per_eu": config["waves_per_eu"]} if "waves_per_eu" in config else {}
"num_warps": ),
config["num_warps"], **(
"num_stages": {"matrix_instr_nonkdim": config["matrix_instr_nonkdim"]}
config["num_stages"], if "matrix_instr_nonkdim" in config
**({ else {}
"waves_per_eu": config["waves_per_eu"] ),
} if "waves_per_eu" in config else {}), **({"kpack": config["kpack"]} if "kpack" in config else {}),
**({
"matrix_instr_nonkdim": config["matrix_instr_nonkdim"]
} if "matrix_instr_nonkdim" in config else {}),
**({
"kpack": config["kpack"]
} if "kpack" in config else {}),
} }
def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int, def save_configs(
shard_intermediate_size: int, hidden_size: int, topk: int, configs: dict[int, BenchmarkConfig],
dtype: torch.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool, num_experts: int,
block_quant_shape: List[int]) -> None: shard_intermediate_size: int,
dtype_str = get_config_dtype_str(dtype, hidden_size: int,
use_int8_w8a16=use_int8_w8a16, topk: int,
use_fp8_w8a8=use_fp8_w8a8) dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
block_quant_shape: List[int],
) -> None:
dtype_str = get_config_dtype_str(
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which # NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul. # is the intermediate size after silu_and_mul.
filename = get_config_file_name(num_experts, shard_intermediate_size // 2, filename = get_config_file_name(
dtype_str, block_quant_shape) num_experts, shard_intermediate_size // 2, dtype_str, block_quant_shape
)
print(f"Writing best config to {filename}...") print(f"Writing best config to {filename}...")
with open(filename, "w") as f: with open(filename, "w") as f:
@ -518,18 +550,20 @@ def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int,
def get_weight_block_size_safety(config, default_value=None): def get_weight_block_size_safety(config, default_value=None):
quantization_config = getattr(config, "quantization_config", {})
quantization_config = getattr(config, 'quantization_config', {})
if isinstance(quantization_config, dict): if isinstance(quantization_config, dict):
return quantization_config.get('weight_block_size', default_value) return quantization_config.get("weight_block_size", default_value)
return default_value return default_value
def main(args: argparse.Namespace): def main(args: argparse.Namespace):
print(args) print(args)
block_quant_shape = None
config = AutoConfig.from_pretrained( config = get_config(model=args.model, trust_remote_code=args.trust_remote_code)
args.model, trust_remote_code=args.trust_remote_code) if args.model_prefix:
config = getattr(config, args.model_prefix)
config = SimpleNamespace(**config)
if config.architectures[0] == "DbrxForCausalLM": if config.architectures[0] == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k topk = config.ffn_config.moe_top_k
@ -540,14 +574,12 @@ def main(args: argparse.Namespace):
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif (config.architectures[0] == "DeepseekV3ForCausalLM" elif config.architectures[0] in ("DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"):
or config.architectures[0] == "DeepseekV2ForCausalLM"):
E = config.n_routed_experts E = config.n_routed_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size shard_intermediate_size = 2 * intermediate_size // args.tp_size
block_quant_shape = get_weight_block_size_safety(config) elif config.architectures[0] in ("Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"):
elif config.architectures[0] == "Qwen2MoeForCausalLM":
E = config.num_experts E = config.num_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
@ -562,20 +594,51 @@ def main(args: argparse.Namespace):
shard_intermediate_size = 2 * intermediate_size // args.tp_size shard_intermediate_size = 2 * intermediate_size // args.tp_size
hidden_size = config.hidden_size hidden_size = config.hidden_size
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype dtype = (
torch.float16
if current_platform.is_rocm()
else getattr(torch, config.torch_dtype)
)
use_fp8_w8a8 = args.dtype == "fp8_w8a8" use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16" use_int8_w8a16 = args.dtype == "int8_w8a16"
block_quant_shape = get_weight_block_size_safety(config)
if args.batch_size is None: if args.batch_size is None:
batch_sizes = [ batch_sizes = [
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, 1,
2048, 3072, 4096 2,
4,
8,
16,
24,
32,
48,
64,
96,
128,
256,
512,
1024,
1536,
2048,
3072,
4096,
] ]
else: else:
batch_sizes = [args.batch_size] batch_sizes = [args.batch_size]
use_deep_gemm = bool(args.use_deep_gemm) use_deep_gemm = bool(args.use_deep_gemm)
if current_platform.is_rocm() and "HIP_VISIBLE_DEVICES" in os.environ:
# Ray will set ROCR_VISIBLE_DEVICES for device visibility
logger.warning(
"Ray uses ROCR_VISIBLE_DEVICES to control device accessibility."
"Replacing HIP_VISIBLE_DEVICES with ROCR_VISIBLE_DEVICES."
)
val = os.environ["HIP_VISIBLE_DEVICES"]
os.environ["ROCR_VISIBLE_DEVICES"] = val
del os.environ["HIP_VISIBLE_DEVICES"]
ray.init() ray.init()
num_gpus = int(ray.available_resources()["GPU"]) num_gpus = int(ray.available_resources()["GPU"])
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)] workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
@ -598,25 +661,59 @@ def main(args: argparse.Namespace):
start = time.time() start = time.time()
configs = _distribute( configs = _distribute(
"tune", [(batch_size, E, shard_intermediate_size, hidden_size, "tune",
topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space, [
block_quant_shape, use_deep_gemm) (
for batch_size in batch_sizes]) batch_size,
E,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a16,
search_space,
block_quant_shape,
use_deep_gemm,
)
for batch_size in batch_sizes
],
)
best_configs = { best_configs = {
M: sort_config(config) M: sort_config(config) for M, config in zip(batch_sizes, configs)
for M, config in zip(batch_sizes, configs)
} }
save_configs(best_configs, E, shard_intermediate_size, hidden_size, save_configs(
topk, dtype, use_fp8_w8a8, use_int8_w8a16, best_configs,
block_quant_shape) E,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a16,
block_quant_shape,
)
end = time.time() end = time.time()
print(f"Tuning took {end - start:.2f} seconds") print(f"Tuning took {end - start:.2f} seconds")
else: else:
outputs = _distribute( outputs = _distribute(
"benchmark", "benchmark",
[(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype, [
use_fp8_w8a8, use_int8_w8a16, block_quant_shape, use_deep_gemm) (
for batch_size in batch_sizes]) batch_size,
E,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a16,
block_quant_shape,
use_deep_gemm,
)
for batch_size in batch_sizes
],
)
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs): for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
print(f"Batch size: {batch_size}, config: {config}") print(f"Batch size: {batch_size}, config: {config}")
@ -625,23 +722,21 @@ def main(args: argparse.Namespace):
if __name__ == "__main__": if __name__ == "__main__":
parser = FlexibleArgumentParser() parser = FlexibleArgumentParser()
parser.add_argument("--model", parser.add_argument(
type=str, "--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
default="mistralai/Mixtral-8x7B-Instruct-v0.1") )
parser.add_argument("--tp-size", parser.add_argument(
"-tp", "--tp-size", "-tp", "--tensor-parallel-size", type=int, default=2
"--tensor-parallel-size", )
type=int, parser.add_argument(
default=2) "--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto"
parser.add_argument("--dtype", )
type=str,
choices=["auto", "fp8_w8a8", "int8_w8a16"],
default="auto")
parser.add_argument("--use-deep-gemm", action="store_true") parser.add_argument("--use-deep-gemm", action="store_true")
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--batch-size", type=int, required=False) parser.add_argument("--batch-size", type=int, required=False)
parser.add_argument("--tune", action="store_true") parser.add_argument("--tune", action="store_true")
parser.add_argument("--trust-remote-code", action="store_true") parser.add_argument("--trust-remote-code", action="store_true")
parser.add_argument("--model-prefix", type=str, required=False)
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@ -0,0 +1,416 @@
# SPDX-License-Identifier: Apache-2.0
import argparse
from typing import Any, TypedDict
import ray
import torch
from transformers import AutoConfig
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_moe_permute,
_moe_unpermute_and_reduce,
)
from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import *
from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize
from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser
FP8_DTYPE = current_platform.fp8_dtype()
class BenchmarkConfig(TypedDict):
BLOCK_SIZE_M: int
BLOCK_SIZE_N: int
BLOCK_SIZE_K: int
GROUP_SIZE_M: int
num_warps: int
num_stages: int
def benchmark_permute(
num_tokens: int,
num_experts: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
num_iters: int = 100,
use_customized_permute: bool = False,
) -> float:
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
# output_hidden_states = torch.empty_like(hidden_states)
if use_fp8_w8a8:
align_block_size = 128 # deepgemm needs 128 m aligned block
qhidden_states, scale = _fp8_quantize(hidden_states, None, None)
else:
align_block_size = None
qhidden_states = hidden_states
gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
topk_weights, topk_ids, token_expert_indices = fused_topk(
qhidden_states, input_gating, topk, False
)
def prepare(i: int):
input_gating.copy_(gating_output[i])
def run():
if use_customized_permute:
(permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = (
moe_permute(
qhidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
token_expert_indices=token_expert_indices,
topk=topk,
n_expert=num_experts,
n_local_expert=num_experts,
expert_map=None,
align_block_size=align_block_size,
)
)
else:
(
permuted_hidden_states,
a1q_scale,
sorted_token_ids,
expert_ids,
inv_perm,
) = _moe_permute(
qhidden_states, None, topk_ids, num_experts, None, align_block_size
)
# JIT compilation & warmup
run()
torch.cuda.synchronize()
# Capture 10 invocations with CUDA graph
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
for _ in range(10):
run()
torch.cuda.synchronize()
# Warmup
for _ in range(5):
graph.replay()
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
latencies: list[float] = []
for i in range(num_iters):
prepare(i)
torch.cuda.synchronize()
start_event.record()
graph.replay()
end_event.record()
end_event.synchronize()
latencies.append(start_event.elapsed_time(end_event))
avg = sum(latencies) / (num_iters * 10) * 1000 # us
graph.reset()
return avg
def benchmark_unpermute(
num_tokens: int,
num_experts: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
num_iters: int = 100,
use_customized_permute: bool = False,
) -> float:
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
output_hidden_states = torch.empty_like(hidden_states)
if use_fp8_w8a8:
align_block_size = 128 # deepgemm needs 128 m aligned block
qhidden_states, scale = _fp8_quantize(hidden_states, None, None)
else:
align_block_size = None
qhidden_states = hidden_states
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
topk_weights, topk_ids, token_expert_indices = fused_topk(
qhidden_states, input_gating, topk, False
)
def prepare():
if use_customized_permute:
(permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = (
moe_permute(
qhidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
token_expert_indices=token_expert_indices,
topk=topk,
n_expert=num_experts,
n_local_expert=num_experts,
expert_map=None,
align_block_size=align_block_size,
)
)
# convert to fp16/bf16 as gemm output
return (
permuted_hidden_states.to(dtype),
first_token_off,
inv_perm_idx,
m_indices,
)
else:
(
permuted_qhidden_states,
a1q_scale,
sorted_token_ids,
expert_ids,
inv_perm,
) = _moe_permute(
qhidden_states, None, topk_ids, num_experts, None, align_block_size
)
# convert to fp16/bf16 as gemm output
return (
permuted_qhidden_states.to(dtype),
a1q_scale,
sorted_token_ids,
expert_ids,
inv_perm,
)
def run(input: tuple):
if use_customized_permute:
(permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = input
moe_unpermute(
permuted_hidden_states,
topk_weights,
topk_ids,
inv_perm_idx,
first_token_off,
topk,
num_experts,
num_experts,
)
else:
(
permuted_hidden_states,
a1q_scale,
sorted_token_ids,
expert_ids,
inv_perm,
) = input
_moe_unpermute_and_reduce(
output_hidden_states, permuted_hidden_states, inv_perm, topk_weights
)
# JIT compilation & warmup
input = prepare()
run(input)
torch.cuda.synchronize()
# Capture 10 invocations with CUDA graph
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
for _ in range(10):
run(input)
torch.cuda.synchronize()
# Warmup
for _ in range(5):
graph.replay()
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
latencies: list[float] = []
for i in range(num_iters):
torch.cuda.synchronize()
start_event.record()
graph.replay()
end_event.record()
end_event.synchronize()
latencies.append(start_event.elapsed_time(end_event))
avg = sum(latencies) / (num_iters * 10) * 1000 # us
graph.reset()
return avg
@ray.remote(num_gpus=1)
class BenchmarkWorker:
def __init__(self, seed: int) -> None:
torch.set_default_device("cuda")
current_platform.seed_everything(seed)
self.seed = seed
# Get the device ID to allocate tensors and kernels
# on the respective GPU. This is required for Ray to work
# correctly with multi-GPU tuning on the ROCm platform.
self.device_id = int(ray.get_gpu_ids()[0])
def benchmark(
self,
num_tokens: int,
num_experts: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
use_customized_permute: bool = False,
) -> tuple[dict[str, int], float]:
current_platform.seed_everything(self.seed)
permute_time = benchmark_permute(
num_tokens,
num_experts,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a16,
num_iters=100,
use_customized_permute=use_customized_permute,
)
unpermute_time = benchmark_unpermute(
num_tokens,
num_experts,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a16,
num_iters=100,
use_customized_permute=use_customized_permute,
)
return permute_time, unpermute_time
def get_weight_block_size_safety(config, default_value=None):
quantization_config = getattr(config, "quantization_config", {})
if isinstance(quantization_config, dict):
return quantization_config.get("weight_block_size", default_value)
return default_value
def main(args: argparse.Namespace):
print(args)
config = AutoConfig.from_pretrained(
args.model, trust_remote_code=args.trust_remote_code
)
if config.architectures[0] == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k
elif config.architectures[0] == "JambaForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
elif (
config.architectures[0] == "DeepseekV3ForCausalLM"
or config.architectures[0] == "DeepseekV2ForCausalLM"
):
E = config.n_routed_experts
topk = config.num_experts_per_tok
elif config.architectures[0] in ["Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"]:
E = config.num_experts
topk = config.num_experts_per_tok
else:
# Support for llama4
config = config.get_text_config()
# Default: Mixtral.
E = config.num_local_experts
topk = config.num_experts_per_tok
hidden_size = config.hidden_size
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16"
use_customized_permute = args.use_customized_permute
if args.batch_size is None:
batch_sizes = [
1,
2,
4,
8,
16,
24,
32,
48,
64,
96,
128,
256,
512,
1024,
1536,
2048,
3072,
4096,
]
else:
batch_sizes = [args.batch_size]
ray.init()
num_gpus = int(ray.available_resources()["GPU"])
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
def _distribute(method: str, inputs: list[Any]) -> list[Any]:
outputs = []
worker_idx = 0
for input_args in inputs:
worker = workers[worker_idx]
worker_method = getattr(worker, method)
output = worker_method.remote(*input_args)
outputs.append(output)
worker_idx = (worker_idx + 1) % num_gpus
return ray.get(outputs)
outputs = _distribute(
"benchmark",
[
(
batch_size,
E,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a16,
use_customized_permute,
)
for batch_size in batch_sizes
],
)
for batch_size, (permute, unpermute) in zip(batch_sizes, outputs):
print(f"Batch size: {batch_size}")
print(f"Permute time: {permute:.2f} us")
print(f"Unpermute time: {unpermute:.2f} us")
if __name__ == "__main__":
parser = FlexibleArgumentParser()
parser.add_argument(
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
)
parser.add_argument(
"--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto"
)
parser.add_argument("--use-customized-permute", action="store_true")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--batch-size", type=int, required=False)
parser.add_argument("--trust-remote-code", action="store_true")
args = parser.parse_args()
main(args)

View File

@ -9,8 +9,11 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser, from vllm.utils import (
create_kv_caches_with_random) STR_DTYPE_TO_TORCH_DTYPE,
FlexibleArgumentParser,
create_kv_caches_with_random,
)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -38,19 +41,15 @@ def main(
current_platform.seed_everything(seed) current_platform.seed_everything(seed)
scale = float(1.0 / (head_size**0.5)) scale = float(1.0 / (head_size**0.5))
query = torch.empty(num_seqs, query = torch.empty(
num_query_heads, num_seqs, num_query_heads, head_size, dtype=dtype, device=device
head_size, )
dtype=dtype,
device=device)
query.uniform_(-scale, scale) query.uniform_(-scale, scale)
assert num_query_heads % num_kv_heads == 0 assert num_query_heads % num_kv_heads == 0
alibi_slopes = None alibi_slopes = None
if use_alibi: if use_alibi:
alibi_slopes = torch.randn(num_query_heads, alibi_slopes = torch.randn(num_query_heads, dtype=torch.float, device=device)
dtype=torch.float,
device=device)
seq_lens = [seq_len for _ in range(num_seqs)] seq_lens = [seq_len for _ in range(num_seqs)]
max_seq_len = max(seq_lens) max_seq_len = max(seq_lens)
@ -61,24 +60,23 @@ def main(
block_tables_lst: list[list[int]] = [] block_tables_lst: list[list[int]] = []
for _ in range(num_seqs): for _ in range(num_seqs):
block_table = [ block_table = [
random.randint(0, NUM_BLOCKS - 1) random.randint(0, NUM_BLOCKS - 1) for _ in range(max_num_blocks_per_seq)
for _ in range(max_num_blocks_per_seq)
] ]
block_tables_lst.append(block_table) block_tables_lst.append(block_table)
block_tables = torch.tensor(block_tables_lst, block_tables = torch.tensor(block_tables_lst, dtype=torch.int, device=device)
dtype=torch.int,
device=device)
# Create the KV cache. # Create the KV cache.
key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS, key_caches, value_caches = create_kv_caches_with_random(
block_size, NUM_BLOCKS,
1, block_size,
num_kv_heads, 1,
head_size, num_kv_heads,
kv_cache_dtype, head_size,
dtype, kv_cache_dtype,
device=device) dtype,
device=device,
)
key_cache, value_cache = key_caches[0], value_caches[0] key_cache, value_cache = key_caches[0], value_caches[0]
# Prepare for the paged attention kernel. # Prepare for the paged attention kernel.
@ -86,11 +84,8 @@ def main(
if version == "v2": if version == "v2":
if current_platform.is_rocm(): if current_platform.is_rocm():
global PARTITION_SIZE global PARTITION_SIZE
if not args.custom_paged_attn: PARTITION_SIZE = 1024 if not args.custom_paged_attn else PARTITION_SIZE_ROCM
PARTITION_SIZE = 1024 num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE
else:
PARTITION_SIZE = PARTITION_SIZE_ROCM
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
tmp_output = torch.empty( tmp_output = torch.empty(
size=(num_seqs, num_query_heads, num_partitions, head_size), size=(num_seqs, num_query_heads, num_partitions, head_size),
dtype=output.dtype, dtype=output.dtype,
@ -110,9 +105,7 @@ def main(
start_time = time.perf_counter() start_time = time.perf_counter()
# Using default kv_scale # Using default kv_scale
k_scale = v_scale = torch.tensor(1.0, k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
dtype=torch.float32,
device=device)
for _ in range(num_iters): for _ in range(num_iters):
if version == "v1": if version == "v1":
@ -195,30 +188,29 @@ def main(
print(f"Kernel running time: {latency * 1000000:.3f} us") print(f"Kernel running time: {latency * 1000000:.3f} us")
if __name__ == '__main__': if __name__ == "__main__":
logger.warning("This script benchmarks the paged attention kernel. " logger.warning(
"By default this is no longer used in vLLM inference.") "This script benchmarks the paged attention kernel. "
"By default this is no longer used in vLLM inference."
)
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(description="Benchmark the paged attention kernel.")
description="Benchmark the paged attention kernel.") parser.add_argument("--version", type=str, choices=["v1", "v2"], default="v2")
parser.add_argument("--version",
type=str,
choices=["v1", "v2"],
default="v2")
parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--seq-len", type=int, default=4096) parser.add_argument("--seq-len", type=int, default=4096)
parser.add_argument("--num-query-heads", type=int, default=64) parser.add_argument("--num-query-heads", type=int, default=64)
parser.add_argument("--num-kv-heads", type=int, default=8) parser.add_argument("--num-kv-heads", type=int, default=8)
parser.add_argument("--head-size", parser.add_argument(
type=int, "--head-size",
choices=[64, 80, 96, 112, 120, 128, 192, 256], type=int,
default=128) choices=[64, 80, 96, 112, 120, 128, 192, 256],
default=128,
)
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16) parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
parser.add_argument("--use-alibi", action="store_true") parser.add_argument("--use-alibi", action="store_true")
parser.add_argument("--dtype", parser.add_argument(
type=str, "--dtype", type=str, choices=["half", "bfloat16", "float"], default="half"
choices=["half", "bfloat16", "float"], )
default="half")
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--profile", action="store_true") parser.add_argument("--profile", action="store_true")
parser.add_argument( parser.add_argument(
@ -228,10 +220,11 @@ if __name__ == '__main__':
default="auto", default="auto",
help="Data type for kv cache storage. If 'auto', will use model " help="Data type for kv cache storage. If 'auto', will use model "
"data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. " "data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
"ROCm (AMD GPU) supports fp8 (=fp8_e4m3)") "ROCm (AMD GPU) supports fp8 (=fp8_e4m3)",
parser.add_argument("--custom-paged-attn", )
action="store_true", parser.add_argument(
help="Use custom paged attention") "--custom-paged-attn", action="store_true", help="Use custom paged attention"
)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)

View File

@ -10,15 +10,17 @@ from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
@torch.inference_mode() @torch.inference_mode()
def main(num_tokens: int, def main(
hidden_size: int, num_tokens: int,
static_scale: bool, hidden_size: int,
quant_dtype: torch.dtype, static_scale: bool,
dtype: torch.dtype, quant_dtype: torch.dtype,
seed: int = 0, dtype: torch.dtype,
do_profile: bool = False, seed: int = 0,
num_warmup_iters: int = 5, do_profile: bool = False,
num_iters: int = 100) -> None: num_warmup_iters: int = 5,
num_iters: int = 100,
) -> None:
current_platform.seed_everything(seed) current_platform.seed_everything(seed)
torch.set_default_device("cuda") torch.set_default_device("cuda")
@ -56,7 +58,7 @@ def main(num_tokens: int,
print(f"Kernel running time: {latency * 1000000:.3f} us") print(f"Kernel running time: {latency * 1000000:.3f} us")
if __name__ == '__main__': if __name__ == "__main__":
def to_torch_dtype(dt): def to_torch_dtype(dt):
if dt == "int8": if dt == "int8":
@ -66,37 +68,40 @@ if __name__ == '__main__':
raise ValueError(f"Unsupported dtype: {dt}") raise ValueError(f"Unsupported dtype: {dt}")
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description="Benchmark the quantization (fp8 or int8) kernel.") description="Benchmark the quantization (fp8 or int8) kernel."
)
parser.add_argument("--num-tokens", type=int, default=4096) parser.add_argument("--num-tokens", type=int, default=4096)
parser.add_argument("--hidden-size", type=int, default=8192) parser.add_argument("--hidden-size", type=int, default=8192)
parser.add_argument("--static-scale", action="store_true") parser.add_argument("--static-scale", action="store_true")
parser.add_argument("--quant-dtype", parser.add_argument(
type=str, "--quant-dtype", type=str, choices=["fp8", "int8"], default="int8"
choices=["fp8", "int8"], )
default="int8") parser.add_argument(
parser.add_argument("--dtype", "--dtype", type=str, choices=["half", "bfloat16", "float"], default="half"
type=str, )
choices=["half", "bfloat16", "float"],
default="half")
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--profile", action="store_true") parser.add_argument("--profile", action="store_true")
parser.add_argument("--num-warmup-iters", type=int, default=5) parser.add_argument("--num-warmup-iters", type=int, default=5)
parser.add_argument("--num-iters", parser.add_argument(
type=int, "--num-iters",
default=100, type=int,
help="Number of benchmark iterations. " default=100,
"If --profile is set, this number is ignored") help="Number of benchmark iterations. "
"If --profile is set, this number is ignored",
)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
main(num_tokens=args.num_tokens, main(
hidden_size=args.hidden_size, num_tokens=args.num_tokens,
static_scale=args.static_scale, hidden_size=args.hidden_size,
quant_dtype=to_torch_dtype(args.quant_dtype), static_scale=args.static_scale,
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], quant_dtype=to_torch_dtype(args.quant_dtype),
seed=args.seed, dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
do_profile=args.profile, seed=args.seed,
num_warmup_iters=args.num_warmup_iters, do_profile=args.profile,
num_iters=args.num_iters) num_warmup_iters=args.num_warmup_iters,
num_iters=args.num_iters,
)

View File

@ -4,15 +4,14 @@ import itertools
from typing import Optional, Union from typing import Optional, Union
import torch import torch
import triton
from flashinfer.norm import fused_add_rmsnorm, rmsnorm from flashinfer.norm import fused_add_rmsnorm, rmsnorm
from torch import nn from torch import nn
from vllm import _custom_ops as vllm_ops from vllm import _custom_ops as vllm_ops
from vllm.triton_utils import triton
class HuggingFaceRMSNorm(nn.Module): class HuggingFaceRMSNorm(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
super().__init__() super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size)) self.weight = nn.Parameter(torch.ones(hidden_size))
@ -114,23 +113,19 @@ def rmsnorm_vllm(
def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True): def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
dtype = torch.bfloat16 dtype = torch.bfloat16
x = torch.randn(batch_size, x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
seq_len,
hidden_size,
dtype=dtype,
device="cuda")
weight = torch.ones(hidden_size, dtype=dtype, device="cuda") weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
residual = torch.randn_like(x) if use_residual else None residual = torch.randn_like(x) if use_residual else None
output_naive = rmsnorm_naive( output_naive = rmsnorm_naive(
x.clone(), weight, x.clone(), weight, residual.clone() if residual is not None else None
residual.clone() if residual is not None else None) )
output_flashinfer = rmsnorm_flashinfer( output_flashinfer = rmsnorm_flashinfer(
x.clone(), weight, x.clone(), weight, residual.clone() if residual is not None else None
residual.clone() if residual is not None else None) )
output_vllm = rmsnorm_vllm( output_vllm = rmsnorm_vllm(
x.clone(), weight, x.clone(), weight, residual.clone() if residual is not None else None
residual.clone() if residual is not None else None) )
if use_residual: if use_residual:
output_naive = output_naive[0] output_naive = output_naive[0]
@ -141,9 +136,9 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
print(f"FlashInfer output={output_flashinfer}") print(f"FlashInfer output={output_flashinfer}")
print(f"vLLM output={output_vllm}") print(f"vLLM output={output_vllm}")
if torch.allclose(output_naive, output_flashinfer, atol=1e-2, if torch.allclose(
rtol=1e-2) and torch.allclose( output_naive, output_flashinfer, atol=1e-2, rtol=1e-2
output_naive, output_vllm, atol=1e-2, rtol=1e-2): ) and torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2):
print("✅ All implementations match") print("✅ All implementations match")
else: else:
print("❌ Implementations differ") print("❌ Implementations differ")
@ -152,12 +147,10 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
batch_size_range = [2**i for i in range(0, 7, 2)] batch_size_range = [2**i for i in range(0, 7, 2)]
seq_length_range = [2**i for i in range(6, 11, 1)] seq_length_range = [2**i for i in range(6, 11, 1)]
head_num_range = [32, 48] head_num_range = [32, 48]
configs = list( configs = list(itertools.product(head_num_range, batch_size_range, seq_length_range))
itertools.product(head_num_range, batch_size_range, seq_length_range))
def get_benchmark(use_residual): def get_benchmark(use_residual):
@triton.testing.perf_report( @triton.testing.perf_report(
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=["head_num", "batch_size", "seq_len"], x_names=["head_num", "batch_size", "seq_len"],
@ -167,19 +160,15 @@ def get_benchmark(use_residual):
line_names=["HuggingFace", "FlashInfer", "vLLM"], line_names=["HuggingFace", "FlashInfer", "vLLM"],
styles=[("blue", "-"), ("green", "-"), ("red", "-")], styles=[("blue", "-"), ("green", "-"), ("red", "-")],
ylabel="us", ylabel="us",
plot_name= plot_name=f"rmsnorm-perf-{'with' if use_residual else 'without'}-residual",
f"rmsnorm-perf-{'with' if use_residual else 'without'}-residual",
args={}, args={},
)) )
)
def benchmark(head_num, batch_size, seq_len, provider): def benchmark(head_num, batch_size, seq_len, provider):
dtype = torch.bfloat16 dtype = torch.bfloat16
hidden_size = head_num * 128 # assuming head_dim = 128 hidden_size = head_num * 128 # assuming head_dim = 128
x = torch.randn(batch_size, x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
seq_len,
hidden_size,
dtype=dtype,
device="cuda")
weight = torch.ones(hidden_size, dtype=dtype, device="cuda") weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
residual = torch.randn_like(x) if use_residual else None residual = torch.randn_like(x) if use_residual else None
@ -240,9 +229,9 @@ if __name__ == "__main__":
default=4096, default=4096,
help="Hidden size (2nd dimension) of the sequence", help="Hidden size (2nd dimension) of the sequence",
) )
parser.add_argument("--use-residual", parser.add_argument(
action="store_true", "--use-residual", action="store_true", help="Whether to use residual connection"
help="Whether to use residual connection") )
parser.add_argument( parser.add_argument(
"--save-path", "--save-path",
type=str, type=str,
@ -253,10 +242,12 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
# Run correctness test # Run correctness test
calculate_diff(batch_size=args.batch_size, calculate_diff(
seq_len=args.seq_len, batch_size=args.batch_size,
hidden_size=args.hidden_size, seq_len=args.seq_len,
use_residual=args.use_residual) hidden_size=args.hidden_size,
use_residual=args.use_residual,
)
# Get the benchmark function with proper use_residual setting # Get the benchmark function with proper use_residual setting
benchmark = get_benchmark(args.use_residual) benchmark = get_benchmark(args.use_residual)

View File

@ -6,8 +6,7 @@ from typing import Optional
import nvtx import nvtx
import torch import torch
from vllm.model_executor.layers.rotary_embedding import (RotaryEmbedding, from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding, get_rope
get_rope)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
@ -32,40 +31,49 @@ def benchmark_rope_kernels_multi_lora(
# silulating serving 4 LoRAs # silulating serving 4 LoRAs
scaling_factors = [1, 2, 4, 8] scaling_factors = [1, 2, 4, 8]
# batched RoPE can take multiple scaling factors # batched RoPE can take multiple scaling factors
batched_rope = get_rope(head_size, rotary_dim, max_position, base, batched_rope = get_rope(
is_neox_style, { head_size,
"rope_type": "linear", rotary_dim,
"factor": tuple(scaling_factors) max_position,
}) base,
is_neox_style,
{"rope_type": "linear", "factor": tuple(scaling_factors)},
)
# non-batched RoPE takes only one scaling factor, we create multiple # non-batched RoPE takes only one scaling factor, we create multiple
# instances to simulate the same behavior # instances to simulate the same behavior
non_batched_ropes: list[RotaryEmbedding] = [] non_batched_ropes: list[RotaryEmbedding] = []
for scaling_factor in scaling_factors: for scaling_factor in scaling_factors:
non_batched_ropes.append( non_batched_ropes.append(
get_rope(head_size, rotary_dim, max_position, base, is_neox_style, get_rope(
{ head_size,
"rope_type": "linear", rotary_dim,
"factor": (scaling_factor, ) max_position,
})) base,
is_neox_style,
{"rope_type": "linear", "factor": (scaling_factor,)},
)
)
positions = torch.randint(0, max_position, (batch_size, seq_len)) positions = torch.randint(0, max_position, (batch_size, seq_len))
query = torch.randn(batch_size, query = torch.randn(batch_size, seq_len, num_heads * head_size, dtype=dtype)
seq_len,
num_heads * head_size,
dtype=dtype)
key = torch.randn_like(query) key = torch.randn_like(query)
# create query offsets for batched RoPE, we concat multiple kv cache # create query offsets for batched RoPE, we concat multiple kv cache
# together and each query needs to find the right kv cache of its type # together and each query needs to find the right kv cache of its type
offset_map = torch.tensor( offset_map = torch.tensor(
list( list(
accumulate([0] + [ accumulate(
max_position * scaling_factor * 2 [0]
for scaling_factor in scaling_factors[:-1] + [
]))) max_position * scaling_factor * 2
query_types = torch.randint(0, for scaling_factor in scaling_factors[:-1]
len(scaling_factors), (batch_size, seq_len), ]
device=device) )
)
)
query_types = torch.randint(
0, len(scaling_factors), (batch_size, seq_len), device=device
)
# map query types to offsets # map query types to offsets
query_offsets = offset_map[query_types] query_offsets = offset_map[query_types]
# the kernel takes flattened offsets # the kernel takes flattened offsets
@ -86,27 +94,28 @@ def benchmark_rope_kernels_multi_lora(
torch.cuda.synchronize() torch.cuda.synchronize()
if __name__ == '__main__': if __name__ == "__main__":
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description="Benchmark the rotary embedding kernels.") description="Benchmark the rotary embedding kernels."
)
parser.add_argument("--is-neox-style", type=bool, default=True) parser.add_argument("--is-neox-style", type=bool, default=True)
parser.add_argument("--batch-size", type=int, default=16) parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--seq-len", type=int, default=512) parser.add_argument("--seq-len", type=int, default=512)
parser.add_argument("--num-heads", type=int, default=8) parser.add_argument("--num-heads", type=int, default=8)
parser.add_argument("--head-size", parser.add_argument(
type=int, "--head-size",
choices=[64, 80, 96, 112, 120, 128, 192, 256], type=int,
default=128) choices=[64, 80, 96, 112, 120, 128, 192, 256],
default=128,
)
parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32) parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
parser.add_argument("--dtype", parser.add_argument(
type=str, "--dtype", type=str, choices=["bfloat16", "float"], default="float"
choices=["bfloat16", "float"], )
default="float")
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--device", parser.add_argument(
type=str, "--device", type=str, choices=["cuda:0", "cuda:1"], default="cuda:0"
choices=["cuda:0", "cuda:1"], )
default="cuda:0")
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)

View File

@ -14,14 +14,16 @@ import tqdm
import triton import triton
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
_w8a8_block_fp8_matmul) _w8a8_block_fp8_matmul,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
mp.set_start_method("spawn", force=True) mp.set_start_method("spawn", force=True)
assert current_platform.is_cuda( assert current_platform.is_cuda(), (
), "Only support tune w8a8 block fp8 kernel on CUDA device." "Only support tune w8a8 block fp8 kernel on CUDA device."
)
DTYPE_MAP = { DTYPE_MAP = {
"float32": torch.float32, "float32": torch.float32,
@ -71,18 +73,18 @@ def w8a8_block_matmul(
assert triton.cdiv(N, block_n) == Bs.shape[0] assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1] assert triton.cdiv(K, block_k) == Bs.shape[1]
C_shape = A.shape[:-1] + (N, ) C_shape = A.shape[:-1] + (N,)
C = A.new_empty(C_shape, dtype=output_dtype) C = A.new_empty(C_shape, dtype=output_dtype)
def grid(META): def grid(META):
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * return (
triton.cdiv(N, META["BLOCK_SIZE_N"]), ) triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
if A.dtype == torch.float8_e4m3fn: if A.dtype == torch.float8_e4m3fn:
kernel = _w8a8_block_fp8_matmul kernel = _w8a8_block_fp8_matmul
else: else:
raise RuntimeError( raise RuntimeError("Currently, only support tune w8a8 block fp8 kernel.")
"Currently, only support tune w8a8 block fp8 kernel.")
kernel[grid]( kernel[grid](
A, A,
@ -119,14 +121,16 @@ def get_configs_compute_bound():
for block_n in [32, 64, 128, 256]: for block_n in [32, 64, 128, 256]:
for num_warps in [4, 8]: for num_warps in [4, 8]:
for group_size in [1, 16, 32, 64]: for group_size in [1, 16, 32, 64]:
configs.append({ configs.append(
"BLOCK_SIZE_M": block_m, {
"BLOCK_SIZE_N": block_n, "BLOCK_SIZE_M": block_m,
"BLOCK_SIZE_K": block_k, "BLOCK_SIZE_N": block_n,
"GROUP_SIZE_M": group_size, "BLOCK_SIZE_K": block_k,
"num_warps": num_warps, "GROUP_SIZE_M": group_size,
"num_stages": num_stages, "num_warps": num_warps,
}) "num_stages": num_stages,
}
)
return configs return configs
@ -165,15 +169,9 @@ def get_weight_shapes(tp_size):
return weight_shapes return weight_shapes
def benchmark_config(A, def benchmark_config(
B, A, B, As, Bs, block_size, config, out_dtype=torch.float16, num_iters=10
As, ):
Bs,
block_size,
config,
out_dtype=torch.float16,
num_iters=10):
def run(): def run():
w8a8_block_matmul(A, B, As, Bs, block_size, config, out_dtype) w8a8_block_matmul(A, B, As, Bs, block_size, config, out_dtype)
@ -206,26 +204,26 @@ def tune(M, N, K, block_size, out_dtype, search_space, input_type):
fp8_max, fp8_min = fp8_info.max, fp8_info.min fp8_max, fp8_min = fp8_info.max, fp8_info.min
A_fp32 = ( A_fp32 = (
(torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * (torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
fp8_max) )
A = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) A = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
B_fp32 = ( B_fp32 = (
(torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
fp8_max) )
B = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) B = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
else: else:
raise RuntimeError( raise RuntimeError("Currently, only support tune w8a8 block fp8 kernel.")
"Currently, only support tune w8a8 block fp8 kernel.")
block_n, block_k = block_size[0], block_size[1] block_n, block_k = block_size[0], block_size[1]
n_tiles = (N + block_n - 1) // block_n n_tiles = (N + block_n - 1) // block_n
k_tiles = (K + block_k - 1) // block_k k_tiles = (K + block_k - 1) // block_k
As = torch.rand(M, k_tiles, dtype=torch.float32, As = torch.rand(M, k_tiles, dtype=torch.float32, device="cuda") * factor_for_scale
device="cuda") * factor_for_scale Bs = (
Bs = (torch.rand(n_tiles, k_tiles, dtype=torch.float32, device="cuda") * torch.rand(n_tiles, k_tiles, dtype=torch.float32, device="cuda")
factor_for_scale) * factor_for_scale
)
best_config = None best_config = None
best_time = float("inf") best_time = float("inf")
@ -267,7 +265,8 @@ def save_configs(
device_name = current_platform.get_device_name().replace(" ", "_") device_name = current_platform.get_device_name().replace(" ", "_")
json_file_name = ( json_file_name = (
f"N={N},K={K},device_name={device_name},dtype={input_type}_w8a8," f"N={N},K={K},device_name={device_name},dtype={input_type}_w8a8,"
f"block_shape=[{block_n},{block_k}].json") f"block_shape=[{block_n},{block_k}].json"
)
config_file_path = os.path.join(save_path, json_file_name) config_file_path = os.path.join(save_path, json_file_name)
print(f"Writing best config to {config_file_path}...") print(f"Writing best config to {config_file_path}...")
@ -295,8 +294,7 @@ def tune_on_gpu(args_dict):
search_space = get_configs_compute_bound() search_space = get_configs_compute_bound()
search_space = [ search_space = [
config for config in search_space config for config in search_space if block_k % config["BLOCK_SIZE_K"] == 0
if block_k % config["BLOCK_SIZE_K"] == 0
] ]
start = time.time() start = time.time()
@ -312,15 +310,11 @@ def tune_on_gpu(args_dict):
out_dtype, out_dtype,
search_space, search_space,
input_type, input_type,
) for batch_size in tqdm(batch_sizes, )
desc=f"GPU {gpu_id} - Batch sizes") for batch_size in tqdm(batch_sizes, desc=f"GPU {gpu_id} - Batch sizes")
] ]
best_configs = { best_configs = {M: config for M, config in zip(batch_sizes, benchmark_results)}
M: config save_configs(N, K, block_n, block_k, best_configs, save_path, input_type)
for M, config in zip(batch_sizes, benchmark_results)
}
save_configs(N, K, block_n, block_k, best_configs, save_path,
input_type)
end = time.time() end = time.time()
print(f"Tuning on GPU {gpu_id} took {end - start:.2f} seconds") print(f"Tuning on GPU {gpu_id} took {end - start:.2f} seconds")
@ -376,13 +370,14 @@ def main(args):
process_args = [] process_args = []
for gpu_id in range(num_gpus): for gpu_id in range(num_gpus):
process_args.append({ process_args.append(
"gpu_id": gpu_id, {
"batch_sizes": batches_per_gpu[gpu_id], "gpu_id": gpu_id,
"weight_shapes": "batch_sizes": batches_per_gpu[gpu_id],
weight_shapes, # Each GPU processes all weight shapes "weight_shapes": weight_shapes, # Each GPU processes all weight shapes
"args": args, "args": args,
}) }
)
ctx = mp.get_context("spawn") ctx = mp.get_context("spawn")
with ctx.Pool(num_gpus) as pool: with ctx.Pool(num_gpus) as pool:
@ -398,13 +393,11 @@ Tune triton w8a8 block fp8 for DeepSeek-V3/DeepSeek-R1:
python3 benchmark_w8a8_block_fp8.py --tp-size 8 --input-type fp8 python3 benchmark_w8a8_block_fp8.py --tp-size 8 --input-type fp8
Then copy to model_executor/layers/quantization/utils/configs Then copy to model_executor/layers/quantization/utils/configs
""", """,
formatter_class=argparse.RawTextHelpFormatter) formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument("--tp-size", "-tp", type=int, default=8) parser.add_argument("--tp-size", "-tp", type=int, default=8)
parser.add_argument("--input-type", parser.add_argument("--input-type", type=str, choices=["fp8"], default="fp8")
type=str,
choices=["fp8"],
default="fp8")
parser.add_argument( parser.add_argument(
"--out-dtype", "--out-dtype",
type=str, type=str,

View File

@ -6,13 +6,15 @@ import time
# Import DeepGEMM functions # Import DeepGEMM functions
import deep_gemm import deep_gemm
import torch import torch
import triton
from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor
# Import vLLM functions # Import vLLM functions
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8, w8a8_block_fp8_matmul) per_token_group_quant_fp8,
w8a8_block_fp8_matmul,
)
from vllm.triton_utils import triton
# Copied from # Copied from

View File

@ -14,13 +14,14 @@ from vllm.utils import FlexibleArgumentParser
if __name__ == "__main__": if __name__ == "__main__":
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description='Benchmark the latency of processing a single batch of ' description="Benchmark the latency of processing a single batch of "
'requests till completion.') "requests till completion."
parser.add_argument('filename', type=str) )
parser.add_argument("filename", type=str)
args = parser.parse_args() args = parser.parse_args()
with open(args.filename, 'rb') as f: with open(args.filename, "rb") as f:
data = pickle.load(f) data = pickle.load(f)
raw_results: list[TMeasurement] = data["results"] raw_results: list[TMeasurement] = data["results"]
@ -38,11 +39,7 @@ if __name__ == "__main__":
raise Exception("MKN not found") raise Exception("MKN not found")
kernel = v.task_spec.description kernel = v.task_spec.description
results[KN].append({ results[KN].append({"kernel": kernel, "batch_size": M, "median": v.median})
"kernel": kernel,
"batch_size": M,
"median": v.median
})
rows = int(math.ceil(len(results) / 2)) rows = int(math.ceil(len(results) / 2))
fig, axs = plt.subplots(rows, 2, figsize=(12, 5 * rows)) fig, axs = plt.subplots(rows, 2, figsize=(12, 5 * rows))
@ -50,14 +47,16 @@ if __name__ == "__main__":
for axs_idx, (shape, data) in enumerate(results.items()): for axs_idx, (shape, data) in enumerate(results.items()):
plt.sca(axs[axs_idx]) plt.sca(axs[axs_idx])
df = pd.DataFrame(data) df = pd.DataFrame(data)
sns.lineplot(data=df, sns.lineplot(
x="batch_size", data=df,
y="median", x="batch_size",
hue="kernel", y="median",
style="kernel", hue="kernel",
markers=True, style="kernel",
dashes=False, markers=True,
palette="Dark2") dashes=False,
palette="Dark2",
)
plt.title(f"Shape: {shape}") plt.title(f"Shape: {shape}")
plt.ylabel("time (median, s)") plt.ylabel("time (median, s)")
plt.tight_layout() plt.tight_layout()

View File

@ -23,6 +23,7 @@ class ArgPool:
For every invocation during a benchmarking run, it will choose a For every invocation during a benchmarking run, it will choose a
different value from the list. different value from the list.
""" """
values: Iterable[Any] values: Iterable[Any]
def __getitem__(self, index): def __getitem__(self, index):
@ -30,9 +31,7 @@ class ArgPool:
class Bench: class Bench:
class ArgsIterator: class ArgsIterator:
def __init__(self, args_list, kwargs_list): def __init__(self, args_list, kwargs_list):
assert len(args_list) == len(kwargs_list) assert len(args_list) == len(kwargs_list)
self.args_list = args_list self.args_list = args_list
@ -53,10 +52,16 @@ class Bench:
def n_args(self): def n_args(self):
return self.n return self.n
def __init__(self, cuda_graph_params: Optional[CudaGraphBenchParams], def __init__(
label: str, sub_label: str, description: str, fn: Callable, self,
*args, **kwargs): cuda_graph_params: Optional[CudaGraphBenchParams],
label: str,
sub_label: str,
description: str,
fn: Callable,
*args,
**kwargs,
):
self.cuda_graph_params = cuda_graph_params self.cuda_graph_params = cuda_graph_params
self.use_cuda_graph = self.cuda_graph_params is not None self.use_cuda_graph = self.cuda_graph_params is not None
self.label = label self.label = label
@ -67,10 +72,8 @@ class Bench:
# Process args # Process args
self._args = args self._args = args
self._kwargs = kwargs self._kwargs = kwargs
self.args_list, self.kwargs_list = self.collapse_argpool( self.args_list, self.kwargs_list = self.collapse_argpool(*args, **kwargs)
*args, **kwargs) self.args_iterator = self.ArgsIterator(self.args_list, self.kwargs_list)
self.args_iterator = self.ArgsIterator(self.args_list,
self.kwargs_list)
# Cudagraph runner # Cudagraph runner
self.g = None self.g = None
@ -100,16 +103,13 @@ class Bench:
for i in range(argpool_size): for i in range(argpool_size):
# collapse args; Just pick the ith value # collapse args; Just pick the ith value
args_list[i] = tuple([ args_list[i] = tuple(
arg[i] if isinstance(arg, ArgPool) else arg [arg[i] if isinstance(arg, ArgPool) else arg for arg in args_list[i]]
for arg in args_list[i] )
])
# collapse kwargs # collapse kwargs
kwargs_i = kwargs_list[i] kwargs_i = kwargs_list[i]
arg_pool_keys = [ arg_pool_keys = [k for k, v in kwargs_i.items() if isinstance(v, ArgPool)]
k for k, v in kwargs_i.items() if isinstance(v, ArgPool)
]
for k in arg_pool_keys: for k in arg_pool_keys:
# again just pick the ith value # again just pick the ith value
kwargs_i[k] = kwargs_i[k][i] kwargs_i[k] = kwargs_i[k][i]
@ -142,7 +142,7 @@ class Bench:
def run_cudagrah(self) -> TMeasurement: def run_cudagrah(self) -> TMeasurement:
assert self.use_cuda_graph assert self.use_cuda_graph
globals = {'g': self.g} globals = {"g": self.g}
return TBenchmark.Timer( return TBenchmark.Timer(
stmt="g.replay()", stmt="g.replay()",
@ -162,15 +162,15 @@ class Bench:
has_arg_pool = self.args_iterator.n_args > 1 has_arg_pool = self.args_iterator.n_args > 1
if has_arg_pool: if has_arg_pool:
setup = ''' setup = """
args_iterator.reset() args_iterator.reset()
args_it = args_iterator.__next__() args_it = args_iterator.__next__()
''' """
stmt = ''' stmt = """
args, kwargs = next(args_it) args, kwargs = next(args_it)
fn(*args, **kwargs) fn(*args, **kwargs)
''' """
globals = {'fn': self.fn, 'args_iterator': self.args_iterator} globals = {"fn": self.fn, "args_iterator": self.args_iterator}
else: else:
# no arg pool. Just use the args and kwargs directly # no arg pool. Just use the args and kwargs directly
self.args_iterator.reset() self.args_iterator.reset()
@ -178,10 +178,10 @@ class Bench:
args, kwargs = next(args_it) args, kwargs = next(args_it)
setup = "" setup = ""
stmt = ''' stmt = """
fn(*args, **kwargs) fn(*args, **kwargs)
''' """
globals = {'fn': self.fn, 'args': args, 'kwargs': kwargs} globals = {"fn": self.fn, "args": args, "kwargs": kwargs}
return TBenchmark.Timer( return TBenchmark.Timer(
stmt=stmt, stmt=stmt,

View File

@ -7,9 +7,8 @@ from vllm import LLM, SamplingParams
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
# A very long prompt, total number of tokens is about 15k. # A very long prompt, total number of tokens is about 15k.
LONG_PROMPT = ["You are an expert in large language models, aren't you?" LONG_PROMPT = ["You are an expert in large language models, aren't you?"] * 1000
] * 1000 LONG_PROMPT = " ".join(LONG_PROMPT)
LONG_PROMPT = ' '.join(LONG_PROMPT)
def main(args): def main(args):
@ -30,32 +29,35 @@ def main(args):
print("------start generating------") print("------start generating------")
for i in range(3): for i in range(3):
profiler.runctx('llm.generate(LONG_PROMPT, sampling_params)', profiler.runctx(
globals(), locals()) "llm.generate(LONG_PROMPT, sampling_params)", globals(), locals()
)
# analyze the runtime of hashing function # analyze the runtime of hashing function
stats = pstats.Stats(profiler) stats = pstats.Stats(profiler)
stats.sort_stats('cumulative') stats.sort_stats("cumulative")
total_time = 0 total_time = 0
total_calls = 0 total_calls = 0
for func in stats.stats: for func in stats.stats:
if 'hash_of_block' in func[2]: if "hash_of_block" in func[2]:
total_time = stats.stats[func][3] total_time = stats.stats[func][3]
total_calls = stats.stats[func][0] total_calls = stats.stats[func][0]
percentage = (total_time / stats.total_tt) * 100 percentage = (total_time / stats.total_tt) * 100
print(f"Hashing took {total_time:.2f} seconds," print(
f"{percentage:.2f}% of the total runtime.") f"Hashing took {total_time:.2f} seconds,{percentage:.2f}% of the total runtime."
)
if __name__ == "__main__": if __name__ == "__main__":
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description='Benchmark the performance of hashing function in' description="Benchmark the performance of hashing function in"
'automatic prefix caching.') "automatic prefix caching."
parser.add_argument('--model', type=str, default='lmsys/longchat-7b-16k') )
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) parser.add_argument("--model", type=str, default="lmsys/longchat-7b-16k")
parser.add_argument('--output-len', type=int, default=10) parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
parser.add_argument('--enable-prefix-caching', parser.add_argument("--output-len", type=int, default=10)
action='store_true', parser.add_argument(
help='enable prefix caching') "--enable-prefix-caching", action="store_true", help="enable prefix caching"
)
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

54
benchmarks/pyproject.toml Normal file
View File

@ -0,0 +1,54 @@
# This local pyproject file is part of the migration from yapf to ruff format.
# It uses the same core rules as the main pyproject.toml file, but with the
# following differences:
# - ruff line length is overridden to 88
# - deprecated typing ignores (UP006, UP035) have been removed
[tool.ruff]
line-length = 88
exclude = [
# External file, leaving license intact
"examples/other/fp8/quantizer/quantize.py",
"vllm/vllm_flash_attn/flash_attn_interface.pyi"
]
[tool.ruff.lint.per-file-ignores]
"vllm/third_party/**" = ["ALL"]
"vllm/version.py" = ["F401"]
"vllm/_version.py" = ["ALL"]
[tool.ruff.lint]
select = [
# pycodestyle
"E",
# Pyflakes
"F",
# pyupgrade
"UP",
# flake8-bugbear
"B",
# flake8-simplify
"SIM",
# isort
"I",
# flake8-logging-format
"G",
]
ignore = [
# star imports
"F405", "F403",
# lambda expression assignment
"E731",
# Loop control variable not used within loop body
"B007",
# f-string format
"UP032",
# Can remove once 3.10+ is the minimum Python version
"UP007",
]
[tool.ruff.lint.isort]
known-first-party = ["vllm"]
[tool.ruff.format]
docstring-code-format = true

View File

@ -1,41 +1,102 @@
#!/bin/bash #!/bin/bash
# Define the model to use # default values
MODEL=${1:-"Qwen/Qwen2.5-7B-Instruct"} MODEL=${MODEL:-"Qwen/Qwen2.5-7B-Instruct"}
BACKEND=${BACKEND:-"vllm"}
# Define the backend to use DATASET=${DATASET:-"xgrammar_bench"}
BACKEND=${2:-"vllm"}
# Define the dataset to use
DATASET=${3:-"xgrammar_bench"}
# Define the guided decoding backend
GUIDED_BACKEND=${4:-"xgrammar"}
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
OUTPUT_DIR=${5:-"$SCRIPT_DIR/structured_output_benchmark_results"} OUTPUT_DIR=${OUTPUT_DIR:-"$SCRIPT_DIR/structured_output_benchmark_results"}
PORT=${PORT:-8000}
STRUCTURED_OUTPUT_RATIO=${STRUCTURED_OUTPUT_RATIO:-1}
TOTAL_SECONDS=${TOTAL_SECONDS:-90}
MAX_NEW_TOKENS=${MAX_NEW_TOKENS:-300}
TOKENIZER_MODE=${TOKENIZER_MODE:-"auto"}
GUIDED_RATIO=${6:-0.5} usage() {
echo "Usage: $0 [options]"
echo "Options:"
echo " --model MODEL Model to benchmark (default: $MODEL)"
echo " --backend BACKEND Backend to use (default: $BACKEND)"
echo " --dataset DATASET Dataset to use (default: $DATASET)"
echo " --max-new-tokens N Maximum number of tokens to generate (default: $MAX_NEW_TOKENS)"
echo " --output-dir DIR Output directory for results (default: $OUTPUT_DIR)"
echo " --port PORT Port to use (default: $PORT)"
echo " --structured-output-ratio N Ratio of structured outputs (default: $STRUCTURED_OUTPUT_RATIO)"
echo " --tokenizer-mode MODE Tokenizer mode to use (default: $TOKENIZER_MODE)"
echo " --total-seconds N Total seconds to run the benchmark (default: $TOTAL_SECONDS)"
echo " -h, --help Show this help message and exit"
exit 0
}
# parse command line arguments
while [[ $# -gt 0 ]]; do
case $1 in
--model)
MODEL="$2"
shift 2
;;
--backend)
BACKEND="$2"
shift 2
;;
--dataset)
DATASET="$2"
shift 2
;;
--max-new-tokens)
MAX_NEW_TOKENS="$2"
shift 2
;;
--output-dir)
OUTPUT_DIR="$2"
shift 2
;;
--port)
PORT="$2"
shift 2
;;
--structured-output-ratio)
STRUCTURED_OUTPUT_RATIO="$2"
shift 2
;;
--tokenizer-mode)
TOKENIZER_MODE="$2"
shift 2
;;
--total-seconds)
TOTAL_SECONDS="$2"
shift 2
;;
-h|--help)
usage
;;
*)
echo "Unknown argument: $1\n"
usage
;;
esac
done
# Create output directory if it doesn't exist # Create output directory if it doesn't exist
mkdir -p "$OUTPUT_DIR" mkdir -p "$OUTPUT_DIR"
# Define QPS values to test # Define QPS values to test
QPS_VALUES=(70 60 50 25 20 15 10) QPS_VALUES=(25 20 15 10 5 1)
# Common parameters # Common parameters
COMMON_PARAMS="--backend $BACKEND \ COMMON_PARAMS="--backend $BACKEND \
--model $MODEL \ --model $MODEL \
--dataset $DATASET \ --dataset $DATASET \
--structured-output-backend $GUIDED_BACKEND \ --structured-output-ratio $STRUCTURED_OUTPUT_RATIO \
--structured-output-ratio $GUIDED_RATIO \
--save-results \ --save-results \
--result-dir $OUTPUT_DIR" --result-dir $OUTPUT_DIR \
--output-len $MAX_NEW_TOKENS \
--port $PORT \
--tokenizer-mode $TOKENIZER_MODE"
echo "Starting structured output benchmark with model: $MODEL" echo "Starting structured output benchmark with model: $MODEL"
echo "Backend: $BACKEND" echo "Backend: $BACKEND"
echo "Dataset: $DATASET" echo "Dataset: $DATASET"
echo "Structured output backend: $GUIDED_BACKEND"
echo "Results will be saved to: $OUTPUT_DIR" echo "Results will be saved to: $OUTPUT_DIR"
echo "----------------------------------------" echo "----------------------------------------"
@ -48,14 +109,17 @@ for qps in "${QPS_VALUES[@]}"; do
GIT_BRANCH=$(git rev-parse --abbrev-ref HEAD 2>/dev/null || echo "unknown") GIT_BRANCH=$(git rev-parse --abbrev-ref HEAD 2>/dev/null || echo "unknown")
# Construct filename for this run # Construct filename for this run
FILENAME="${GUIDED_BACKEND}_${BACKEND}_${qps}qps_$(basename $MODEL)_${DATASET}_${GIT_HASH}.json" FILENAME="${BACKEND}_${qps}qps_$(basename $MODEL)_${DATASET}_${GIT_HASH}.json"
NUM_PROMPTS=$(echo "$TOTAL_SECONDS * $qps" | bc)
NUM_PROMPTS=${NUM_PROMPTS%.*} # Remove fractional part
echo "Running benchmark with $NUM_PROMPTS prompts"
# Run the benchmark # Run the benchmark
python "$SCRIPT_DIR/benchmark_serving_structured_output.py" $COMMON_PARAMS \ python "$SCRIPT_DIR/benchmark_serving_structured_output.py" $COMMON_PARAMS \
--request-rate $qps \ --request-rate $qps \
--result-filename "$FILENAME" \ --result-filename "$FILENAME" \
--tokenizer-mode ${TOKENIZER_MODE:-"auto"} \ --num-prompts $NUM_PROMPTS
--port ${PORT:-8000}
echo "Completed benchmark with QPS: $qps" echo "Completed benchmark with QPS: $qps"
echo "----------------------------------------" echo "----------------------------------------"

View File

@ -167,6 +167,33 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
FetchContent_MakeAvailable(oneDNN) FetchContent_MakeAvailable(oneDNN)
list(APPEND LIBS dnnl)
elseif(POWER10_FOUND)
FetchContent_Declare(
oneDNN
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
GIT_TAG v3.7.2
GIT_PROGRESS TRUE
GIT_SHALLOW TRUE
)
set(ONEDNN_LIBRARY_TYPE "STATIC")
set(ONEDNN_BUILD_DOC "OFF")
set(ONEDNN_BUILD_EXAMPLES "OFF")
set(ONEDNN_BUILD_TESTS "OFF")
set(ONEDNN_ENABLE_WORKLOAD "INFERENCE")
set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER")
set(ONEDNN_BUILD_GRAPH "OFF")
set(ONEDNN_ENABLE_JIT_PROFILING "OFF")
set(ONEDNN_ENABLE_ITT_TASKS "OFF")
set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF")
set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF")
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
set(DNNL_CPU_RUNTIME "OMP")
FetchContent_MakeAvailable(oneDNN)
list(APPEND LIBS dnnl) list(APPEND LIBS dnnl)
endif() endif()
@ -197,6 +224,10 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
"csrc/cpu/quant.cpp" "csrc/cpu/quant.cpp"
"csrc/cpu/shm.cpp" "csrc/cpu/shm.cpp"
${VLLM_EXT_SRC}) ${VLLM_EXT_SRC})
elseif(POWER10_FOUND)
set(VLLM_EXT_SRC
"csrc/cpu/quant.cpp"
${VLLM_EXT_SRC})
endif() endif()
# #

View File

@ -228,11 +228,26 @@ macro(set_gencode_flags_for_srcs)
"${multiValueArgs}" ${ARGN} ) "${multiValueArgs}" ${ARGN} )
foreach(_ARCH ${arg_CUDA_ARCHS}) foreach(_ARCH ${arg_CUDA_ARCHS})
string(REPLACE "." "" _ARCH "${_ARCH}") # handle +PTX suffix: generate both sm and ptx codes if requested
set_gencode_flag_for_srcs( string(FIND "${_ARCH}" "+PTX" _HAS_PTX)
SRCS ${arg_SRCS} if(NOT _HAS_PTX EQUAL -1)
ARCH "compute_${_ARCH}" string(REPLACE "+PTX" "" _BASE_ARCH "${_ARCH}")
CODE "sm_${_ARCH}") string(REPLACE "." "" _STRIPPED_ARCH "${_BASE_ARCH}")
set_gencode_flag_for_srcs(
SRCS ${arg_SRCS}
ARCH "compute_${_STRIPPED_ARCH}"
CODE "sm_${_STRIPPED_ARCH}")
set_gencode_flag_for_srcs(
SRCS ${arg_SRCS}
ARCH "compute_${_STRIPPED_ARCH}"
CODE "compute_${_STRIPPED_ARCH}")
else()
string(REPLACE "." "" _STRIPPED_ARCH "${_ARCH}")
set_gencode_flag_for_srcs(
SRCS ${arg_SRCS}
ARCH "compute_${_STRIPPED_ARCH}"
CODE "sm_${_STRIPPED_ARCH}")
endif()
endforeach() endforeach()
if (${arg_BUILD_PTX_FOR_ARCH}) if (${arg_BUILD_PTX_FOR_ARCH})
@ -251,7 +266,10 @@ endmacro()
# #
# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form # For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
# `<major>.<minor>[letter]` compute the "loose intersection" with the # `<major>.<minor>[letter]` compute the "loose intersection" with the
# `TGT_CUDA_ARCHS` list of gencodes. # `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in
# `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there
# is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the
# architecture in `SRC_CUDA_ARCHS`.
# The loose intersection is defined as: # The loose intersection is defined as:
# { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} } # { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
# where `<=` is the version comparison operator. # where `<=` is the version comparison operator.
@ -268,44 +286,63 @@ endmacro()
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS) # cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
# OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a" # OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a"
# #
# Example With PTX:
# SRC_CUDA_ARCHS="8.0+PTX"
# TGT_CUDA_ARCHS="9.0"
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
# OUT_CUDA_ARCHS="8.0+PTX"
#
function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS) function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
list(REMOVE_DUPLICATES SRC_CUDA_ARCHS) set(_SRC_CUDA_ARCHS "${SRC_CUDA_ARCHS}")
set(TGT_CUDA_ARCHS_ ${TGT_CUDA_ARCHS}) set(_TGT_CUDA_ARCHS ${TGT_CUDA_ARCHS})
# handle +PTX suffix: separate base arch for matching, record PTX requests
set(_PTX_ARCHS)
foreach(_arch ${_SRC_CUDA_ARCHS})
if(_arch MATCHES "\\+PTX$")
string(REPLACE "+PTX" "" _base "${_arch}")
list(APPEND _PTX_ARCHS "${_base}")
list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}")
list(APPEND _SRC_CUDA_ARCHS "${_base}")
endif()
endforeach()
list(REMOVE_DUPLICATES _PTX_ARCHS)
list(REMOVE_DUPLICATES _SRC_CUDA_ARCHS)
# if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should # if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
# remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS # remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS
set(_CUDA_ARCHS) set(_CUDA_ARCHS)
if ("9.0a" IN_LIST SRC_CUDA_ARCHS) if ("9.0a" IN_LIST _SRC_CUDA_ARCHS)
list(REMOVE_ITEM SRC_CUDA_ARCHS "9.0a") list(REMOVE_ITEM _SRC_CUDA_ARCHS "9.0a")
if ("9.0" IN_LIST TGT_CUDA_ARCHS_) if ("9.0" IN_LIST TGT_CUDA_ARCHS)
list(REMOVE_ITEM TGT_CUDA_ARCHS_ "9.0") list(REMOVE_ITEM _TGT_CUDA_ARCHS "9.0")
set(_CUDA_ARCHS "9.0a") set(_CUDA_ARCHS "9.0a")
endif() endif()
endif() endif()
if ("10.0a" IN_LIST SRC_CUDA_ARCHS) if ("10.0a" IN_LIST _SRC_CUDA_ARCHS)
list(REMOVE_ITEM SRC_CUDA_ARCHS "10.0a") list(REMOVE_ITEM _SRC_CUDA_ARCHS "10.0a")
if ("10.0" IN_LIST TGT_CUDA_ARCHS) if ("10.0" IN_LIST TGT_CUDA_ARCHS)
list(REMOVE_ITEM TGT_CUDA_ARCHS_ "10.0") list(REMOVE_ITEM _TGT_CUDA_ARCHS "10.0")
set(_CUDA_ARCHS "10.0a") set(_CUDA_ARCHS "10.0a")
endif() endif()
endif() endif()
list(SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING) list(SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
# for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that # for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that
# is less or equal to ARCH (but has the same major version since SASS binary # is less or equal to ARCH (but has the same major version since SASS binary
# compatibility is only forward compatible within the same major version). # compatibility is only forward compatible within the same major version).
foreach(_ARCH ${TGT_CUDA_ARCHS_}) foreach(_ARCH ${_TGT_CUDA_ARCHS})
set(_TMP_ARCH) set(_TMP_ARCH)
# Extract the major version of the target arch # Extract the major version of the target arch
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" TGT_ARCH_MAJOR "${_ARCH}") string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" TGT_ARCH_MAJOR "${_ARCH}")
foreach(_SRC_ARCH ${SRC_CUDA_ARCHS}) foreach(_SRC_ARCH ${_SRC_CUDA_ARCHS})
# Extract the major version of the source arch # Extract the major version of the source arch
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" SRC_ARCH_MAJOR "${_SRC_ARCH}") string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" SRC_ARCH_MAJOR "${_SRC_ARCH}")
# Check major-version match AND version-less-or-equal # Check version-less-or-equal, and allow PTX arches to match across majors
if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH) if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
if (SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR) if (_SRC_ARCH IN_LIST _PTX_ARCHS OR SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
set(_TMP_ARCH "${_SRC_ARCH}") set(_TMP_ARCH "${_SRC_ARCH}")
endif() endif()
else() else()
@ -321,6 +358,18 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
endforeach() endforeach()
list(REMOVE_DUPLICATES _CUDA_ARCHS) list(REMOVE_DUPLICATES _CUDA_ARCHS)
# reapply +PTX suffix to architectures that requested PTX
set(_FINAL_ARCHS)
foreach(_arch ${_CUDA_ARCHS})
if(_arch IN_LIST _PTX_ARCHS)
list(APPEND _FINAL_ARCHS "${_arch}+PTX")
else()
list(APPEND _FINAL_ARCHS "${_arch}")
endif()
endforeach()
set(_CUDA_ARCHS ${_FINAL_ARCHS})
set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE) set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
endfunction() endfunction()

View File

@ -70,6 +70,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
int64_t num_tokens = input.numel() / input.size(-1); \ int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \ dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \ dim3 block(std::min(d, 1024)); \
if (num_tokens == 0) { \
return; \
} \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \ VLLM_DISPATCH_FLOATING_TYPES( \

View File

@ -172,7 +172,7 @@ __device__ void paged_attention_kernel(
// Load the query to registers. // Load the query to registers.
// Each thread in a thread group has a different part of the query. // Each thread in a thread group has a different part of the query.
// For example, if the the thread group size is 4, then the first thread in // For example, if the thread group size is 4, then the first thread in
// the group has 0, 4, 8, ... th vectors of the query, and the second thread // the group has 0, 4, 8, ... th vectors of the query, and the second thread
// has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because // has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
// q is split from a qkv tensor, it may not be contiguous. // q is split from a qkv tensor, it may not be contiguous.
@ -259,7 +259,7 @@ __device__ void paged_attention_kernel(
// Load a key to registers. // Load a key to registers.
// Each thread in a thread group has a different part of the key. // Each thread in a thread group has a different part of the key.
// For example, if the the thread group size is 4, then the first thread in // For example, if the thread group size is 4, then the first thread in
// the group has 0, 4, 8, ... th vectors of the key, and the second thread // the group has 0, 4, 8, ... th vectors of the key, and the second thread
// has 1, 5, 9, ... th vectors of the key, and so on. // has 1, 5, 9, ... th vectors of the key, and so on.
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {

View File

@ -0,0 +1,401 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#include <assert.h>
#include <cuda.h>
#include <torch/all.h>
__device__ int64_t save_blocks(int* block_offset, int64_t range_start,
int64_t range_end, int64_t block_size,
int64_t input_block_count, int64_t kv_seqlen) {
if (range_start >= kv_seqlen) {
return input_block_count;
}
if (range_end > kv_seqlen) {
range_end = kv_seqlen;
}
int64_t current_block_count = input_block_count;
for (int idx = range_start; idx < range_end; idx += block_size) {
block_offset[current_block_count++] = idx;
}
return current_block_count;
}
__global__ void convert_vertical_slash_indexes_kernel(
const int* q_seqlens, // [BATCH, ]
const int* kv_seqlens, // [BATCH, ]
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
int64_t NNZ_V, int64_t NNZ_S,
bool causal // True for intra, False for succ
) {
const int batch_idx = blockIdx.y;
const int head_idx = blockIdx.x;
const int group_idx = blockIdx.z;
int64_t q_seqlen = q_seqlens[batch_idx];
int64_t kv_seqlen = kv_seqlens[batch_idx];
int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x;
int64_t start_m = block_idx_m * BLOCK_SIZE_M;
if (start_m >= q_seqlen) {
return;
}
int64_t end_m = start_m + BLOCK_SIZE_M;
vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
block_count += row_offset;
block_offset += row_offset * NNZ_S;
column_count += row_offset;
column_index += row_offset * NNZ_V;
bool has_slash = true;
int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0;
int64_t s = 0, v = 0;
int64_t v_idx = vertical_indexes[v++];
int64_t s_idx = slash_indexes[s++];
if (causal) {
while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) {
s_idx = slash_indexes[s++];
}
if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false;
s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M);
} else {
while (s_idx >= end_m + kv_seqlen && s < NNZ_S) {
s_idx = slash_indexes[s++];
}
if (s_idx > end_m + kv_seqlen) has_slash = false;
s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M);
}
int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx;
if (!has_slash) {
if (causal) {
range_start = (kv_seqlen - q_seqlen) + end_m;
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
} else {
range_start = kv_seqlen;
range_end = kv_seqlen + BLOCK_SIZE_N;
}
}
bool slash_finished = false;
while (1) {
if (v_idx < range_end) {
if (v_idx < range_start) {
column_index[tmp_col_cnt++] = v_idx;
}
if (v < NNZ_V) {
v_idx = vertical_indexes[v++];
} else {
if (causal)
v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen);
else
v_idx = end_m + BLOCK_SIZE_N + kv_seqlen;
}
} else {
if ((s < NNZ_S && causal) ||
(s < NNZ_S && !causal && slash_indexes[s] >= start_m)) {
if (causal)
s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++],
BLOCK_SIZE_M);
else
s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M);
} else {
if (v == NNZ_V || (v_idx > range_start && causal)) {
// add the last vertical if no more slash
if (v == NNZ_V && !causal && v_idx < kv_seqlen) {
column_index[tmp_col_cnt++] = v_idx;
}
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
break;
} else {
if (causal) {
range_start = (kv_seqlen - q_seqlen) + end_m;
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
} else {
// if slash_finished but there are vertical left, save current
// blocks
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
range_start = kv_seqlen;
range_end = kv_seqlen + BLOCK_SIZE_N;
}
slash_finished = true;
}
}
if (!slash_finished) {
if (s_idx > range_end + BLOCK_SIZE_M) {
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
range_start = s_idx - BLOCK_SIZE_M;
range_end = s_idx;
} else if (s_idx > range_end) {
range_end += BLOCK_SIZE_M;
}
}
}
}
block_count[0] = tmp_blk_cnt;
column_count[0] = tmp_col_cnt;
}
void convert_vertical_slash_indexes_64x64(
const int* q_seqlens, // [BATCH, ]
const int* kv_seqlens, // [BATCH, ]
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int64_t BATCH_SIZE, int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M,
int64_t BLOCK_SIZE_N, int64_t NNZ_V, int64_t NNZ_S, bool causal) {
const int N_THREADS = 64;
const dim3 dimBlock(N_THREADS);
const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS);
convert_vertical_slash_indexes_kernel<<<dimGrid, dimBlock>>>(
q_seqlens, kv_seqlens, vertical_indexes, slash_indexes, block_count,
block_offset, column_count, column_index, N_HEADS, N_ROWS, BLOCK_SIZE_M,
BLOCK_SIZE_N, NNZ_V, NNZ_S, causal);
}
/**
* Implements the Algorithm 4 in paper https://arxiv.org/abs/2407.02490.
*
* This function builds the index of each row of blocks from vertical indices
* and slash indices. The vertical indices are treated as points, while the
* slash indices are converted as ranges. The output consists of the merged
* ranges and separate column indices, where the ranges are represented by
* block indices.
*
* The implementation is referenced from the original MInference repo:
* https://github.com/microsoft/MInference/blob/main/csrc/vertical_slash_index.cu.
*/
void convert_vertical_slash_indexes(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
torch::Tensor q_seqlens, // [BATCH, ]
torch::Tensor kv_seqlens, // [BATCH, ]
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int64_t context_size, int64_t block_size_M, int64_t block_size_N,
bool causal) {
cudaSetDevice(q_seqlens.get_device());
int batch_size = slash_indexes.size(0);
int num_heads = slash_indexes.size(1);
int nnz_slash = slash_indexes.size(2);
int nnz_vertical = vertical_indexes.size(2);
int num_rows = (context_size + block_size_M - 1) / block_size_M;
convert_vertical_slash_indexes_64x64(
q_seqlens.data_ptr<int>(), kv_seqlens.data_ptr<int>(),
vertical_indexes.data_ptr<int>(), slash_indexes.data_ptr<int>(),
block_count.data_ptr<int>(), block_offset.data_ptr<int>(),
column_count.data_ptr<int>(), column_index.data_ptr<int>(), batch_size,
num_heads, num_rows, block_size_M, block_size_N, nnz_vertical, nnz_slash,
causal);
}
__global__ void convert_vertical_slash_indexes_kernel_mergehead(
const int* q_seqlens, // [BATCH, ]
const int* kv_seqlens, // [BATCH, ]
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
const int* per_head_vertical_topkv, const int* per_head_slash_topkv,
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
int64_t NNZ_V, int64_t NNZ_S,
bool causal // True for intra, False for succ
) {
const int batch_idx = blockIdx.y;
const int head_idx = blockIdx.x;
const int group_idx = blockIdx.z;
int64_t q_seqlen = q_seqlens[batch_idx];
int64_t kv_seqlen = kv_seqlens[batch_idx];
int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x;
int64_t start_m = block_idx_m * BLOCK_SIZE_M;
if (start_m >= q_seqlen) {
return;
}
int64_t end_m = start_m + BLOCK_SIZE_M;
vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
block_count += row_offset;
block_offset += row_offset * NNZ_S;
column_count += row_offset;
column_index += row_offset * NNZ_V;
// MergeHead: each head has it's unique max topk NNZ_VNNZ_S. (NNZ_VNNZ_S
// above is buffer size, use to compute offset)
NNZ_S = per_head_slash_topkv[head_idx];
NNZ_V = per_head_vertical_topkv[head_idx];
bool has_slash = true;
int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0;
int64_t s = 0, v = 0;
int64_t v_idx = vertical_indexes[v++];
int64_t s_idx = slash_indexes[s++];
if (causal) {
while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) {
s_idx = slash_indexes[s++];
}
if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false;
s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M);
} else {
while (s_idx >= end_m + kv_seqlen && s < NNZ_S) {
s_idx = slash_indexes[s++];
}
if (s_idx > end_m + kv_seqlen) has_slash = false;
s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M);
}
int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx;
if (!has_slash) {
if (causal) {
range_start = (kv_seqlen - q_seqlen) + end_m;
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
} else {
range_start = kv_seqlen;
range_end = kv_seqlen + BLOCK_SIZE_N;
}
}
bool slash_finished = false;
while (1) {
if (v_idx < range_end) {
if (v_idx < range_start) {
column_index[tmp_col_cnt++] = v_idx;
}
if (v < NNZ_V) {
v_idx = vertical_indexes[v++];
} else {
if (causal)
v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen);
else
v_idx = end_m + BLOCK_SIZE_N + kv_seqlen;
}
} else {
if ((s < NNZ_S && causal) ||
(s < NNZ_S && !causal && slash_indexes[s] >= start_m)) {
if (causal)
s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++],
BLOCK_SIZE_M);
else
s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M);
} else {
if (v == NNZ_V || (v_idx > range_start && causal)) {
// add the last vertical if no more slash
if (v == NNZ_V && !causal && v_idx < kv_seqlen) {
column_index[tmp_col_cnt++] = v_idx;
}
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
break;
} else {
if (causal) {
range_start = (kv_seqlen - q_seqlen) + end_m;
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
} else {
// if slash_finished but there are vertical left, save current
// blocks
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
range_start = kv_seqlen;
range_end = kv_seqlen + BLOCK_SIZE_N;
}
slash_finished = true;
}
}
if (!slash_finished) {
if (s_idx > range_end + BLOCK_SIZE_M) {
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
range_start = s_idx - BLOCK_SIZE_M;
range_end = s_idx;
} else if (s_idx > range_end) {
range_end += BLOCK_SIZE_M;
}
}
}
}
block_count[0] = tmp_blk_cnt;
column_count[0] = tmp_col_cnt;
}
void convert_vertical_slash_indexes_64x64_mergehead(
const int* q_seqlens, // [BATCH, ]
const int* kv_seqlens, // [BATCH, ]
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int* per_head_vertical_topkv, int* per_head_slash_topkv,
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int64_t BATCH_SIZE, int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M,
int64_t BLOCK_SIZE_N, int64_t NNZ_V, int64_t NNZ_S, bool causal) {
const int N_THREADS = 64;
const dim3 dimBlock(N_THREADS);
const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS);
convert_vertical_slash_indexes_kernel_mergehead<<<dimGrid, dimBlock>>>(
q_seqlens, kv_seqlens, vertical_indexes, slash_indexes,
per_head_vertical_topkv, per_head_slash_topkv, block_count, block_offset,
column_count, column_index, N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N,
NNZ_V, NNZ_S, causal);
}
/**
* Implements the Algorithm 4 in paper https://arxiv.org/abs/2407.02490.
*
* Like the above convert_vertical_slash_indexes, but with
* pre-computed vertical and slash counts.
*/
void convert_vertical_slash_indexes_mergehead(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
torch::Tensor q_seqlens, // [BATCH, ]
torch::Tensor kv_seqlens, // [BATCH, ]
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
torch::Tensor vertical_indices_count, // [N_HEADS, ]
torch::Tensor slash_indices_count, // [N_HEADS, ]
int64_t context_size, int64_t block_size_M, int64_t block_size_N,
bool causal) {
cudaSetDevice(q_seqlens.get_device());
int batch_size = slash_indexes.size(0);
int num_heads = slash_indexes.size(1);
int nnz_slash = slash_indexes.size(2);
int nnz_vertical = vertical_indexes.size(2);
int num_rows = (context_size + block_size_M - 1) / block_size_M;
convert_vertical_slash_indexes_64x64_mergehead(
q_seqlens.data_ptr<int>(), kv_seqlens.data_ptr<int>(),
vertical_indexes.data_ptr<int>(), slash_indexes.data_ptr<int>(),
vertical_indices_count.data_ptr<int>(),
slash_indices_count.data_ptr<int>(), block_count.data_ptr<int>(),
block_offset.data_ptr<int>(), column_count.data_ptr<int>(),
column_index.data_ptr<int>(), batch_size, num_heads, num_rows,
block_size_M, block_size_N, nnz_vertical, nnz_slash, causal);
}

View File

@ -7,3 +7,22 @@ inline constexpr uint32_t next_pow_2(uint32_t const num) {
if (num <= 1) return num; if (num <= 1) return num;
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
} }
template <typename A, typename B>
static inline constexpr auto div_ceil(A a, B b) {
return (a + b - 1) / b;
}
// Round a down to the next multiple of b. The caller is responsible for making
// sure that b is non-zero
template <typename T>
inline constexpr T round_to_previous_multiple_of(T a, T b) {
return a % b == 0 ? a : (a / b) * b;
}
// Round a up to the next multiple of b. The caller is responsible for making
// sure that b is non-zero
template <typename T>
inline constexpr T round_to_next_multiple_of(T a, T b) {
return a % b == 0 ? a : ((a / b) + 1) * b;
}

View File

@ -315,6 +315,8 @@ static inline constexpr auto kS8 = ScalarType::int_(8);
static inline constexpr auto kU8 = ScalarType::uint(8); static inline constexpr auto kU8 = ScalarType::uint(8);
static inline constexpr auto kU8B128 = ScalarType::uint(8, 128); static inline constexpr auto kU8B128 = ScalarType::uint(8, 128);
static inline constexpr auto kFE2M1f =
ScalarType::float_(2, 1, true, ScalarType::NAN_NONE);
static inline constexpr auto kFE3M2f = static inline constexpr auto kFE3M2f =
ScalarType::float_(3, 2, true, ScalarType::NAN_NONE); ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);
static inline constexpr auto kFE4M3fn = static inline constexpr auto kFE4M3fn =
@ -332,6 +334,7 @@ static inline constexpr auto kInt8 = kS8;
static inline constexpr auto kUint8 = kU8; static inline constexpr auto kUint8 = kU8;
static inline constexpr auto kUint8b128 = kU8B128; static inline constexpr auto kUint8b128 = kU8B128;
static inline constexpr auto kFloat4_e2m1f = kFE2M1f;
static inline constexpr auto kFloat6_e3m2f = kFE3M2f; static inline constexpr auto kFloat6_e3m2f = kFE3M2f;
static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn; static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn;
static inline constexpr auto kFloat8_e5m2 = kFE5M2; static inline constexpr auto kFloat8_e5m2 = kFE5M2;

View File

@ -4,6 +4,7 @@
#include <altivec.h> #include <altivec.h>
#include <cmath> #include <cmath>
#include <algorithm>
#include <torch/all.h> #include <torch/all.h>
namespace vec_op { namespace vec_op {
@ -62,6 +63,10 @@ typedef struct f32x4x4_t {
__vector float val[4]; __vector float val[4];
} f32x4x4_t; } f32x4x4_t;
typedef struct i32x4x4_t {
__vector int32_t val[4];
} i32x4x4_t;
struct FP32Vec8; struct FP32Vec8;
struct FP32Vec16; struct FP32Vec16;
@ -98,6 +103,28 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
vec_xst(reg.val[0], 0, (signed short*)ptr); vec_xst(reg.val[0], 0, (signed short*)ptr);
vec_xst(reg.val[1], 16, (signed short*)ptr); vec_xst(reg.val[1], 16, (signed short*)ptr);
} }
void save(void* ptr, const int elem_num) const {
const int clamped_elem = std::max(0, std::min(elem_num, 16));
// Calculate elements to store in each 128-bit part (8 elements each)
const int elements_val0 = std::min(clamped_elem, 8);
const int elements_val1 = std::max(clamped_elem - 8, 0);
// Convert elements to bytes (2 bytes per element)
const size_t bytes_val0 = elements_val0 * sizeof(signed short);
const size_t bytes_val1 = elements_val1 * sizeof(signed short);
signed short* dest = static_cast<signed short*>(ptr);
// Store the first part using vec_xst_len
if (bytes_val0 > 0) {
vec_xst_len(reg.val[0], dest, bytes_val0);
}
// Store the second part if needed
if (bytes_val1 > 0) {
vec_xst_len(reg.val[1], dest + elements_val0, bytes_val1);
}
}
}; };
const static __vector signed short zero = vec_splats((signed short)0); const static __vector signed short zero = vec_splats((signed short)0);
@ -257,6 +284,64 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
} }
}; };
struct INT32Vec16 : public Vec<INT32Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
union AliasReg {
i32x4x4_t reg;
int32_t values[VEC_ELEM_NUM];
};
i32x4x4_t reg;
explicit INT32Vec16(const void* data_ptr) {
reg.val[0] = vec_xl(0, reinterpret_cast<const __vector int32_t*>(data_ptr));
reg.val[1] =
vec_xl(16, reinterpret_cast<const __vector int32_t*>(data_ptr));
reg.val[2] =
vec_xl(32, reinterpret_cast<const __vector int32_t*>(data_ptr));
reg.val[3] =
vec_xl(48, reinterpret_cast<const __vector int32_t*>(data_ptr));
}
void save(int32_t* ptr) const {
vec_xst(reg.val[0], 0, reinterpret_cast<__vector int32_t*>(ptr));
vec_xst(reg.val[1], 16, reinterpret_cast<__vector int32_t*>(ptr));
vec_xst(reg.val[2], 32, reinterpret_cast<__vector int32_t*>(ptr));
vec_xst(reg.val[3], 48, reinterpret_cast<__vector int32_t*>(ptr));
}
void save(int32_t* ptr, const int elem_num) const {
const int elements_in_chunk1 =
(elem_num >= 0) ? ((elem_num >= 4) ? 4 : elem_num) : 0;
const int elements_in_chunk2 =
(elem_num > 4) ? ((elem_num >= 8) ? 4 : elem_num - 4) : 0;
const int elements_in_chunk3 =
(elem_num > 8) ? ((elem_num >= 12) ? 4 : elem_num - 8) : 0;
const int elements_in_chunk4 =
(elem_num > 12) ? ((elem_num >= 16) ? 4 : elem_num - 12) : 0;
const size_t bytes_chunk1 =
static_cast<size_t>(elements_in_chunk1 * sizeof(int32_t));
const size_t bytes_chunk2 =
static_cast<size_t>(elements_in_chunk2 * sizeof(int32_t));
const size_t bytes_chunk3 =
static_cast<size_t>(elements_in_chunk3 * sizeof(int32_t));
const size_t bytes_chunk4 =
static_cast<size_t>(elements_in_chunk4 * sizeof(int32_t));
vec_xst_len(reg.val[0], reinterpret_cast<int32_t*>(ptr), bytes_chunk1);
vec_xst_len(reg.val[1],
reinterpret_cast<int32_t*>(reinterpret_cast<char*>(ptr) + 16),
bytes_chunk2);
vec_xst_len(reg.val[2],
reinterpret_cast<int32_t*>(reinterpret_cast<char*>(ptr) + 32),
bytes_chunk3);
vec_xst_len(reg.val[3],
reinterpret_cast<int32_t*>(reinterpret_cast<char*>(ptr) + 48),
bytes_chunk4);
}
};
struct FP32Vec16 : public Vec<FP32Vec16> { struct FP32Vec16 : public Vec<FP32Vec16> {
constexpr static int VEC_ELEM_NUM = 16; constexpr static int VEC_ELEM_NUM = 16;
union AliasReg { union AliasReg {
@ -319,6 +404,13 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {} explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
explicit FP32Vec16(const INT32Vec16& v) {
reg.val[0] = vec_ctf(v.reg.val[0], 0);
reg.val[1] = vec_ctf(v.reg.val[1], 0);
reg.val[2] = vec_ctf(v.reg.val[2], 0);
reg.val[3] = vec_ctf(v.reg.val[3], 0);
}
FP32Vec16 operator*(const FP32Vec16& b) const { FP32Vec16 operator*(const FP32Vec16& b) const {
return FP32Vec16(f32x4x4_t({vec_mul(reg.val[0], b.reg.val[0]), return FP32Vec16(f32x4x4_t({vec_mul(reg.val[0], b.reg.val[0]),
vec_mul(reg.val[1], b.reg.val[1]), vec_mul(reg.val[1], b.reg.val[1]),
@ -347,6 +439,117 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
vec_div(reg.val[3], b.reg.val[3])})); vec_div(reg.val[3], b.reg.val[3])}));
} }
FP32Vec16 clamp(const FP32Vec16& min, const FP32Vec16& max) const {
return FP32Vec16(f32x4x4_t(
{vec_min(max.reg.val[0], vec_max(min.reg.val[0], reg.val[0])),
vec_min(max.reg.val[1], vec_max(min.reg.val[1], reg.val[1])),
vec_min(max.reg.val[2], vec_max(min.reg.val[2], reg.val[2])),
vec_min(max.reg.val[3], vec_max(min.reg.val[3], reg.val[3]))}));
}
FP32Vec16 max(const FP32Vec16& b) const {
return FP32Vec16(f32x4x4_t({vec_max(reg.val[0], b.reg.val[0]),
vec_max(reg.val[1], b.reg.val[1]),
vec_max(reg.val[2], b.reg.val[2]),
vec_max(reg.val[3], b.reg.val[3])}));
}
FP32Vec16 max(const FP32Vec16& b, int elem_num) const {
FP32Vec16 result;
// Create a vector of element indices for each chunk
__vector unsigned int indices = {0, 1, 2, 3};
__vector unsigned int elem_num_vec =
vec_splats(static_cast<unsigned int>(elem_num));
// Compute masks for each chunk
__vector unsigned int chunk_offset0 = {0, 0, 0,
0}; // Chunk 0: Elements 0-3
__vector unsigned int chunk_offset1 = {4, 4, 4,
4}; // Chunk 1: Elements 4-7
__vector unsigned int chunk_offset2 = {8, 8, 8,
8}; // Chunk 2: Elements 8-11
__vector unsigned int chunk_offset3 = {12, 12, 12,
12}; // Chunk 3: Elements 12-15
// Compute masks for each chunk
__vector bool int mask0 = vec_cmplt(indices + chunk_offset0, elem_num_vec);
__vector bool int mask1 = vec_cmplt(indices + chunk_offset1, elem_num_vec);
__vector bool int mask2 = vec_cmplt(indices + chunk_offset2, elem_num_vec);
__vector bool int mask3 = vec_cmplt(indices + chunk_offset3, elem_num_vec);
// Apply masks to compute the result for each chunk
result.reg.val[0] = vec_sel(this->reg.val[0],
vec_max(this->reg.val[0], b.reg.val[0]), mask0);
result.reg.val[1] = vec_sel(this->reg.val[1],
vec_max(this->reg.val[1], b.reg.val[1]), mask1);
result.reg.val[2] = vec_sel(this->reg.val[2],
vec_max(this->reg.val[2], b.reg.val[2]), mask2);
result.reg.val[3] = vec_sel(this->reg.val[3],
vec_max(this->reg.val[3], b.reg.val[3]), mask3);
return FP32Vec16(result.reg);
}
FP32Vec16 min(const FP32Vec16& b) const {
return FP32Vec16(f32x4x4_t({vec_min(reg.val[0], b.reg.val[0]),
vec_min(reg.val[1], b.reg.val[1]),
vec_min(reg.val[2], b.reg.val[2]),
vec_min(reg.val[3], b.reg.val[3])}));
}
FP32Vec16 min(const FP32Vec16& b, int elem_num) const {
FP32Vec16 result;
vector unsigned int indices = {0, 1, 2, 3};
vector unsigned int elem_num_vec =
vec_splats(static_cast<unsigned int>(elem_num));
vector unsigned int chunk_offset0 = {0, 0, 0, 0};
vector unsigned int chunk_offset1 = {4, 4, 4, 4};
vector unsigned int chunk_offset2 = {8, 8, 8, 8};
vector unsigned int chunk_offset3 = {12, 12, 12, 12};
vector bool int mask0 = vec_cmplt(indices + chunk_offset0, elem_num_vec);
vector bool int mask1 = vec_cmplt(indices + chunk_offset1, elem_num_vec);
vector bool int mask2 = vec_cmplt(indices + chunk_offset2, elem_num_vec);
vector bool int mask3 = vec_cmplt(indices + chunk_offset3, elem_num_vec);
result.reg.val[0] = vec_sel(this->reg.val[0],
vec_min(this->reg.val[0], b.reg.val[0]), mask0);
result.reg.val[1] = vec_sel(this->reg.val[1],
vec_min(this->reg.val[1], b.reg.val[1]), mask1);
result.reg.val[2] = vec_sel(this->reg.val[2],
vec_min(this->reg.val[2], b.reg.val[2]), mask2);
result.reg.val[3] = vec_sel(this->reg.val[3],
vec_min(this->reg.val[3], b.reg.val[3]), mask3);
return FP32Vec16(result.reg);
}
FP32Vec16 abs() const {
return FP32Vec16(f32x4x4_t({vec_abs(reg.val[0]), vec_abs(reg.val[1]),
vec_abs(reg.val[2]), vec_abs(reg.val[3])}));
}
float reduce_max() {
__vector float max01 = vec_max(reg.val[0], reg.val[1]);
__vector float max23 = vec_max(reg.val[2], reg.val[3]);
__vector float max_all = vec_max(max01, max23);
__vector float temp = vec_max(max_all, vec_sld(max_all, max_all, 8));
temp = vec_max(temp, vec_sld(temp, temp, 4));
return vec_extract(temp, 0);
}
float reduce_min() {
__vector float min01 = vec_min(reg.val[0], reg.val[1]);
__vector float min23 = vec_min(reg.val[2], reg.val[3]);
__vector float min_all = vec_min(min01, min23);
__vector float temp = vec_min(min_all, vec_sld(min_all, min_all, 8));
temp = vec_min(temp, vec_sld(temp, temp, 4));
return vec_extract(temp, 0);
}
float reduce_sum() const { float reduce_sum() const {
AliasReg ar; AliasReg ar;
ar.reg = reg; ar.reg = reg;
@ -377,6 +580,68 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
vec_xst(reg.val[2], 32, ptr); vec_xst(reg.val[2], 32, ptr);
vec_xst(reg.val[3], 48, ptr); vec_xst(reg.val[3], 48, ptr);
} }
void save(float* ptr, const int elem_num) const {
const int elements_in_chunk1 =
(elem_num >= 0) ? ((elem_num >= 4) ? 4 : elem_num) : 0;
const int elements_in_chunk2 =
(elem_num > 4) ? ((elem_num >= 8) ? 4 : elem_num - 4) : 0;
const int elements_in_chunk3 =
(elem_num > 8) ? ((elem_num >= 12) ? 4 : elem_num - 8) : 0;
const int elements_in_chunk4 =
(elem_num > 12) ? ((elem_num >= 16) ? 4 : elem_num - 12) : 0;
const size_t bytes_chunk1 =
static_cast<size_t>(elements_in_chunk1 * sizeof(float));
const size_t bytes_chunk2 =
static_cast<size_t>(elements_in_chunk2 * sizeof(float));
const size_t bytes_chunk3 =
static_cast<size_t>(elements_in_chunk3 * sizeof(float));
const size_t bytes_chunk4 =
static_cast<size_t>(elements_in_chunk4 * sizeof(float));
vec_xst_len(reg.val[0], ptr, bytes_chunk1);
vec_xst_len(reg.val[1],
reinterpret_cast<float*>(reinterpret_cast<char*>(ptr) + 16),
bytes_chunk2);
vec_xst_len(reg.val[2],
reinterpret_cast<float*>(reinterpret_cast<char*>(ptr) + 32),
bytes_chunk3);
vec_xst_len(reg.val[3],
reinterpret_cast<float*>(reinterpret_cast<char*>(ptr) + 48),
bytes_chunk4);
}
};
struct INT8Vec16 : public Vec<INT8Vec16> {
constexpr static int VEC_NUM_ELEM = 16; // 128 bits / 8 bits = 16
union AliasReg {
__vector signed char reg;
int8_t values[VEC_NUM_ELEM];
};
__vector signed char reg;
explicit INT8Vec16(const FP32Vec16& vec) {
__vector signed int ret[4];
ret[0] = vec_cts(vec.reg.val[0], 0);
ret[1] = vec_cts(vec.reg.val[1], 0);
ret[2] = vec_cts(vec.reg.val[2], 0);
ret[3] = vec_cts(vec.reg.val[3], 0);
__vector signed short packed1 = vec_packs(ret[0], ret[1]);
__vector signed short packed2 = vec_packs(ret[2], ret[3]);
reg = vec_packs(packed1, packed2);
}
void save(void* ptr) const {
*reinterpret_cast<__vector signed char*>(ptr) = reg;
}
void save(signed char* ptr, const int elem_num) {
vec_xst_len(reg, ptr, static_cast<size_t>(elem_num));
}
}; };
template <typename T> template <typename T>

View File

@ -9,7 +9,8 @@ void rotary_embedding_impl(
scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads, scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads,
/// head_size] or [num_tokens, num_heads, /// head_size] or [num_tokens, num_heads,
/// head_size] /// head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, scalar_t* __restrict__ key, // nullptr (optional) or
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads, // head_size] or [num_tokens, num_kv_heads,
// head_size] // head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
@ -85,10 +86,13 @@ void rotary_embedding_impl(
compute_loop(token_head, cache_ptr, query); compute_loop(token_head, cache_ptr, query);
} }
for (int i = 0; i < num_kv_heads; ++i) { if (key != nullptr) {
const int head_idx = i; for (int i = 0; i < num_kv_heads; ++i) {
const int64_t token_head = token_idx * key_stride + head_idx * head_size; const int head_idx = i;
compute_loop(token_head, cache_ptr, key); const int64_t token_head =
token_idx * key_stride + head_idx * head_size;
compute_loop(token_head, cache_ptr, key);
}
} }
} }
} }
@ -100,7 +104,8 @@ void rotary_embedding_gptj_impl(
scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads, scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads,
/// head_size] or [num_tokens, num_heads, /// head_size] or [num_tokens, num_heads,
/// head_size] /// head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, scalar_t* __restrict__ key, // nullptr (optional) or
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads, // head_size] or [num_tokens, num_kv_heads,
// head_size] // head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
@ -138,6 +143,10 @@ void rotary_embedding_gptj_impl(
} }
} }
if (key == nullptr) {
return;
}
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
for (int i = 0; i < num_kv_heads; ++i) { for (int i = 0; i < num_kv_heads; ++i) {
@ -168,13 +177,13 @@ void rotary_embedding_gptj_impl(
}; // namespace }; // namespace
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& key, int64_t head_size, std::optional<torch::Tensor> key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox) { torch::Tensor& cos_sin_cache, bool is_neox) {
int num_tokens = positions.numel(); int num_tokens = positions.numel();
int rot_dim = cos_sin_cache.size(1); int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-1) / head_size; int num_heads = query.size(-1) / head_size;
int num_kv_heads = key.size(-1) / head_size; int num_kv_heads = key.has_value() ? key->size(-1) / head_size : num_heads;
int64_t key_stride = key.stride(-2); int64_t key_stride = key.has_value() ? key->stride(-2) : 0;
int64_t query_stride = query.stride(-2); int64_t query_stride = query.stride(-2);
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(
@ -183,15 +192,15 @@ void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
if (is_neox) { if (is_neox) {
rotary_embedding_impl( rotary_embedding_impl(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(), positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(), key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
rot_dim, query_stride, key_stride, num_heads, num_kv_heads, cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride,
head_size, num_tokens); key_stride, num_heads, num_kv_heads, head_size, num_tokens);
} else { } else {
rotary_embedding_gptj_impl( rotary_embedding_gptj_impl(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(), positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(), key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
rot_dim, query_stride, key_stride, num_heads, num_kv_heads, cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride,
head_size, num_tokens); key_stride, num_heads, num_kv_heads, head_size, num_tokens);
} }
CPU_KERNEL_GUARD_OUT(rotary_embedding_impl) CPU_KERNEL_GUARD_OUT(rotary_embedding_impl)

View File

@ -239,6 +239,280 @@ void static_quant_epilogue(const float* input, scalar_t* output,
} }
} }
template <bool AZP, bool PerChannel, bool Bias, typename scalar_t>
void dynamic_quant_epilogue(const float* input, scalar_t* output,
const float* a_scale, const float* b_scale,
const int32_t* azp, const int32_t* azp_adj,
const scalar_t* bias, const int num_tokens,
const int hidden_size) {
CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue)
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
using azp_adj_load_vec_t =
typename KernelVecType<scalar_t>::azp_adj_load_vec_type;
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
#pragma omp parallel for
for (int i = 0; i < num_tokens; ++i) {
int j = 0;
cvt_vec_t token_scale_vec(a_scale[i]);
cvt_vec_t token_zp_scale_vec;
if constexpr (AZP) {
float zp_scale_val = a_scale[i] * static_cast<float>(azp[i]);
if constexpr (!PerChannel) {
zp_scale_val *= *b_scale;
}
token_zp_scale_vec = cvt_vec_t(zp_scale_val);
}
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
cvt_vec_t elems_fp32(input + i * hidden_size + j);
elems_fp32 = elems_fp32 * token_scale_vec;
if constexpr (AZP) {
azp_adj_load_vec_t azp_adj_vec(azp_adj + j);
cvt_vec_t azp_adj_fp32(azp_adj_vec);
azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec;
if constexpr (PerChannel) {
cvt_vec_t b_scale_vec(b_scale + j);
azp_adj_fp32 = azp_adj_fp32 * b_scale_vec;
}
elems_fp32 = elems_fp32 - azp_adj_fp32;
}
if constexpr (Bias) {
load_vec_t bias_vec(bias + j);
cvt_vec_t bias_vec_fp32(bias_vec);
elems_fp32 = elems_fp32 + bias_vec_fp32;
}
load_vec_t elems_out(elems_fp32);
elems_out.save(output + i * hidden_size + j);
}
cvt_vec_t elems_fp32(input + i * hidden_size + j);
elems_fp32 = elems_fp32 * token_scale_vec;
if constexpr (AZP) {
azp_adj_load_vec_t azp_adj_vec(azp_adj + j);
cvt_vec_t azp_adj_fp32(azp_adj_vec);
azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec;
if constexpr (PerChannel) {
cvt_vec_t b_scale_vec(b_scale + j);
azp_adj_fp32 = azp_adj_fp32 * b_scale_vec;
}
elems_fp32 = elems_fp32 - azp_adj_fp32;
}
if constexpr (Bias) {
load_vec_t bias_vec(bias + j);
cvt_vec_t bias_vec_fp32(bias_vec);
elems_fp32 = elems_fp32 + bias_vec_fp32;
}
load_vec_t elems_out(elems_fp32);
elems_out.save(output + i * hidden_size + j, hidden_size - j);
}
}
#elif defined(__powerpc64__)
template <bool AZP, typename scalar_t>
void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
const float* scale, const int32_t* azp,
const int num_tokens,
const int hidden_size) {
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
constexpr float i8_min =
static_cast<float>(std::numeric_limits<int8_t>::min());
constexpr float i8_max =
static_cast<float>(std::numeric_limits<int8_t>::max());
const cvt_vec_t inv_scale(1.0 / *scale);
const cvt_vec_t i8_min_vec(i8_min);
const cvt_vec_t i8_max_vec(i8_max);
cvt_vec_t zp_vec;
if constexpr (AZP) {
zp_vec = cvt_vec_t(static_cast<float>(*azp));
}
#pragma omp parallel for
for (int i = 0; i < num_tokens; ++i) {
int j = 0;
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
load_vec_t elems(input + i * hidden_size + j);
cvt_vec_t elems_fp32(elems);
elems_fp32 = elems_fp32 * inv_scale;
if constexpr (AZP) {
elems_fp32 = elems_fp32 + zp_vec;
}
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
vec_op::INT8Vec16 elems_int8(elems_fp32);
elems_int8.save(output + i * hidden_size + j);
}
load_vec_t elems(input + i * hidden_size + j);
cvt_vec_t elems_fp32(elems);
elems_fp32 = elems_fp32 * inv_scale;
if constexpr (AZP) {
elems_fp32 = elems_fp32 + zp_vec;
}
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
vec_op::INT8Vec16 elems_int8(elems_fp32);
elems_int8.save(output + i * hidden_size + j, hidden_size - j);
}
}
template <bool AZP, typename scalar_t>
void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
float* scale, int32_t* azp,
const int num_tokens,
const int hidden_size) {
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
constexpr float i8_min =
static_cast<float>(std::numeric_limits<int8_t>::min());
constexpr float i8_max =
static_cast<float>(std::numeric_limits<int8_t>::max());
const cvt_vec_t i8_min_vec(i8_min);
const cvt_vec_t i8_max_vec(i8_max);
#pragma omp parallel for
for (int i = 0; i < num_tokens; ++i) {
cvt_vec_t max_value(std::numeric_limits<float>::lowest());
cvt_vec_t min_value(std::numeric_limits<float>::max());
{
int j = 0;
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
load_vec_t elems(input + i * hidden_size + j);
cvt_vec_t elems_fp32(elems);
if constexpr (AZP) {
max_value = max_value.max(elems_fp32);
min_value = min_value.min(elems_fp32);
} else {
max_value = max_value.max(elems_fp32.abs());
}
}
load_vec_t elems(input + i * hidden_size + j);
cvt_vec_t elems_fp32(elems);
if (j + vec_elem_num == hidden_size) {
if constexpr (AZP) {
max_value = max_value.max(elems_fp32);
min_value = min_value.min(elems_fp32);
} else {
max_value = max_value.max(elems_fp32.abs());
}
} else {
if constexpr (AZP) {
max_value = max_value.max(elems_fp32, hidden_size - j);
min_value = min_value.min(elems_fp32, hidden_size - j);
} else {
max_value = max_value.max(elems_fp32.abs(), hidden_size - j);
}
}
}
float scale_val, azp_val;
if constexpr (AZP) {
float max_scalar = max_value.reduce_max();
float min_scalar = min_value.reduce_min();
scale_val = (max_scalar - min_scalar) / 255.0f;
azp_val = std::nearbyint(-128.0f - min_scalar / scale_val);
azp[i] = static_cast<int32_t>(azp_val);
scale[i] = scale_val;
} else {
scale_val = max_value.reduce_max() / 127.0f;
scale[i] = scale_val;
}
const cvt_vec_t inv_scale(1.0 / scale_val);
const cvt_vec_t azp_vec(azp_val);
{
int j = 0;
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
load_vec_t elems(input + i * hidden_size + j);
cvt_vec_t elems_fp32(elems);
elems_fp32 = (elems_fp32 * inv_scale);
if constexpr (AZP) {
elems_fp32 = elems_fp32 + azp_vec;
}
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
vec_op::INT8Vec16 elems_int8(elems_fp32);
elems_int8.save(output + i * hidden_size + j);
}
load_vec_t elems(input + i * hidden_size + j);
cvt_vec_t elems_fp32(elems);
elems_fp32 = (elems_fp32 * inv_scale);
if constexpr (AZP) {
elems_fp32 = elems_fp32 + azp_vec;
}
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
vec_op::INT8Vec16 elems_int8(elems_fp32);
elems_int8.save(output + i * hidden_size + j, hidden_size - j);
}
}
}
template <bool PerChannel, typename scalar_t>
void static_quant_epilogue(const float* input, scalar_t* output,
const float a_scale, const float* b_scale,
const int32_t* azp_with_adj, const int num_tokens,
const int hidden_size) {
CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl)
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
using azp_adj_load_vec_t =
typename KernelVecType<scalar_t>::azp_adj_load_vec_type;
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
#pragma omp parallel for
for (int i = 0; i < num_tokens; ++i) {
cvt_vec_t a_scale_vec(a_scale);
cvt_vec_t b_scale_vec(*b_scale);
cvt_vec_t scale_vec = a_scale_vec * b_scale_vec;
int j = 0;
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
cvt_vec_t elems_fp32(input + i * hidden_size + j);
azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j);
cvt_vec_t azp_adj_fp32(azp_adj_vec);
if constexpr (PerChannel) {
b_scale_vec = cvt_vec_t(b_scale + j);
scale_vec = b_scale_vec * a_scale_vec;
}
elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32;
load_vec_t elems_out(elems_fp32);
elems_out.save(output + i * hidden_size + j);
}
cvt_vec_t elems_fp32(input + i * hidden_size + j);
azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j);
cvt_vec_t azp_adj_fp32(azp_adj_vec);
if constexpr (PerChannel) {
b_scale_vec = cvt_vec_t(b_scale + j);
scale_vec = b_scale_vec * a_scale_vec;
}
elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32;
load_vec_t elems_out(elems_fp32);
elems_out.save(output + i * hidden_size + j, hidden_size - j);
}
}
template <bool AZP, bool PerChannel, bool Bias, typename scalar_t> template <bool AZP, bool PerChannel, bool Bias, typename scalar_t>
void dynamic_quant_epilogue(const float* input, scalar_t* output, void dynamic_quant_epilogue(const float* input, scalar_t* output,
const float* a_scale, const float* b_scale, const float* a_scale, const float* b_scale,
@ -324,7 +598,8 @@ void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
const float* scale, const int32_t* azp, const float* scale, const int32_t* azp,
const int num_tokens, const int num_tokens,
const int hidden_size) { const int hidden_size) {
TORCH_CHECK(false, "static_scaled_int8_quant_impl requires AVX512 support.") TORCH_CHECK(
false, "static_scaled_int8_quant_impl requires AVX512/powerpc64 support.")
} }
template <typename scalar_t> template <typename scalar_t>
@ -332,7 +607,9 @@ void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
float* scale, int32_t* azp, float* scale, int32_t* azp,
const int num_tokens, const int num_tokens,
const int hidden_size) { const int hidden_size) {
TORCH_CHECK(false, "dynamic_scaled_int8_quant_impl requires AVX512 support.") TORCH_CHECK(
false,
"dynamic_scaled_int8_quant_impl requires AVX512/powerpc64 support.")
} }
template <bool PerChannel, typename scalar_t> template <bool PerChannel, typename scalar_t>
@ -340,7 +617,7 @@ void static_quant_epilogue(const float* input, scalar_t* output,
const float a_scale, const float* b_scale, const float a_scale, const float* b_scale,
const int32_t* azp_with_adj, const int num_tokens, const int32_t* azp_with_adj, const int num_tokens,
const int hidden_size) { const int hidden_size) {
TORCH_CHECK(false, "static_quant_epilogue requires AVX512 support.") TORCH_CHECK(false, "static_quant_epilogue requires AVX512/powerpc64 support.")
} }
template <typename scalar_t> template <typename scalar_t>
@ -349,7 +626,8 @@ void dynamic_quant_epilogue(const float* input, scalar_t* output,
const int32_t* azp, const int32_t* azp_with_adj, const int32_t* azp, const int32_t* azp_with_adj,
const scalar_t* bias, const int num_tokens, const scalar_t* bias, const int num_tokens,
const int hidden_size) { const int hidden_size) {
TORCH_CHECK(false, "dynamic_quant_epilogue requires AVX512 support.") TORCH_CHECK(false,
"dynamic_quant_epilogue requires AVX512/powerpc64 support.")
} }
#endif #endif
} // namespace } // namespace
@ -611,3 +889,58 @@ void dynamic_scaled_int8_quant(
} }
}); });
} }
#if defined(__powerpc64__)
void int8_scaled_mm_ppc64le(torch::Tensor& c, // [M, OC], row-major
const torch::Tensor& a, // [M, IC], row-major
const torch::Tensor& b, // [IC, OC], column-major
const torch::Tensor& a_scales,
const torch::Tensor& b_scales,
const std::optional<torch::Tensor>& bias // [OC]
) {
CPU_KERNEL_GUARD_IN(cutlass_scaled_mm)
// Checks for conformality
TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8,
"int8_scaled_mm_ppc64le only supports INT8 inputs.");
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
b.size(1) == c.size(1));
// We dont need this
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
// Check for strides and alignment
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
TORCH_CHECK(b.stride(0) == 1); // Column-major
TORCH_CHECK(c.stride(0) % 16 == 0 &&
b.stride(1) % 16 == 0); // 16 Byte Alignment
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) {
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
bias->dim() == 1);
}
VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm_ppc64le", [&] {
torch::Tensor tmp_fp32_out = torch::empty_like(c, ::at::ScalarType::Float);
// Compute C_inter=s_b * (A@B)
DNNLPrimitiveHelper<true>::gemm_s8s8_jit<float, void>(
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(),
tmp_fp32_out.data_ptr<float>(), nullptr, a.size(0), b.size(1),
a.size(1), nullptr, b_scales.data_ptr<float>(), 0, b_scales.numel());
if (bias.has_value()) {
// Compute C=s_a * C_inter + bias
dynamic_quant_epilogue<false, true, true>(
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
a_scales.data_ptr<float>(), nullptr, nullptr, nullptr,
bias->data_ptr<scalar_t>(), c.size(0), c.size(1));
} else {
// Compute C=s_a * C_inter
dynamic_quant_epilogue<false, true, false, scalar_t>(
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
a_scales.data_ptr<float>(), nullptr, nullptr, nullptr, nullptr,
c.size(0), c.size(1));
}
});
}
#endif

View File

@ -18,6 +18,14 @@ void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a,
const std::optional<torch::Tensor>& azp, const std::optional<torch::Tensor>& azp,
const std::optional<torch::Tensor>& bias); const std::optional<torch::Tensor>& bias);
#if defined(__powerpc64__)
void int8_scaled_mm_ppc64le(torch::Tensor& c, const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& a_scales,
const torch::Tensor& b_scales,
const std::optional<torch::Tensor>& bias);
#endif
void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query, void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,
torch::Tensor& kv_cache, double scale, torch::Tensor& kv_cache, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens); torch::Tensor& block_tables, torch::Tensor& seq_lens);
@ -117,7 +125,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key. // Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops.def( ops.def(
"rotary_embedding(Tensor positions, Tensor! query," "rotary_embedding(Tensor positions, Tensor! query,"
" Tensor! key, int head_size," " Tensor!? key, int head_size,"
" Tensor cos_sin_cache, bool is_neox) -> ()"); " Tensor cos_sin_cache, bool is_neox) -> ()");
ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding); ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding);
@ -150,6 +158,33 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor b_scales, Tensor azp_adj," " Tensor b_scales, Tensor azp_adj,"
" Tensor? azp, Tensor? bias) -> ()"); " Tensor? azp, Tensor? bias) -> ()");
ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp); ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp);
#elif defined(__powerpc64__)
// Compute int8 quantized tensor for given scaling factor.
ops.def(
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
"Tensor? azp) -> ()");
ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant);
// Compute int8 quantized tensor and scaling factor
ops.def(
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, "
"Tensor!? azp) -> ()");
ops.impl("dynamic_scaled_int8_quant", torch::kCPU,
&dynamic_scaled_int8_quant);
// W8A8 GEMM, supporting symmetric quantization.
ops.def(
"cutlass_scaled_mm(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor? bias) -> ()");
ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm_ppc64le);
// w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
// quantization.
ops.def(
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor azp_adj,"
" Tensor? azp, Tensor? bias) -> ()");
ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp);
#endif #endif
// SHM CCL // SHM CCL

View File

@ -59,3 +59,13 @@ struct enable_sm90_only : Kernel {
#endif #endif
} }
}; };
template <typename Kernel>
struct enable_sm100_only : Kernel {
template <typename... Args>
CUTLASS_DEVICE void operator()(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 1000
Kernel::operator()(std::forward<Args>(args)...);
#endif
}
};

View File

@ -65,5 +65,19 @@
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
#define VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::UInt16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::UInt32, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::UInt64, __VA_ARGS__)
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__))

View File

@ -140,6 +140,10 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size] torch::Tensor& input, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size] torch::Tensor& weight, // [hidden_size]
double epsilon) { double epsilon) {
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(weight.is_contiguous());
int hidden_size = input.size(-1); int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size; int num_tokens = input.numel() / hidden_size;

File diff suppressed because it is too large Load Diff

View File

@ -1,31 +0,0 @@
#include "marlin_moe_kernel_ku4.h"
namespace marlin_moe {
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool call_marlin_moe_kernel_ku4(
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int m_block, int max_par, int cfg_max_m_blocks) {
bool has_zp = true;
if (false) {
}
AWQ_CALL_IF_MOE(vllm::kU4, 16, 4, 256)
AWQ_CALL_IF_MOE(vllm::kU4, 8, 8, 256)
AWQ_CALL_IF_MOE(vllm::kU4, 8, 4, 128)
AWQ_CALL_IF_MOE(vllm::kU4, 4, 8, 128)
else {
return false;
}
return true;
}
} // namespace marlin_moe

View File

@ -1,20 +0,0 @@
#pragma once
#include "marlin_moe_kernel.h"
namespace marlin_moe {
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool call_marlin_moe_kernel_ku4(
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int m_block, int max_par, int cfg_max_m_blocks);
} // namespace marlin_moe

View File

@ -1,31 +0,0 @@
#include "marlin_moe_kernel_ku4b8.h"
namespace marlin_moe {
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool call_marlin_moe_kernel_ku4b8(
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int m_block, int max_par, int cfg_max_m_blocks) {
bool has_zp = false;
if (false) {
}
GPTQ_CALL_IF_MOE(vllm::kU4B8, 16, 4, 256)
GPTQ_CALL_IF_MOE(vllm::kU4B8, 8, 8, 256)
GPTQ_CALL_IF_MOE(vllm::kU4B8, 8, 4, 128)
GPTQ_CALL_IF_MOE(vllm::kU4B8, 4, 8, 128)
else {
return false;
}
return true;
}
} // namespace marlin_moe

View File

@ -1,20 +0,0 @@
#pragma once
#include "marlin_moe_kernel.h"
namespace marlin_moe {
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool call_marlin_moe_kernel_ku4b8(
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int m_block, int max_par, int cfg_max_m_blocks);
} // namespace marlin_moe

View File

@ -1,31 +0,0 @@
#include "marlin_moe_kernel_ku8b128.h"
namespace marlin_moe {
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool call_marlin_moe_kernel_ku8b128(
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int m_block, int max_par, int cfg_max_m_blocks) {
bool has_zp = false;
if (false) {
}
GPTQ_CALL_IF_MOE(vllm::kU8B128, 16, 4, 256)
GPTQ_CALL_IF_MOE(vllm::kU8B128, 8, 8, 256)
GPTQ_CALL_IF_MOE(vllm::kU8B128, 8, 4, 128)
GPTQ_CALL_IF_MOE(vllm::kU8B128, 4, 8, 128)
else {
return false;
}
return true;
}
} // namespace marlin_moe

View File

@ -1,18 +0,0 @@
#pragma once
#include "marlin_moe_kernel.h"
namespace marlin_moe {
bool call_marlin_moe_kernel_ku8b128(
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int m_block, int max_par, int cfg_max_m_blocks);
}

View File

@ -1,588 +0,0 @@
/*
* Modified by Neural Magic
* Copyright (C) Marlin.2024 Elias Frantar
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>
#include "core/exception.hpp"
#include "core/scalar_type.hpp"
#include "core/registration.h"
#include "marlin_kernels/marlin_moe_kernel_ku4b8.h"
#include "marlin_kernels/marlin_moe_kernel_ku8b128.h"
#include "marlin_kernels/marlin_moe_kernel_ku4.h"
template <typename T>
inline std::string str(T x) {
return std::to_string(x);
}
namespace marlin_moe {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// For a given "a" of size [M,K] performs a permutation of the K columns based
// on the given "perm" indices.
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
int const* __restrict__ perm_int_ptr,
int4* __restrict__ out_int4_ptr, int size_m,
int size_k, int block_rows) {
int start_row = block_rows * blockIdx.x;
int finish_row = start_row + block_rows;
if (finish_row > size_m) {
finish_row = size_m;
}
int cur_block_rows = finish_row - start_row;
int row_stride = size_k * sizeof(half) / 16;
auto permute_row = [&](int row) {
int iters = size_k / blockDim.x;
int rest = size_k % blockDim.x;
int offset = row * row_stride;
half const* a_row_half = reinterpret_cast<half const*>(a_int4_ptr + offset);
half* out_half = reinterpret_cast<half*>(out_int4_ptr + offset);
int base_k = 0;
for (int i = 0; i < iters; i++) {
int cur_k = base_k + threadIdx.x;
int src_pos = perm_int_ptr[cur_k];
out_half[cur_k] = a_row_half[src_pos];
base_k += blockDim.x;
}
if (rest) {
if (threadIdx.x < rest) {
int cur_k = base_k + threadIdx.x;
int src_pos = perm_int_ptr[cur_k];
out_half[cur_k] = a_row_half[src_pos];
}
}
};
for (int i = 0; i < cur_block_rows; i++) {
int cur_row = start_row + i;
if (cur_row < size_m) {
permute_row(cur_row);
}
}
}
__global__ void compute_expert_offsets(int const* __restrict__ topk_ids,
int* __restrict__ expert_offsets,
int topk_length, int block_size) {
int expert_id = threadIdx.x;
int num_experts = blockDim.x;
int occurrences = 0;
for (int i = 0; i < topk_length; ++i) {
occurrences += (topk_ids[i] == expert_id);
}
expert_offsets[expert_id + 1] = occurrences;
__syncthreads();
if (threadIdx.x == 0) {
int tot_offset = 0;
expert_offsets[0] = 0;
for (int i = 0; i < num_experts; ++i) {
tot_offset += ceildiv(expert_offsets[i + 1], block_size) * block_size;
expert_offsets[i + 1] = tot_offset;
}
}
__syncthreads();
}
#else
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
int const* __restrict__ perm_int_ptr,
int4* __restrict__ out_int4_ptr, int size_m,
int size_k, int block_rows) {
// Marlin is not implemented yet for SM < 8.0
assert(false);
return;
}
__global__ void compute_expert_offsets(int const* __restrict__ topk_ids,
int* __restrict__ expert_offsets,
int topk_length, int block_size) {
// Marlin is not implemented yet for SM < 8.0
assert(false);
return;
}
#endif
typedef struct {
int thread_k;
int thread_n;
int num_threads;
} thread_config_t;
typedef struct {
int max_m_blocks;
thread_config_t tb_cfg;
} exec_config_t;
thread_config_t small_batch_thread_configs[] = {
// Ordered by priority
// thread_k, thread_n, num_threads
{128, 128, 256}, // Default
{128, 64, 128}, // Reduce N 2X, same K
{64, 256, 256}, // Reduce K 2X, increase N 2X
{64, 128, 128}, // Reduce K 2X, same N
{64, 64, 128}, // Reduce both 2X
};
thread_config_t large_batch_thread_configs[] = {
// Ordered by priority
// thread_k, thread_n, num_threads
{64, 256, 256}, // Default
{128, 128, 256}, // Reduce N 2X, increase K 2X
{64, 128, 128}, // Reduce N 2X, same K
{128, 64, 128}, // Reduce N 4X, increase K 2X
{64, 64, 128}, // Reduce N 4X, same K
};
int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
int prob_n, int prob_k, int num_bits, int group_size,
bool has_act_order, bool is_k_full) {
bool cache_scales_chunk = has_act_order && !is_k_full;
int tb_n = th_config.thread_n;
int tb_k = th_config.thread_k;
// Get max scale groups per thread-block
int tb_groups;
if (group_size == -1) {
tb_groups = 1;
} else if (group_size == 0) {
tb_groups = ceildiv(tb_k, 32); // Worst case is 32 group size
} else {
tb_groups = ceildiv(tb_k, group_size);
}
if (cache_scales_chunk) {
int load_groups =
tb_groups * STAGES * 2; // Chunk size is 2x pipeline over dim K
load_groups = max(load_groups, 32); // We load at least 32 scale groups
return load_groups * tb_n * 4;
} else {
int tb_scales = tb_groups * tb_n * 2;
return tb_scales * STAGES;
}
}
bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,
int prob_m, int prob_n, int prob_k, int num_bits,
int scales_cache_size, int max_shared_mem) {
int pack_factor = 32 / num_bits;
// Get B size
int tb_k = th_config.thread_k;
int tb_n = th_config.thread_n;
int b_size = (tb_k * tb_n / pack_factor) * 4;
// Get A size
int m_blocks = ceildiv(prob_m, 16);
int tb_max_m = 16;
while (true) {
if (m_blocks >= max_m_blocks) {
tb_max_m *= max_m_blocks;
break;
}
max_m_blocks--;
if (max_m_blocks == 0) {
TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks);
}
}
int a_size = (tb_max_m * tb_k) * 2;
float pipe_size = (a_size + b_size) * STAGES;
TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity
return pipe_size < 0.95f * (max_shared_mem - scales_cache_size);
}
bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
int prob_m, int prob_n, int prob_k, int num_bits,
int group_size, bool has_act_order, bool is_k_full,
int max_shared_mem) {
// Sanity
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
th_config.num_threads == -1) {
return false;
}
// Verify K/N are divisible by thread K/N
if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {
return false;
}
// thread_k can be only 128 or 64 (because it must be less than groupsize
// which is 128)
if (th_config.thread_k != 128 && th_config.thread_k != 64) {
return false;
}
// Verify min for thread K/N
if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {
return false;
}
// num_threads must be at least 128 (= 4 warps)
if (th_config.num_threads < 128) {
return false;
}
// Determine cache for scales
int scales_cache_size =
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
group_size, has_act_order, is_k_full);
// Check that pipeline fits into cache
if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k,
num_bits, scales_cache_size, max_shared_mem)) {
return false;
}
return true;
}
exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
int num_bits, int group_size,
bool has_act_order, bool is_k_full,
int max_shared_mem) {
int max_m_blocks = 4;
while (max_m_blocks > 0) {
if (prob_m <= 16) {
for (auto th_config : small_batch_thread_configs) {
if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full,
max_shared_mem)) {
return exec_config_t{max_m_blocks, th_config};
}
}
} else {
for (auto th_config : large_batch_thread_configs) {
if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full,
max_shared_mem)) {
return exec_config_t{max_m_blocks, th_config};
}
}
}
max_m_blocks--; // Process less M blocks per invocation to reduce cache
// usage
}
return exec_config_t{0, {-1, -1, -1}};
}
#define CALL_MOE_KERNEL_FUNCTION(KERNEL_FUNCTION) \
else if (KERNEL_FUNCTION( \
q_type, thread_n_blocks, thread_k_blocks, has_act_order, \
group_blocks, num_threads, blocks, max_shared_mem, stream, \
A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \
zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \
num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \
replicate_input, apply_weights, m_block, max_par, \
exec_cfg.max_m_blocks)) { \
}
void marlin_mm_moe(const void* A, const void* B, void* C,
const void* sorted_ids, const void* topk_weights,
const void* topk_ids, const void* s, void* zp,
const void* g_idx, const void* perm, void* a_tmp,
void* expert_offsets, int prob_m, int prob_n, int prob_k,
void* workspace, vllm::ScalarType const& q_type,
bool has_act_order, bool is_k_full, bool has_zp,
int num_groups, int group_size, int num_experts, int topk,
int moe_block_size, int dev, cudaStream_t stream,
int thread_k, int thread_n, int sms, int max_par,
bool replicate_input, bool apply_weights) {
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
", ", prob_n, ", ", prob_k, "]");
if (sms == -1) {
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
}
int max_shared_mem = 0;
cudaDeviceGetAttribute(&max_shared_mem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
TORCH_CHECK(max_shared_mem > 0);
int num_bits = q_type.size_bits();
// Set thread config
exec_config_t exec_cfg;
if (thread_k != -1 && thread_n != -1) {
// User-defined config
exec_cfg =
exec_config_t{4, thread_config_t{thread_k, thread_n, USER_THREADS}};
} else {
// Auto config
exec_cfg =
determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size,
has_act_order, is_k_full, max_shared_mem);
}
TORCH_CHECK(exec_cfg.max_m_blocks > 0 &&
is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks,
prob_m, prob_n, prob_k, num_bits, group_size,
has_act_order, is_k_full, max_shared_mem),
"Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks,
", thread_k = ", exec_cfg.tb_cfg.thread_k,
", thread_n = ", exec_cfg.tb_cfg.thread_n,
", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [",
prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
", group_size = ", group_size,
", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
", max_shared_mem = ", max_shared_mem);
int num_threads = exec_cfg.tb_cfg.num_threads;
thread_k = exec_cfg.tb_cfg.thread_k;
thread_n = exec_cfg.tb_cfg.thread_n;
int thread_k_blocks = thread_k / 16;
int thread_n_blocks = thread_n / 16;
int blocks = sms;
TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
" is not divisible by thread_n = ", thread_n);
TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
" is not divisible by thread_k = ", thread_k);
int group_blocks = 0;
if (has_act_order) {
if (is_k_full) {
TORCH_CHECK(group_size != -1);
group_blocks = group_size / 16;
TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
" is not divisible by group_blocks = ", group_blocks);
} else {
TORCH_CHECK(group_size == 0);
group_blocks = 0;
}
} else {
if (group_size == -1) {
group_blocks = -1;
} else {
group_blocks = group_size / 16;
TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
" is not divisible by group_blocks = ", group_blocks);
}
}
int tot_m = prob_m;
const int* topk_ids_ptr = (const int*)topk_ids;
int* expert_offsets_ptr = (int*)expert_offsets;
compute_expert_offsets<<<1, num_experts, 0, stream>>>(
topk_ids_ptr, expert_offsets_ptr, tot_m * topk, moe_block_size);
bool do_permute_a = has_act_order;
// If we have a full K, then we can run the non-act-order version of Marlin
// (since the weight rows are reordered by increasing group ids, and by
// having a full K, we have full original groups)
if (is_k_full) {
has_act_order = false;
}
int pack_factor = 32 / q_type.size_bits();
for (int expert_idx = 0; expert_idx < num_experts; ++expert_idx) {
const int4* A_ptr = (const int4*)A;
int4* a_tmp_ptr = (int4*)a_tmp;
const int4* B_ptr =
(const int4*)B + (prob_n * prob_k / (pack_factor * 4)) * expert_idx;
int4* C_ptr = (int4*)C;
const float* topk_weights_ptr = (const float*)topk_weights;
const int* sorted_ids_ptr = (const int*)sorted_ids;
const int4* s_ptr = (const int4*)s + num_groups * prob_n / 8 * expert_idx;
const int4* zp_ptr =
(const int4*)zp + num_groups * prob_n / (pack_factor * 4) * expert_idx;
const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx;
const int* perm_ptr = (const int*)perm + prob_k * expert_idx;
int* locks = (int*)workspace;
if (do_permute_a) {
// Permute A columns
int topk_rows = replicate_input ? tot_m : tot_m * topk;
int block_rows = ceildiv(topk_rows, blocks);
permute_cols_kernel<<<blocks, num_threads, 0, stream>>>(
A_ptr, perm_ptr, a_tmp_ptr, topk_rows, prob_k, block_rows);
A_ptr = a_tmp_ptr;
}
int tot_m_blocks = ceildiv(tot_m, 16);
for (int m_block = 0; m_block < tot_m_blocks;
m_block += 4 * exec_cfg.max_m_blocks) {
if (false) {
}
CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4b8)
CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8b128)
CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4)
else {
TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " +
str(prob_n) + ", " + str(prob_k) + "]" +
", has_act_order = " + str(has_act_order) +
", num_groups = " + str(num_groups) +
", group_size = " + str(group_size) +
", thread_n_blocks = " + str(thread_n_blocks) +
", thread_k_blocks = " + str(thread_k_blocks));
}
}
}
}
} // namespace marlin_moe
torch::Tensor marlin_gemm_moe(
const torch::Tensor& a, const torch::Tensor& b_q_weights,
const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights,
const torch::Tensor& topk_ids, const torch::Tensor& b_scales,
torch::Tensor& b_zeros, const torch::Tensor& g_idx,
const torch::Tensor& perm, torch::Tensor& workspace,
vllm::ScalarTypeId const b_q_type_id, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk,
int64_t moe_block_size, bool replicate_input, bool apply_weights) {
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
bool has_zp = b_zeros.size(1) != 0;
if (has_zp) {
TORCH_CHECK(
b_q_type == vllm::kU4,
"b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str());
} else {
TORCH_CHECK(
b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128,
"b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type.str());
}
int pack_factor = 32 / b_q_type.size_bits();
int max_par = 4;
int dev = a.get_device();
auto options_dtype =
torch::TensorOptions().dtype(a.dtype()).device(a.device());
auto options_int =
torch::TensorOptions().dtype(torch::kInt).device(a.device());
torch::Tensor c = torch::zeros({size_m, topk, size_n}, options_dtype);
torch::Tensor a_tmp =
replicate_input ? torch::zeros({size_m, size_k}, options_dtype)
: torch::zeros({size_m, topk, size_k}, options_dtype);
torch::Tensor expert_offsets = torch::empty({num_experts + 1}, options_int);
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int thread_k = -1;
// thread_n: `n` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int thread_n = -1;
// sms: number of SMs to use for the kernel (can usually be left as auto -1)
int sms = -1;
// Detect groupsize and act_order
int num_groups = -1;
int group_size = -1;
bool has_act_order = g_idx.size(1) != 0;
int b_rank = b_scales.sizes().size();
TORCH_CHECK(b_rank == 3, "b_scales rank = ", b_rank, " is not 3");
TORCH_CHECK(b_scales.size(2) == size_n, "b_scales dim 2 = ", b_scales.size(2),
" is not size_n = ", size_n);
num_groups = b_scales.size(1);
TORCH_CHECK(VLLM_IMPLIES(!is_k_full, has_act_order),
"if is_k_full is false, has_act_order must be true");
if (has_act_order) {
if (is_k_full) {
TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1");
TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k,
", is not divisible by num_groups = ", num_groups);
group_size = size_k / num_groups;
} else {
group_size = 0;
}
} else {
if (num_groups > 1) {
TORCH_CHECK(
size_k % num_groups == 0, "size_k = ", size_k,
", is not divisible by b_scales.size(0) = ", b_scales.size(0));
group_size = size_k / num_groups;
} else {
group_size = -1;
}
}
// Verify b_zeros
if (has_zp) {
int rank = b_zeros.sizes().size();
TORCH_CHECK(rank == 3, "b_zeros rank = ", rank, " is not 3");
TORCH_CHECK(b_zeros.size(1) == num_groups,
"b_zeros dim 1 = ", b_zeros.size(1),
" is not num_groups = ", num_groups);
TORCH_CHECK(b_zeros.size(2) == size_n / pack_factor,
"b_zeros dim 2 = ", b_zeros.size(2),
" is not size_n / pack_factor = ", size_n / pack_factor);
}
marlin_moe::marlin_mm_moe(
a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(), sorted_ids.data_ptr(),
topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(),
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(),
expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(),
b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size,
num_experts, topk, moe_block_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par,
replicate_input, apply_weights);
return c;
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("marlin_gemm_moe", &marlin_gemm_moe);
}

1
csrc/moe/marlin_moe_wna16/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
kernel_*.cu

View File

@ -25,15 +25,16 @@ TEMPLATE = ("template __global__ void Marlin<"
"{{thread_k_blocks}}, " "{{thread_k_blocks}}, "
"{{'true' if m_block_size_8 else 'false'}}, " "{{'true' if m_block_size_8 else 'false'}}, "
"{{stages}}, " "{{stages}}, "
"{{'true' if has_act_order else 'false'}}, "
"{{'true' if has_zp else 'false'}}, "
"{{group_blocks}}, " "{{group_blocks}}, "
"{{'true' if is_zp_float else 'false'}}>" "{{'true' if is_zp_float else 'false'}}>"
"( MARLIN_KERNEL_PARAMS );") "( MARLIN_KERNEL_PARAMS );")
# int8 with zero point case (vllm::kU8) is also supported, # int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size. # we don't add it to reduce wheel size.
SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128"] SCALAR_TYPES = [
"vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn",
"vllm::kFE2M1f"
]
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)] THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4] THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
@ -41,7 +42,7 @@ THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
# = 0 : act order case # = 0 : act order case
# = -1 : channelwise quantization # = -1 : channelwise quantization
# > 0 : group_size=16*group_blocks # > 0 : group_size=16*group_blocks
GROUP_BLOCKS = [0, -1, 2, 4, 8] GROUP_BLOCKS = [0, -1, 1, 2, 4, 8]
DTYPES = ["fp16", "bf16"] DTYPES = ["fp16", "bf16"]
@ -52,21 +53,35 @@ def remove_old_kernels():
def generate_new_kernels(): def generate_new_kernels():
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES): for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
has_zp = "B" not in scalar_type
all_template_str_list = [] all_template_str_list = []
for group_blocks, m_blocks, thread_configs in itertools.product( for group_blocks, m_blocks, thread_configs in itertools.product(
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS): GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS):
has_act_order = group_blocks == 0 # act order case only support gptq-int4 and gptq-int8
if has_zp and has_act_order: if group_blocks == 0 and scalar_type not in [
"vllm::kU4B8", "vllm::kU8B128"
]:
continue continue
if thread_configs[2] == 256: if thread_configs[2] == 256:
# for small batch (m_blocks == 1), we only need (128, 128, 256)
# for large batch (m_blocks > 1), we only need (64, 256, 256)
if m_blocks <= 1 and thread_configs[0] != 128: if m_blocks <= 1 and thread_configs[0] != 128:
continue continue
if m_blocks > 1 and thread_configs[0] != 64: if m_blocks > 1 and thread_configs[0] != 64:
continue continue
# we only support channelwise quantization and group_size == 128
# for fp8
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
continue
# nvfp4 only supports group_size == 16
if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]:
continue
# other quantization methods don't support group_size = 16
if scalar_type != "vllm::kFE2M1f" and group_blocks == 1:
continue
k_blocks = thread_configs[0] // 16 k_blocks = thread_configs[0] // 16
n_blocks = thread_configs[1] // 16 n_blocks = thread_configs[1] // 16
threads = thread_configs[2] threads = thread_configs[2]
@ -82,8 +97,6 @@ def generate_new_kernels():
thread_k_blocks=k_blocks, thread_k_blocks=k_blocks,
m_block_size_8=m_blocks == 0.5, m_block_size_8=m_blocks == 0.5,
stages="pipe_stages", stages="pipe_stages",
has_act_order=has_act_order,
has_zp=has_zp,
group_blocks=group_blocks, group_blocks=group_blocks,
is_zp_float=False, is_zp_float=False,
) )

View File

@ -7,18 +7,19 @@
#include "quantization/gptq_marlin/marlin_dtypes.cuh" #include "quantization/gptq_marlin/marlin_dtypes.cuh"
#include "core/scalar_type.hpp" #include "core/scalar_type.hpp"
#define MARLIN_KERNEL_PARAMS \ #define MARLIN_KERNEL_PARAMS \
const int4 *__restrict__ A, const int4 *__restrict__ B, \ const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \ const int4 *__restrict__ scales_ptr, \
const int *__restrict__ g_idx, \ const uint16_t *__restrict__ scale2_ptr, \
const int32_t *__restrict__ sorted_token_ids_ptr, \ const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
const int32_t *__restrict__ expert_ids_ptr, \ const int32_t *__restrict__ sorted_token_ids_ptr, \
const int32_t *__restrict__ num_tokens_past_padded_ptr, \ const int32_t *__restrict__ expert_ids_ptr, \
const float *__restrict__ topk_weights_ptr, int top_k, \ const int32_t *__restrict__ num_tokens_past_padded_ptr, \
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \ const float *__restrict__ topk_weights_ptr, int top_k, \
int prob_n, int prob_k, int *locks, bool use_atomic_add, \ bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
bool use_fp32_reduce int prob_n, int prob_k, int *locks, bool use_atomic_add, \
bool use_fp32_reduce, int max_shared_mem
namespace MARLIN_NAMESPACE_NAME { namespace MARLIN_NAMESPACE_NAME {
template <typename scalar_t, // compute dtype, half or nv_float16 template <typename scalar_t, // compute dtype, half or nv_float16
@ -33,11 +34,9 @@ template <typename scalar_t, // compute dtype, half or nv_float16
// only works when thread_m_blocks == 1 // only works when thread_m_blocks == 1
const int stages, // number of stages for the async global->shared const int stages, // number of stages for the async global->shared
// fetch pipeline // fetch pipeline
const bool has_act_order, // whether act_order is enabled const int group_blocks, // number of consecutive 16x16 blocks
const bool has_zp, // whether zero-points are enabled // with a separate quantization scale
const int group_blocks, // number of consecutive 16x16 blocks const bool is_zp_float // is zero point of float16 type?
// with a separate quantization scale
const bool is_zp_float // is zero point of float16 type?
> >
__global__ void Marlin(MARLIN_KERNEL_PARAMS); __global__ void Marlin(MARLIN_KERNEL_PARAMS);

View File

@ -25,6 +25,7 @@
#include "quantization/gptq_marlin/marlin.cuh" #include "quantization/gptq_marlin/marlin.cuh"
#include "quantization/gptq_marlin/marlin_dtypes.cuh" #include "quantization/gptq_marlin/marlin_dtypes.cuh"
#include "quantization/gptq_marlin/dequant.h"
#include "core/scalar_type.hpp" #include "core/scalar_type.hpp"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
@ -48,11 +49,9 @@ template <typename scalar_t, // compute dtype, half or nv_float16
// only works when thread_m_blocks == 1 // only works when thread_m_blocks == 1
const int stages, // number of stages for the async global->shared const int stages, // number of stages for the async global->shared
// fetch pipeline // fetch pipeline
const bool has_act_order, // whether act_order is enabled const int group_blocks, // number of consecutive 16x16 blocks
const bool has_zp, // whether zero-points are enabled // with a separate quantization scale
const int group_blocks, // number of consecutive 16x16 blocks const bool is_zp_float // is zero point of float16 type?
// with a separate quantization scale
const bool is_zp_float // is zero point of float16 type?
> >
__global__ void Marlin( __global__ void Marlin(
const int4* __restrict__ A, // fp16 input matrix of shape mxk const int4* __restrict__ A, // fp16 input matrix of shape mxk
@ -77,8 +76,8 @@ __global__ void Marlin(
int prob_k, // reduction dimension k int prob_k, // reduction dimension k
int* locks, // extra global storage for barrier synchronization int* locks, // extra global storage for barrier synchronization
bool use_atomic_add, // whether to use atomic add to reduce bool use_atomic_add, // whether to use atomic add to reduce
bool use_fp32_reduce // whether to use fp32 global reduce bool use_fp32_reduce, // whether to use fp32 global reduce
) {} int max_shared_mem) {}
} // namespace MARLIN_NAMESPACE_NAME } // namespace MARLIN_NAMESPACE_NAME
@ -166,144 +165,6 @@ __device__ inline void ldsm(typename ScalarType<scalar_t>::FragA& frag_a,
} }
} }
// Lookup-table based 3-input logical operation; explicitly used for
// dequantization as the compiler does not seem to automatically recognize it in
// all cases.
template <int lut>
__device__ inline int lop3(int a, int b, int c) {
int res;
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(res)
: "r"(a), "r"(b), "r"(c), "n"(lut));
return res;
}
// Constructs destination register by taking bytes from 2 sources (based on
// mask)
template <int start_byte, int mask>
__device__ inline uint32_t prmt(uint32_t a) {
uint32_t res;
asm volatile("prmt.b32 %0, %1, %2, %3;\n"
: "=r"(res)
: "r"(a), "n"(start_byte), "n"(mask));
return res;
}
template <typename scalar_t, int bit>
__device__ inline typename ScalarType<scalar_t>::FragB dequant(
int q, typename ScalarType<scalar_t>::FragB& frag_b);
//
// Efficiently dequantize 4bit values packed in an int32 value into a full
// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
// with some small changes:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
//
template <>
__device__ inline typename ScalarType<half>::FragB dequant<half, 4>(
int q, typename ScalarType<half>::FragB& frag_b) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const int SUB = 0x64086408;
const int MUL = 0x2c002c00;
const int ADD = 0xd480d480;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&SUB));
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&MUL),
*reinterpret_cast<const half2*>(&ADD));
return frag_b;
}
template <>
__device__ inline typename ScalarType<nv_bfloat16>::FragB
dequant<nv_bfloat16, 4>(int q,
typename ScalarType<nv_bfloat16>::FragB& frag_b) {
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t EX = 0x43004300;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
q >>= 4;
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
static constexpr uint32_t MUL = 0x3F803F80;
static constexpr uint32_t ADD = 0xC308C308;
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
return frag_b;
}
//
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
// bf16 Reference:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
//
template <>
__device__ inline typename ScalarType<half>::FragB dequant<half, 8>(
int q, typename ScalarType<half>::FragB& frag_b) {
static constexpr uint32_t mask_for_elt_01 = 0x5250;
static constexpr uint32_t mask_for_elt_23 = 0x5351;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
return frag_b;
}
template <>
__device__ inline typename ScalarType<nv_bfloat16>::FragB
dequant<nv_bfloat16, 8>(int q,
typename ScalarType<nv_bfloat16>::FragB& frag_b) {
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted =
reinterpret_cast<uint32_t*>(fp32_intermediates);
static constexpr uint32_t fp32_base = 0x4B000000;
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
fp32_intermediates[0] -= 8388736.f;
fp32_intermediates[1] -= 8388736.f;
fp32_intermediates[2] -= 8388736.f;
fp32_intermediates[3] -= 8388736.f;
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&frag_b);
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
fp32_intermediates_casted[1], 0x7632);
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
fp32_intermediates_casted[3], 0x7632);
return frag_b;
}
// Multiply dequantized values by the corresponding quantization scale; used // Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization. // only for grouped quantization.
template <typename scalar_t> template <typename scalar_t>
@ -429,11 +290,9 @@ template <typename scalar_t, // compute dtype, half or nv_float16
// only works when thread_m_blocks == 1 // only works when thread_m_blocks == 1
const int stages, // number of stages for the async global->shared const int stages, // number of stages for the async global->shared
// fetch pipeline // fetch pipeline
const bool has_act_order, // whether act_order is enabled const int group_blocks, // number of consecutive 16x16 blocks
const bool has_zp, // whether zero-points are enabled // with a separate quantization scale
const int group_blocks, // number of consecutive 16x16 blocks const bool is_zp_float // is zero point of float16 type?
// with a separate quantization scale
const bool is_zp_float // is zero point of float16 type?
> >
__global__ void Marlin( __global__ void Marlin(
const int4* __restrict__ A, // fp16 input matrix of shape mxk const int4* __restrict__ A, // fp16 input matrix of shape mxk
@ -442,9 +301,11 @@ __global__ void Marlin(
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
// (k/groupsize)xn // (k/groupsize)xn
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4
// (k/groupsize)x(n/pack_factor) // only)
const int* __restrict__ g_idx, // int32 group indices of shape k const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const int* __restrict__ g_idx, // int32 group indices of shape k
const int32_t* __restrict__ sorted_token_ids_ptr, // moe sorted_ids const int32_t* __restrict__ sorted_token_ids_ptr, // moe sorted_ids
const int32_t* __restrict__ expert_ids_ptr, // moe expert ids const int32_t* __restrict__ expert_ids_ptr, // moe expert ids
const int32_t* __restrict__ num_tokens_past_padded_ptr, // moe num tokens const int32_t* __restrict__ num_tokens_past_padded_ptr, // moe num tokens
@ -458,8 +319,8 @@ __global__ void Marlin(
int prob_k, // reduction dimension k int prob_k, // reduction dimension k
int* locks, // extra global storage for barrier synchronization int* locks, // extra global storage for barrier synchronization
bool use_atomic_add, // whether to use atomic add to reduce bool use_atomic_add, // whether to use atomic add to reduce
bool use_fp32_reduce // whether to use fp32 global reduce bool use_fp32_reduce, // whether to use fp32 global reduce
) { int max_shared_mem) {
// Each threadblock processes one "stripe" of the B matrix with (roughly) the // Each threadblock processes one "stripe" of the B matrix with (roughly) the
// same size, which might involve multiple column "slices" (of width 16 * // same size, which might involve multiple column "slices" (of width 16 *
// `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
@ -481,13 +342,26 @@ __global__ void Marlin(
extern __shared__ int4 sh[]; extern __shared__ int4 sh[];
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8;
constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 ||
w_type == vllm::kU4B8 || w_type == vllm::kU8B128;
// see comments of dequant.h for more details
constexpr bool dequant_skip_flop =
!is_int_type ||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
has_zp && !is_zp_float && !(w_type == vllm::kU8);
scalar_t2 global_scale;
constexpr bool has_act_order = group_blocks == 0;
constexpr int pack_factor = 32 / w_type.size_bits(); constexpr int pack_factor = 32 / w_type.size_bits();
static_assert(thread_m_blocks == 1 || !m_block_size_8); static_assert(thread_m_blocks == 1 || !m_block_size_8);
constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks);
const int group_size = const int group_size =
(!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups; (!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups;
const int scales_expert_stride = prob_n * prob_k / group_size / 8; const int scales_expert_stride =
prob_n * prob_k / group_size / (w_type == vllm::kFE2M1f ? 16 : 8);
const int zp_expert_stride = const int zp_expert_stride =
is_zp_float ? prob_n * prob_k / group_size / 8 is_zp_float ? prob_n * prob_k / group_size / 8
: prob_n * prob_k / group_size / (pack_factor * 4); : prob_n * prob_k / group_size / (pack_factor * 4);
@ -534,13 +408,20 @@ __global__ void Marlin(
int64_t B_expert_off = 0; int64_t B_expert_off = 0;
int4* sh_block_sorted_ids_int4 = sh; int4* sh_block_sorted_ids_int4 = sh;
int4* sh_rd_block_sorted_ids_int4 =
sh_block_sorted_ids_int4 + moe_block_size / 4;
int4* sh_block_topk_weights_int4 =
sh_rd_block_sorted_ids_int4 + moe_block_size / 4;
// sh_block_topk_weights_int4 only need (moe_block_size / 4);
// but we pad to align to 256 bytes
int4* sh_new =
sh_block_topk_weights_int4 + moe_block_size / 2 + moe_block_size;
int32_t* sh_block_sorted_ids = int32_t* sh_block_sorted_ids =
reinterpret_cast<int*>(sh_block_sorted_ids_int4); reinterpret_cast<int*>(sh_block_sorted_ids_int4);
int4* sh_block_topk_weights_int4 = int32_t* sh_rd_block_sorted_ids =
sh_block_sorted_ids_int4 + moe_block_size / 4; reinterpret_cast<int*>(sh_rd_block_sorted_ids_int4);
scalar_t2* sh_block_topk_weights = scalar_t2* sh_block_topk_weights =
reinterpret_cast<scalar_t2*>(sh_block_topk_weights_int4); reinterpret_cast<scalar_t2*>(sh_block_topk_weights_int4);
int4* sh_new = sh_block_topk_weights_int4 + moe_block_size / 4;
int32_t block_num_valid_tokens = 0; int32_t block_num_valid_tokens = 0;
int32_t locks_off = 0; int32_t locks_off = 0;
@ -584,12 +465,24 @@ __global__ void Marlin(
sh_block_sorted_ids_int4[tid4] = reinterpret_cast<const int4*>( sh_block_sorted_ids_int4[tid4] = reinterpret_cast<const int4*>(
sorted_token_ids_ptr)[block_id * moe_block_size / 4 + tid4]; sorted_token_ids_ptr)[block_id * moe_block_size / 4 + tid4];
#pragma unroll
for (int i = 0; i < 4; i++)
sh_rd_block_sorted_ids[tid4 * 4 + i] =
sh_block_sorted_ids[tid4 * 4 + i] / top_k;
if (mul_topk_weights) { if (mul_topk_weights) {
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
sh_block_topk_weights[tid4 * 4 + i] = int idx = tid4 * 4 + i;
Dtype::num2num2(Dtype::float2num( idx = idx < block_num_valid_tokens ? idx : 0;
topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]])); if constexpr (w_type == vllm::kFE2M1f) {
sh_block_topk_weights[idx] = __hmul2(
global_scale, Dtype::num2num2(Dtype::float2num(
topk_weights_ptr[sh_block_sorted_ids[idx]])));
} else {
sh_block_topk_weights[idx] = Dtype::num2num2(
Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[idx]]));
}
} }
} }
} }
@ -620,6 +513,11 @@ __global__ void Marlin(
expert_id = expert_ids_ptr[block_id]; expert_id = expert_ids_ptr[block_id];
} }
if constexpr (w_type == vllm::kFE2M1f) {
uint16_t val = scale2_ptr[expert_id];
global_scale = Dtype::num2num2(*reinterpret_cast<scalar_t*>(&val));
}
B_expert_off = expert_id * prob_n * prob_k / (pack_factor * 4); B_expert_off = expert_id * prob_n * prob_k / (pack_factor * 4);
scales_ptr += (expert_id - old_expert_id) * scales_expert_stride; scales_ptr += (expert_id - old_expert_id) * scales_expert_stride;
if constexpr (has_zp) { if constexpr (has_zp) {
@ -733,7 +631,7 @@ __global__ void Marlin(
constexpr int s_sh_stride = 16 * thread_n_blocks / 8; constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
constexpr int s_tb_groups = constexpr int s_tb_groups =
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
? thread_k_blocks / group_blocks ? thread_k_blocks / group_blocks / (w_type == vllm::kFE2M1f ? 2 : 1)
: 1; : 1;
constexpr int s_sh_stage = s_tb_groups * s_sh_stride; constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
int s_gl_rd_delta = s_gl_stride; int s_gl_rd_delta = s_gl_stride;
@ -743,6 +641,7 @@ __global__ void Marlin(
constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0;
// constexpr int act_s_row_stride = 1; // constexpr int act_s_row_stride = 1;
// int act_s_col_stride = act_s_row_stride * num_groups; // int act_s_col_stride = act_s_row_stride * num_groups;
constexpr int act_s_max_num_groups = 32;
int act_s_col_stride = 1; int act_s_col_stride = 1;
int act_s_col_warp_stride = act_s_col_stride * 8; int act_s_col_warp_stride = act_s_col_stride * 8;
int tb_n_warps = thread_n_blocks / 4; int tb_n_warps = thread_n_blocks / 4;
@ -758,9 +657,9 @@ __global__ void Marlin(
int zp_gl_rd_delta = zp_gl_stride; int zp_gl_rd_delta = zp_gl_stride;
// Global A read index of current thread. // Global A read index of current thread.
int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + int a_gl_rd_row = threadIdx.x / a_gl_rd_delta_o;
(threadIdx.x % a_gl_rd_delta_o); int a_gl_rd_col = a_gl_rd_delta_o * slice_row + threadIdx.x % a_gl_rd_delta_o;
a_gl_rd += a_gl_rd_delta_o * slice_row;
// Shared write index of current thread. // Shared write index of current thread.
int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
(threadIdx.x % a_gl_rd_delta_o); (threadIdx.x % a_gl_rd_delta_o);
@ -774,8 +673,8 @@ __global__ void Marlin(
(threadIdx.x % b_sh_stride_threads) * b_thread_vecs; (threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
b_gl_rd += b_sh_stride * slice_col; b_gl_rd += b_sh_stride * slice_col;
b_gl_rd += b_gl_rd_delta_o * slice_row; b_gl_rd += b_gl_rd_delta_o * slice_row;
int b_sh_wr = threadIdx.x * b_thread_vecs; auto b_sh_wr = threadIdx.x * b_thread_vecs;
int b_sh_rd = threadIdx.x * b_thread_vecs; auto b_sh_rd = threadIdx.x * b_thread_vecs;
// For act_order // For act_order
constexpr int k_iter_size = tb_k / b_sh_wr_iters; constexpr int k_iter_size = tb_k / b_sh_wr_iters;
@ -790,11 +689,12 @@ __global__ void Marlin(
if constexpr (group_blocks == -1) { if constexpr (group_blocks == -1) {
s_gl_rd = s_sh_stride * slice_col + threadIdx.x; s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
} else { } else {
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) /
(w_type == vllm::kFE2M1f ? 2 : 1) +
s_sh_stride * slice_col + threadIdx.x; s_sh_stride * slice_col + threadIdx.x;
} }
} }
int s_sh_wr = threadIdx.x; auto s_sh_wr = threadIdx.x;
bool s_sh_wr_pred = threadIdx.x < s_sh_stride; bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
// Zero-points // Zero-points
@ -807,17 +707,27 @@ __global__ void Marlin(
zp_sh_stride * slice_col + threadIdx.x; zp_sh_stride * slice_col + threadIdx.x;
} }
} }
int zp_sh_wr = threadIdx.x; auto zp_sh_wr = threadIdx.x;
bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride;
// We use a different scale layout for grouped and column-wise quantization as // We use a different scale layout for grouped and column-wise quantization as
// we scale a `half2` tile in column-major layout in the former and in // we scale a `half2` tile in column-major layout in the former and in
// row-major in the latter case. // row-major in the latter case.
int s_sh_rd; int s_sh_rd;
if constexpr (group_blocks != -1) if constexpr (group_blocks != -1 && w_type == vllm::kFE2M1f) {
auto warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 4; (threadIdx.x % 32) / 4;
else if constexpr (group_blocks == -1 && (m_block_size_8 || has_zp)) s_sh_rd = s_sh_rd * 2 + warp_row % 2;
} else if constexpr (group_blocks != -1)
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 4;
else if constexpr (group_blocks == -1 &&
(m_block_size_8 || (has_zp && !dequant_skip_flop)))
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 8; (threadIdx.x % 32) / 8;
else else
@ -851,7 +761,7 @@ __global__ void Marlin(
// each warp must also write a consecutive memory segment? // each warp must also write a consecutive memory segment?
auto transform_a = [&](int i) { auto transform_a = [&](int i) {
int row = i / a_gl_rd_delta_o; int row = i / a_gl_rd_delta_o;
return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ (row % 8);
}; };
// Since the computation of this remapping is non-trivial and, due to our main // Since the computation of this remapping is non-trivial and, due to our main
// loop unrolls, all shared memory accesses are static, we simply precompute // loop unrolls, all shared memory accesses are static, we simply precompute
@ -879,12 +789,28 @@ __global__ void Marlin(
B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
// Shared memory storage for global fetch pipelines. // Shared memory storage for global fetch pipelines.
int4* sh_a = sh_new; constexpr int sh_red_size = (2 * thread_n_blocks + 1) * 16 * thread_m_blocks;
int4* sh_b = sh_a + (stages * a_sh_stage); constexpr int sh_b_size = stages * b_sh_stage;
int4* sh_g_idx = sh_b + (stages * b_sh_stage); int4* sh_b = sh_new;
int4* sh_red = sh_new;
int4* sh_g_idx = sh_b + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
int4* sh_zp = sh_g_idx + (stages * g_idx_stage); int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride)
: (stages * s_sh_stage);
int4* sh_s = sh_zp + (stages * zp_sh_stage); int4* sh_s = sh_zp + (stages * zp_sh_stage);
int4* sh_red = sh_b; // shared memory reused by reduction should be smaller than
// shared memory used by weight.
static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <=
stages * b_sh_stage);
int4* sh_a = sh_s + sh_s_size;
constexpr int shm_size_used =
moe_block_size + stages * (g_idx_stage + zp_sh_stage) + sh_s_size +
(sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
// all remaining shared memory is used to cache A (input)
// sh_a_max_row is at least ` stages * 16 * thread_m_blocks `
int sh_a_max_row =
((max_shared_mem - 1024) / 16 - shm_size_used) / (thread_k_blocks * 2);
// Register storage for double buffer of shared memory reads. // Register storage for double buffer of shared memory reads.
FragA frag_a[2][thread_m_blocks]; FragA frag_a[2][thread_m_blocks];
@ -905,15 +831,14 @@ __global__ void Marlin(
int sh_first_group_id = -1; int sh_first_group_id = -1;
int sh_num_groups = -1; int sh_num_groups = -1;
constexpr int sh_max_num_groups = 32;
auto fetch_act_order_scales_to_shared = [&](bool is_async, int first_group_id, auto fetch_act_order_scales_to_shared = [&](bool is_async, int first_group_id,
int last_group_id) { int last_group_id) {
sh_first_group_id = first_group_id; sh_first_group_id = first_group_id;
sh_num_groups = last_group_id - first_group_id + 1; sh_num_groups = last_group_id - first_group_id + 1;
if (sh_num_groups < sh_max_num_groups) { if (sh_num_groups > act_s_max_num_groups) {
sh_num_groups = sh_max_num_groups; sh_num_groups = act_s_max_num_groups;
} }
if (sh_first_group_id + sh_num_groups > num_groups) { if (sh_first_group_id + sh_num_groups > num_groups) {
@ -940,27 +865,31 @@ __global__ void Marlin(
} }
} }
}; };
// Asynchronously fetch the next A, B and s tile from global to the next // Asynchronously fetch the next A, B and s tile from global to the next
// shared memory pipeline location. // shared memory pipeline location.
int a_remaining_load_count_in_slice = stages; bool should_load_a = true;
auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { int max_num_stage_groups =
((sh_a_max_row - moe_block_size) / moe_block_size + 1) / stages;
max_num_stage_groups = max(max_num_stage_groups, 1);
auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true,
int pipe_a = 0) {
if (pred) { if (pred) {
int4* sh_a_stage = sh_a + a_sh_stage * pipe; if (should_load_a) {
if (prob_k > thread_k_blocks * 16 * stages || slice_col == 0 || int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe_a;
a_remaining_load_count_in_slice > 0) {
a_remaining_load_count_in_slice--;
#pragma unroll #pragma unroll
for (int i = 0; i < a_sh_wr_iters; i++) { for (int i = 0; i < a_sh_wr_iters; i++) {
int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off; int row = a_gl_rd_delta_i / a_gl_stride * i + a_gl_rd_row;
int row = a_idx / a_gl_stride;
int64_t sorted_row = 0; int64_t sorted_row = 0;
if (!m_block_size_8 || row < 8) if (!m_block_size_8 || row < 8)
sorted_row = sh_block_sorted_ids[row] / top_k; sorted_row = sh_rd_block_sorted_ids[row];
int64_t true_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride; int64_t true_idx =
sorted_row * a_gl_stride + a_gl_rd_col + a_gl_rd_delta_o * a_off;
cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[true_idx], cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[true_idx],
row < block_num_valid_tokens); row < block_num_valid_tokens);
} }
} }
int4* sh_b_stage = sh_b + b_sh_stage * pipe; int4* sh_b_stage = sh_b + b_sh_stage * pipe;
#pragma unroll #pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) { for (int i = 0; i < b_sh_wr_iters; i++) {
@ -1063,8 +992,8 @@ __global__ void Marlin(
// Load the next sub-tile from the current location in the shared memory pipe // Load the next sub-tile from the current location in the shared memory pipe
// into the current register buffer. // into the current register buffer.
auto fetch_to_registers = [&](int k, int pipe) { auto fetch_to_registers = [&](int k, int pipe, int pipe_a = 0) {
int4* sh_a_stage = sh_a + a_sh_stage * pipe; int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe_a;
#pragma unroll #pragma unroll
for (int i = 0; i < thread_m_blocks; i++) for (int i = 0; i < thread_m_blocks; i++)
ldsm<m_block_size_8 ? 2 : 4, scalar_t>( ldsm<m_block_size_8 ? 2 : 4, scalar_t>(
@ -1109,12 +1038,17 @@ __global__ void Marlin(
} }
} else if constexpr (group_blocks != -1) { } else if constexpr (group_blocks != -1) {
if constexpr (group_blocks >= thread_k_blocks) { if constexpr (group_blocks >= thread_k_blocks) {
int4* sh_s_stage = if (k % b_sh_wr_iters == 0) {
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * int4* sh_s_stage =
(pipe / (group_blocks / thread_k_blocks))); sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; (pipe / (group_blocks / thread_k_blocks)));
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
} else {
reinterpret_cast<int4*>(&frag_s[1])[0] =
reinterpret_cast<int4*>(&frag_s[0])[0];
}
} else { } else {
int warp_id = threadIdx.x / 32; auto warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4; int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps; int warp_row = warp_id / n_warps;
@ -1123,12 +1057,19 @@ __global__ void Marlin(
cur_k += k_iter_size * (k % b_sh_wr_iters); cur_k += k_iter_size * (k % b_sh_wr_iters);
int k_blocks = cur_k / 16; int k_blocks = cur_k / 16;
int cur_group_id = k_blocks / group_blocks; int cur_group_id =
k_blocks / (group_blocks * (w_type == vllm::kFE2M1f ? 2 : 1));
int4* sh_s_stage = sh_s + s_sh_stage * pipe; int4* sh_s_stage = sh_s + s_sh_stage * pipe;
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = if constexpr (w_type_id != vllm::kFE2M1f.id()) {
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
} else {
reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
reinterpret_cast<int2*>(
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
}
} }
} }
@ -1152,7 +1093,7 @@ __global__ void Marlin(
// Determine "position" inside the thread-block (based on warp and // Determine "position" inside the thread-block (based on warp and
// thread-id) // thread-id)
int warp_id = threadIdx.x / 32; auto warp_id = threadIdx.x / 32;
int n_warps = int n_warps =
thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N
@ -1161,7 +1102,7 @@ __global__ void Marlin(
cur_k += warp_row * 16; cur_k += warp_row * 16;
int th_id = threadIdx.x % 32; auto th_id = threadIdx.x % 32;
cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix
int s_col_shift = int s_col_shift =
@ -1222,15 +1163,18 @@ __global__ void Marlin(
} }
} else if constexpr (group_blocks >= thread_k_blocks) { } else if constexpr (group_blocks >= thread_k_blocks) {
int4* sh_zp_stage = if (k % b_sh_wr_iters == 0) {
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * int4* sh_zp_stage =
(pipe / (group_blocks / thread_k_blocks))); sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
for (int i = 0; i < num_ints_per_thread; i++) { (pipe / (group_blocks / thread_k_blocks)));
frag_qzp[k % 2][i] = #pragma unroll
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i]; for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] =
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
}
} }
} else { } else {
int warp_id = threadIdx.x / 32; auto warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4; int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps; int warp_row = warp_id / n_warps;
@ -1251,6 +1195,7 @@ __global__ void Marlin(
sh_zp_stage += cur_group_id * zp_sh_stride; sh_zp_stage += cur_group_id * zp_sh_stride;
#pragma unroll
for (int i = 0; i < num_ints_per_thread; i++) { for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] = frag_qzp[k % 2][i] =
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i]; (reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
@ -1263,12 +1208,16 @@ __global__ void Marlin(
if constexpr (group_blocks != -1) { if constexpr (group_blocks != -1) {
if constexpr (group_blocks >= thread_k_blocks) { if constexpr (group_blocks >= thread_k_blocks) {
int4* sh_zp_stage = if (k % b_sh_wr_iters == 0) {
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * int4* sh_zp_stage =
(pipe / (group_blocks / thread_k_blocks))); sh_zp +
reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd]; zp_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] =
sh_zp_stage[zp_sh_rd];
}
} else { } else {
int warp_id = threadIdx.x / 32; auto warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4; int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps; int warp_row = warp_id / n_warps;
@ -1292,6 +1241,10 @@ __global__ void Marlin(
} }
}; };
auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) {
dequant<scalar_t2, w_type_id, dequant_skip_flop>(q, frag_b_ptr);
};
// Execute the actual tensor core matmul of a sub-tile. // Execute the actual tensor core matmul of a sub-tile.
bool is_first_matmul_in_slice = true; bool is_first_matmul_in_slice = true;
auto matmul = [&](int k) { auto matmul = [&](int k) {
@ -1315,15 +1268,27 @@ __global__ void Marlin(
zp_quant_1 = frag_qzp[k2][1]; zp_quant_1 = frag_qzp[k2][1];
} }
dequant<scalar_t, w_type.size_bits()>(zp_quant_0, frag_zp_0); dequant_data(zp_quant_0, reinterpret_cast<scalar_t2*>(&frag_zp));
dequant<scalar_t, w_type.size_bits()>(zp_quant_1, frag_zp_1); dequant_data(zp_quant_1, reinterpret_cast<scalar_t2*>(&frag_zp) + 2);
frag_zp[0] = frag_zp_0[0];
frag_zp[1] = frag_zp_0[1];
frag_zp[2] = frag_zp_1[0];
frag_zp[3] = frag_zp_1[1];
} }
} }
if constexpr (!dequant_skip_flop && has_zp && is_zp_float) {
if (is_new_zp) {
reinterpret_cast<int4*>(&frag_zp)[0] =
reinterpret_cast<int4*>(&frag_zpf[k2])[0];
}
}
if constexpr (w_type == vllm::kFE2M1f) {
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];
dequant_fp8_scales<scalar_t2>(s_quant_0,
reinterpret_cast<scalar_t2*>(&frag_s[k2]));
dequant_fp8_scales<scalar_t2>(
s_quant_1, reinterpret_cast<scalar_t2*>(&frag_s[k2]) + 2);
}
// We have the m dimension as the inner loop in order to encourage overlapping // We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations. // dequantization and matmul operations.
#pragma unroll #pragma unroll
@ -1332,7 +1297,10 @@ __global__ void Marlin(
FragB frag_b1; FragB frag_b1;
int b_quant_0, b_quant_1; int b_quant_0, b_quant_1;
if constexpr (w_type.size_bits() == 4) { if constexpr (w_type_id == vllm::kFE2M1f.id()) {
b_quant_1 = frag_b_quant[k2][0][j];
b_quant_0 = b_quant_1 << 8;
} else if constexpr (w_type.size_bits() == 4) {
b_quant_0 = frag_b_quant[k2][0][j]; b_quant_0 = frag_b_quant[k2][0][j];
b_quant_1 = b_quant_0 >> 8; b_quant_1 = b_quant_0 >> 8;
} else { } else {
@ -1342,8 +1310,13 @@ __global__ void Marlin(
b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
} }
dequant<scalar_t, w_type.size_bits()>(b_quant_0, frag_b0); dequant_data(b_quant_0, reinterpret_cast<scalar_t2*>(&frag_b0));
dequant<scalar_t, w_type.size_bits()>(b_quant_1, frag_b1); dequant_data(b_quant_1, reinterpret_cast<scalar_t2*>(&frag_b1));
if constexpr (dequant_skip_flop && has_zp && !is_zp_float) {
sub_zp<scalar_t>(frag_b0, frag_zp[j], 0);
sub_zp<scalar_t>(frag_b1, frag_zp[j], 1);
}
// Apply scale to frag_b0 // Apply scale to frag_b0
if constexpr (has_act_order) { if constexpr (has_act_order) {
@ -1351,9 +1324,9 @@ __global__ void Marlin(
scale4<scalar_t>(frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j], scale4<scalar_t>(frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0);
scale4<scalar_t>(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], scale4<scalar_t>(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
act_frag_s[k][2][j], act_frag_s[k2][3][j], 1); act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1);
} else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float &&
} else if constexpr (has_zp && !is_zp_float && group_blocks == -1) { group_blocks == -1) {
int idx = (threadIdx.x / 4) % 2; int idx = (threadIdx.x / 4) % 2;
scalar_t2 s2 = Dtype::nums2num2( scalar_t2 s2 = Dtype::nums2num2(
reinterpret_cast<scalar_t*>(&frag_s[j / 2][j % 2 * 2 + 0])[idx], reinterpret_cast<scalar_t*>(&frag_s[j / 2][j % 2 * 2 + 0])[idx],
@ -1361,18 +1334,12 @@ __global__ void Marlin(
if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2); if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2);
scale_and_sub<scalar_t>(frag_b0, s2.x, frag_zp[j].x); scale_and_sub<scalar_t>(frag_b0, s2.x, frag_zp[j].x);
scale_and_sub<scalar_t>(frag_b1, s2.y, frag_zp[j].y); scale_and_sub<scalar_t>(frag_b1, s2.y, frag_zp[j].y);
} else if constexpr (has_zp && !is_zp_float && group_blocks != -1) { } else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) {
if (is_new_zp) if (is_new_zp)
frag_zp[j] = __hmul2(frag_zp[j], frag_zp[j] = __hmul2(frag_zp[j],
*reinterpret_cast<scalar_t2*>(&frag_s[k2][j])); *reinterpret_cast<scalar_t2*>(&frag_s[k2][j]));
scale_and_sub<scalar_t>(frag_b0, frag_s[k % 2][j][0].x, frag_zp[j].x); scale_and_sub<scalar_t>(frag_b0, frag_s[k2][j][0].x, frag_zp[j].x);
scale_and_sub<scalar_t>(frag_b1, frag_s[k % 2][j][0].y, frag_zp[j].y); scale_and_sub<scalar_t>(frag_b1, frag_s[k2][j][0].y, frag_zp[j].y);
} else if constexpr (has_zp && is_zp_float && group_blocks != -1) {
if (is_new_zp)
frag_zpf[k2][j] = __hmul2(
frag_zpf[k2][j], *reinterpret_cast<scalar_t2*>(&frag_s[k2][j]));
scale_and_sub<scalar_t>(frag_b0, frag_s[k2][j].x, frag_zpf[k2][j].x);
scale_and_sub<scalar_t>(frag_b1, frag_s[k2][j].y, frag_zpf[k2][j].y);
} else if constexpr (group_blocks != -1) { } else if constexpr (group_blocks != -1) {
scale<scalar_t>(frag_b0, frag_s[k2][j], 0); scale<scalar_t>(frag_b0, frag_s[k2][j], 0);
scale<scalar_t>(frag_b1, frag_s[k2][j], 1); scale<scalar_t>(frag_b1, frag_s[k2][j], 1);
@ -1397,7 +1364,7 @@ __global__ void Marlin(
auto thread_block_reduce = [&]() { auto thread_block_reduce = [&]() {
constexpr int red_off = threads / b_sh_stride_threads / 2; constexpr int red_off = threads / b_sh_stride_threads / 2;
if (red_off >= 1) { if (red_off >= 1) {
int red_idx = threadIdx.x / b_sh_stride_threads; auto red_idx = threadIdx.x / b_sh_stride_threads;
constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
constexpr int red_sh_delta = b_sh_stride_threads; constexpr int red_sh_delta = b_sh_stride_threads;
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
@ -1634,10 +1601,17 @@ __global__ void Marlin(
// For per-column quantization we finally apply the scale here (only for // For per-column quantization we finally apply the scale here (only for
// 4-bit) // 4-bit)
if constexpr (!has_act_order && group_blocks == -1 && if constexpr (!has_act_order && group_blocks == -1 &&
w_type.size_bits() == 4 && !has_zp) { w_type.size_bits() == 4 &&
(has_zp && dequant_skip_flop || !has_zp)) {
res = __hmul2(res, s[0]); res = __hmul2(res, s[0]);
} }
if constexpr (w_type == vllm::kFE2M1f) {
if (!mul_topk_weights) {
res = __hmul2(res, global_scale);
}
}
if constexpr (m_block_size_8) { if constexpr (m_block_size_8) {
((scalar_t*)sh_red)[idx] = res.x; ((scalar_t*)sh_red)[idx] = res.x;
((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y; ((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y;
@ -1728,10 +1702,12 @@ __global__ void Marlin(
if constexpr (has_zp && !is_zp_float && group_blocks == -1) { if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
if (i == 0) { if (i == 0) {
fetch_col_zp_to_shared(); fetch_col_zp_to_shared();
fetch_col_scale_to_shared(); if constexpr (!dequant_skip_flop) {
fetch_col_scale_to_shared();
}
} }
} }
fetch_to_shared(i, i, i < slice_iters); fetch_to_shared(i, i, i < slice_iters, i);
} }
zero_accums(); zero_accums();
@ -1740,8 +1716,10 @@ __global__ void Marlin(
fetch_to_registers(0, 0); fetch_to_registers(0, 0);
fetch_scales_to_registers(0, 0); fetch_scales_to_registers(0, 0);
fetch_zp_to_registers(0, 0); fetch_zp_to_registers(0, 0);
a_gl_rd += a_gl_rd_delta_o * (stages - 1); a_gl_rd_col += a_gl_rd_delta_o * (stages - 1);
slice_k_start_shared_fetch += tb_k * (stages - 1); if constexpr (has_act_order) {
slice_k_start_shared_fetch += tb_k * (stages - 1);
}
}; };
if (slice_iters) { if (slice_iters) {
start_pipes(); start_pipes();
@ -1754,45 +1732,61 @@ __global__ void Marlin(
// have even length meaning that the next iteration will always start at // have even length meaning that the next iteration will always start at
// index 0. // index 0.
for (int stage_group_id = 0; stage_group_id < max_num_stage_groups;
stage_group_id++) {
#pragma unroll #pragma unroll
for (int pipe = 0; pipe < stages;) { for (int pipe = 0; pipe < stages;) {
#pragma unroll #pragma unroll
for (int k = 0; k < b_sh_wr_iters; k++) { for (int k = 0; k < b_sh_wr_iters; k++) {
fetch_to_registers(k + 1, pipe % stages); int idx =
fetch_scales_to_registers(k + 1, pipe); (pipe >= stages && stage_group_id == max_num_stage_groups - 1)
fetch_zp_to_registers(k + 1, pipe); ? (pipe - stages)
if (k == b_sh_wr_iters - 2) { : (pipe + stage_group_id * stages);
fetch_to_shared((pipe + stages - 1) % stages, pipe, fetch_to_registers(k + 1, pipe % stages, idx);
slice_iters >= stages); fetch_scales_to_registers(k + 1, pipe);
pipe++; fetch_zp_to_registers(k + 1, pipe);
wait_for_stage(); if (k == b_sh_wr_iters - 2) {
init_same_group(pipe % stages); int idx = (pipe >= 1 && stage_group_id == max_num_stage_groups - 1)
? (pipe - 1)
: (pipe + (stage_group_id + 1) * stages - 1);
fetch_to_shared((pipe + stages - 1) % stages, pipe,
slice_iters >= stages, idx);
pipe++;
wait_for_stage();
init_same_group(pipe % stages);
}
matmul(k);
}
slice_iters--;
if (slice_iters == 0) {
break;
}
}
a_gl_rd_col += a_gl_rd_delta_o * stages;
if constexpr (has_act_order) {
slice_k_start += tb_k * stages;
if (slice_k_start < prob_k) {
slice_k_start_shared_fetch += tb_k * stages;
int first_group_id = g_idx[slice_k_start];
int last_g_idx = slice_k_start + stages * tb_k * 2;
if (last_g_idx >= prob_k) {
last_g_idx = prob_k - 1;
}
int last_group_id = g_idx[last_g_idx];
if (last_group_id >= sh_first_group_id + sh_num_groups) {
fetch_act_order_scales_to_shared(false, first_group_id,
last_group_id);
__syncthreads();
}
} }
matmul(k);
} }
slice_iters--;
if (slice_iters == 0) { if (slice_iters == 0) {
break; break;
} }
} }
a_remaining_load_count_in_slice = 0;
a_gl_rd += a_gl_rd_delta_o * stages;
slice_k_start += tb_k * stages;
slice_k_start_shared_fetch += tb_k * stages;
if constexpr (has_act_order) {
int first_group_id = g_idx[slice_k_start];
int last_g_idx = slice_k_start + stages * tb_k * 2;
if (last_g_idx >= prob_k) {
last_g_idx = prob_k - 1;
}
int last_group_id = g_idx[last_g_idx];
if (last_group_id >= sh_first_group_id + sh_num_groups) {
fetch_act_order_scales_to_shared(false, first_group_id, last_group_id);
__syncthreads();
}
}
// Process results and, if necessary, proceed to the next column slice. // Process results and, if necessary, proceed to the next column slice.
// While this pattern may not be the most readable, other ways of writing // While this pattern may not be the most readable, other ways of writing
@ -1802,7 +1796,8 @@ __global__ void Marlin(
bool last = slice_idx == slice_count - 1; bool last = slice_idx == slice_count - 1;
// For per-column scales, we only fetch them here in the final step before // For per-column scales, we only fetch them here in the final step before
// write-out // write-out
if constexpr (!has_act_order && group_blocks == -1 && !has_zp) { if constexpr (!has_act_order && group_blocks == -1 &&
(has_zp && dequant_skip_flop || !has_zp)) {
if (w_type.size_bits() == 8 || (last || use_atomic_add)) { if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
if (s_sh_wr_pred) { if (s_sh_wr_pred) {
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
@ -1812,7 +1807,8 @@ __global__ void Marlin(
} }
thread_block_reduce(); thread_block_reduce();
if constexpr (!has_act_order && group_blocks == -1 && !has_zp) { if constexpr (!has_act_order && group_blocks == -1 &&
(has_zp && dequant_skip_flop || !has_zp)) {
if (w_type.size_bits() == 8 || (last || use_atomic_add)) { if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
cp_async_wait<0>(); cp_async_wait<0>();
__syncthreads(); __syncthreads();
@ -1836,7 +1832,8 @@ __global__ void Marlin(
// that converts the fp32 results to fp16 (so that we avoid possible // that converts the fp32 results to fp16 (so that we avoid possible
// overflow in fp16) // overflow in fp16)
if constexpr (!has_act_order && group_blocks == -1 && if constexpr (!has_act_order && group_blocks == -1 &&
w_type.size_bits() == 8 && !has_zp) { w_type.size_bits() == 8 &&
(has_zp && dequant_skip_flop || !has_zp)) {
if (threadIdx.x / 32 < thread_n_blocks / 4) { if (threadIdx.x / 32 < thread_n_blocks / 4) {
#pragma unroll #pragma unroll
for (int i = 0; i < thread_m_blocks; i++) { for (int i = 0; i < thread_m_blocks; i++) {
@ -1877,15 +1874,30 @@ __global__ void Marlin(
if (last || use_atomic_add) if (last || use_atomic_add)
// only the last block in a slice actually writes the result // only the last block in a slice actually writes the result
write_result(); write_result();
if (slice_row) a_remaining_load_count_in_slice = stages; int old_slice_row = slice_row;
slice_row = 0; slice_row = 0;
slice_col_par++; slice_col_par++;
slice_col++; slice_col++;
is_first_matmul_in_slice = true; is_first_matmul_in_slice = true;
init_slice(); init_slice();
// Should we load A matrix in next slice?
// `slice_col == 0`: when move to a new moe block
// `old_slice_row > 0`:
// when the last slice is not starting from k_index == 0
// (only happen when it is the first slice of a threadblock)
// `prob_k > thread_k_blocks * 16 * stages * max_num_stage_groups`:
// when the required shared memory size is larger than
// the remaining shared memory
if (slice_col == 0 || old_slice_row ||
prob_k > thread_k_blocks * 16 * stages * max_num_stage_groups) {
should_load_a = true;
} else {
should_load_a = false;
}
if (slice_iters) { if (slice_iters) {
a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + a_gl_rd_col = (threadIdx.x % a_gl_rd_delta_o);
(threadIdx.x % a_gl_rd_delta_o);
#pragma unroll #pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) for (int i = 0; i < b_sh_wr_iters; i++)
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
@ -1900,12 +1912,10 @@ __global__ void Marlin(
slice_k_finish = slice_k_start + tb_k * slice_iters; slice_k_finish = slice_k_start + tb_k * slice_iters;
slice_k_start_shared_fetch = slice_k_start; slice_k_start_shared_fetch = slice_k_start;
slice_n_offset = act_s_col_tb_stride * slice_col; slice_n_offset = act_s_col_tb_stride * slice_col;
} else { } else {
s_gl_rd = s_sh_stride * slice_col + threadIdx.x; s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
} }
start_pipes(); start_pipes();
} }
} }

View File

@ -116,7 +116,7 @@ __global__ void permute_cols_kernel(
int base_k = 0; int base_k = 0;
for (int i = 0; i < iters; i++) { for (int i = 0; i < iters; i++) {
int cur_k = base_k + threadIdx.x; auto cur_k = base_k + threadIdx.x;
int src_pos = perm_int_ptr[cur_k]; int src_pos = perm_int_ptr[cur_k];
out_half[cur_k] = a_row_half[src_pos]; out_half[cur_k] = a_row_half[src_pos];
@ -126,7 +126,7 @@ __global__ void permute_cols_kernel(
if (rest) { if (rest) {
if (threadIdx.x < rest) { if (threadIdx.x < rest) {
int cur_k = base_k + threadIdx.x; auto cur_k = base_k + threadIdx.x;
int src_pos = perm_int_ptr[cur_k]; int src_pos = perm_int_ptr[cur_k];
out_half[cur_k] = a_row_half[src_pos]; out_half[cur_k] = a_row_half[src_pos];
@ -195,7 +195,6 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K
load_groups = max(load_groups, 32); // We load at least 32 scale groups load_groups = max(load_groups, 32); // We load at least 32 scale groups
return load_groups * tb_n * 2; return load_groups * tb_n * 2;
} else { } else {
int tb_scales = tb_groups * tb_n * 2; int tb_scales = tb_groups * tb_n * 2;
@ -203,22 +202,24 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
} }
} }
int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks, int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
int prob_m, int prob_n, int prob_k, int num_bits, int thread_m_blocks, int prob_m, int prob_n,
int group_size, bool has_act_order, bool is_k_full, int prob_k, int num_bits, int group_size,
int has_zp, int is_zp_float) { bool has_act_order, bool is_k_full, int has_zp,
int is_zp_float) {
int pack_factor = 32 / num_bits; int pack_factor = 32 / num_bits;
// Get B size // Get B size
int tb_k = th_config.thread_k; int tb_k = th_config.thread_k;
int tb_n = th_config.thread_n; int tb_n = th_config.thread_n;
int tb_m = thread_m_blocks * 16; int tb_m = thread_m_blocks * (m_block_size_8 ? 8 : 16);
// shm size for block_sorted_ids/block_topk_weights // shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights
// both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32) // both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)
int sh_block_meta_size = tb_m * 4 * 2; int sh_block_meta_size = tb_m * 4;
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2; int sh_a_size = pipe_stages * (tb_m * tb_k) * 2;
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4; int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
int sh_red_size = tb_m * (tb_n + 8) * 2;
int sh_s_size = int sh_s_size =
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
group_size, has_act_order, is_k_full); group_size, has_act_order, is_k_full);
@ -233,16 +234,17 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
sh_zp_size = sh_s_size / 2; sh_zp_size = sh_s_size / 2;
} }
int total_size = sh_a_size + sh_b_size + sh_s_size + sh_zp_size + int total_size = max(sh_b_size, sh_red_size) + sh_a_size + sh_s_size +
sh_g_idx_size + sh_block_meta_size; sh_zp_size + sh_g_idx_size + sh_block_meta_size;
return total_size; return total_size;
} }
bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks, bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
int prob_m, int prob_n, int prob_k, int num_bits, int thread_m_blocks, int prob_m, int prob_n, int prob_k,
int group_size, bool has_act_order, bool is_k_full, int num_bits, int group_size, bool has_act_order,
int has_zp, int is_zp_float, int max_shared_mem) { bool is_k_full, int has_zp, int is_zp_float,
int max_shared_mem) {
// Sanity // Sanity
if (th_config.thread_k == -1 || th_config.thread_n == -1 || if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
th_config.num_threads == -1) { th_config.num_threads == -1) {
@ -266,143 +268,129 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
// Check that pipeline fits into cache // Check that pipeline fits into cache
int cache_size = get_kernel_cache_size( int cache_size = get_kernel_cache_size(
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k,
has_act_order, is_k_full, has_zp, is_zp_float); num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float);
return cache_size <= max_shared_mem; return cache_size <= max_shared_mem;
} }
#define __GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ #define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
M_BLOCK_SIZE_8, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \ M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \
NUM_THREADS, IS_ZP_FLOAT) \ else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ thread_n_blocks == THREAD_N_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \ thread_k_blocks == THREAD_K_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \ m_block_size_8 == M_BLOCK_SIZE_8 && \
m_block_size_8 == M_BLOCK_SIZE_8 && \ group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ is_zp_float == IS_ZP_FLOAT) { \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ kernel = Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
is_zp_float == IS_ZP_FLOAT) { \ THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, \
kernel = Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \ pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, \
pipe_stages, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
IS_ZP_FLOAT>; \
} }
#define GPTQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ // COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false)
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, true, false, 0, NUM_THREADS, \ // this is the most common cases
false) \ // BIGGROUP: cases for big group size (group_blocks in [-1, 8])
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, false, 0, \ // FZP: cases for float-zero-point (is_zp_float = true)
NUM_THREADS, false) \ // ACT: cases for act order case (group_blocks == 0)
\ // FP4: cases for nvfp4(e2m1) (group_blocks == 1)
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, -1, \ #define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
NUM_THREADS, false) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 2, \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
NUM_THREADS, false) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 4, \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
NUM_THREADS, false) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 8, \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
NUM_THREADS, false) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, -1, \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
NUM_THREADS, false)
#define GPTQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ #define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, false, 0, \ _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
NUM_THREADS, false) \ _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, false, 0, \ _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
NUM_THREADS, false) \ _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, false, 0, \ \
NUM_THREADS, false) \ _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
\ _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, -1, \ _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
NUM_THREADS, false) \ _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 2, \ \
NUM_THREADS, false) \ _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 4, \ _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
NUM_THREADS, false) \ _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 8, \ _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
NUM_THREADS, false)
#define AWQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ #define COMMON_GET_IF(W_TYPE) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, -1, \ COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \
NUM_THREADS, false) \ COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 2, NUM_THREADS, \ COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \
false) \ COMMON_GET_IF_M234(W_TYPE, 8, 4, 128)
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \
false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 8, NUM_THREADS, \
false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
NUM_THREADS, false)
#define AWQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ #define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, -1, \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
NUM_THREADS, false) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 2, \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
NUM_THREADS, false) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, false) \ #define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 8, \ _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
NUM_THREADS, false) \ _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
\ _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, -1, \ _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
NUM_THREADS, false) \ \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 2, \ _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
NUM_THREADS, false) \ _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, false) \ #define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 8, \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \
NUM_THREADS, false) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
\
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, -1, \ #define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
NUM_THREADS, false) \ _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 2, \ _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
NUM_THREADS, false) \ _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, false) \ #define FP4_GET_IF(W_TYPE) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 8, \ FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
NUM_THREADS, false) FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
FP4_GET_IF_M234(W_TYPE, 8, 4, 128)
#define BIGGROUP_GET_IF(W_TYPE) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \
BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \
BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128)
// We currently have 4-bit models only with group_blocks == 4 // We currently have 4-bit models only with group_blocks == 4
#define HQQ_GET_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ #define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \
true) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, true) \ #define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
NUM_THREADS, true) \ _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
NUM_THREADS, true) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ #define FZP_GET_IF(W_TYPE) \
NUM_THREADS, true) FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \
FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \
FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \
FZP_GET_IF_M234(W_TYPE, 8, 4, 128)
// We currently have 4-bit models only with group_blocks == 4
#define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
#define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
#define ACT_GET_IF(W_TYPE) \
ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \
ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \
ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \
ACT_GET_IF_M234(W_TYPE, 8, 4, 128)
template <typename scalar_t> template <typename scalar_t>
MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type, MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
@ -415,23 +403,17 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
auto kernel = MarlinDefault; auto kernel = MarlinDefault;
if (false) { if (false) {
} }
GPTQ_GET_IF_M1(vllm::kU4B8, 8, 8, 256)
GPTQ_GET_IF_M1(vllm::kU4B8, 8, 4, 128)
GPTQ_GET_IF_M234(vllm::kU4B8, 16, 4, 256) COMMON_GET_IF(vllm::kU4)
GPTQ_GET_IF_M234(vllm::kU4B8, 8, 4, 128) COMMON_GET_IF(vllm::kU4B8)
COMMON_GET_IF(vllm::kU8B128)
GPTQ_GET_IF_M1(vllm::kU8B128, 8, 8, 256) BIGGROUP_GET_IF(vllm::kFE4M3fn)
GPTQ_GET_IF_M1(vllm::kU8B128, 8, 4, 128)
GPTQ_GET_IF_M234(vllm::kU8B128, 16, 4, 256) FP4_GET_IF(vllm::kFE2M1f)
GPTQ_GET_IF_M234(vllm::kU8B128, 8, 4, 128)
AWQ_GET_IF_M1(vllm::kU4, 8, 8, 256) ACT_GET_IF(vllm::kU4B8)
AWQ_GET_IF_M1(vllm::kU4, 8, 4, 128) ACT_GET_IF(vllm::kU8B128)
AWQ_GET_IF_M234(vllm::kU4, 16, 4, 256)
AWQ_GET_IF_M234(vllm::kU4, 8, 4, 128)
return kernel; return kernel;
} }
@ -457,19 +439,19 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
for (int i = 0; i < thread_configs_size; i++) { for (int i = 0; i < thread_configs_size; i++) {
thread_config_t th_config = thread_configs[i]; thread_config_t th_config = thread_configs[i];
if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k, if (!is_valid_config(th_config, m_block_size_8, thread_m_blocks, prob_m,
num_bits, group_size, has_act_order, is_k_full, has_zp, prob_n, prob_k, num_bits, group_size, has_act_order,
is_zp_float, max_shared_mem)) { is_k_full, has_zp, is_zp_float, max_shared_mem)) {
continue; continue;
} }
int cache_size = get_kernel_cache_size( int cache_size = get_kernel_cache_size(
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k,
group_size, has_act_order, is_k_full, has_zp, is_zp_float); num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float);
int group_blocks = 0; int group_blocks = 0;
if (!has_act_order) { if (!has_act_order) {
group_blocks = group_size == -1 ? -1 : group_size / 16; group_blocks = group_size == -1 ? -1 : (group_size / 16);
} }
auto kernel = get_marlin_kernel<scalar_t>( auto kernel = get_marlin_kernel<scalar_t>(
@ -501,7 +483,7 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
template <typename scalar_t> template <typename scalar_t>
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
void* zp, void* g_idx, void* perm, void* a_tmp, void* s2, void* zp, void* g_idx, void* perm, void* a_tmp,
void* sorted_token_ids, void* expert_ids, void* sorted_token_ids, void* expert_ids,
void* num_tokens_past_padded, void* topk_weights, void* num_tokens_past_padded, void* topk_weights,
int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep, int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep,
@ -520,8 +502,10 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
"q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str()); "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
} else { } else {
TORCH_CHECK( TORCH_CHECK(
q_type == vllm::kU4B8 || q_type == vllm::kU8B128, q_type == vllm::kU4B8 || q_type == vllm::kU8B128 ||
"q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ", q_type == vllm::kFE4M3fn || q_type == vllm::kFE2M1f,
"q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when "
"has_zp = False. Got = ",
q_type.str()); q_type.str());
} }
@ -555,6 +539,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
int4* C_ptr = (int4*)C; int4* C_ptr = (int4*)C;
int4* C_tmp_ptr = (int4*)C_tmp; int4* C_tmp_ptr = (int4*)C_tmp;
const int4* s_ptr = (const int4*)s; const int4* s_ptr = (const int4*)s;
const uint16_t* s2_ptr = (const uint16_t*)s2;
const int4* zp_ptr = (const int4*)zp; const int4* zp_ptr = (const int4*)zp;
const int* g_idx_ptr = (const int*)g_idx; const int* g_idx_ptr = (const int*)g_idx;
const int* perm_ptr = (const int*)perm; const int* perm_ptr = (const int*)perm;
@ -631,18 +616,18 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
int thread_k_blocks = thread_k / 16; int thread_k_blocks = thread_k / 16;
int thread_n_blocks = thread_n / 16; int thread_n_blocks = thread_n / 16;
TORCH_CHECK(is_valid_config(thread_tfg, thread_m_blocks, prob_m, prob_n, TORCH_CHECK(
prob_k, num_bits, group_size, has_act_order, is_valid_config(thread_tfg, m_block_size_8, thread_m_blocks, prob_m,
is_k_full, has_zp, is_zp_float, max_shared_mem), prob_n, prob_k, num_bits, group_size, has_act_order,
"Invalid thread config: thread_m_blocks = ", thread_m_blocks, is_k_full, has_zp, is_zp_float, max_shared_mem),
", thread_k = ", thread_tfg.thread_k, "Invalid thread config: thread_m_blocks = ", thread_m_blocks,
", thread_n = ", thread_tfg.thread_n, ", thread_k = ", thread_tfg.thread_k,
", num_threads = ", thread_tfg.num_threads, " for MKN = [", ", thread_n = ", thread_tfg.thread_n,
prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, ", num_threads = ", thread_tfg.num_threads, " for MKN = [", prob_m, ", ",
", group_size = ", group_size, prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, ", group_size = ", group_size, ", has_act_order = ", has_act_order,
", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float, ", is_k_full = ", is_k_full, ", has_zp = ", has_zp,
", max_shared_mem = ", max_shared_mem); ", is_zp_float = ", is_zp_float, ", max_shared_mem = ", max_shared_mem);
auto kernel = get_marlin_kernel<scalar_t>( auto kernel = get_marlin_kernel<scalar_t>(
q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, m_block_size_8, q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, m_block_size_8,
@ -663,10 +648,10 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
// avoid ">>>" being formatted to "> > >" // avoid ">>>" being formatted to "> > >"
// clang-format off // clang-format off
kernel<<<blocks, num_threads, max_shared_mem, stream>>>( kernel<<<blocks, num_threads, max_shared_mem, stream>>>(
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr,
sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,
topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m, topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m,
prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce); prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce, max_shared_mem);
// clang-format on // clang-format on
} }
@ -675,6 +660,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
torch::Tensor moe_wna16_marlin_gemm( torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& a, std::optional<torch::Tensor> const& c_or_none, torch::Tensor& a, std::optional<torch::Tensor> const& c_or_none,
torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& b_q_weight, torch::Tensor& b_scales,
std::optional<torch::Tensor> const& global_scale_or_none,
std::optional<torch::Tensor> const& b_zeros_or_none, std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none, std::optional<torch::Tensor> const& g_idx_or_none,
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace, std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
@ -826,6 +812,17 @@ torch::Tensor moe_wna16_marlin_gemm(
} }
} }
torch::Tensor global_scale;
if (global_scale_or_none.has_value()) {
global_scale = global_scale_or_none.value();
TORCH_CHECK(b_q_type == vllm::kFE2M1f,
"global_scale can only be used for float4_e2m1f.");
} else {
global_scale = torch::empty({0}, options);
TORCH_CHECK(!(b_q_type == vllm::kFE2M1f),
"the global_scale parameter must be passed for float4_e2m1f.");
}
torch::Tensor b_zeros; torch::Tensor b_zeros;
if (b_zeros_or_none.has_value()) { if (b_zeros_or_none.has_value()) {
b_zeros = b_zeros_or_none.value(); b_zeros = b_zeros_or_none.value();
@ -838,13 +835,15 @@ torch::Tensor moe_wna16_marlin_gemm(
if (has_zp) { if (has_zp) {
TORCH_CHECK( TORCH_CHECK(
b_q_type == vllm::kU4, b_q_type == vllm::kU4 || b_q_type == vllm::kU8,
"b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str()); "b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str());
} else { } else {
TORCH_CHECK( TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 ||
b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128, b_q_type == vllm::kFE4M3fn || b_q_type == vllm::kFE2M1f,
"b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ", "b_q_type must be uint4b8, uint8b128, float8_e4m3fn or "
b_q_type.str()); "float4_e2m1f when "
"has_zp = False. Got = ",
b_q_type.str());
} }
if (has_zp && is_zp_float) { if (has_zp && is_zp_float) {
@ -889,9 +888,16 @@ torch::Tensor moe_wna16_marlin_gemm(
int dev = a.get_device(); int dev = a.get_device();
if (a.scalar_type() == at::ScalarType::Half) { if (a.scalar_type() == at::ScalarType::Half) {
void* scales_ptr;
if (b_q_type == vllm::kFE2M1f) {
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
} else {
scales_ptr = b_scales.data_ptr<at::Half>();
}
MARLIN_NAMESPACE_NAME::marlin_mm<half>( MARLIN_NAMESPACE_NAME::marlin_mm<half>(
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(), a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
c_tmp.data_ptr<float>(), b_scales.data_ptr<at::Half>(), c_tmp.data_ptr<float>(), scales_ptr, global_scale.data_ptr<at::Half>(),
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
a_tmp.data_ptr<at::Half>(), sorted_token_ids.data_ptr(), a_tmp.data_ptr<at::Half>(), sorted_token_ids.data_ptr(),
expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(), expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(),
@ -901,11 +907,18 @@ torch::Tensor moe_wna16_marlin_gemm(
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
use_atomic_add, use_fp32_reduce, is_zp_float); use_atomic_add, use_fp32_reduce, is_zp_float);
} else if (a.scalar_type() == at::ScalarType::BFloat16) { } else if (a.scalar_type() == at::ScalarType::BFloat16) {
void* scales_ptr;
if (b_q_type == vllm::kFE2M1f) {
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
} else {
scales_ptr = b_scales.data_ptr<at::BFloat16>();
}
MARLIN_NAMESPACE_NAME::marlin_mm<nv_bfloat16>( MARLIN_NAMESPACE_NAME::marlin_mm<nv_bfloat16>(
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(), a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(), c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(), scales_ptr,
b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(), global_scale.data_ptr<at::BFloat16>(), b_zeros.data_ptr(),
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
sorted_token_ids.data_ptr(), expert_ids.data_ptr(), sorted_token_ids.data_ptr(), expert_ids.data_ptr(),
num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(), num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(),
moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k, moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k,

View File

@ -326,7 +326,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
} }
if (use_global_memory) { if (use_global_memory) {
VLLM_DISPATCH_INTEGRAL_TYPES( VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] { topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
// calc needed amount of shared mem for `tokens_cnts` and `cumsum` // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors // tensors
@ -351,7 +351,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
cumsum_buffer.data_ptr<int32_t>()); cumsum_buffer.data_ptr<int32_t>());
}); });
} else if (use_i16) { } else if (use_i16) {
VLLM_DISPATCH_INTEGRAL_TYPES( VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// set dynamic shared mem // set dynamic shared mem
auto kernel = auto kernel =
@ -366,7 +366,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
topk_ids.numel()); topk_ids.numel());
}); });
} else { } else {
VLLM_DISPATCH_INTEGRAL_TYPES( VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
auto kernel = auto kernel =
vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>; vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>;
@ -391,7 +391,7 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
TORCH_CHECK(num_experts == 256, TORCH_CHECK(num_experts == 256,
"sgl_moe_align_block_size kernel only supports deepseek v3."); "sgl_moe_align_block_size kernel only supports deepseek v3.");
VLLM_DISPATCH_INTEGRAL_TYPES( VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] { topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] {
// calc needed amount of shared mem for `cumsum` tensors // calc needed amount of shared mem for `cumsum` tensors
auto options_int = auto options_int =

View File

@ -0,0 +1,133 @@
#include <c10/core/ScalarType.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include "permute_unpermute_kernels/moe_permute_unpermute_kernel.h"
#include "permute_unpermute_kernels/dispatch.h"
#include "core/registration.h"
void moe_permute(
const torch::Tensor& input, // [n_token, hidden]
const torch::Tensor& topk_weights, //[n_token, topk]
torch::Tensor& topk_ids, // [n_token, topk]
const torch::Tensor& token_expert_indicies, // [n_token, topk]
const std::optional<torch::Tensor>& expert_map, // [n_expert]
int64_t n_expert, int64_t n_local_expert, int64_t topk,
const std::optional<int64_t>& align_block_size,
torch::Tensor&
permuted_input, // [topk * n_token/align_block_size_m, hidden]
torch::Tensor& expert_first_token_offset, // [n_local_expert + 1]
torch::Tensor& src_row_id2dst_row_id_map, // [n_token, topk]
torch::Tensor& m_indices) { // [align_expand_m]
TORCH_CHECK(topk_weights.scalar_type() == at::ScalarType::Float,
"topk_weights must be float32");
TORCH_CHECK(expert_first_token_offset.scalar_type() == at::ScalarType::Long,
"expert_first_token_offset must be int64");
TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int,
"topk_ids must be int32");
TORCH_CHECK(token_expert_indicies.scalar_type() == at::ScalarType::Int,
"token_expert_indicies must be int32");
TORCH_CHECK(src_row_id2dst_row_id_map.scalar_type() == at::ScalarType::Int,
"src_row_id2dst_row_id_map must be int32");
TORCH_CHECK(expert_first_token_offset.size(0) == n_local_expert + 1,
"expert_first_token_offset shape != n_local_expert+1")
TORCH_CHECK(
src_row_id2dst_row_id_map.sizes() == token_expert_indicies.sizes(),
"token_expert_indicies shape must be same as src_row_id2dst_row_id_map");
auto n_token = input.sizes()[0];
auto n_hidden = input.sizes()[1];
auto align_block_size_value =
align_block_size.has_value() ? align_block_size.value() : -1;
auto stream = at::cuda::getCurrentCUDAStream().stream();
const long sorter_size =
CubKeyValueSorter::getWorkspaceSize(n_token * topk, n_expert);
auto sort_workspace = torch::empty(
{sorter_size},
torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false));
auto permuted_experts_id = torch::empty_like(topk_ids);
auto dst_row_id2src_row_id_map = torch::empty_like(src_row_id2dst_row_id_map);
auto align_expert_first_token_offset =
torch::zeros_like(expert_first_token_offset);
CubKeyValueSorter sorter{};
int64_t* valid_num_ptr = nullptr;
// pre-process kernel for expert-parallelism:
// no local expert id plus "n_expert" offset for priority to local expert
// map local expert id [n, .., n+n_local_expert-1] to [0, n_local_expert -1]
// For example, 4 expert with ep_size=2. ep_rank=1 owns global expert id
// [2,3] with expert_map[-1, -1, 0, 1], preprocess_topk_id process topk_ids
// and map global expert id [2, 3] to local_expert id [0, 1] and map global
// expert id [0, 1] ( not in ep rank=1) to [4, 5] by plus n_expert. This map
// operation is to make local expert high priority in following sort topk_ids
// and scan local expert_first_token_offset for each ep rank for next group
// gemm.
if (expert_map.has_value()) {
const int* expert_map_ptr = get_ptr<int>(expert_map.value());
valid_num_ptr =
get_ptr<int64_t>(expert_first_token_offset) + n_local_expert;
preprocessTopkIdLauncher(get_ptr<int>(topk_ids), n_token * topk,
expert_map_ptr, n_expert, stream);
}
// expert sort topk expert id and scan expert id get expert_first_token_offset
sortAndScanExpert(get_ptr<int>(topk_ids), get_ptr<int>(token_expert_indicies),
get_ptr<int>(permuted_experts_id),
get_ptr<int>(dst_row_id2src_row_id_map),
get_ptr<int64_t>(expert_first_token_offset), n_token,
n_expert, n_local_expert, topk, sorter,
get_ptr<int>(sort_workspace), stream);
// dispatch expandInputRowsKernelLauncher
MOE_DISPATCH(input.scalar_type(), [&] {
expandInputRowsKernelLauncher<scalar_t>(
get_ptr<scalar_t>(input), get_ptr<scalar_t>(permuted_input),
get_ptr<float>(topk_weights), get_ptr<int>(permuted_experts_id),
get_ptr<int>(dst_row_id2src_row_id_map),
get_ptr<int>(src_row_id2dst_row_id_map),
get_ptr<int64_t>(expert_first_token_offset), n_token, valid_num_ptr,
n_hidden, topk, n_local_expert, align_block_size_value, stream);
});
// get m_indices and update expert_first_token_offset with align block
getMIndices(get_ptr<int64_t>(expert_first_token_offset),
get_ptr<int64_t>(align_expert_first_token_offset),
get_ptr<int>(m_indices), n_local_expert, align_block_size_value,
stream);
if (align_block_size.has_value()) {
// update align_expert_first_token_offset
expert_first_token_offset.copy_(align_expert_first_token_offset);
}
}
void moe_unpermute(
const torch::Tensor& permuted_hidden_states, // [n_token * topk, hidden]
const torch::Tensor& topk_weights, //[n_token, topk]
const torch::Tensor& topk_ids, // [n_token, topk]
const torch::Tensor& src_row_id2dst_row_id_map, // [n_token, topk]
const torch::Tensor& expert_first_token_offset, // [n_local_expert+1]
int64_t n_expert, int64_t n_local_expert, int64_t topk,
torch::Tensor& hidden_states // [n_token, hidden]
) {
TORCH_CHECK(src_row_id2dst_row_id_map.sizes() == topk_ids.sizes(),
"topk_ids shape must be same as src_row_id2dst_row_id_map");
TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int,
"topk_ids must be int32");
TORCH_CHECK(
permuted_hidden_states.scalar_type() == hidden_states.scalar_type(),
"topk_ids dtype must be same as src_row_id2dst_row_id_map");
auto n_token = hidden_states.size(0);
auto n_hidden = hidden_states.size(1);
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int64_t* valid_ptr =
get_ptr<int64_t>(expert_first_token_offset) + n_local_expert;
MOE_DISPATCH(hidden_states.scalar_type(), [&] {
finalizeMoeRoutingKernelLauncher<scalar_t, scalar_t>(
get_ptr<scalar_t>(permuted_hidden_states),
get_ptr<scalar_t>(hidden_states), get_ptr<float>(topk_weights),
get_ptr<int>(src_row_id2dst_row_id_map), get_ptr<int>(topk_ids),
n_token, n_hidden, topk, valid_ptr, stream);
});
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("moe_permute", &moe_permute);
m.impl("moe_unpermute", &moe_unpermute);
}

View File

@ -108,11 +108,11 @@ __device__ inline void dequant<half2, 4>(int q, half2* res) {
const int MUL = 0x2c002c00; const int MUL = 0x2c002c00;
const int ADD = 0xd400d400; const int ADD = 0xd400d400;
int lo0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX); int lo0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX); int hi0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
q >>= 8; q >>= 8;
int lo1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX); int lo1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX); int hi1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
res[0] = __hsub2(*reinterpret_cast<half2*>(&lo0), res[0] = __hsub2(*reinterpret_cast<half2*>(&lo0),
*reinterpret_cast<const half2*>(&SUB)); *reinterpret_cast<const half2*>(&SUB));
@ -149,13 +149,13 @@ __device__ inline void dequant<nv_bfloat162, 4>(int q, nv_bfloat162* res) {
static constexpr uint32_t MASK = 0x000f000f; static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t EX = 0x43004300; static constexpr uint32_t EX = 0x43004300;
int lo0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); int lo0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
q >>= 4; q >>= 4;
int hi0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); int hi0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
q >>= 4; q >>= 4;
int lo1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); int lo1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
q >>= 4; q >>= 4;
int hi1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); int hi1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
static constexpr uint32_t MUL = 0x3F803F80; static constexpr uint32_t MUL = 0x3F803F80;
static constexpr uint32_t ADD = 0xC300C300; static constexpr uint32_t ADD = 0xC300C300;

Some files were not shown because too many files have changed in this diff Show More