Compare commits
598 Commits
fix_use_ep
...
woosuk-jf
| Author | SHA1 | Date | |
|---|---|---|---|
| bcf3c8230d | |||
| 2858830c39 | |||
| d6484ef3c3 | |||
| 46fae69cf0 | |||
| a01af39aa8 | |||
| f66f1e0fa3 | |||
| 887d7af882 | |||
| a92842454c | |||
| c8386fa61d | |||
| 87baebebd8 | |||
| e3d0a1d190 | |||
| d47b605eca | |||
| 22c6f6397f | |||
| 3ec97e2cc5 | |||
| 9b103a1d76 | |||
| b90b0852e9 | |||
| 9352cdb56d | |||
| 182f40ea8b | |||
| 3e887d2e0c | |||
| 0f87d8f7b2 | |||
| 4c33d67321 | |||
| cb234955df | |||
| 3a500cd0b6 | |||
| 868c546da4 | |||
| 99404f53c7 | |||
| 785d75a03b | |||
| 6d1479ca4b | |||
| b8b0859b5c | |||
| d7543862bd | |||
| c777df79f7 | |||
| cc2a77d7f1 | |||
| 9e2de9b9e9 | |||
| 109e15a335 | |||
| f192ca90e6 | |||
| f89d0e11bf | |||
| b4003d11fc | |||
| 292fc59d61 | |||
| afcb3f8863 | |||
| afb12e4294 | |||
| eeb5761cf1 | |||
| 24aebae177 | |||
| 39c0813a7f | |||
| 9b70e2b4c1 | |||
| 173daac19d | |||
| 04f2cfc894 | |||
| 811a6c0972 | |||
| 9b1769dd9a | |||
| 61c299f81f | |||
| 4acfa3354a | |||
| 88c8304104 | |||
| 6768ff4a22 | |||
| f2e7af9b86 | |||
| 7423cf0a9b | |||
| 460a2b1100 | |||
| 28566d73b3 | |||
| 98060b001d | |||
| f5a3c655b2 | |||
| 7169f87ad0 | |||
| b74d888c63 | |||
| 2007d4d54f | |||
| 48e925fab5 | |||
| 1903c0b8a3 | |||
| 86a1f67a3b | |||
| a257d9bccc | |||
| 015069b017 | |||
| fbefc8a78d | |||
| 26bc4bbcd8 | |||
| 3c3d767201 | |||
| 13cf6b6236 | |||
| 90d0a54c4d | |||
| 7a0a146c54 | |||
| 7ab643e425 | |||
| afb4429b4f | |||
| aa4502e7f3 | |||
| 17b4d85f63 | |||
| 1144a8efe7 | |||
| 08fb5587b4 | |||
| dbc18e7816 | |||
| 02bd654846 | |||
| 200bbf92e8 | |||
| 81ecf425f0 | |||
| 42d9a2c4c7 | |||
| 2ac74d098e | |||
| 584f5fb4c6 | |||
| d586ddc691 | |||
| 0b7e701dd4 | |||
| 947f2f5375 | |||
| 739e03b344 | |||
| da4e7687b5 | |||
| 39317cf42b | |||
| 2990cee95b | |||
| 0be6d05b5e | |||
| 77073c77bc | |||
| a7d5b016bd | |||
| d803786731 | |||
| 1534d389af | |||
| ece5a8b0b6 | |||
| 54072f315f | |||
| be633fba0f | |||
| ed6cfb90c8 | |||
| 6ed9f6047e | |||
| a44c4f1d2f | |||
| 88fcf00dda | |||
| d1f569b1b9 | |||
| 13698db634 | |||
| 2c4f59afc3 | |||
| 1c2bc7ead0 | |||
| 4055130a85 | |||
| 34120f5acd | |||
| 7489ec0bab | |||
| 70788bdbdc | |||
| c9c1b59e59 | |||
| 0350809f3a | |||
| a6977dbd15 | |||
| 2fa2a50bf9 | |||
| 08e15defa9 | |||
| b37685afbb | |||
| 792595b59d | |||
| 0c1c788312 | |||
| 56d64fbe30 | |||
| 608968b7c5 | |||
| 06ffc7e1d3 | |||
| d3cf61b89b | |||
| a39203f99e | |||
| 24e6ad3f16 | |||
| 2ef5d106bb | |||
| 0ed27ef66c | |||
| 900edfa8d4 | |||
| 88ad9ec6b2 | |||
| 40896bdf3f | |||
| 00ee37efa2 | |||
| 890f104cdf | |||
| 4a5e13149a | |||
| 97cc8729f0 | |||
| 4464109219 | |||
| 193e78e35d | |||
| bdb2cddafc | |||
| ebb3930d28 | |||
| cde384cd92 | |||
| 96e06e3cb7 | |||
| 17eb306fcc | |||
| 165cb56329 | |||
| d6da8a8ff2 | |||
| b4ac4fa04d | |||
| e136000595 | |||
| 86d9fc29cb | |||
| 506475de5f | |||
| cfe4532093 | |||
| 8fc88d63f1 | |||
| 6e74fd4945 | |||
| dcbac4cb4b | |||
| ed2462030f | |||
| cc5befbced | |||
| 2c89cd96a8 | |||
| a0304dc504 | |||
| c7941cca18 | |||
| b6dd32aa07 | |||
| f94886946e | |||
| 72dfe4c74f | |||
| 8b464d9660 | |||
| 889ebb2638 | |||
| 3ad986c28b | |||
| 344e193b7d | |||
| fb1c933ade | |||
| 72c5b97231 | |||
| fa93cd9f60 | |||
| aec9674dbe | |||
| 7fcc4223dc | |||
| 8262a3e23b | |||
| f211331c48 | |||
| 9053d0b134 | |||
| cb3f2d8d10 | |||
| c12df53b60 | |||
| d1aeea7553 | |||
| d8bccde686 | |||
| 20e489eaa1 | |||
| 4213475ec7 | |||
| d92879baf6 | |||
| 690fe019f0 | |||
| ed7a29d9f8 | |||
| 756848e79e | |||
| 18445edd0f | |||
| 30215ca61f | |||
| 838cedade7 | |||
| 4283a28c2f | |||
| 93a126fbc7 | |||
| 8e4b351a0c | |||
| 9869453c42 | |||
| 3642c59aa8 | |||
| 43eea2953b | |||
| de7eb10ce4 | |||
| fd11a325b8 | |||
| 4d17e20310 | |||
| 10fd1d7380 | |||
| 52b4f4a8d7 | |||
| e782e0a170 | |||
| dc2ceca5c5 | |||
| f8acd01ff7 | |||
| c48334d405 | |||
| 909fdaf152 | |||
| 8c1c926d00 | |||
| df6f3ce883 | |||
| 513f074766 | |||
| b07bf83c7d | |||
| 53e8cf53a4 | |||
| 54271bb766 | |||
| 9e96f56efb | |||
| b278911229 | |||
| 7bd0c7745c | |||
| 1cf0719ebd | |||
| 537d5ee025 | |||
| c8e5be35f7 | |||
| a6e72e1e4f | |||
| 5e83a7277f | |||
| 68af5f6c5c | |||
| 8de2901fea | |||
| c53e0730cb | |||
| a0e619e62a | |||
| 70116459c3 | |||
| 65e262b93b | |||
| 43faa0461a | |||
| 48cb2109b6 | |||
| a5450f11c9 | |||
| 9d98ab5ec6 | |||
| df5c879527 | |||
| 423e9f1cbe | |||
| 0bd7f8fca5 | |||
| d5615af9ae | |||
| 19dcc02a72 | |||
| 7feae92c1f | |||
| f851b84266 | |||
| fc966e9cc6 | |||
| ef19e67d2c | |||
| a41351f363 | |||
| 6aae216b4e | |||
| b22980a1dc | |||
| 881f735827 | |||
| 2f54045508 | |||
| 5aa6efb9a5 | |||
| 6ca0234478 | |||
| 649818995f | |||
| 7a0a9da72b | |||
| 69bff9bc89 | |||
| 41ca7eb491 | |||
| eef364723c | |||
| 0d6e187e88 | |||
| 9420a1fc30 | |||
| 583e900996 | |||
| 05e1fbfc52 | |||
| fe92176321 | |||
| 6d0df0ebeb | |||
| 0fa939e2d1 | |||
| 0422ce109f | |||
| 47bdee409c | |||
| 49f189439d | |||
| 5adf6f6b7f | |||
| 4115f19958 | |||
| 340d7b1b21 | |||
| 1bcbcbf574 | |||
| 82e43b2d7e | |||
| 67309a1cb5 | |||
| b724afe343 | |||
| 21f4f1c9a4 | |||
| b0c1f6202d | |||
| c0dfd97519 | |||
| a9138e85b1 | |||
| 0a05ed57e6 | |||
| 14288d1332 | |||
| b411418ff0 | |||
| 2bc0f72ae5 | |||
| 9c1244de57 | |||
| db2f8d915c | |||
| 6167c0e5d2 | |||
| ed2e464653 | |||
| 2c8ed8ee48 | |||
| ed50f46641 | |||
| 46e678bcff | |||
| 6b2427f995 | |||
| b07d741661 | |||
| 41fb013d29 | |||
| 32d4b669d0 | |||
| 3cde34a4a4 | |||
| bdb3660312 | |||
| f3a21e9c68 | |||
| 8e630d680e | |||
| af869f6dff | |||
| 53c0fa1e25 | |||
| f7912cba3d | |||
| 6317a5174a | |||
| aa72d9a4ea | |||
| ce17db8085 | |||
| 8c87a9ad46 | |||
| ec69124eb4 | |||
| d0da99fb70 | |||
| b2f195c429 | |||
| 047797ef90 | |||
| eb8ef4224d | |||
| 56a735261c | |||
| e1cf90e099 | |||
| 6bc1e30ef9 | |||
| 7e081ba7ca | |||
| 1e013fa388 | |||
| bc7c4d206b | |||
| f67e9e9f22 | |||
| 36fe78769f | |||
| 83d933718c | |||
| 5175b884f7 | |||
| 5536b30a4c | |||
| 7f58fb9718 | |||
| 30bc3e0f66 | |||
| f34410715f | |||
| 68d4c33202 | |||
| f961d7f6ef | |||
| d059110498 | |||
| 571e8dd65e | |||
| 4b91c927f6 | |||
| 0e237f0035 | |||
| 8f7bace7c3 | |||
| e4d6144232 | |||
| 8d32dc603d | |||
| c4ab9f3e71 | |||
| 2689d5c027 | |||
| acba33a0f1 | |||
| a114bf20a3 | |||
| 3097ce3a32 | |||
| d6da9322c8 | |||
| 71ce44047f | |||
| 188b7f9b8c | |||
| b9b4746950 | |||
| 7b8a2ab76f | |||
| c9acbf1141 | |||
| 5b794cae8d | |||
| 0e4254492f | |||
| 1311913f55 | |||
| 29f395c97c | |||
| fa3bba2a53 | |||
| 986537f1c3 | |||
| 210207525e | |||
| 71eda0bb76 | |||
| 471fe65630 | |||
| 3a0fba5cf4 | |||
| 299ebb62b2 | |||
| f728ab8e35 | |||
| 63e26fff78 | |||
| fe3462c774 | |||
| 3b34fd5273 | |||
| 55d6d3fdb8 | |||
| 7272bfae77 | |||
| d9ac9e3dc5 | |||
| d41faaf9df | |||
| b34f33438a | |||
| 26c0406555 | |||
| 4c41278b77 | |||
| bb3605db85 | |||
| fe742aef5a | |||
| 4b07d36891 | |||
| 87aaadef73 | |||
| 682e0b6d2f | |||
| d6195a748b | |||
| 205d84aaa9 | |||
| 5124f5bf51 | |||
| 83f3c3bd91 | |||
| d9737ca1c6 | |||
| 9d4ca19d50 | |||
| 2ef0dc53b8 | |||
| 1d4680fad2 | |||
| 2c1bd848a6 | |||
| 5c9121203c | |||
| 490b1698a5 | |||
| 5a5e29de88 | |||
| 3d3ab3689f | |||
| 686623c5e7 | |||
| aadb656562 | |||
| 87e067de41 | |||
| 26507f8973 | |||
| 9c1d5b456d | |||
| e31045f95c | |||
| aaec845f8e | |||
| 7bdfd29a35 | |||
| e78587a64c | |||
| 7eb4255628 | |||
| 6a0f547561 | |||
| 30ed81b7ca | |||
| 7a4a5de729 | |||
| c16fb5dae8 | |||
| e37073efd7 | |||
| 183dad7a85 | |||
| 3408e47159 | |||
| 0377b8310b | |||
| e4755f7fac | |||
| 92edf35826 | |||
| eb5819b2d9 | |||
| 5989f4684d | |||
| 5125d72f02 | |||
| a018e555fd | |||
| 6211b92273 | |||
| 05fcd1b430 | |||
| 7c02d6a137 | |||
| 11c3b98491 | |||
| dbe7f07001 | |||
| c69bf4ee06 | |||
| d27ea94034 | |||
| 99ed526101 | |||
| 207da28186 | |||
| 5b1aca2ae3 | |||
| d8e557b5e5 | |||
| 61a44a0b22 | |||
| a6481525b8 | |||
| 8cac35ba43 | |||
| 9dbf7a2dc1 | |||
| 607029e515 | |||
| cb072ce93b | |||
| 95aca283b4 | |||
| 2b05b8ce69 | |||
| 3c776dcefb | |||
| 2cbd4d2999 | |||
| 3092375e27 | |||
| 3cd91dc955 | |||
| 8a7368e069 | |||
| 93e561ec4d | |||
| e1b004839a | |||
| ee378f3d49 | |||
| e82ee40de3 | |||
| facbe2a114 | |||
| 7168920491 | |||
| 21378a2323 | |||
| 976711d9db | |||
| 44fa4d556c | |||
| 3ac98edcb1 | |||
| 966c742ed2 | |||
| 0d7d05f4b6 | |||
| 96bb8aa68b | |||
| 3badb0213b | |||
| fdcb850f14 | |||
| 54a66e5fee | |||
| 280d62b8a2 | |||
| 1666e66443 | |||
| 1575c1701a | |||
| 6ae996a873 | |||
| b590adfdc1 | |||
| b4fe16c75b | |||
| bc5dd4f669 | |||
| dbb036cf61 | |||
| 70e7ed841d | |||
| d06ba4ed3f | |||
| 6b40996ae8 | |||
| d2020acac7 | |||
| 1eb3c2ed48 | |||
| c64ee87267 | |||
| b1308b84a3 | |||
| 7b5ecf79bd | |||
| 9883a18859 | |||
| b3f2fddd17 | |||
| aa29841ede | |||
| 6bf27affb6 | |||
| 1dd23386ec | |||
| 7cbfc10943 | |||
| ce4ddd2d1a | |||
| e51929ebca | |||
| dc1b4a6f13 | |||
| 63d2705edb | |||
| d085a44082 | |||
| f49e5aff11 | |||
| 6c11ecf8d3 | |||
| 93e5f3c5fb | |||
| 70363bccfa | |||
| 3cdc57669f | |||
| 68bb122eb4 | |||
| d9fc8cd9da | |||
| f069f3ea74 | |||
| c5bc0e7fcc | |||
| 4a3a518722 | |||
| fbf722c6e6 | |||
| e92d7085bf | |||
| bd6028d6b0 | |||
| 802329dee9 | |||
| 41cc883c29 | |||
| 57504a4bcf | |||
| ed4792c990 | |||
| 87b836ba77 | |||
| 56c76c2e0e | |||
| c09632a66c | |||
| a3bf8d4a2b | |||
| 16eda8c43a | |||
| cd77382ac1 | |||
| 71b9cde010 | |||
| 5285589f37 | |||
| f41647ee6b | |||
| 4d022cbc75 | |||
| 70de35a881 | |||
| 34b2cf3b33 | |||
| 9e90c9f73f | |||
| e9528f6dc6 | |||
| 51baa9c333 | |||
| 35e076b3a8 | |||
| a26f59ccbc | |||
| aa3b3d76e0 | |||
| f7030df3be | |||
| 905e91e9ac | |||
| f8f9c0ba62 | |||
| dda811021a | |||
| 93195146ea | |||
| ed37599544 | |||
| 99ef59cf7f | |||
| d544d141ec | |||
| 3e397a9484 | |||
| 268c325078 | |||
| 3cc9af88ff | |||
| 7cd0bd7212 | |||
| 56d4aefa33 | |||
| dd143ef541 | |||
| daefed052c | |||
| 5fbab20e02 | |||
| e8224f3dca | |||
| 9665313c39 | |||
| 0c54fc7273 | |||
| c1b57855ec | |||
| 83b824c8b4 | |||
| 7678fcd5b6 | |||
| 8661c0241d | |||
| ce8d6b75fc | |||
| 61de3ef74b | |||
| ec1f9c8c91 | |||
| 65e09094c4 | |||
| c70cf0fe06 | |||
| a5d11a54dc | |||
| 3d4c87758e | |||
| a9bd832fc5 | |||
| 417bcefbae | |||
| baada0e737 | |||
| 82eb61dd4c | |||
| 0d4d06fe2f | |||
| 4aed0ca6a2 | |||
| 1621b25288 | |||
| a564797151 | |||
| 1da6a09274 | |||
| 1e44ffc3ff | |||
| a454748544 | |||
| 1bff42c4b7 | |||
| cb391d85dc | |||
| fee5b8d37f | |||
| b2ce859bd2 | |||
| 566f10a929 | |||
| c3b5189137 | |||
| a25866ac8d | |||
| 098900d7c2 | |||
| 98d01d3ce2 | |||
| d55244df31 | |||
| 04149cce27 | |||
| 24834f4894 | |||
| ec7da6fcf3 | |||
| 819d548e8a | |||
| 477d2a8aa2 | |||
| e484e02857 | |||
| 24f6b9a713 | |||
| 9cdde47289 | |||
| b1eb4ca152 | |||
| 87b4ac56c2 | |||
| cb84e45ac7 | |||
| 4716377fbc | |||
| 4e9cf8c1dd | |||
| 2976dc27e9 | |||
| 102bf967f0 | |||
| 1f4b09b525 | |||
| 86c3369eb8 | |||
| 2755c34a8f | |||
| db10422184 | |||
| e1a2c699dd | |||
| 0115ccd5c0 | |||
| 40b4284fe3 | |||
| 4ebc0b9640 | |||
| dc96fd54c6 | |||
| 1f5d13ab9f | |||
| 90cb44eb02 | |||
| e11880deea | |||
| 9351f91be9 | |||
| 5a1e1c8353 | |||
| 69ecaa7c79 | |||
| 7f00899ff7 | |||
| 995e3d1f41 | |||
| b4ac449a83 | |||
| 8e5314a468 | |||
| 87918e40c4 | |||
| f6b32efb7f | |||
| b99733d092 | |||
| 05a015d6a5 | |||
| ad971af8c7 | |||
| f2ebb6f541 | |||
| 1d01211264 | |||
| f94ab12f79 | |||
| a865bc1ca6 | |||
| 21802c4b6d | |||
| 652907b354 | |||
| 24f1c01e0f | |||
| fad6e2538e | |||
| 7f6d47c1a2 | |||
| 3147586ebd | |||
| ed636d99ca |
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m deepseek-ai/DeepSeek-V2-Lite-Chat -b "auto" -l 1000 -f 5 -t 2
|
||||
model_name: "deepseek-ai/DeepSeek-V2-Lite-Chat"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For hf script, without -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform -b auto -l 1000 -f 5
|
||||
model_name: "nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For hf script, without -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m meta-llama/Meta-Llama-3-70B-Instruct -b 32 -l 250 -f 5
|
||||
model_name: "meta-llama/Meta-Llama-3-70B-Instruct"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8A8-FP8-Channelwise-compressed-tensors -b auto -l 1000 -f 5 -t 1
|
||||
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-FP8-Channelwise-compressed-tensors"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform -b auto -l 1000 -f 5 -t 1
|
||||
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test -b 32 -l 1000 -f 5 -t 1
|
||||
model_name: "nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Meta-Llama-3-8B-Instruct-FP8 -b 32 -l 250 -f 5 -t 1
|
||||
model_name: "neuralmagic/Meta-Llama-3-8B-Instruct-FP8"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Asym-Per-Token-Test -b "auto" -l 250 -f 5 -t 1
|
||||
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Asym-Per-Token-Test"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Per-Token-Test -b "auto" -l 250 -f 5 -t 1
|
||||
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Per-Token-Test"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test -b auto -l 1000 -f 5 -t 1
|
||||
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test"
|
||||
tasks:
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m meta-llama/Meta-Llama-3-8B-Instruct -b 32 -l 250 -f 5 -t 1
|
||||
# For hf script, without -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m meta-llama/Meta-Llama-3-8B-Instruct -b 32 -l 250 -f 5
|
||||
model_name: "meta-llama/Meta-Llama-3-8B-Instruct"
|
||||
tasks:
|
||||
- name: "gsm8k"
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m HandH1998/QQQ-Llama-3-8b-g128 -b 32 -l 1000 -f 5 -t 1
|
||||
model_name: "HandH1998/QQQ-Llama-3-8b-g128"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8 -b "auto" -l 1000 -f 5 -t 1
|
||||
model_name: "neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m mgoin/Minitron-4B-Base-FP8 -b auto -l 1000 -f 5 -t 1
|
||||
model_name: "mgoin/Minitron-4B-Base-FP8"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Mixtral-8x22B-Instruct-v0.1-FP8-dynamic -b "auto" -l 250 -f 5 -t 8
|
||||
model_name: "neuralmagic/Mixtral-8x22B-Instruct-v0.1-FP8-dynamic"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8 -b "auto" -l 250 -f 5 -t 4
|
||||
model_name: "neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8"
|
||||
tasks:
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m neuralmagic/Mixtral-8x7B-Instruct-v0.1 -b 32 -l 250 -f 5 -t 4
|
||||
# For hf script, without -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m neuralmagic/Mixtral-8x7B-Instruct-v0.1 -b 32 -l 250 -f 5
|
||||
model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
tasks:
|
||||
- name: "gsm8k"
|
||||
|
||||
@ -0,0 +1,12 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16 -b auto -l 1319 -f 5 -t 1
|
||||
model_name: "nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16"
|
||||
tasks:
|
||||
- name: "gsm8k"
|
||||
metrics:
|
||||
- name: "exact_match,strict-match"
|
||||
value: 0.30
|
||||
- name: "exact_match,flexible-extract"
|
||||
value: 0.465
|
||||
limit: 1319
|
||||
num_fewshot: 5
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Qwen2-1.5B-Instruct-FP8W8 -b auto -l 1000 -f 5 -t 1
|
||||
model_name: "nm-testing/Qwen2-1.5B-Instruct-FP8W8"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Qwen2-1.5B-Instruct-quantized.w8a8 -b "auto" -l 1000 -f 5 -t 1
|
||||
model_name: "neuralmagic/Qwen2-1.5B-Instruct-quantized.w8a8"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Qwen2-1.5B-Instruct-W8A16-Channelwise -b "auto" -l 1000 -f 5 -t 1
|
||||
model_name: "nm-testing/Qwen2-1.5B-Instruct-W8A16-Channelwise"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m Qwen/Qwen2-57B-A14B-Instruct -b "auto" -l 250 -f 5 -t 4
|
||||
model_name: "Qwen/Qwen2-57B-A14B-Instruct"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM -b "auto" -t 2
|
||||
model_name: "nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM"
|
||||
tasks:
|
||||
|
||||
@ -4,7 +4,7 @@ Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml
|
||||
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml
|
||||
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
|
||||
Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
|
||||
Minitron-4B-Base-FP8.yaml
|
||||
Qwen1.5-MoE-W4A16-compressed-tensors.yaml
|
||||
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
|
||||
Qwen2-1.5B-Instruct-FP8W8.yaml
|
||||
Meta-Llama-3-8B-QQQ.yaml
|
||||
|
||||
@ -16,7 +16,7 @@ import numpy
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
RTOL = 0.05
|
||||
RTOL = 0.08
|
||||
TEST_DATA_FILE = os.environ.get(
|
||||
"LM_EVAL_TEST_DATA_FILE",
|
||||
".buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct.yaml")
|
||||
|
||||
@ -1,20 +1,20 @@
|
||||
steps:
|
||||
- label: "Build wheel - CUDA 12.4"
|
||||
- label: "Build wheel - CUDA 12.8"
|
||||
agents:
|
||||
queue: cpu_queue_postmerge
|
||||
commands:
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.4.0 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
|
||||
- "mkdir artifacts"
|
||||
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||
- "bash .buildkite/scripts/upload-wheels.sh"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
- label: "Build wheel - CUDA 12.1"
|
||||
- label: "Build wheel - CUDA 12.6"
|
||||
agents:
|
||||
queue: cpu_queue_postmerge
|
||||
commands:
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.1.0 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.6.3 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
|
||||
- "mkdir artifacts"
|
||||
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||
- "bash .buildkite/scripts/upload-wheels.sh"
|
||||
@ -48,7 +48,7 @@ steps:
|
||||
queue: cpu_queue_postmerge
|
||||
commands:
|
||||
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.4.0 --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ."
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ."
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT"
|
||||
|
||||
- label: "Build and publish TPU release image"
|
||||
@ -57,6 +57,8 @@ steps:
|
||||
agents:
|
||||
queue: tpu_queue_postmerge
|
||||
commands:
|
||||
- "yes | docker system prune -a"
|
||||
- "git fetch --all"
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --tag vllm/vllm-tpu:nightly --tag vllm/vllm-tpu:$BUILDKITE_COMMIT --progress plain -f docker/Dockerfile.tpu ."
|
||||
- "docker push vllm/vllm-tpu:nightly"
|
||||
- "docker push vllm/vllm-tpu:$BUILDKITE_COMMIT"
|
||||
@ -86,3 +88,18 @@ steps:
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version)"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
- block: "Build Neuron release image"
|
||||
key: block-neuron-release-image-build
|
||||
depends_on: ~
|
||||
|
||||
- label: "Build and publish Neuron release image"
|
||||
depends_on: block-neuron-release-image-build
|
||||
agents:
|
||||
queue: neuron-postmerge
|
||||
commands:
|
||||
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:latest --progress plain -f docker/Dockerfile.neuron ."
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:$(buildkite-agent meta-data get release-version)"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
@ -75,30 +75,51 @@ HF_MOUNT="/root/.cache/huggingface"
|
||||
commands=$@
|
||||
echo "Commands:$commands"
|
||||
#ignore certain kernels tests
|
||||
if [[ $commands == *" kernels "* ]]; then
|
||||
if [[ $commands == *" kernels/core"* ]]; then
|
||||
commands="${commands} \
|
||||
--ignore=kernels/test_attention_selector.py \
|
||||
--ignore=kernels/test_blocksparse_attention.py \
|
||||
--ignore=kernels/test_causal_conv1d.py \
|
||||
--ignore=kernels/test_cutlass.py \
|
||||
--ignore=kernels/test_encoder_decoder_attn.py \
|
||||
--ignore=kernels/test_flash_attn.py \
|
||||
--ignore=kernels/test_flashinfer.py \
|
||||
--ignore=kernels/test_int8_quant.py \
|
||||
--ignore=kernels/test_machete_gemm.py \
|
||||
--ignore=kernels/test_mamba_ssm.py \
|
||||
--ignore=kernels/test_marlin_gemm.py \
|
||||
--ignore=kernels/test_moe.py \
|
||||
--ignore=kernels/test_prefix_prefill.py \
|
||||
--ignore=kernels/test_rand.py \
|
||||
--ignore=kernels/test_sampler.py \
|
||||
--ignore=kernels/test_cascade_flash_attn.py \
|
||||
--ignore=kernels/test_mamba_mixer2.py \
|
||||
--ignore=kernels/test_aqlm.py \
|
||||
--ignore=kernels/test_machete_mm.py \
|
||||
--ignore=kernels/test_mha_attn.py \
|
||||
--ignore=kernels/test_block_fp8.py \
|
||||
--ignore=kernels/test_permute_cols.py"
|
||||
--ignore=kernels/core/test_fused_quant_layernorm.py \
|
||||
--ignore=kernels/core/test_permute_cols.py"
|
||||
fi
|
||||
|
||||
if [[ $commands == *" kernels/attention"* ]]; then
|
||||
commands="${commands} \
|
||||
--ignore=kernels/attention/stest_attention_selector.py \
|
||||
--ignore=kernels/attention/test_blocksparse_attention.py \
|
||||
--ignore=kernels/attention/test_encoder_decoder_attn.py \
|
||||
--ignore=kernels/attention/test_attention_selector.py \
|
||||
--ignore=kernels/attention/test_flash_attn.py \
|
||||
--ignore=kernels/attention/test_flashinfer.py \
|
||||
--ignore=kernels/attention/test_prefix_prefill.py \
|
||||
--ignore=kernels/attention/test_cascade_flash_attn.py \
|
||||
--ignore=kernels/attention/test_mha_attn.py \
|
||||
--ignore=kernels/attention/test_lightning_attn.py \
|
||||
--ignore=kernels/attention/test_attention.py"
|
||||
fi
|
||||
|
||||
if [[ $commands == *" kernels/quantization"* ]]; then
|
||||
commands="${commands} \
|
||||
--ignore=kernels/quantization/test_int8_quant.py \
|
||||
--ignore=kernels/quantization/test_aqlm.py \
|
||||
--ignore=kernels/quantization/test_machete_mm.py \
|
||||
--ignore=kernels/quantization/test_block_fp8.py \
|
||||
--ignore=kernels/quantization/test_block_int8.py \
|
||||
--ignore=kernels/quantization/test_marlin_gemm.py \
|
||||
--ignore=kernels/quantization/test_cutlass_scaled_mm.py \
|
||||
--ignore=kernels/quantization/test_int8_kernel.py"
|
||||
fi
|
||||
|
||||
if [[ $commands == *" kernels/mamba"* ]]; then
|
||||
commands="${commands} \
|
||||
--ignore=kernels/mamba/test_mamba_mixer2.py \
|
||||
--ignore=kernels/mamba/test_causal_conv1d.py \
|
||||
--ignore=kernels/mamba/test_mamba_ssm_ssd.py"
|
||||
fi
|
||||
|
||||
if [[ $commands == *" kernels/moe"* ]]; then
|
||||
commands="${commands} \
|
||||
--ignore=kernels/moe/test_moe.py \
|
||||
--ignore=kernels/moe/test_cutlass_moe.py \
|
||||
--ignore=kernels/moe/test_triton_moe_ptpc_fp8.py"
|
||||
fi
|
||||
|
||||
#ignore certain Entrypoints/openai tests
|
||||
|
||||
@ -5,10 +5,41 @@
|
||||
set -ex
|
||||
|
||||
# Setup cleanup
|
||||
remove_docker_container() { docker rm -f cpu-test || true; docker system prune -f; }
|
||||
remove_docker_container() {
|
||||
if [[ -n "$container_id" ]]; then
|
||||
podman rm -f "$container_id" || true
|
||||
fi
|
||||
podman system prune -f
|
||||
}
|
||||
trap remove_docker_container EXIT
|
||||
remove_docker_container
|
||||
|
||||
# Try building the docker image
|
||||
docker build -t cpu-test -f docker/Dockerfile.ppc64le .
|
||||
podman build -t cpu-test-ubi9-ppc -f docker/Dockerfile.ppc64le .
|
||||
|
||||
# Run the image
|
||||
container_id=$(podman run -itd --entrypoint /bin/bash -v /tmp/:/root/.cache/huggingface --privileged=true --network host -e HF_TOKEN cpu-test-ubi9-ppc)
|
||||
|
||||
function cpu_tests() {
|
||||
|
||||
# offline inference
|
||||
podman exec -it "$container_id" bash -c "
|
||||
set -e
|
||||
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m"
|
||||
|
||||
# Run basic model test
|
||||
podman exec -it "$container_id" bash -c "
|
||||
set -e
|
||||
pip install pytest pytest-asyncio einops peft Pillow soundfile transformers_stream_generator matplotlib
|
||||
pip install sentence-transformers datamodel_code_generator
|
||||
pytest -v -s tests/models/embedding/language/test_cls_models.py::test_classification_models[float-jason9693/Qwen2.5-1.5B-apeach]
|
||||
pytest -v -s tests/models/embedding/language/test_embedding.py::test_models[half-BAAI/bge-base-en-v1.5]
|
||||
pytest -v -s tests/models/encoder_decoder/language -m cpu_model"
|
||||
}
|
||||
|
||||
# All of CPU tests are expected to be finished less than 40 mins.
|
||||
|
||||
export container_id
|
||||
export -f cpu_tests
|
||||
timeout 40m bash -c cpu_tests
|
||||
|
||||
|
||||
13
.buildkite/scripts/hardware_ci/run-cpu-test-s390x.sh
Executable file
13
.buildkite/scripts/hardware_ci/run-cpu-test-s390x.sh
Executable file
@ -0,0 +1,13 @@
|
||||
#!/bin/bash
|
||||
|
||||
# This script build the CPU docker image and run the offline inference inside the container.
|
||||
# It serves a sanity check for compilation and basic model usage.
|
||||
set -ex
|
||||
|
||||
# Setup cleanup
|
||||
remove_docker_container() { docker rm -f cpu-test || true; docker system prune -f; }
|
||||
trap remove_docker_container EXIT
|
||||
remove_docker_container
|
||||
|
||||
# Try building the docker image
|
||||
docker build -t cpu-test -f docker/Dockerfile.s390x .
|
||||
@ -17,10 +17,13 @@ source /etc/environment
|
||||
docker run --privileged --net host --shm-size=16G -it \
|
||||
-e "HF_TOKEN=$HF_TOKEN" --name tpu-test \
|
||||
vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git \
|
||||
&& python3 -m pip install pytest \
|
||||
&& python3 -m pip install pytest pytest-asyncio tpu-info \
|
||||
&& python3 -m pip install lm_eval[api]==0.4.4 \
|
||||
&& export VLLM_XLA_CACHE_PATH= \
|
||||
&& export VLLM_USE_V1=1 \
|
||||
&& export VLLM_XLA_CHECK_RECOMPILATION=1 \
|
||||
&& echo HARDWARE \
|
||||
&& tpu-info \
|
||||
&& echo TEST_0 \
|
||||
&& pytest -v -s /workspace/vllm/tests/v1/tpu/test_perf.py \
|
||||
&& echo TEST_1 \
|
||||
@ -40,7 +43,11 @@ docker run --privileged --net host --shm-size=16G -it \
|
||||
&& echo TEST_8 \
|
||||
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py \
|
||||
&& echo TEST_9 \
|
||||
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py" \
|
||||
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_multimodal.py \
|
||||
&& echo TEST_10 \
|
||||
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py \
|
||||
&& echo TEST_11 \
|
||||
&& pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py" \
|
||||
|
||||
|
||||
# TODO: This test fails because it uses RANDOM_SEED sampling
|
||||
|
||||
@ -50,11 +50,11 @@ aws s3 cp "$normal_wheel" "s3://vllm-wheels/$BUILDKITE_COMMIT/"
|
||||
if [[ $normal_wheel == *"cu118"* ]]; then
|
||||
# if $normal_wheel matches cu118, do not upload the index.html
|
||||
echo "Skipping index files for cu118 wheels"
|
||||
elif [[ $normal_wheel == *"cu121"* ]]; then
|
||||
# if $normal_wheel matches cu121, do not upload the index.html
|
||||
echo "Skipping index files for cu121 wheels"
|
||||
elif [[ $normal_wheel == *"cu126"* ]]; then
|
||||
# if $normal_wheel matches cu126, do not upload the index.html
|
||||
echo "Skipping index files for cu126 wheels"
|
||||
else
|
||||
# only upload index.html for cu124 wheels (default wheels)
|
||||
# only upload index.html for cu128 wheels (default wheels)
|
||||
aws s3 cp index.html "s3://vllm-wheels/$BUILDKITE_COMMIT/vllm/index.html"
|
||||
aws s3 cp "s3://vllm-wheels/nightly/index.html" "s3://vllm-wheels/$BUILDKITE_COMMIT/index.html"
|
||||
fi
|
||||
@ -66,12 +66,12 @@ aws s3 cp "$normal_wheel" "s3://vllm-wheels/nightly/"
|
||||
if [[ $normal_wheel == *"cu118"* ]]; then
|
||||
# if $normal_wheel matches cu118, do not upload the index.html
|
||||
echo "Skipping index files for cu118 wheels"
|
||||
elif [[ $normal_wheel == *"cu121"* ]]; then
|
||||
# if $normal_wheel matches cu121, do not upload the index.html
|
||||
echo "Skipping index files for cu121 wheels"
|
||||
elif [[ $normal_wheel == *"cu126"* ]]; then
|
||||
# if $normal_wheel matches cu126, do not upload the index.html
|
||||
echo "Skipping index files for cu126 wheels"
|
||||
else
|
||||
# only upload index.html for cu124 wheels (default wheels)
|
||||
# only upload index.html for cu128 wheels (default wheels)
|
||||
aws s3 cp index.html "s3://vllm-wheels/nightly/vllm/index.html"
|
||||
fi
|
||||
|
||||
aws s3 cp "$wheel" "s3://vllm-wheels/$version/"
|
||||
aws s3 cp "$wheel" "s3://vllm-wheels/$version/"
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
# Documentation
|
||||
# label(str): the name of the test. emoji allowed.
|
||||
# fast_check(bool): whether to run this on each commit on fastcheck pipeline.
|
||||
# torch_nightly(bool): whether to run this on vllm against torch nightly pipeline.
|
||||
# fast_check_only(bool): run this test on fastcheck pipeline only
|
||||
# optional(bool): never run this test by default (i.e. need to unblock manually) unless it's scheduled nightly run.
|
||||
# command(str): the single command to run for tests. incompatible with commands.
|
||||
@ -38,7 +39,7 @@ steps:
|
||||
- pip install -r ../../requirements/docs.txt
|
||||
- SPHINXOPTS=\"-W\" make html
|
||||
# Check API reference (if it fails, you may have missing mock imports)
|
||||
- grep \"sig sig-object py\" build/html/api/inference_params.html
|
||||
- grep \"sig sig-object py\" build/html/api/vllm/vllm.sampling_params.html
|
||||
|
||||
- label: Async Engine, Inputs, Utils, Worker Test # 24min
|
||||
source_file_dependencies:
|
||||
@ -70,6 +71,7 @@ steps:
|
||||
- label: Basic Correctness Test # 30min
|
||||
#mirror_hardwares: [amd]
|
||||
fast_check: true
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/basic_correctness/test_basic_correctness
|
||||
@ -104,6 +106,7 @@ steps:
|
||||
- label: Entrypoints Test # 40min
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
fast_check: true
|
||||
torch_nightly: true
|
||||
#mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -118,7 +121,7 @@ steps:
|
||||
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
|
||||
- VLLM_USE_V1=0 pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/correctness/
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_openai_schema.py
|
||||
- pytest -v -s entrypoints/test_chat_utils.py
|
||||
- VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
|
||||
|
||||
@ -163,11 +166,6 @@ steps:
|
||||
- tests/tracing
|
||||
commands:
|
||||
- pytest -v -s metrics
|
||||
- "pip install \
|
||||
'opentelemetry-sdk>=1.26.0,<1.27.0' \
|
||||
'opentelemetry-api>=1.26.0,<1.27.0' \
|
||||
'opentelemetry-exporter-otlp>=1.26.0,<1.27.0' \
|
||||
'opentelemetry-semantic-conventions-ai>=0.4.1,<0.5.0'"
|
||||
- pytest -v -s tracing
|
||||
|
||||
##### fast check tests #####
|
||||
@ -210,6 +208,8 @@ steps:
|
||||
- pytest -v -s v1/sample
|
||||
- pytest -v -s v1/worker
|
||||
- pytest -v -s v1/structured_output
|
||||
- pytest -v -s v1/spec_decode
|
||||
- pytest -v -s v1/test_serial_utils.py
|
||||
- pytest -v -s v1/test_stats.py
|
||||
- pytest -v -s v1/test_utils.py
|
||||
- pytest -v -s v1/test_oracle.py
|
||||
@ -292,7 +292,18 @@ steps:
|
||||
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py
|
||||
parallelism: 4
|
||||
|
||||
- label: PyTorch Compilation Unit Tests
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/compile
|
||||
commands:
|
||||
- pytest -v -s compile/test_pass_manager.py
|
||||
- pytest -v -s compile/test_fusion.py
|
||||
- pytest -v -s compile/test_sequence_parallelism.py
|
||||
|
||||
- label: PyTorch Fullgraph Smoke Test # 9min
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/compile
|
||||
@ -301,24 +312,60 @@ steps:
|
||||
# these tests need to be separated, cannot combine
|
||||
- pytest -v -s compile/piecewise/test_simple.py
|
||||
- pytest -v -s compile/piecewise/test_toy_llama.py
|
||||
- pytest -v -s compile/test_pass_manager.py
|
||||
|
||||
- label: PyTorch Fullgraph Test # 18min
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/compile
|
||||
commands:
|
||||
- pytest -v -s compile/test_full_graph.py
|
||||
|
||||
- label: Kernels Test %N # 1h each
|
||||
# mirror_hardwares: [amd]
|
||||
- label: Kernels Core Operation Test
|
||||
mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- csrc/
|
||||
- vllm/attention
|
||||
- tests/kernels
|
||||
- tests/kernels/core
|
||||
commands:
|
||||
- pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
||||
parallelism: 4
|
||||
- pytest -v -s kernels/core
|
||||
|
||||
- label: Kernels Attention Test %N
|
||||
mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- csrc/attention/
|
||||
- vllm/attention
|
||||
- vllm/v1/attention
|
||||
- tests/kernels/attention
|
||||
commands:
|
||||
- pytest -v -s kernels/attention --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
||||
parallelism: 2
|
||||
|
||||
- label: Kernels Quantization Test %N
|
||||
mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- csrc/quantization/
|
||||
- vllm/model_executor/layers/quantization
|
||||
- tests/kernels/quantization
|
||||
commands:
|
||||
- pytest -v -s kernels/quantization --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
||||
parallelism: 2
|
||||
|
||||
- label: Kernels MoE Test
|
||||
#mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- csrc/moe/
|
||||
- tests/kernels/moe
|
||||
- vllm/model_executor/layers/fused_moe/
|
||||
commands:
|
||||
- pytest -v -s kernels/moe
|
||||
|
||||
- label: Kernels Mamba Test
|
||||
#mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- csrc/mamba/
|
||||
- tests/kernels/mamba
|
||||
commands:
|
||||
- pytest -v -s kernels/mamba
|
||||
|
||||
- label: Tensorizer Test # 11min
|
||||
# mirror_hardwares: [amd]
|
||||
@ -339,12 +386,20 @@ steps:
|
||||
commands:
|
||||
- bash scripts/run-benchmarks.sh
|
||||
|
||||
- label: Quantization Test # 33min
|
||||
- label: Benchmarks CLI Test # 10min
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/benchmarks/
|
||||
commands:
|
||||
- pytest -v -s benchmarks/
|
||||
|
||||
- label: Quantization Test
|
||||
source_file_dependencies:
|
||||
- csrc/
|
||||
- vllm/model_executor/layers/quantization
|
||||
- tests/quantization
|
||||
command: VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization
|
||||
commands:
|
||||
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization
|
||||
|
||||
- label: LM Eval Small Models # 53min
|
||||
working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
|
||||
@ -376,92 +431,93 @@ steps:
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/tool_use
|
||||
- tests/mistral_tool_use
|
||||
commands:
|
||||
- pytest -v -s tool_use
|
||||
- pytest -v -s mistral_tool_use
|
||||
|
||||
##### models test #####
|
||||
|
||||
- label: Basic Models Test # 24min
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/models
|
||||
commands:
|
||||
- pytest -v -s models/test_transformers.py
|
||||
- pytest -v -s models/test_registry.py
|
||||
- pytest -v -s models/test_utils.py
|
||||
- pytest -v -s models/test_vision.py
|
||||
# V1 Test: https://github.com/vllm-project/vllm/issues/14531
|
||||
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4'
|
||||
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2'
|
||||
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'llama4'
|
||||
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'plamo2'
|
||||
|
||||
- label: Language Models Test (Standard) # 32min
|
||||
- label: Language Models Test (Standard)
|
||||
#mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/models/decoder_only/language
|
||||
- tests/models/embedding/language
|
||||
- tests/models/encoder_decoder/language
|
||||
- tests/models/language
|
||||
commands:
|
||||
- pytest -v -s models/decoder_only/language -m 'core_model or quant_model'
|
||||
- pytest -v -s models/embedding/language -m core_model
|
||||
# Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
|
||||
- pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8'
|
||||
- pytest -v -s models/language -m core_model
|
||||
|
||||
- label: Language Models Test (Extended) # 1h10min
|
||||
- label: Language Models Test (Extended)
|
||||
optional: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/models/decoder_only/language
|
||||
- tests/models/embedding/language
|
||||
- tests/models/encoder_decoder/language
|
||||
- tests/models/language
|
||||
commands:
|
||||
- pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model'
|
||||
- pytest -v -s models/embedding/language -m 'not core_model'
|
||||
# Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
|
||||
- pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8'
|
||||
- pytest -v -s models/language -m 'not core_model'
|
||||
|
||||
- label: Multi-Modal Models Test (Standard) # 40min
|
||||
- label: Multi-Modal Models Test (Standard)
|
||||
#mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/models/decoder_only/audio_language
|
||||
- tests/models/decoder_only/vision_language
|
||||
- tests/models/embedding/vision_language
|
||||
- tests/models/encoder_decoder/audio_language
|
||||
- tests/models/encoder_decoder/vision_language
|
||||
- tests/models/multimodal
|
||||
commands:
|
||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||
- pytest -v -s models/multimodal
|
||||
- pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model'
|
||||
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model'
|
||||
- pytest -v -s models/embedding/vision_language -m core_model
|
||||
- pytest -v -s models/encoder_decoder/audio_language -m core_model
|
||||
- pytest -v -s models/encoder_decoder/language -m core_model
|
||||
- pytest -v -s models/encoder_decoder/vision_language -m core_model
|
||||
- pytest -v -s models/decoder_only/vision_language/test_interleaved.py
|
||||
- pytest -v -s models/multimodal/processing
|
||||
- pytest -v -s --ignore models/multimodal/generation/test_whisper.py models/multimodal -m core_model
|
||||
- cd .. && pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work
|
||||
|
||||
- label: Multi-Modal Models Test (Extended) 1 # 48m
|
||||
- label: Multi-Modal Models Test (Extended) 1
|
||||
optional: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/models/decoder_only/audio_language
|
||||
- tests/models/decoder_only/vision_language
|
||||
- tests/models/embedding/vision_language
|
||||
- tests/models/encoder_decoder/vision_language
|
||||
- tests/models/multimodal
|
||||
commands:
|
||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||
- pytest -v -s models/decoder_only/audio_language -m 'not core_model and not quant_model'
|
||||
- pytest -v -s models/decoder_only/vision_language/test_models.py -m 'split(group=0) and not core_model and not quant_model'
|
||||
# HACK - run phi3v tests separately to sidestep this transformers bug
|
||||
# https://github.com/huggingface/transformers/issues/34307
|
||||
- pytest -v -s models/decoder_only/vision_language/test_phi3v.py
|
||||
- pytest -v -s --ignore models/decoder_only/vision_language/test_models.py --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model and not quant_model'
|
||||
- pytest -v -s models/embedding/vision_language -m 'not core_model'
|
||||
- pytest -v -s models/encoder_decoder/language -m 'not core_model'
|
||||
- pytest -v -s models/encoder_decoder/vision_language -m 'not core_model'
|
||||
- pytest -v -s --ignore models/multimodal/generation/test_common.py --ignore models/multimodal/processing models/multimodal -m 'not core_model'
|
||||
|
||||
- label: Multi-Modal Models Test (Extended) 2 # 38m
|
||||
- label: Multi-Modal Models Test (Extended) 2
|
||||
optional: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/models/decoder_only/vision_language
|
||||
- tests/models/multimodal
|
||||
commands:
|
||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||
- pytest -v -s models/decoder_only/vision_language/test_models.py -m 'split(group=1) and not core_model and not quant_model'
|
||||
- pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=0) and not core_model'
|
||||
|
||||
- label: Multi-Modal Models Test (Extended) 3
|
||||
optional: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/models/multimodal
|
||||
commands:
|
||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||
- pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=1) and not core_model'
|
||||
|
||||
- label: Quantized Models Test
|
||||
#mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- vllm/model_executor/layers/quantization
|
||||
- tests/models/quantization
|
||||
commands:
|
||||
- pytest -v -s models/quantization
|
||||
|
||||
# This test is used only in PR development phase to test individual models and should never run on main
|
||||
- label: Custom Models Test
|
||||
@ -531,14 +587,16 @@ steps:
|
||||
- TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)'
|
||||
# Avoid importing model tests that cause CUDA reinitialization error
|
||||
- pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest models/encoder_decoder/language/test_bart.py -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest models/decoder_only/vision_language/test_models.py -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest models/language -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest models/multimodal -v -s -m 'distributed(num_gpus=2)'
|
||||
# test sequence parallel
|
||||
- pytest -v -s distributed/test_sequence_parallel.py
|
||||
# this test fails consistently.
|
||||
# TODO: investigate and fix
|
||||
# - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
|
||||
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
||||
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/test_disagg.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown
|
||||
|
||||
- label: Plugin Tests (2 GPUs) # 40min
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
|
||||
1
.github/CODEOWNERS
vendored
1
.github/CODEOWNERS
vendored
@ -12,6 +12,7 @@
|
||||
/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth
|
||||
/vllm/model_executor/guided_decoding @mgoin @russellb
|
||||
/vllm/multimodal @DarkLight1337 @ywang96
|
||||
/vllm/vllm_flash_attn @LucasWilkinson
|
||||
CMakeLists.txt @tlrmchlsmth
|
||||
|
||||
# vLLM V1
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/200-installation.yml
vendored
2
.github/ISSUE_TEMPLATE/200-installation.yml
vendored
@ -14,7 +14,7 @@ body:
|
||||
description: |
|
||||
Please run the following and paste the output below.
|
||||
```sh
|
||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/collect_env.py
|
||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/vllm/collect_env.py
|
||||
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
||||
python collect_env.py
|
||||
```
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/300-usage.yml
vendored
2
.github/ISSUE_TEMPLATE/300-usage.yml
vendored
@ -14,7 +14,7 @@ body:
|
||||
description: |
|
||||
Please run the following and paste the output below.
|
||||
```sh
|
||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/collect_env.py
|
||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/vllm/collect_env.py
|
||||
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
||||
python collect_env.py
|
||||
```
|
||||
|
||||
6
.github/ISSUE_TEMPLATE/400-bug-report.yml
vendored
6
.github/ISSUE_TEMPLATE/400-bug-report.yml
vendored
@ -14,19 +14,19 @@ body:
|
||||
description: |
|
||||
Please run the following and paste the output below.
|
||||
```sh
|
||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/collect_env.py
|
||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/vllm/collect_env.py
|
||||
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
||||
python collect_env.py
|
||||
```
|
||||
It is suggested to download and execute the latest script, as vllm might frequently update the diagnosis information needed for accurately and quickly responding to issues.
|
||||
value: |
|
||||
<details>
|
||||
<summary>The output of `python collect_env.py`</summary>
|
||||
<summary>The output of <code>python collect_env.py</code></summary>
|
||||
|
||||
```text
|
||||
Your output of `python collect_env.py` here
|
||||
```
|
||||
|
||||
|
||||
</details>
|
||||
validations:
|
||||
required: true
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/600-new-model.yml
vendored
2
.github/ISSUE_TEMPLATE/600-new-model.yml
vendored
@ -9,7 +9,7 @@ body:
|
||||
value: >
|
||||
#### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue+sort%3Acreated-desc+).
|
||||
|
||||
#### We also highly recommend you read https://docs.vllm.ai/en/latest/contributing/model/adding_model.html first to understand how to add a new model.
|
||||
#### We also highly recommend you read https://docs.vllm.ai/en/latest/contributing/model/index.html first to understand how to add a new model.
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: The model to consider.
|
||||
|
||||
@ -35,7 +35,7 @@ body:
|
||||
description: |
|
||||
Please run the following and paste the output below.
|
||||
```sh
|
||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/collect_env.py
|
||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/vllm/collect_env.py
|
||||
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
||||
python collect_env.py
|
||||
```
|
||||
|
||||
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
@ -3,4 +3,4 @@ FILL IN THE PR DESCRIPTION HERE
|
||||
FIX #xxxx (*link existing issues this PR will resolve*)
|
||||
|
||||
<!--- pyml disable-next-line no-emphasis-as-heading -->
|
||||
**BEFORE SUBMITTING, PLEASE READ <https://docs.vllm.ai/en/latest/contributing/overview.html>**
|
||||
**BEFORE SUBMITTING, PLEASE READ <https://docs.vllm.ai/en/latest/contributing/overview.html>** (anything written below this line will be removed by GitHub Actions)
|
||||
|
||||
34
.github/mergify.yml
vendored
34
.github/mergify.yml
vendored
@ -55,11 +55,19 @@ pull_request_rules:
|
||||
description: Automatically apply structured-output label
|
||||
conditions:
|
||||
- or:
|
||||
- files~=^benchmarks/structured_schemas/
|
||||
- files=benchmarks/benchmark_serving_structured_output.py
|
||||
- files=benchmarks/run_structured_output_benchmark.sh
|
||||
- files=docs/source/features/structured_outputs.md
|
||||
- files=examples/offline_inference/structured_outputs.py
|
||||
- files=examples/online_serving/openai_chat_completion_structured_outputs.py
|
||||
- files=examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py
|
||||
- files~=^vllm/model_executor/guided_decoding/
|
||||
- files=tests/model_executor/test_guided_processors.py
|
||||
- files=tests/entrypoints/llm/test_guided_generate.py
|
||||
- files=benchmarks/benchmark_serving_guided.py
|
||||
- files=benchmarks/benchmark_guided.py
|
||||
- files~=^tests/v1/structured_output/
|
||||
- files=tests/v1/entrypoints/llm/test_guided_generate.py
|
||||
- files~=^vllm/v1/structured_output/
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
@ -118,6 +126,28 @@ pull_request_rules:
|
||||
remove:
|
||||
- tpu
|
||||
|
||||
- name: label-tool-calling
|
||||
description: Automatically add tool-calling label
|
||||
conditions:
|
||||
- or:
|
||||
- files~=^tests/tool_use/
|
||||
- files~=^tests/mistral_tool_use/
|
||||
- files~=^tests/entrypoints/openai/tool_parsers/
|
||||
- files=tests/entrypoints/openai/test_chat_with_tool_reasoning.py
|
||||
- files~=^vllm/entrypoints/openai/tool_parsers/
|
||||
- files=docs/source/features/tool_calling.md
|
||||
- files=docs/source/getting_started/examples/openai_chat_completion_client_with_tools.md
|
||||
- files=docs/source/getting_started/examples/chat_with_tools.md
|
||||
- files~=^examples/tool_chat_*
|
||||
- files=examples/offline_inference/chat_with_tools.py
|
||||
- files=examples/online_serving/openai_chat_completion_client_with_tools_required.py
|
||||
- files=examples/online_serving/openai_chat_completion_tool_calls_with_reasoning.py
|
||||
- files=examples/online_serving/openai_chat_completion_client_with_tools.py
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
- tool-calling
|
||||
|
||||
- name: ping author on conflicts and add 'needs-rebase' label
|
||||
conditions:
|
||||
- conflict
|
||||
|
||||
4
.github/workflows/lint-and-deploy.yaml
vendored
4
.github/workflows/lint-and-deploy.yaml
vendored
@ -66,7 +66,7 @@ jobs:
|
||||
export AWS_SECRET_ACCESS_KEY=minioadmin
|
||||
sleep 30 && kubectl -n ns-vllm logs -f "$(kubectl -n ns-vllm get pods | awk '/deployment/ {print $1;exit}')" &
|
||||
helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env"
|
||||
|
||||
|
||||
- name: curl test
|
||||
run: |
|
||||
kubectl -n ns-vllm port-forward service/test-vllm-service 8001:80 &
|
||||
@ -79,4 +79,4 @@ jobs:
|
||||
"max_tokens": 7,
|
||||
"temperature": 0
|
||||
}'):$CODE"
|
||||
echo "$CODE"
|
||||
echo "$CODE"
|
||||
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@ -3,7 +3,6 @@
|
||||
|
||||
# vllm-flash-attn built from source
|
||||
vllm/vllm_flash_attn/*
|
||||
!vllm/vllm_flash_attn/fa_utils.py
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
@ -81,6 +80,7 @@ instance/
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
docs/source/getting_started/examples/
|
||||
docs/source/api/vllm
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
@ -203,3 +203,6 @@ benchmarks/**/*.json
|
||||
# Linting
|
||||
actionlint
|
||||
shellcheck*/
|
||||
|
||||
# Ingore moe/marlin_moe gen code
|
||||
csrc/moe/marlin_moe_wna16/kernel_*
|
||||
|
||||
@ -11,31 +11,30 @@ repos:
|
||||
hooks:
|
||||
- id: yapf
|
||||
args: [--in-place, --verbose]
|
||||
additional_dependencies: [toml] # TODO: Remove when yapf is upgraded
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.9.3
|
||||
rev: v0.11.7
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--output-format, github, --fix]
|
||||
- repo: https://github.com/codespell-project/codespell
|
||||
rev: v2.4.0
|
||||
rev: v2.4.1
|
||||
hooks:
|
||||
- id: codespell
|
||||
additional_dependencies: ['tomli']
|
||||
args: ['--toml', 'pyproject.toml']
|
||||
- repo: https://github.com/PyCQA/isort
|
||||
rev: 0a0b7a830386ba6a31c2ec8316849ae4d1b8240d # 6.0.0
|
||||
rev: 6.0.1
|
||||
hooks:
|
||||
- id: isort
|
||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||
rev: v19.1.7
|
||||
rev: v20.1.3
|
||||
hooks:
|
||||
- id: clang-format
|
||||
exclude: 'csrc/(moe/topk_softmax_kernels.cu|quantization/gguf/(ggml-common.h|dequantize.cuh|vecdotq.cuh|mmq.cuh|mmvq.cuh))|vllm/third_party/.*'
|
||||
types_or: [c++, cuda]
|
||||
args: [--style=file, --verbose]
|
||||
- repo: https://github.com/jackdewinter/pymarkdown
|
||||
rev: v0.9.27
|
||||
rev: v0.9.29
|
||||
hooks:
|
||||
- id: pymarkdown
|
||||
args: [fix]
|
||||
@ -44,10 +43,10 @@ repos:
|
||||
hooks:
|
||||
- id: actionlint
|
||||
- repo: https://github.com/astral-sh/uv-pre-commit
|
||||
rev: 0.6.2
|
||||
rev: 0.6.17
|
||||
hooks:
|
||||
- id: pip-compile
|
||||
args: [requirements/test.in, -o, requirements/test.txt]
|
||||
args: [requirements/test.in, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --torch-backend, cu128]
|
||||
files: ^requirements/test\.(in|txt)$
|
||||
- repo: local
|
||||
hooks:
|
||||
@ -122,6 +121,12 @@ repos:
|
||||
language: system
|
||||
always_run: true
|
||||
pass_filenames: false
|
||||
- id: update-dockerfile-graph
|
||||
name: Update Dockerfile dependency graph
|
||||
entry: tools/update-dockerfile-graph.sh
|
||||
language: script
|
||||
files: ^docker/Dockerfile$
|
||||
pass_filenames: false
|
||||
# Keep `suggestion` last
|
||||
- id: suggestion
|
||||
name: Suggestion
|
||||
|
||||
104
CMakeLists.txt
104
CMakeLists.txt
@ -15,7 +15,6 @@ project(vllm_extensions LANGUAGES CXX)
|
||||
|
||||
# CUDA by default, can be overridden by using -DVLLM_TARGET_DEVICE=... (used by setup.py)
|
||||
set(VLLM_TARGET_DEVICE "cuda" CACHE STRING "Target device backend for vLLM")
|
||||
|
||||
message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
|
||||
message(STATUS "Target device: ${VLLM_TARGET_DEVICE}")
|
||||
|
||||
@ -46,8 +45,8 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1
|
||||
# requirements.txt files and should be kept consistent. The ROCm torch
|
||||
# versions are derived from docker/Dockerfile.rocm
|
||||
#
|
||||
set(TORCH_SUPPORTED_VERSION_CUDA "2.6.0")
|
||||
set(TORCH_SUPPORTED_VERSION_ROCM "2.6.0")
|
||||
set(TORCH_SUPPORTED_VERSION_CUDA "2.7.0")
|
||||
set(TORCH_SUPPORTED_VERSION_ROCM "2.7.0")
|
||||
|
||||
#
|
||||
# Try to find python package with an executable that exactly matches
|
||||
@ -230,6 +229,7 @@ set(VLLM_EXT_SRC
|
||||
"csrc/cache_kernels.cu"
|
||||
"csrc/attention/paged_attention_v1.cu"
|
||||
"csrc/attention/paged_attention_v2.cu"
|
||||
"csrc/attention/merge_attn_states.cu"
|
||||
"csrc/pos_encoding_kernels.cu"
|
||||
"csrc/activation_kernels.cu"
|
||||
"csrc/layernorm_kernels.cu"
|
||||
@ -240,6 +240,7 @@ set(VLLM_EXT_SRC
|
||||
"csrc/quantization/fp8/common.cu"
|
||||
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
|
||||
"csrc/quantization/gguf/gguf_kernel.cu"
|
||||
"csrc/quantization/activation_kernels.cu"
|
||||
"csrc/cuda_utils_kernels.cu"
|
||||
"csrc/prepare_inputs/advance_step.cu"
|
||||
"csrc/custom_all_reduce.cu"
|
||||
@ -248,9 +249,8 @@ set(VLLM_EXT_SRC
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
|
||||
|
||||
# Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case.
|
||||
# Please keep this in sync with FetchContent_Declare line below.
|
||||
set(CUTLASS_REVISION "v3.8.0" CACHE STRING "CUTLASS revision to use")
|
||||
# Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building.
|
||||
set(CUTLASS_REVISION "v3.9.1" CACHE STRING "CUTLASS revision to use")
|
||||
|
||||
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
|
||||
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
|
||||
@ -268,7 +268,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
cutlass
|
||||
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
|
||||
# Please keep this in sync with CUTLASS_REVISION line above.
|
||||
GIT_TAG v3.8.0
|
||||
GIT_TAG ${CUTLASS_REVISION}
|
||||
GIT_PROGRESS TRUE
|
||||
|
||||
# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
|
||||
@ -289,7 +289,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
|
||||
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
|
||||
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
|
||||
"csrc/cutlass_extensions/common.cpp")
|
||||
"csrc/cutlass_extensions/common.cpp"
|
||||
"csrc/attention/mla/cutlass_mla_entry.cu")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${VLLM_EXT_SRC}"
|
||||
@ -462,7 +463,26 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
set(FP4_ARCHS)
|
||||
endif()
|
||||
|
||||
#
|
||||
# CUTLASS MLA Archs and flags
|
||||
cuda_archs_loose_intersection(MLA_ARCHS "10.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND MLA_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/attention/mla/cutlass_mla_kernels.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${MLA_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MLA=1")
|
||||
# Add MLA-specific include directories only to MLA source files
|
||||
set_source_files_properties(${SRCS}
|
||||
PROPERTIES INCLUDE_DIRECTORIES "${CUTLASS_DIR}/examples/77_blackwell_fmha;${CUTLASS_DIR}/examples/common")
|
||||
message(STATUS "Building CUTLASS MLA for archs: ${MLA_ARCHS}")
|
||||
else()
|
||||
message(STATUS "Not building CUTLASS MLA as no compatible archs were found.")
|
||||
# clear MLA_ARCHS
|
||||
set(MLA_ARCHS)
|
||||
endif()
|
||||
|
||||
# CUTLASS MoE kernels
|
||||
|
||||
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works
|
||||
@ -608,21 +628,51 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}")
|
||||
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
|
||||
if (MARLIN_MOE_ARCHS)
|
||||
set(MARLIN_MOE_SRC
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel.h"
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h"
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu"
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h"
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu"
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h"
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu"
|
||||
"csrc/moe/marlin_moe_ops.cu")
|
||||
|
||||
#
|
||||
# For the Marlin MOE kernels we automatically generate sources for various
|
||||
# preselected input type pairs and schedules.
|
||||
# Generate sources:
|
||||
set(MOE_MARLIN_GEN_SCRIPT
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/csrc/moe/marlin_moe_wna16/generate_kernels.py)
|
||||
file(MD5 ${MOE_MARLIN_GEN_SCRIPT} MOE_MARLIN_GEN_SCRIPT_HASH)
|
||||
|
||||
message(STATUS "Marlin MOE generation script hash: ${MOE_MARLIN_GEN_SCRIPT_HASH}")
|
||||
message(STATUS "Last run Marlin MOE generate script hash: $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH}")
|
||||
|
||||
if (NOT DEFINED CACHE{MOE_MARLIN_GEN_SCRIPT_HASH}
|
||||
OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH})
|
||||
execute_process(
|
||||
COMMAND ${CMAKE_COMMAND} -E env
|
||||
PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH
|
||||
${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT}
|
||||
RESULT_VARIABLE moe_marlin_generation_result
|
||||
OUTPUT_VARIABLE moe_marlin_generation_output
|
||||
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log
|
||||
ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log
|
||||
)
|
||||
|
||||
if (NOT moe_marlin_generation_result EQUAL 0)
|
||||
message(FATAL_ERROR "Marlin MOE generation failed."
|
||||
" Result: \"${moe_marlin_generation_result}\""
|
||||
"\nCheck the log for details: "
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log")
|
||||
else()
|
||||
set(MOE_MARLIN_GEN_SCRIPT_HASH ${MOE_MARLIN_GEN_SCRIPT_HASH}
|
||||
CACHE STRING "Last run Marlin MOE generate script hash" FORCE)
|
||||
message(STATUS "Marlin MOE generation completed successfully.")
|
||||
endif()
|
||||
else()
|
||||
message(STATUS "Marlin MOE generation script has not changed, skipping generation.")
|
||||
endif()
|
||||
|
||||
file(GLOB MOE_WNAA16_MARLIN_SRC "csrc/moe/marlin_moe_wna16/*.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_MOE_SRC}"
|
||||
SRCS "${MOE_WNAA16_MARLIN_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_MOE_ARCHS}")
|
||||
|
||||
list(APPEND VLLM_MOE_EXT_SRC "${MARLIN_MOE_SRC}")
|
||||
list(APPEND VLLM_MOE_EXT_SRC ${MOE_WNAA16_MARLIN_SRC})
|
||||
|
||||
message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}")
|
||||
else()
|
||||
message(STATUS "Not building Marlin MOE kernels as no compatible archs found"
|
||||
@ -630,6 +680,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
set(MOE_PERMUTE_SRC
|
||||
"csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu"
|
||||
"csrc/moe/moe_permute_unpermute_op.cu")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_PERMUTE_SRC}"
|
||||
CUDA_ARCHS "${MOE_PERMUTE_ARCHS}")
|
||||
|
||||
list(APPEND VLLM_MOE_EXT_SRC "${MOE_PERMUTE_SRC}")
|
||||
endif()
|
||||
message(STATUS "Enabling moe extension.")
|
||||
define_gpu_extension_target(
|
||||
_moe_C
|
||||
@ -638,6 +699,8 @@ define_gpu_extension_target(
|
||||
SOURCES ${VLLM_MOE_EXT_SRC}
|
||||
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
|
||||
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
||||
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
|
||||
INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
|
||||
USE_SABI 3
|
||||
WITH_SOABI)
|
||||
|
||||
@ -647,6 +710,7 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
|
||||
#
|
||||
set(VLLM_ROCM_EXT_SRC
|
||||
"csrc/rocm/torch_bindings.cpp"
|
||||
"csrc/rocm/skinny_gemms.cu"
|
||||
"csrc/rocm/attention.cu")
|
||||
|
||||
define_gpu_extension_target(
|
||||
|
||||
@ -10,16 +10,13 @@ Easy, fast, and cheap LLM serving for everyone
|
||||
</h3>
|
||||
|
||||
<p align="center">
|
||||
| <a href="https://docs.vllm.ai"><b>Documentation</b></a> | <a href="https://vllm.ai"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://x.com/vllm_project"><b>Twitter/X</b></a> | <a href="https://discuss.vllm.ai"><b>User Forum</b></a> | <a href="https://slack.vllm.ai"><b>Developer Slack</b></a> |
|
||||
| <a href="https://docs.vllm.ai"><b>Documentation</b></a> | <a href="https://blog.vllm.ai/"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://x.com/vllm_project"><b>Twitter/X</b></a> | <a href="https://discuss.vllm.ai"><b>User Forum</b></a> | <a href="https://slack.vllm.ai"><b>Developer Slack</b></a> |
|
||||
</p>
|
||||
|
||||
---
|
||||
|
||||
[2025/04] We're hosting our first-ever *vLLM Asia Developer Day* in Singapore on *April 3rd*! This is a full-day event (9 AM - 9 PM SGT) in partnership with SGInnovate, AMD, and Embedded LLM. Meet the vLLM team and learn about LLM inference for RL, MI300X, and more! [Register Now](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day)
|
||||
|
||||
---
|
||||
|
||||
*Latest News* 🔥
|
||||
- [2025/04] We hosted [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing).
|
||||
- [2025/03] We hosted [vLLM x Ollama Inference Night](https://lu.ma/vllm-ollama)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/16T2PDD1YwRnZ4Tu8Q5r6n53c5Lr5c73UV9Vd2_eBo4U/edit?usp=sharing).
|
||||
- [2025/03] We hosted [the first vLLM China Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg)! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1REHvfQMKGnvz6p3Fd23HhSO4c8j5WPGZV0bKYLwnHyQ/edit?usp=sharing).
|
||||
- [2025/03] We hosted [the East Coast vLLM Meetup](https://lu.ma/7mu4k4xx)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1NHiv8EUFF1NLd3fEYODm56nDmL26lEeXCaDgyDlTsRs/edit#slide=id.g31441846c39_0_0).
|
||||
|
||||
212
benchmarks/auto_tune.sh
Normal file
212
benchmarks/auto_tune.sh
Normal file
@ -0,0 +1,212 @@
|
||||
#!/bin/bash
|
||||
|
||||
# This script aims to tune the best server parameter combinations to maximize throughput for given requirement.
|
||||
# The current server parameter combination is max_num_seqs and max_num_batched_tokens
|
||||
# It also supports additional requirement: e2e latency and prefix cache.
|
||||
|
||||
# Pre-requisite:
|
||||
# 1. Checkout to your branch, install/ update the correct running env. For TPU, activate conda env and install the corresponding torch, xla version.
|
||||
# 2. If the model is customized, replace the MODEL's config with the customized config.
|
||||
# 3. Set variables (ALL REQUIRED)
|
||||
# BASE: your directory for vllm repo
|
||||
# MODEL: the model served by vllm
|
||||
# DOWNLOAD_DIR: directory to download and load model weights.
|
||||
# INPUT_LEN: request input len
|
||||
# OUTPUT_LEN: request output len
|
||||
# MIN_CACHE_HIT_PCT: prefix cache rate
|
||||
# MAX_LATENCY_ALLOWED_MS: (e2e) latency requirement. If there's no latency requirement, set it to a large number like 1000000000
|
||||
# 4. Run the script, it might take a long time, you can use tmux to avoid the script stop if disconnection happens.
|
||||
# 5. The final result will be saved in RESULT file.
|
||||
|
||||
|
||||
# Example use cases
|
||||
# 1. Given input_len=1800, output_len=20, what's the best max_num_seqs and max_num_batched_tokens to get highest throughput?
|
||||
# Use INPUT_LEN=1800, OUTPUT_LEN=20, MIN_CACHE_HIT_PCT=0, MAX_LATENCY_ALLOWED_MS=100000000000
|
||||
# 2. If we have latency requirement to be lower than 500ms, what's the best server parameter?
|
||||
# Use INPUT_LEN=1800, OUTPUT_LEN=20, MIN_CACHE_HIT_PCT=0, MAX_LATENCY_ALLOWED_MS=500
|
||||
# 3. If we want to reach 60% prefix cache, what's the best server parameter?
|
||||
# Use INPUT_LEN=1800, OUTPUT_LEN=20, MIN_CACHE_HIT_PCT=60, MAX_LATENCY_ALLOWED_MS=500
|
||||
|
||||
TAG=$(date +"%Y_%m_%d_%H_%M")
|
||||
BASE=""
|
||||
MODEL="meta-llama/Llama-3.1-8B-Instruct"
|
||||
DOWNLOAD_DIR=""
|
||||
INPUT_LEN=4000
|
||||
OUTPUT_LEN=16
|
||||
MIN_CACHE_HIT_PCT_PCT=0
|
||||
MAX_LATENCY_ALLOWED_MS=100000000000
|
||||
|
||||
LOG_FOLDER="$BASE/auto-benchmark/$TAG"
|
||||
RESULT="$LOG_FOLDER/result.txt"
|
||||
|
||||
echo "result file$ $RESULT"
|
||||
echo "model: $MODEL"
|
||||
echo
|
||||
|
||||
rm -rf $LOG_FOLDER
|
||||
mkdir -p $LOG_FOLDER
|
||||
|
||||
cd "$BASE/vllm"
|
||||
# create sonnet-4x.txt so that we can sample 2048 tokens for input
|
||||
echo "" > benchmarks/sonnet_4x.txt
|
||||
for _ in {1..4}
|
||||
do
|
||||
cat benchmarks/sonnet.txt >> benchmarks/sonnet_4x.txt
|
||||
done
|
||||
|
||||
pip install datasets
|
||||
|
||||
current_hash=$(git rev-parse HEAD)
|
||||
echo "hash:$current_hash" >> "$RESULT"
|
||||
echo "current_hash: $current_hash"
|
||||
|
||||
best_throughput=0
|
||||
best_max_num_seqs=0
|
||||
best_num_batched_tokens=0
|
||||
best_goodput=0
|
||||
run_benchmark() {
|
||||
local max_num_seqs=$1
|
||||
local max_num_batched_tokens=$2
|
||||
echo "max_num_seq: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens"
|
||||
local vllm_log="$LOG_FOLDER/vllm_log_${max_num_seqs}_${max_num_batched_tokens}.txt"
|
||||
echo "vllm_log: $vllm_log"
|
||||
echo
|
||||
rm -f $vllm_log
|
||||
|
||||
# start the server
|
||||
VLLM_USE_V1=1 VLLM_SERVER_DEV_MODE=1 vllm serve $MODEL \
|
||||
--disable-log-requests \
|
||||
--port 8004 \
|
||||
--gpu-memory-utilization 0.98 \
|
||||
--max-num-seqs $max_num_seqs \
|
||||
--max-num-batched-tokens $max_num_batched_tokens \
|
||||
--tensor-parallel-size 1 \
|
||||
--enable-prefix-caching \
|
||||
--load-format dummy \
|
||||
--download-dir $DOWNLOAD_DIR \
|
||||
--max-model-len $(( INPUT_LEN+OUTPUT_LEN )) > "$vllm_log" 2>&1 &
|
||||
echo "wait for 10 minutes.."
|
||||
echo
|
||||
# wait for 10 minutes...
|
||||
server_started=0
|
||||
for i in {1..60}; do
|
||||
if grep -Fq "Application startup complete" "$vllm_log"; then
|
||||
echo "Application started"
|
||||
server_started=1
|
||||
break
|
||||
else
|
||||
# echo "wait for 10 seconds..."
|
||||
sleep 10
|
||||
fi
|
||||
done
|
||||
|
||||
if (( ! server_started )); then
|
||||
echo "server did not start within 10 minutes, terminate the benchmarking. Please check server log at $vllm_log"
|
||||
echo "pkill -f vllm"
|
||||
echo
|
||||
pkill vllm
|
||||
sleep 10
|
||||
return 1
|
||||
fi
|
||||
|
||||
echo "run benchmark test..."
|
||||
echo
|
||||
meet_latency_requirement=0
|
||||
# get a basic qps by using request-rate inf
|
||||
bm_log="$LOG_FOLDER/bm_log_${max_num_seqs}_${max_num_batched_tokens}_requestrate_inf.txt"
|
||||
prefix_len=$(( INPUT_LEN * MIN_CACHE_HIT_PCT / 100 ))
|
||||
python benchmarks/benchmark_serving.py \
|
||||
--backend vllm \
|
||||
--model $MODEL \
|
||||
--dataset-name sonnet \
|
||||
--dataset-path benchmarks/sonnet_4x.txt \
|
||||
--sonnet-input-len $INPUT_LEN \
|
||||
--sonnet-output-len $OUTPUT_LEN \
|
||||
--ignore-eos \
|
||||
--disable-tqdm \
|
||||
--request-rate inf \
|
||||
--percentile-metrics ttft,tpot,itl,e2el \
|
||||
--goodput e2el:$MAX_LATENCY_ALLOWED_MS \
|
||||
--num-prompts 100 \
|
||||
--sonnet-prefix-len $prefix_len \
|
||||
--port 8004 > "$bm_log"
|
||||
through_put=$(grep "Request throughput (req/s):" "$bm_log" | sed 's/[^0-9.]//g')
|
||||
e2el=$(grep "P99 E2EL (ms):" "$bm_log" | awk '{print $NF}')
|
||||
goodput=$(grep "Request goodput (req/s):" "$bm_log" | sed 's/[^0-9.]//g')
|
||||
|
||||
if (( $(echo "$e2el <= $MAX_LATENCY_ALLOWED_MS" | bc -l) )); then
|
||||
meet_latency_requirement=1
|
||||
fi
|
||||
|
||||
if (( ! meet_latency_requirement )); then
|
||||
# start from request-rate as int(through_put) + 1
|
||||
request_rate=$((${through_put%.*} + 1))
|
||||
while ((request_rate > 0)); do
|
||||
# clear prefix cache
|
||||
curl -X POST http://0.0.0.0:8004/reset_prefix_cache
|
||||
sleep 5
|
||||
bm_log="$LOG_FOLDER/bm_log_${max_num_seqs}_${max_num_batched_tokens}_requestrate_${request_rate}.txt"
|
||||
python benchmarks/benchmark_serving.py \
|
||||
--backend vllm \
|
||||
--model $MODEL \
|
||||
--dataset-name sonnet \
|
||||
--dataset-path benchmarks/sonnet_4x.txt \
|
||||
--sonnet-input-len $INPUT_LEN \
|
||||
--sonnet-output-len $OUTPUT_LEN \
|
||||
--ignore_eos \
|
||||
--disable-tqdm \
|
||||
--request-rate $request_rate \
|
||||
--percentile-metrics ttft,tpot,itl,e2el \
|
||||
--goodput e2el:$MAX_LATENCY_ALLOWED_MS \
|
||||
--num-prompts 100 \
|
||||
--sonnet-prefix-len $prefix_len \
|
||||
--port 8004 > "$bm_log"
|
||||
through_put=$(grep "Request throughput (req/s):" "$bm_log" | sed 's/[^0-9.]//g')
|
||||
e2el=$(grep "P99 E2EL (ms):" "$bm_log" | awk '{print $NF}')
|
||||
goodput=$(grep "Request goodput (req/s):" "$bm_log" | sed 's/[^0-9.]//g')
|
||||
if (( $(echo "$e2el <= $MAX_LATENCY_ALLOWED_MS" | bc -l) )); then
|
||||
meet_latency_requirement=1
|
||||
break
|
||||
fi
|
||||
request_rate=$((request_rate-1))
|
||||
done
|
||||
fi
|
||||
# write the results and update the best result.
|
||||
if ((meet_latency_requirement)); then
|
||||
echo "max_num_seqs: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens, request_rate: $request_rate, e2el: $e2el, through put: $through_put, goodput: $goodput"
|
||||
echo "max_num_seqs: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens, request_rate: $request_rate, e2el: $e2el, through put: $through_put, goodput: $goodput" >> "$RESULT"
|
||||
if (( $(echo "$through_put > $best_throughput" | bc -l) )); then
|
||||
best_throughput=$through_put
|
||||
best_max_num_seqs=$max_num_seqs
|
||||
best_num_batched_tokens=$max_num_batched_tokens
|
||||
best_goodput=$goodput
|
||||
fi
|
||||
else
|
||||
echo "max_num_seqs: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens does not meet latency requirement ${MAX_LATENCY_ALLOWED_MS}"
|
||||
echo "max_num_seqs: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens does not meet latency requirement ${MAX_LATENCY_ALLOWED_MS}" >> "$RESULT"
|
||||
fi
|
||||
|
||||
echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput"
|
||||
|
||||
echo "pkill -f vllm"
|
||||
echo
|
||||
pkill vllm
|
||||
sleep 10
|
||||
rm -f $vllm_log
|
||||
printf '=%.0s' $(seq 1 20)
|
||||
return 0
|
||||
}
|
||||
|
||||
|
||||
num_seqs_list="128 256"
|
||||
num_batched_tokens_list="512 1024 2048 4096"
|
||||
for num_seqs in $num_seqs_list; do
|
||||
for num_batched_tokens in $num_batched_tokens_list; do
|
||||
run_benchmark $num_seqs $num_batched_tokens
|
||||
exit 0
|
||||
done
|
||||
done
|
||||
echo "finish permutations"
|
||||
echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput"
|
||||
echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput" >> "$RESULT"
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
@ -32,6 +33,7 @@ class RequestFuncInput:
|
||||
extra_body: Optional[dict] = None
|
||||
multi_modal_content: Optional[dict] = None
|
||||
ignore_eos: bool = False
|
||||
language: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -199,6 +201,7 @@ async def async_request_deepspeed_mii(
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
|
||||
payload = {
|
||||
"model": request_func_input.model,
|
||||
"prompt": request_func_input.prompt,
|
||||
"max_tokens": request_func_input.output_len,
|
||||
"temperature": 0.01, # deepspeed-mii does not accept 0.0 temp.
|
||||
@ -258,6 +261,7 @@ async def async_request_openai_completions(
|
||||
if request_func_input.model_name else request_func_input.model,
|
||||
"prompt": request_func_input.prompt,
|
||||
"temperature": 0.0,
|
||||
"repetition_penalty": 1.0,
|
||||
"max_tokens": request_func_input.output_len,
|
||||
"logprobs": request_func_input.logprobs,
|
||||
"stream": True,
|
||||
@ -436,6 +440,110 @@ async def async_request_openai_chat_completions(
|
||||
return output
|
||||
|
||||
|
||||
async def async_request_openai_audio(
|
||||
request_func_input: RequestFuncInput,
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
# Lazy import without PlaceholderModule to avoid vllm dep.
|
||||
import soundfile
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith(
|
||||
("transcriptions", "translations"
|
||||
)), "OpenAI Chat Completions API URL must end with 'transcriptions' "
|
||||
"or `translations`."
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
content = [{"type": "text", "text": request_func_input.prompt}]
|
||||
payload = {
|
||||
"model": request_func_input.model_name \
|
||||
if request_func_input.model_name else request_func_input.model,
|
||||
"temperature": 0.0,
|
||||
"max_completion_tokens": request_func_input.output_len,
|
||||
"stream": True,
|
||||
"language": "en",
|
||||
# Flattened due to multipart/form-data
|
||||
"stream_include_usage": True,
|
||||
"stream_continuous_usage_stats": True
|
||||
}
|
||||
if request_func_input.extra_body:
|
||||
payload.update(request_func_input.extra_body)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
}
|
||||
|
||||
# Send audio file
|
||||
def to_bytes(y, sr):
|
||||
buffer = io.BytesIO()
|
||||
soundfile.write(buffer, y, sr, format="WAV")
|
||||
buffer.seek(0)
|
||||
return buffer
|
||||
|
||||
with to_bytes(*request_func_input.multi_modal_content['audio']) as f:
|
||||
form = aiohttp.FormData()
|
||||
form.add_field('file', f, content_type='audio/wav')
|
||||
for key, value in payload.items():
|
||||
form.add_field(key, str(value))
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
generated_text = ""
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url,
|
||||
data=form,
|
||||
headers=headers) as response:
|
||||
if response.status == 200:
|
||||
async for chunk_bytes in response.content:
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
||||
"data: ")
|
||||
if chunk != "[DONE]":
|
||||
timestamp = time.perf_counter()
|
||||
data = json.loads(chunk)
|
||||
|
||||
if choices := data.get("choices"):
|
||||
content = choices[0]["delta"].get(
|
||||
"content")
|
||||
# First token
|
||||
if ttft == 0.0:
|
||||
ttft = timestamp - st
|
||||
output.ttft = ttft
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(
|
||||
timestamp - most_recent_timestamp)
|
||||
|
||||
generated_text += content or ""
|
||||
elif usage := data.get("usage"):
|
||||
output.output_tokens = usage.get(
|
||||
"completion_tokens")
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
|
||||
output.generated_text = generated_text
|
||||
output.success = True
|
||||
output.latency = most_recent_timestamp - st
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
except Exception:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
def get_model(pretrained_model_name_or_path: str) -> str:
|
||||
if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true':
|
||||
from modelscope import snapshot_download
|
||||
@ -493,6 +601,7 @@ ASYNC_REQUEST_FUNCS = {
|
||||
"deepspeed-mii": async_request_deepspeed_mii,
|
||||
"openai": async_request_openai_completions,
|
||||
"openai-chat": async_request_openai_chat_completions,
|
||||
"openai-audio": async_request_openai_audio,
|
||||
"tensorrt-llm": async_request_trt_llm,
|
||||
"scalellm": async_request_openai_completions,
|
||||
"sglang": async_request_openai_completions,
|
||||
|
||||
@ -64,6 +64,7 @@ class SampleRequest:
|
||||
|
||||
class BenchmarkDataset(ABC):
|
||||
DEFAULT_SEED = 0
|
||||
IS_MULTIMODAL = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -288,7 +289,7 @@ def process_image(image: Any) -> Mapping[str, Any]:
|
||||
class RandomDataset(BenchmarkDataset):
|
||||
# Default values copied from benchmark_serving.py for the random dataset.
|
||||
DEFAULT_PREFIX_LEN = 0
|
||||
DEFAULT_RANGE_RATIO = 1.0
|
||||
DEFAULT_RANGE_RATIO = 0.0
|
||||
DEFAULT_INPUT_LEN = 1024
|
||||
DEFAULT_OUTPUT_LEN = 128
|
||||
|
||||
@ -308,19 +309,32 @@ class RandomDataset(BenchmarkDataset):
|
||||
output_len: int = DEFAULT_OUTPUT_LEN,
|
||||
**kwargs,
|
||||
) -> list[SampleRequest]:
|
||||
# Enforce range_ratio < 1
|
||||
assert range_ratio < 1.0, (
|
||||
"random_range_ratio must be < 1.0 to ensure a valid sampling range"
|
||||
)
|
||||
|
||||
vocab_size = tokenizer.vocab_size
|
||||
|
||||
prefix_token_ids = (np.random.randint(
|
||||
0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else [])
|
||||
|
||||
input_low = int(input_len * range_ratio)
|
||||
output_low = int(output_len * range_ratio)
|
||||
# New sampling logic: [X * (1 - b), X * (1 + b)]
|
||||
input_low = int(input_len * (1 - range_ratio))
|
||||
input_high = int(input_len * (1 + range_ratio))
|
||||
output_low = int(output_len * (1 - range_ratio))
|
||||
output_high = int(output_len * (1 + range_ratio))
|
||||
|
||||
# Add logging for debugging
|
||||
logger.info("Sampling input_len from [%s, %s]", input_low, input_high)
|
||||
logger.info("Sampling output_len from [%s, %s]", output_low,
|
||||
output_high)
|
||||
|
||||
input_lens = np.random.randint(input_low,
|
||||
input_len + 1,
|
||||
input_high + 1,
|
||||
size=num_requests)
|
||||
output_lens = np.random.randint(output_low,
|
||||
output_len + 1,
|
||||
output_high + 1,
|
||||
size=num_requests)
|
||||
offsets = np.random.randint(0, vocab_size, size=num_requests)
|
||||
|
||||
@ -472,11 +486,11 @@ class SonnetDataset(BenchmarkDataset):
|
||||
|
||||
# Determine how many poem lines to use.
|
||||
num_input_lines = round((input_len - base_offset) / avg_len)
|
||||
num_prefix_lines = round((prefix_len - base_offset) / avg_len)
|
||||
num_prefix_lines = max(round((prefix_len - base_offset) / avg_len), 0)
|
||||
prefix_lines = self.data[:num_prefix_lines]
|
||||
|
||||
samples = []
|
||||
for _ in range(num_requests):
|
||||
while len(samples) < num_requests:
|
||||
extra_lines = random.choices(self.data,
|
||||
k=num_input_lines - num_prefix_lines)
|
||||
prompt = f"{base_prompt}{''.join(prefix_lines + extra_lines)}"
|
||||
@ -484,13 +498,14 @@ class SonnetDataset(BenchmarkDataset):
|
||||
prompt_formatted = tokenizer.apply_chat_template(
|
||||
msg, add_generation_prompt=True, tokenize=False)
|
||||
prompt_len = len(tokenizer(prompt_formatted).input_ids)
|
||||
samples.append(
|
||||
SampleRequest(
|
||||
prompt=prompt_formatted
|
||||
if return_prompt_formatted else prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
))
|
||||
if prompt_len <= input_len:
|
||||
samples.append(
|
||||
SampleRequest(
|
||||
prompt=prompt_formatted
|
||||
if return_prompt_formatted else prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
))
|
||||
return samples
|
||||
|
||||
|
||||
@ -607,6 +622,7 @@ class ConversationDataset(HuggingFaceDataset):
|
||||
SUPPORTED_DATASET_PATHS = {
|
||||
'lmms-lab/LLaVA-OneVision-Data', 'Aeala/ShareGPT_Vicuna_unfiltered'
|
||||
}
|
||||
IS_MULTIMODAL = True
|
||||
|
||||
def sample(self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
@ -671,6 +687,7 @@ class VisionArenaDataset(HuggingFaceDataset):
|
||||
"lmarena-ai/vision-arena-bench-v0.1":
|
||||
lambda x: x["turns"][0][0]["content"]
|
||||
}
|
||||
IS_MULTIMODAL = True
|
||||
|
||||
def sample(
|
||||
self,
|
||||
@ -754,6 +771,60 @@ class InstructCoderDataset(HuggingFaceDataset):
|
||||
return sampled_requests
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# MT-Bench Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MTBenchDataset(HuggingFaceDataset):
|
||||
"""
|
||||
MT-Bench Dataset.
|
||||
https://huggingface.co/datasets/philschmid/mt-bench
|
||||
|
||||
We create a single turn dataset for MT-Bench.
|
||||
This is similar to Spec decoding benchmark setup in vLLM
|
||||
https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18
|
||||
""" # noqa: E501
|
||||
|
||||
DEFAULT_OUTPUT_LEN = 256 # avg len used in SD bench in vLLM
|
||||
SUPPORTED_DATASET_PATHS = {
|
||||
"philschmid/mt-bench",
|
||||
}
|
||||
|
||||
def sample(self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
**kwargs) -> list:
|
||||
output_len = (output_len
|
||||
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
|
||||
sampled_requests = []
|
||||
|
||||
for item in self.data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
prompt = item['turns'][0]
|
||||
|
||||
# apply template
|
||||
prompt = tokenizer.apply_chat_template([{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}],
|
||||
add_generation_prompt=True,
|
||||
tokenize=False)
|
||||
|
||||
prompt_len = len(tokenizer(prompt).input_ids)
|
||||
sampled_requests.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
))
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
return sampled_requests
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# AIMO Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
@ -801,3 +872,80 @@ class AIMODataset(HuggingFaceDataset):
|
||||
))
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
return sampled_requests
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# ASR Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ASRDataset(HuggingFaceDataset):
|
||||
"""
|
||||
Dataset class for processing a ASR dataset for transcription.
|
||||
Tested on the following set:
|
||||
|
||||
+----------------+----------------------------------------+--------------------------+-----------------------------+
|
||||
| Dataset | Domain | Speaking Style | hf-subset |
|
||||
+----------------+----------------------------------------+--------------------------+-----------------------------+
|
||||
| TED-LIUM | TED talks | Oratory | release1, release2, release3|
|
||||
| | | | release3-speaker-adaptation |
|
||||
| VoxPopuli | European Parliament | Oratory | en, de, it, fr, ... |
|
||||
| LibriSpeech | Audiobook | Narrated | "LIUM/tedlium" |
|
||||
| GigaSpeech | Audiobook, podcast, YouTube | Narrated, spontaneous | xs, s, m, l, xl, dev, test |
|
||||
| SPGISpeech | Financial meetings | Oratory, spontaneous | S, M, L, dev, test |
|
||||
| AMI | Meetings | Spontaneous | ihm, sdm |
|
||||
+----------------+----------------------------------------+--------------------------+-----------------------------+
|
||||
|
||||
""" # noqa: E501
|
||||
SUPPORTED_DATASET_PATHS = {
|
||||
"openslr/librispeech_asr", "facebook/voxpopuli", "LIUM/tedlium",
|
||||
"edinburghcstr/ami", "speechcolab/gigaspeech", "kensho/spgispeech"
|
||||
}
|
||||
|
||||
DEFAULT_OUTPUT_LEN = 128
|
||||
IS_MULTIMODAL = True
|
||||
|
||||
# TODO Whisper-specific. Abstract interface when more models are supported.
|
||||
TRANSCRIPTION_PREAMBLE = "<|startoftranscript|><|en|><|transcribe|>"\
|
||||
"<|notimestamps|>"
|
||||
skip_long_audios: bool = True
|
||||
|
||||
def sample(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> list:
|
||||
import librosa
|
||||
output_len = (output_len
|
||||
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
|
||||
prompt = ASRDataset.TRANSCRIPTION_PREAMBLE
|
||||
prompt_len = len(tokenizer(prompt).input_ids)
|
||||
sampled_requests = []
|
||||
skipped = 0
|
||||
for item in self.data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
audio = item["audio"]
|
||||
y, sr = audio["array"], audio["sampling_rate"]
|
||||
duration_s = librosa.get_duration(y=y, sr=sr)
|
||||
# Whisper max supported duration
|
||||
if self.skip_long_audios and duration_s > 30:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
mm_content = {"audio": (y, sr)}
|
||||
sampled_requests.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
multi_modal_data=mm_content,
|
||||
))
|
||||
if skipped:
|
||||
logger.warning("%d samples discarded from dataset due to" \
|
||||
" their length being greater than" \
|
||||
" what Whisper supports.", skipped)
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
return sampled_requests
|
||||
|
||||
@ -63,14 +63,16 @@ class Request:
|
||||
output_len: int
|
||||
|
||||
|
||||
def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> str:
|
||||
def sample_tokens(tokenizer: PreTrainedTokenizerBase,
|
||||
length: int) -> list[int]:
|
||||
vocab = tokenizer.get_vocab()
|
||||
all_special_ids = set(tokenizer.all_special_ids)
|
||||
|
||||
# Remove the special tokens.
|
||||
vocab = {
|
||||
k: v
|
||||
for k, v in vocab.items() if k not in tokenizer.all_special_ids
|
||||
}
|
||||
return random.choices(list(vocab.values()), k=length)
|
||||
return random.choices(
|
||||
[v for k, v in vocab.items() if k not in all_special_ids],
|
||||
k=length,
|
||||
)
|
||||
|
||||
|
||||
def sample_requests_from_dataset(
|
||||
|
||||
@ -50,11 +50,11 @@ try:
|
||||
except ImportError:
|
||||
from argparse import ArgumentParser as FlexibleArgumentParser
|
||||
|
||||
from benchmark_dataset import (AIMODataset, BurstGPTDataset,
|
||||
from benchmark_dataset import (AIMODataset, ASRDataset, BurstGPTDataset,
|
||||
ConversationDataset, HuggingFaceDataset,
|
||||
InstructCoderDataset, RandomDataset,
|
||||
SampleRequest, ShareGPTDataset, SonnetDataset,
|
||||
VisionArenaDataset)
|
||||
InstructCoderDataset, MTBenchDataset,
|
||||
RandomDataset, SampleRequest, ShareGPTDataset,
|
||||
SonnetDataset, VisionArenaDataset)
|
||||
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||
|
||||
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
|
||||
@ -156,7 +156,7 @@ def calculate_metrics(
|
||||
if outputs[i].success:
|
||||
output_len = outputs[i].output_tokens
|
||||
|
||||
if output_len is None:
|
||||
if not output_len:
|
||||
# We use the tokenizer to count the number of output tokens
|
||||
# for some serving backends instead of looking at
|
||||
# len(outputs[i].itl) since multiple output tokens may be
|
||||
@ -274,10 +274,6 @@ async def benchmark(
|
||||
input_requests[0].expected_output_len, \
|
||||
input_requests[0].multi_modal_data
|
||||
|
||||
if backend != "openai-chat" and test_mm_content is not None:
|
||||
# multi-modal benchmark is only available on OpenAI Chat backend.
|
||||
raise ValueError(
|
||||
"Multi-modal content is only supported on 'openai-chat' backend.")
|
||||
assert test_mm_content is None or isinstance(test_mm_content, dict)
|
||||
test_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
@ -599,11 +595,17 @@ def main(args: argparse.Namespace):
|
||||
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_class = InstructCoderDataset
|
||||
args.hf_split = "train"
|
||||
elif args.dataset_path in MTBenchDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_class = MTBenchDataset
|
||||
args.hf_split = "train"
|
||||
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_class = ConversationDataset
|
||||
elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_class = AIMODataset
|
||||
args.hf_split = "train"
|
||||
elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_class = ASRDataset
|
||||
args.hf_split = "train"
|
||||
else:
|
||||
supported_datasets = set([
|
||||
dataset_name for cls in HuggingFaceDataset.__subclasses__()
|
||||
@ -615,6 +617,13 @@ def main(args: argparse.Namespace):
|
||||
f" from one of following: {supported_datasets}. "
|
||||
"Please consider contributing if you would "
|
||||
"like to add support for additional dataset formats.")
|
||||
|
||||
if (dataset_class.IS_MULTIMODAL and backend not in \
|
||||
["openai-chat", "openai-audio"]):
|
||||
# multi-modal benchmark is only available on OpenAI Chat backend.
|
||||
raise ValueError(
|
||||
"Multi-modal content is only supported on 'openai-chat' and " \
|
||||
"'openai-audio' backend.")
|
||||
input_requests = dataset_class(
|
||||
dataset_path=args.dataset_path,
|
||||
dataset_subset=args.hf_subset,
|
||||
@ -707,7 +716,7 @@ def main(args: argparse.Namespace):
|
||||
))
|
||||
|
||||
# Save config and results to json
|
||||
if args.save_result:
|
||||
if args.save_result or args.append_result:
|
||||
result_json: dict[str, Any] = {}
|
||||
|
||||
# Setup
|
||||
@ -728,6 +737,14 @@ def main(args: argparse.Namespace):
|
||||
raise ValueError(
|
||||
"Invalid metadata format. Please use KEY=VALUE format."
|
||||
)
|
||||
# Traffic
|
||||
result_json["request_rate"] = (args.request_rate if args.request_rate
|
||||
< float("inf") else "inf")
|
||||
result_json["burstiness"] = args.burstiness
|
||||
result_json["max_concurrency"] = args.max_concurrency
|
||||
|
||||
# Merge with benchmark result
|
||||
result_json = {**result_json, **benchmark_result}
|
||||
|
||||
if not args.save_detailed:
|
||||
# Remove fields with too many data points
|
||||
@ -738,15 +755,6 @@ def main(args: argparse.Namespace):
|
||||
if field in result_json:
|
||||
del result_json[field]
|
||||
|
||||
# Traffic
|
||||
result_json["request_rate"] = (args.request_rate if args.request_rate
|
||||
< float("inf") else "inf")
|
||||
result_json["burstiness"] = args.burstiness
|
||||
result_json["max_concurrency"] = args.max_concurrency
|
||||
|
||||
# Merge with benchmark result
|
||||
result_json = {**result_json, **benchmark_result}
|
||||
|
||||
# Save to file
|
||||
base_model_id = model_id.split("/")[-1]
|
||||
max_concurrency_str = (f"-concurrency{args.max_concurrency}"
|
||||
@ -756,7 +764,12 @@ def main(args: argparse.Namespace):
|
||||
file_name = args.result_filename
|
||||
if args.result_dir:
|
||||
file_name = os.path.join(args.result_dir, file_name)
|
||||
with open(file_name, "w", encoding='utf-8') as outfile:
|
||||
with open(file_name,
|
||||
mode="a+" if args.append_result else "w",
|
||||
encoding='utf-8') as outfile:
|
||||
# Append a newline.
|
||||
if args.append_result and outfile.tell() != 0:
|
||||
outfile.write("\n")
|
||||
json.dump(result_json, outfile)
|
||||
save_to_pytorch_benchmark_format(args, result_json, file_name)
|
||||
|
||||
@ -888,6 +901,11 @@ if __name__ == "__main__":
|
||||
help="When saving the results, whether to include per request "
|
||||
"information such as response, error, ttfs, tpots, etc.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--append-result",
|
||||
action="store_true",
|
||||
help="Append the benchmark result to the existing json file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata",
|
||||
metavar="KEY=VALUE",
|
||||
@ -921,7 +939,7 @@ if __name__ == "__main__":
|
||||
"--percentile-metrics",
|
||||
type=str,
|
||||
default="ttft,tpot,itl",
|
||||
help="Comma-seperated list of selected metrics to report percentils. "
|
||||
help="Comma-separated list of selected metrics to report percentils. "
|
||||
"This argument specifies the metrics to report percentiles. "
|
||||
"Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". "
|
||||
"Default value is \"ttft,tpot,itl\".")
|
||||
@ -929,7 +947,7 @@ if __name__ == "__main__":
|
||||
"--metric-percentiles",
|
||||
type=str,
|
||||
default="99",
|
||||
help="Comma-seperated list of percentiles for selected metrics. "
|
||||
help="Comma-separated list of percentiles for selected metrics. "
|
||||
"To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". "
|
||||
"Default value is \"99\". "
|
||||
"Use \"--percentile-metrics\" to select metrics.",
|
||||
@ -996,18 +1014,23 @@ if __name__ == "__main__":
|
||||
random_group.add_argument(
|
||||
"--random-range-ratio",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Range of sampled ratio of input/output length, "
|
||||
"used only for random sampling.",
|
||||
default=0.0,
|
||||
help="Range ratio for sampling input/output length, "
|
||||
"used only for random sampling. Must be in the range [0, 1) to define "
|
||||
"a symmetric sampling range"
|
||||
"[length * (1 - range_ratio), length * (1 + range_ratio)].",
|
||||
)
|
||||
random_group.add_argument(
|
||||
"--random-prefix-len",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of fixed prefix tokens before random "
|
||||
" context. The length range of context in a random "
|
||||
" request is [random-prefix-len, "
|
||||
" random-prefix-len + random-prefix-len * random-range-ratio).")
|
||||
help=("Number of fixed prefix tokens before the random context "
|
||||
"in a request. "
|
||||
"The total input length is the sum of `random-prefix-len` and "
|
||||
"a random "
|
||||
"context length sampled from [input_len * (1 - range_ratio), "
|
||||
"input_len * (1 + range_ratio)]."),
|
||||
)
|
||||
|
||||
hf_group = parser.add_argument_group("hf dataset options")
|
||||
hf_group.add_argument("--hf-subset",
|
||||
|
||||
@ -11,7 +11,7 @@ On the client side, run:
|
||||
--model <your_model> \
|
||||
--dataset json \
|
||||
--structured-output-ratio 1.0 \
|
||||
--structured-output-backend xgrammar \
|
||||
--structured-output-backend auto \
|
||||
--request-rate 10 \
|
||||
--num-prompts 1000
|
||||
|
||||
@ -51,7 +51,7 @@ try:
|
||||
except ImportError:
|
||||
from argparse import ArgumentParser as FlexibleArgumentParser
|
||||
|
||||
from vllm.v1.structured_output.utils import (
|
||||
from vllm.v1.structured_output.backend_xgrammar import (
|
||||
has_xgrammar_unsupported_json_features)
|
||||
|
||||
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
|
||||
@ -123,6 +123,8 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
||||
copy.deepcopy(schema) for _ in range(args.num_prompts)
|
||||
]
|
||||
for i in range(len(json_schemas)):
|
||||
if "properties" not in json_schemas[i]:
|
||||
json_schemas[i]["properties"] = {}
|
||||
json_schemas[i]["properties"][
|
||||
f"__optional_field_{uuid.uuid4()}"] = {
|
||||
"type":
|
||||
@ -130,10 +132,11 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
||||
"description":
|
||||
"An unique optional field to avoid cached schemas"
|
||||
}
|
||||
else:
|
||||
json_schemas = [schema] * args.num_prompts
|
||||
|
||||
def gen_prompt(index: int):
|
||||
schema = json_schemas[index % len(json_schemas)]
|
||||
return f"Generate an example of a user profile given the following schema: {json.dumps(schema)}" # noqa: E501
|
||||
return f"Generate an example of a brief user profile given the following schema: {json.dumps(get_schema(index))}" # noqa: E501
|
||||
|
||||
def get_schema(index: int):
|
||||
return json_schemas[index % len(json_schemas)]
|
||||
@ -149,17 +152,17 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
||||
|
||||
elif args.dataset == "grammar":
|
||||
schema = """
|
||||
?start: select_statement
|
||||
root ::= select_statement
|
||||
|
||||
?select_statement: "SELECT " column_list " FROM " table_name
|
||||
select_statement ::= "SELECT " column " from " table " where " condition
|
||||
|
||||
?column_list: column_name ("," column_name)*
|
||||
column ::= "col_1 " | "col_2 "
|
||||
|
||||
?table_name: identifier
|
||||
table ::= "table_1 " | "table_2 "
|
||||
|
||||
?column_name: identifier
|
||||
condition ::= column "= " number
|
||||
|
||||
?identifier: /[a-zA-Z_][a-zA-Z0-9_]*/
|
||||
number ::= "1 " | "2 "
|
||||
"""
|
||||
prompt = "Generate an SQL query to show the 'username' \
|
||||
and 'email' from the 'users' table."
|
||||
@ -230,7 +233,8 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
||||
idx -= len_dataset
|
||||
schema = dataset["schema"][idx]
|
||||
prompt = tokenizer.apply_chat_template(dataset["prompt"][idx],
|
||||
tokenize=False)
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
input_len = len(tokenizer(prompt).input_ids)
|
||||
completion = dataset["completion"][idx]
|
||||
|
||||
@ -848,7 +852,7 @@ if __name__ == "__main__":
|
||||
'json', 'json-unique', 'grammar', 'regex',
|
||||
'choice', 'xgrammar_bench'
|
||||
])
|
||||
parser.add_argument("--json_schema_path",
|
||||
parser.add_argument("--json-schema-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to json schema.")
|
||||
@ -963,7 +967,7 @@ if __name__ == "__main__":
|
||||
"--percentile-metrics",
|
||||
type=str,
|
||||
default="ttft,tpot,itl",
|
||||
help="Comma-seperated list of selected metrics to report percentils. "
|
||||
help="Comma-separated list of selected metrics to report percentils. "
|
||||
"This argument specifies the metrics to report percentiles. "
|
||||
"Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". "
|
||||
"Default value is \"ttft,tpot,itl\".")
|
||||
@ -971,7 +975,7 @@ if __name__ == "__main__":
|
||||
"--metric-percentiles",
|
||||
type=str,
|
||||
default="99",
|
||||
help="Comma-seperated list of percentiles for selected metrics. "
|
||||
help="Comma-separated list of percentiles for selected metrics. "
|
||||
"To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". "
|
||||
"Default value is \"99\". "
|
||||
"Use \"--percentile-metrics\" to select metrics.",
|
||||
@ -996,12 +1000,14 @@ if __name__ == "__main__":
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Ratio of Structured Outputs requests")
|
||||
parser.add_argument(
|
||||
"--structured-output-backend",
|
||||
type=str,
|
||||
choices=["outlines", "lm-format-enforcer", "xgrammar", "guidance"],
|
||||
default="xgrammar",
|
||||
help="Backend to use for structured outputs")
|
||||
parser.add_argument("--structured-output-backend",
|
||||
type=str,
|
||||
choices=[
|
||||
"outlines", "lm-format-enforcer", "xgrammar",
|
||||
"guidance", "auto"
|
||||
],
|
||||
default="auto",
|
||||
help="Backend to use for structured outputs")
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
@ -213,14 +213,17 @@ def run_hf(
|
||||
max_prompt_len = 0
|
||||
max_output_len = 0
|
||||
for i in range(len(requests)):
|
||||
prompt, prompt_len, output_len = requests[i]
|
||||
prompt = requests[i].prompt
|
||||
prompt_len = requests[i].prompt_len
|
||||
output_len = requests[i].expected_output_len
|
||||
# Add the prompt to the batch.
|
||||
batch.append(prompt)
|
||||
max_prompt_len = max(max_prompt_len, prompt_len)
|
||||
max_output_len = max(max_output_len, output_len)
|
||||
if len(batch) < max_batch_size and i != len(requests) - 1:
|
||||
# Check if we can add more requests to the batch.
|
||||
_, next_prompt_len, next_output_len = requests[i + 1]
|
||||
next_prompt_len = requests[i + 1].prompt_len
|
||||
next_output_len = requests[i + 1].expected_output_len
|
||||
if (max(max_prompt_len, next_prompt_len) +
|
||||
max(max_output_len, next_output_len)) <= 2048:
|
||||
# We can add more requests to the batch.
|
||||
@ -520,6 +523,13 @@ def validate_args(args):
|
||||
raise ValueError(
|
||||
"Tokenizer must be the same as the model for MII backend.")
|
||||
|
||||
# --data-parallel is not supported currently.
|
||||
# https://github.com/vllm-project/vllm/issues/16222
|
||||
if args.data_parallel_size > 1:
|
||||
raise ValueError(
|
||||
"Data parallel is not supported in offline benchmark, \
|
||||
please use benchmark serving instead")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(description="Benchmark the throughput.")
|
||||
@ -591,18 +601,30 @@ if __name__ == "__main__":
|
||||
default=None,
|
||||
help="Path to the lora adapters to use. This can be an absolute path, "
|
||||
"a relative path, or a Hugging Face model identifier.")
|
||||
parser.add_argument("--prefix-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of prefix tokens per request."
|
||||
"This is for the RandomDataset and SonnetDataset")
|
||||
parser.add_argument(
|
||||
"--prefix-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help=f"Number of prefix tokens to be used in RandomDataset "
|
||||
"and SonnetDataset. For RandomDataset, the total input "
|
||||
"length is the sum of prefix-len (default: "
|
||||
f"{RandomDataset.DEFAULT_PREFIX_LEN}) and a random context length "
|
||||
"sampled from [input_len * (1 - range_ratio), "
|
||||
"input_len * (1 + range_ratio)]. For SonnetDataset, "
|
||||
f"prefix_len (default: {SonnetDataset.DEFAULT_PREFIX_LEN}) "
|
||||
"controls how much of the input is fixed lines versus "
|
||||
"random lines, but the total input length remains approximately "
|
||||
"input_len tokens.")
|
||||
# random dataset
|
||||
parser.add_argument(
|
||||
"--random-range-ratio",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Range of sampled ratio of input/output length, "
|
||||
"used only for RandomDataSet.",
|
||||
help=f"Range ratio (default : {RandomDataset.DEFAULT_RANGE_RATIO}) "
|
||||
"for sampling input/output length, "
|
||||
"used only for RandomDataset. Must be in the range [0, 1) to "
|
||||
"define a symmetric sampling range "
|
||||
"[length * (1 - range_ratio), length * (1 + range_ratio)].",
|
||||
)
|
||||
|
||||
# hf dtaset
|
||||
|
||||
236
benchmarks/kernels/benchmark_bitblas.py
Normal file
236
benchmarks/kernels/benchmark_bitblas.py
Normal file
@ -0,0 +1,236 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
|
||||
MINIMUM_BITBLAS_VERSION)
|
||||
|
||||
try:
|
||||
import bitblas
|
||||
if bitblas.__version__ < MINIMUM_BITBLAS_VERSION:
|
||||
raise ImportError("bitblas version is wrong. Please "
|
||||
f"install bitblas>={MINIMUM_BITBLAS_VERSION}")
|
||||
except ImportError as e:
|
||||
bitblas_import_exception = e
|
||||
raise ValueError("Trying to use the bitblas backend, but could not import"
|
||||
f"with the following error: {bitblas_import_exception}. "
|
||||
"Please install bitblas through the following command: "
|
||||
f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`"
|
||||
) from bitblas_import_exception
|
||||
|
||||
from bitblas import Matmul, MatmulConfig, auto_detect_nvidia_target
|
||||
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark BitBLAS int4 on a specific target.")
|
||||
|
||||
# Add arguments to the parser
|
||||
parser.add_argument(
|
||||
"--target",
|
||||
type=str,
|
||||
default=auto_detect_nvidia_target(),
|
||||
help="Specify the target device for benchmarking.",
|
||||
)
|
||||
parser.add_argument("--group_size",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Group size for grouped quantization.")
|
||||
parser.add_argument(
|
||||
"--A_dtype",
|
||||
type=str,
|
||||
default="float16",
|
||||
choices=["float16", "float32", "float64", "int32", "int8"],
|
||||
help="Data type of activation A.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--W_dtype",
|
||||
type=str,
|
||||
default="int4",
|
||||
choices=[
|
||||
"float16",
|
||||
"float32",
|
||||
"float64",
|
||||
"int32",
|
||||
"int8",
|
||||
"int4",
|
||||
"int2",
|
||||
"int1",
|
||||
"nf4",
|
||||
"fp4_e2m1",
|
||||
],
|
||||
help="Data type of weight W.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--accum_dtype",
|
||||
type=str,
|
||||
default="float16",
|
||||
choices=["float16", "int32"],
|
||||
help="Data type for accumulation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--out_dtype",
|
||||
type=str,
|
||||
default="float16",
|
||||
choices=["float16", "float32", "int32", "int8"],
|
||||
help="Data type for output.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--layout",
|
||||
type=str,
|
||||
default="nt",
|
||||
choices=["nt", "nn"],
|
||||
help="Matrix layout, 'nt' for non-transpose A and transpose W.",
|
||||
)
|
||||
parser.add_argument("--with_bias",
|
||||
action="store_true",
|
||||
help="Include bias in the benchmark.")
|
||||
parser.add_argument(
|
||||
"--with_scaling",
|
||||
action="store_true",
|
||||
help="Include scaling factor in the quantization.",
|
||||
)
|
||||
parser.add_argument("--with_zeros",
|
||||
action="store_true",
|
||||
help="Include zeros in the quantization.")
|
||||
parser.add_argument(
|
||||
"--zeros_mode",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["original", "rescale", "quantized"],
|
||||
help="Specify the mode for calculating zeros.",
|
||||
)
|
||||
|
||||
# Parse the arguments
|
||||
args = parser.parse_args()
|
||||
|
||||
# Assign arguments to variables
|
||||
target = args.target
|
||||
A_dtype = args.A_dtype
|
||||
W_dtype = args.W_dtype
|
||||
accum_dtype = args.accum_dtype
|
||||
out_dtype = args.out_dtype
|
||||
layout = args.layout
|
||||
with_bias = args.with_bias
|
||||
group_size = args.group_size
|
||||
with_scaling = args.with_scaling
|
||||
with_zeros = args.with_zeros
|
||||
zeros_mode = args.zeros_mode
|
||||
|
||||
# Define a list of shared arguments that repeat in every config
|
||||
shared_args = [
|
||||
A_dtype,
|
||||
W_dtype,
|
||||
out_dtype,
|
||||
accum_dtype,
|
||||
layout,
|
||||
with_bias,
|
||||
group_size,
|
||||
with_scaling,
|
||||
with_zeros,
|
||||
zeros_mode,
|
||||
]
|
||||
|
||||
# Define just the (M, K, N) shapes in a more compact list
|
||||
shapes = [
|
||||
# square test
|
||||
(1, 16384, 16384),
|
||||
# BLOOM-176B
|
||||
(1, 43008, 14336),
|
||||
(1, 14336, 14336),
|
||||
(1, 57344, 14336),
|
||||
(1, 14336, 57344),
|
||||
# OPT-65B
|
||||
(1, 9216, 9216),
|
||||
(1, 36864, 9216),
|
||||
(1, 9216, 36864),
|
||||
(1, 22016, 8192),
|
||||
# LLAMA-70B/65B
|
||||
(1, 8192, 22016),
|
||||
(1, 8192, 8192),
|
||||
(1, 28672, 8192),
|
||||
(1, 8192, 28672),
|
||||
# square test
|
||||
(16384, 16384, 16384),
|
||||
# BLOOM-176B
|
||||
(8192, 43008, 14336),
|
||||
(8192, 14336, 14336),
|
||||
(8192, 57344, 14336),
|
||||
(8192, 14336, 57344),
|
||||
# OPT-65B
|
||||
(8192, 9216, 9216),
|
||||
(8192, 36864, 9216),
|
||||
(8192, 9216, 36864),
|
||||
(8192, 22016, 8192),
|
||||
# LLAMA-70B/65B
|
||||
(8192, 8192, 22016),
|
||||
(8192, 8192, 8192),
|
||||
(8192, 28672, 8192),
|
||||
(8192, 8192, 28672),
|
||||
]
|
||||
|
||||
# Build test shapes with all the shared arguments
|
||||
test_shapes = [(MatmulConfig, Matmul, (*shape, *shared_args))
|
||||
for shape in shapes]
|
||||
|
||||
benchmark_sets = []
|
||||
benchmark_sets.extend(test_shapes)
|
||||
|
||||
benchmark_results = {}
|
||||
for config_class, operator, input_args in benchmark_sets:
|
||||
config = config_class(*input_args)
|
||||
matmul = operator(config, target=target, enable_tuning=True)
|
||||
kernel_latency = matmul.profile_latency()
|
||||
|
||||
print("Time cost is: {:.3f} ms".format(kernel_latency))
|
||||
|
||||
profile_config = {
|
||||
f"{operator.__name__}-{'-'.join([str(i) for i in input_args])}": {
|
||||
"BitBLAS_top20_latency": kernel_latency,
|
||||
}
|
||||
}
|
||||
|
||||
benchmark_results.update(profile_config)
|
||||
|
||||
# Define headers for the table
|
||||
headers = [
|
||||
"PrimFunc",
|
||||
"Input Arguments",
|
||||
"BitBLAS Top20 Latency",
|
||||
]
|
||||
|
||||
# Calculate column widths for pretty printing
|
||||
col_widths = [0, 0, 0]
|
||||
for config_key, values in benchmark_results.items():
|
||||
args_split = config_key.split("-")
|
||||
func_name = args_split[0]
|
||||
input_args_str = "-".join(args_split[1:])
|
||||
col_widths[0] = max(col_widths[0], len(func_name) + 2, len(headers[0]) + 2)
|
||||
col_widths[1] = max(col_widths[1],
|
||||
len(input_args_str) + 2,
|
||||
len(headers[1]) + 2)
|
||||
col_widths[2] = max(col_widths[2],
|
||||
len(f"{values['BitBLAS_top20_latency']:.3f} ms") + 2,
|
||||
len(headers[2]) + 2)
|
||||
# break only if you want to measure widths from a single example;
|
||||
# otherwise, let it loop over all items.
|
||||
|
||||
# Print header
|
||||
for i, header in enumerate(headers):
|
||||
headers[i] = header.ljust(col_widths[i])
|
||||
print("".join(headers))
|
||||
print("-" * sum(col_widths))
|
||||
|
||||
# Print rows
|
||||
for config_key, values in benchmark_results.items():
|
||||
args_split = config_key.split("-")
|
||||
func_name = args_split[0]
|
||||
input_args_str = "-".join(args_split[1:])
|
||||
row = [
|
||||
func_name,
|
||||
input_args_str,
|
||||
f"{values['BitBLAS_top20_latency']:.3f} ms",
|
||||
]
|
||||
row_str = "".join(
|
||||
[str(cell).ljust(col_widths[idx]) for idx, cell in enumerate(row)])
|
||||
print(row_str)
|
||||
@ -90,7 +90,8 @@ def bench_run(results: list[benchmark.Measurement], model: str,
|
||||
|
||||
score = torch.randn((m, num_experts), device="cuda", dtype=dtype)
|
||||
|
||||
topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False)
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
a, score, topk, renormalize=False)
|
||||
|
||||
def run_triton_moe(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
|
||||
@ -17,8 +17,14 @@ from torch.utils.benchmark import Measurement as TMeasurement
|
||||
from utils import ArgPool, Bench, CudaGraphBenchParams
|
||||
from weight_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm.lora.ops.triton_ops import LoRAKernelMeta, lora_expand, lora_shrink
|
||||
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
if HAS_TRITON:
|
||||
from vllm.lora.ops.triton_ops import (LoRAKernelMeta, lora_expand,
|
||||
lora_shrink)
|
||||
from vllm.lora.ops.triton_ops.utils import (_LORA_A_PTR_DICT,
|
||||
_LORA_B_PTR_DICT)
|
||||
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
|
||||
|
||||
@ -115,8 +115,8 @@ def benchmark_config(config: BenchmarkConfig,
|
||||
from vllm.model_executor.layers.fused_moe import override_config
|
||||
with override_config(config):
|
||||
if use_deep_gemm:
|
||||
topk_weights, topk_ids = fused_topk(x, input_gating, topk,
|
||||
False)
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
x, input_gating, topk, False)
|
||||
return fused_experts(
|
||||
x,
|
||||
w1,
|
||||
@ -442,8 +442,14 @@ class BenchmarkWorker:
|
||||
hidden_size, search_space,
|
||||
is_fp16, topk)
|
||||
|
||||
with torch.cuda.device(self.device_id) if current_platform.is_rocm(
|
||||
) else nullcontext():
|
||||
need_device_guard = False
|
||||
if current_platform.is_rocm():
|
||||
visible_device = os.environ.get("ROCR_VISIBLE_DEVICES", None)
|
||||
if visible_device != f"{self.device_id}":
|
||||
need_device_guard = True
|
||||
|
||||
with torch.cuda.device(
|
||||
self.device_id) if need_device_guard else nullcontext():
|
||||
for config in tqdm(search_space):
|
||||
try:
|
||||
kernel_time = benchmark_config(
|
||||
@ -527,7 +533,7 @@ def get_weight_block_size_safety(config, default_value=None):
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
print(args)
|
||||
block_quant_shape = None
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
args.model, trust_remote_code=args.trust_remote_code)
|
||||
if config.architectures[0] == "DbrxForCausalLM":
|
||||
@ -546,16 +552,16 @@ def main(args: argparse.Namespace):
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
block_quant_shape = get_weight_block_size_safety(config)
|
||||
elif config.architectures[0] == "Qwen2MoeForCausalLM":
|
||||
elif config.architectures[0] in [
|
||||
"Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"
|
||||
]:
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
else:
|
||||
if not hasattr(config, "hidden_size"):
|
||||
# Support for llama4
|
||||
config = config.text_config
|
||||
# Support for llama4
|
||||
config = config.get_text_config()
|
||||
# Default: Mixtral.
|
||||
E = config.num_local_experts
|
||||
topk = config.num_experts_per_tok
|
||||
@ -566,6 +572,7 @@ def main(args: argparse.Namespace):
|
||||
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
|
||||
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
||||
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
||||
block_quant_shape = get_weight_block_size_safety(config)
|
||||
|
||||
if args.batch_size is None:
|
||||
batch_sizes = [
|
||||
@ -577,6 +584,15 @@ def main(args: argparse.Namespace):
|
||||
|
||||
use_deep_gemm = bool(args.use_deep_gemm)
|
||||
|
||||
if current_platform.is_rocm() and "HIP_VISIBLE_DEVICES" in os.environ:
|
||||
# Ray will set ROCR_VISIBLE_DEVICES for device visibility
|
||||
logger.warning(
|
||||
"Ray uses ROCR_VISIBLE_DEVICES to control device accessibility."
|
||||
"Replacing HIP_VISIBLE_DEVICES with ROCR_VISIBLE_DEVICES.")
|
||||
val = os.environ["HIP_VISIBLE_DEVICES"]
|
||||
os.environ["ROCR_VISIBLE_DEVICES"] = val
|
||||
del os.environ["HIP_VISIBLE_DEVICES"]
|
||||
|
||||
ray.init()
|
||||
num_gpus = int(ray.available_resources()["GPU"])
|
||||
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
|
||||
|
||||
349
benchmarks/kernels/benchmark_moe_permute_unpermute.py
Normal file
349
benchmarks/kernels/benchmark_moe_permute_unpermute.py
Normal file
@ -0,0 +1,349 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import argparse
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import ray
|
||||
import torch
|
||||
from transformers import AutoConfig
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
_moe_permute, _moe_unpermute_and_reduce)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import *
|
||||
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import *
|
||||
from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
class BenchmarkConfig(TypedDict):
|
||||
BLOCK_SIZE_M: int
|
||||
BLOCK_SIZE_N: int
|
||||
BLOCK_SIZE_K: int
|
||||
GROUP_SIZE_M: int
|
||||
num_warps: int
|
||||
num_stages: int
|
||||
|
||||
|
||||
def benchmark_permute(num_tokens: int,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
num_iters: int = 100,
|
||||
use_customized_permute: bool = False) -> float:
|
||||
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
||||
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
# output_hidden_states = torch.empty_like(hidden_states)
|
||||
if use_fp8_w8a8:
|
||||
align_block_size = 128 # deepgemm needs 128 m aligned block
|
||||
qhidden_states, scale = _fp8_quantize(hidden_states, None, None)
|
||||
else:
|
||||
align_block_size = None
|
||||
qhidden_states = hidden_states
|
||||
|
||||
gating_output = torch.randn(num_iters,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
dtype=torch.float32)
|
||||
|
||||
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
qhidden_states, input_gating, topk, False)
|
||||
|
||||
def prepare(i: int):
|
||||
input_gating.copy_(gating_output[i])
|
||||
|
||||
def run():
|
||||
if use_customized_permute:
|
||||
(permuted_hidden_states, first_token_off, inv_perm_idx,
|
||||
m_indices) = moe_permute(
|
||||
qhidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
token_expert_indices=token_expert_indices,
|
||||
topk=topk,
|
||||
n_expert=num_experts,
|
||||
n_local_expert=num_experts,
|
||||
expert_map=None,
|
||||
align_block_size=align_block_size,
|
||||
)
|
||||
else:
|
||||
(permuted_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
|
||||
inv_perm) = _moe_permute(qhidden_states, None, topk_ids,
|
||||
num_experts, None, align_block_size)
|
||||
|
||||
# JIT compilation & warmup
|
||||
run()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Capture 10 invocations with CUDA graph
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
for _ in range(10):
|
||||
run()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Warmup
|
||||
for _ in range(5):
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
latencies: list[float] = []
|
||||
for i in range(num_iters):
|
||||
prepare(i)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event.record()
|
||||
graph.replay()
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
latencies.append(start_event.elapsed_time(end_event))
|
||||
avg = sum(latencies) / (num_iters * 10) * 1000 # us
|
||||
graph.reset()
|
||||
return avg
|
||||
|
||||
|
||||
def benchmark_unpermute(num_tokens: int,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
num_iters: int = 100,
|
||||
use_customized_permute: bool = False) -> float:
|
||||
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
||||
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
output_hidden_states = torch.empty_like(hidden_states)
|
||||
if use_fp8_w8a8:
|
||||
align_block_size = 128 # deepgemm needs 128 m aligned block
|
||||
qhidden_states, scale = _fp8_quantize(hidden_states, None, None)
|
||||
else:
|
||||
align_block_size = None
|
||||
qhidden_states = hidden_states
|
||||
|
||||
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
|
||||
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
qhidden_states, input_gating, topk, False)
|
||||
|
||||
def prepare():
|
||||
if use_customized_permute:
|
||||
(permuted_hidden_states, first_token_off, inv_perm_idx,
|
||||
m_indices) = moe_permute(
|
||||
qhidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
token_expert_indices=token_expert_indices,
|
||||
topk=topk,
|
||||
n_expert=num_experts,
|
||||
n_local_expert=num_experts,
|
||||
expert_map=None,
|
||||
align_block_size=align_block_size,
|
||||
)
|
||||
# convert to fp16/bf16 as gemm output
|
||||
return (permuted_hidden_states.to(dtype), first_token_off,
|
||||
inv_perm_idx, m_indices)
|
||||
else:
|
||||
(permuted_qhidden_states, a1q_scale, sorted_token_ids, expert_ids,
|
||||
inv_perm) = _moe_permute(qhidden_states, None, topk_ids,
|
||||
num_experts, None, align_block_size)
|
||||
# convert to fp16/bf16 as gemm output
|
||||
return (permuted_qhidden_states.to(dtype), a1q_scale,
|
||||
sorted_token_ids, expert_ids, inv_perm)
|
||||
|
||||
def run(input: tuple):
|
||||
if use_customized_permute:
|
||||
(permuted_hidden_states, first_token_off, inv_perm_idx,
|
||||
m_indices) = input
|
||||
moe_unpermute(permuted_hidden_states, topk_weights, topk_ids,
|
||||
inv_perm_idx, first_token_off, topk, num_experts,
|
||||
num_experts)
|
||||
else:
|
||||
(permuted_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
|
||||
inv_perm) = input
|
||||
_moe_unpermute_and_reduce(output_hidden_states,
|
||||
permuted_hidden_states, inv_perm,
|
||||
topk_weights)
|
||||
|
||||
# JIT compilation & warmup
|
||||
input = prepare()
|
||||
run(input)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Capture 10 invocations with CUDA graph
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
for _ in range(10):
|
||||
run(input)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Warmup
|
||||
for _ in range(5):
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
latencies: list[float] = []
|
||||
for i in range(num_iters):
|
||||
torch.cuda.synchronize()
|
||||
start_event.record()
|
||||
graph.replay()
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
latencies.append(start_event.elapsed_time(end_event))
|
||||
avg = sum(latencies) / (num_iters * 10) * 1000 # us
|
||||
graph.reset()
|
||||
return avg
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
class BenchmarkWorker:
|
||||
|
||||
def __init__(self, seed: int) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
current_platform.seed_everything(seed)
|
||||
self.seed = seed
|
||||
# Get the device ID to allocate tensors and kernels
|
||||
# on the respective GPU. This is required for Ray to work
|
||||
# correctly with multi-GPU tuning on the ROCm platform.
|
||||
self.device_id = int(ray.get_gpu_ids()[0])
|
||||
|
||||
def benchmark(
|
||||
self,
|
||||
num_tokens: int,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_customized_permute: bool = False,
|
||||
) -> tuple[dict[str, int], float]:
|
||||
current_platform.seed_everything(self.seed)
|
||||
|
||||
permute_time = benchmark_permute(
|
||||
num_tokens,
|
||||
num_experts,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
num_iters=100,
|
||||
use_customized_permute=use_customized_permute)
|
||||
unpermute_time = benchmark_unpermute(
|
||||
num_tokens,
|
||||
num_experts,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
num_iters=100,
|
||||
use_customized_permute=use_customized_permute)
|
||||
return permute_time, unpermute_time
|
||||
|
||||
|
||||
def get_weight_block_size_safety(config, default_value=None):
|
||||
|
||||
quantization_config = getattr(config, 'quantization_config', {})
|
||||
if isinstance(quantization_config, dict):
|
||||
return quantization_config.get('weight_block_size', default_value)
|
||||
return default_value
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
print(args)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
args.model, trust_remote_code=args.trust_remote_code)
|
||||
if config.architectures[0] == "DbrxForCausalLM":
|
||||
E = config.ffn_config.moe_num_experts
|
||||
topk = config.ffn_config.moe_top_k
|
||||
elif config.architectures[0] == "JambaForCausalLM":
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
elif (config.architectures[0] == "DeepseekV3ForCausalLM"
|
||||
or config.architectures[0] == "DeepseekV2ForCausalLM"):
|
||||
E = config.n_routed_experts
|
||||
topk = config.num_experts_per_tok
|
||||
elif config.architectures[0] in [
|
||||
"Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"
|
||||
]:
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
|
||||
else:
|
||||
# Support for llama4
|
||||
config = config.get_text_config()
|
||||
# Default: Mixtral.
|
||||
E = config.num_local_experts
|
||||
topk = config.num_experts_per_tok
|
||||
|
||||
hidden_size = config.hidden_size
|
||||
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
|
||||
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
||||
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
||||
use_customized_permute = args.use_customized_permute
|
||||
|
||||
if args.batch_size is None:
|
||||
batch_sizes = [
|
||||
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
|
||||
2048, 3072, 4096
|
||||
]
|
||||
else:
|
||||
batch_sizes = [args.batch_size]
|
||||
|
||||
ray.init()
|
||||
num_gpus = int(ray.available_resources()["GPU"])
|
||||
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
|
||||
|
||||
def _distribute(method: str, inputs: list[Any]) -> list[Any]:
|
||||
outputs = []
|
||||
worker_idx = 0
|
||||
for input_args in inputs:
|
||||
worker = workers[worker_idx]
|
||||
worker_method = getattr(worker, method)
|
||||
output = worker_method.remote(*input_args)
|
||||
outputs.append(output)
|
||||
worker_idx = (worker_idx + 1) % num_gpus
|
||||
return ray.get(outputs)
|
||||
|
||||
outputs = _distribute(
|
||||
"benchmark", [(batch_size, E, hidden_size, topk, dtype, use_fp8_w8a8,
|
||||
use_int8_w8a16, use_customized_permute)
|
||||
for batch_size in batch_sizes])
|
||||
|
||||
for batch_size, (permute, unpermute) in zip(batch_sizes, outputs):
|
||||
print(f"Batch size: {batch_size}")
|
||||
print(f"Permute time: {permute:.2f} us")
|
||||
print(f"Unpermute time: {unpermute:.2f} us")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser()
|
||||
parser.add_argument("--model",
|
||||
type=str,
|
||||
default="mistralai/Mixtral-8x7B-Instruct-v0.1")
|
||||
parser.add_argument("--dtype",
|
||||
type=str,
|
||||
choices=["auto", "fp8_w8a8", "int8_w8a16"],
|
||||
default="auto")
|
||||
parser.add_argument("--use-customized-permute", action="store_true")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--batch-size", type=int, required=False)
|
||||
parser.add_argument("--trust-remote-code", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
@ -38,7 +38,7 @@ else()
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn
|
||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||
GIT_TAG dc9d410b3e2d6534a4c70724c2515f4def670a22
|
||||
GIT_TAG 8798f27777fb57f447070301bf33a9f9c607f491
|
||||
GIT_PROGRESS TRUE
|
||||
# Don't share the vllm-flash-attn build between build types
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
|
||||
178
csrc/attention/merge_attn_states.cu
Normal file
178
csrc/attention/merge_attn_states.cu
Normal file
@ -0,0 +1,178 @@
|
||||
#include <optional>
|
||||
#include <torch/all.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <algorithm>
|
||||
|
||||
#include "attention_dtypes.h"
|
||||
#include "attention_utils.cuh"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
|
||||
// can be used to combine partial attention results (in the split-KV case)
|
||||
template <typename scalar_t, const uint NUM_THREADS>
|
||||
__global__ void merge_attn_states_kernel(
|
||||
scalar_t* output, float* output_lse, const scalar_t* prefix_output,
|
||||
const float* prefix_lse, const scalar_t* suffix_output,
|
||||
const float* suffix_lse, const uint num_tokens, const uint num_heads,
|
||||
const uint head_size) {
|
||||
using pack_128b_t = uint4;
|
||||
const uint pack_size = 16 / sizeof(scalar_t);
|
||||
const uint threads_per_head = head_size / pack_size;
|
||||
|
||||
const uint global_idx = blockIdx.x * NUM_THREADS + threadIdx.x;
|
||||
const uint token_head_threads = num_tokens * num_heads * threads_per_head;
|
||||
|
||||
if (global_idx >= token_head_threads) return;
|
||||
|
||||
// global_idx -> token_idx + head_idx + pack_idx
|
||||
const uint token_head_idx = global_idx / threads_per_head;
|
||||
const uint pack_idx = global_idx % threads_per_head;
|
||||
|
||||
const uint token_idx = token_head_idx / num_heads;
|
||||
const uint head_idx = token_head_idx % num_heads;
|
||||
|
||||
const uint pack_offset = pack_idx * pack_size; // (0~15)*8, etc.
|
||||
const uint head_offset =
|
||||
token_idx * num_heads * head_size + head_idx * head_size;
|
||||
const scalar_t* prefix_head_ptr = prefix_output + head_offset;
|
||||
const scalar_t* suffix_head_ptr = suffix_output + head_offset;
|
||||
scalar_t* output_head_ptr = output + head_offset;
|
||||
|
||||
float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
|
||||
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
|
||||
p_lse = std::isinf(p_lse) ? -std::numeric_limits<float>::infinity() : p_lse;
|
||||
s_lse = std::isinf(s_lse) ? -std::numeric_limits<float>::infinity() : s_lse;
|
||||
|
||||
const float max_lse = fmaxf(p_lse, s_lse);
|
||||
p_lse = p_lse - max_lse;
|
||||
s_lse = s_lse - max_lse;
|
||||
const float p_se = expf(p_lse);
|
||||
const float s_se = expf(s_lse);
|
||||
const float out_se = p_se + s_se;
|
||||
const float p_scale = p_se / out_se;
|
||||
const float s_scale = s_se / out_se;
|
||||
|
||||
if (pack_offset < head_size) {
|
||||
// Pack 128b load
|
||||
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(
|
||||
prefix_head_ptr)[pack_offset / pack_size];
|
||||
pack_128b_t s_out_pack = reinterpret_cast<const pack_128b_t*>(
|
||||
suffix_head_ptr)[pack_offset / pack_size];
|
||||
pack_128b_t o_out_pack;
|
||||
|
||||
#pragma unroll
|
||||
for (uint i = 0; i < pack_size; ++i) {
|
||||
// Always use float for FMA to keep high precision.
|
||||
// half(uint16_t), bfloat16, float -> float.
|
||||
const float p_out_f =
|
||||
vllm::to_float(reinterpret_cast<const scalar_t*>(&p_out_pack)[i]);
|
||||
const float s_out_f =
|
||||
vllm::to_float(reinterpret_cast<const scalar_t*>(&s_out_pack)[i]);
|
||||
// fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale)
|
||||
const float o_out_f = p_out_f * p_scale + (s_out_f * s_scale);
|
||||
// float -> half(uint16_t), bfloat16, float.
|
||||
vllm::from_float(reinterpret_cast<scalar_t*>(&o_out_pack)[i], o_out_f);
|
||||
}
|
||||
|
||||
// Pack 128b storage
|
||||
reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] =
|
||||
o_out_pack;
|
||||
}
|
||||
// We only need to write to output_lse once per head.
|
||||
if (output_lse != nullptr && pack_idx == 0) {
|
||||
float out_lse = logf(out_se) + max_lse;
|
||||
output_lse[head_idx * num_tokens + token_idx] = out_lse;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
// The following macro is used to dispatch the conversion function based on
|
||||
// the output data type. The FN is a macro that calls a function with
|
||||
// template<typename scalar_t>.
|
||||
#define DISPATCH_BY_SCALAR_DTYPE(scalar_dtype, fn) \
|
||||
{ \
|
||||
if (scalar_dtype == at::ScalarType::Float) { \
|
||||
fn(float); \
|
||||
} else if (scalar_dtype == at::ScalarType::Half) { \
|
||||
fn(uint16_t); \
|
||||
} else if (scalar_dtype == at::ScalarType::BFloat16) { \
|
||||
fn(__nv_bfloat16); \
|
||||
} else { \
|
||||
TORCH_CHECK(false, "Unsupported data type of O: ", scalar_dtype); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \
|
||||
{ \
|
||||
vllm::merge_attn_states_kernel<scalar_t, NUM_THREADS> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<scalar_t*>(output.data_ptr()), output_lse_ptr, \
|
||||
reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \
|
||||
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
|
||||
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
|
||||
reinterpret_cast<float*>(suffix_lse.data_ptr()), num_tokens, \
|
||||
num_heads, head_size); \
|
||||
}
|
||||
|
||||
/*@brief Merges the attention states from prefix and suffix
|
||||
* into the output tensor. NUM_TOKENS: n, NUM_HEADS: h, HEAD_SIZE: d
|
||||
*
|
||||
* @param output [n,h,d] The output tensor to store the merged attention states.
|
||||
* @param output_lse [h,d] Optional tensor to store the log-sum-exp values.
|
||||
* @param prefix_output [n,h,d] The prefix attention states.
|
||||
* @param prefix_lse [h,n] The log-sum-exp values for the prefix attention
|
||||
* states.
|
||||
* @param suffix_output [n,h,d] The suffix attention states.
|
||||
* @param suffix_lse [h,n] The log-sum-exp values for the suffix attention
|
||||
* states.
|
||||
*/
|
||||
template <typename scalar_t>
|
||||
void merge_attn_states_launcher(torch::Tensor& output,
|
||||
std::optional<torch::Tensor> output_lse,
|
||||
const torch::Tensor& prefix_output,
|
||||
const torch::Tensor& prefix_lse,
|
||||
const torch::Tensor& suffix_output,
|
||||
const torch::Tensor& suffix_lse) {
|
||||
constexpr uint NUM_THREADS = 128;
|
||||
const uint num_tokens = output.size(0);
|
||||
const uint num_heads = output.size(1);
|
||||
const uint head_size = output.size(2);
|
||||
const uint pack_size = 16 / sizeof(scalar_t);
|
||||
TORCH_CHECK(head_size % pack_size == 0,
|
||||
"headsize must be multiple of pack_size:", pack_size);
|
||||
float* output_lse_ptr = nullptr;
|
||||
if (output_lse.has_value()) {
|
||||
output_lse_ptr = output_lse.value().data_ptr<float>();
|
||||
}
|
||||
// Process one pack elements per thread. for float, the
|
||||
// pack_size is 4 for half/bf16, the pack_size is 8.
|
||||
const uint threads_per_head = head_size / pack_size;
|
||||
const uint total_threads = num_tokens * num_heads * threads_per_head;
|
||||
|
||||
dim3 block(NUM_THREADS);
|
||||
dim3 grid((total_threads + NUM_THREADS - 1) / NUM_THREADS);
|
||||
|
||||
const c10::cuda::OptionalCUDAGuard device_guard(prefix_output.device());
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS);
|
||||
}
|
||||
|
||||
#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \
|
||||
{ \
|
||||
merge_attn_states_launcher<scalar_t>(output, output_lse, prefix_output, \
|
||||
prefix_lse, suffix_output, \
|
||||
suffix_lse); \
|
||||
}
|
||||
|
||||
void merge_attn_states(torch::Tensor& output,
|
||||
std::optional<torch::Tensor> output_lse,
|
||||
const torch::Tensor& prefix_output,
|
||||
const torch::Tensor& prefix_lse,
|
||||
const torch::Tensor& suffix_output,
|
||||
const torch::Tensor& suffix_lse) {
|
||||
DISPATCH_BY_SCALAR_DTYPE(output.dtype(), CALL_MERGE_ATTN_STATES_LAUNCHER);
|
||||
}
|
||||
38
csrc/attention/mla/cutlass_mla_entry.cu
Normal file
38
csrc/attention/mla/cutlass_mla_entry.cu
Normal file
@ -0,0 +1,38 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#if defined ENABLE_CUTLASS_MLA && ENABLE_CUTLASS_MLA
|
||||
void cutlass_mla_decode_sm100a(torch::Tensor const& out,
|
||||
torch::Tensor const& q_nope,
|
||||
torch::Tensor const& q_pe,
|
||||
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||
torch::Tensor const& seq_lens,
|
||||
torch::Tensor const& page_table, double scale);
|
||||
#endif
|
||||
|
||||
void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
|
||||
torch::Tensor const& q_pe,
|
||||
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||
torch::Tensor const& seq_lens,
|
||||
torch::Tensor const& page_table, double scale) {
|
||||
#if defined ENABLE_CUTLASS_MLA && ENABLE_CUTLASS_MLA
|
||||
return cutlass_mla_decode_sm100a(out, q_nope, q_pe, kv_c_and_k_pe_cache,
|
||||
seq_lens, page_table, scale);
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled cutlass MLA");
|
||||
}
|
||||
225
csrc/attention/mla/cutlass_mla_kernels.cu
Normal file
225
csrc/attention/mla/cutlass_mla_kernels.cu
Normal file
@ -0,0 +1,225 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/kernel_hardware_info.h"
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
|
||||
#include "device/sm100_mla.hpp"
|
||||
#include "kernel/sm100_mla_tile_scheduler.hpp"
|
||||
|
||||
using namespace cute;
|
||||
using namespace cutlass::fmha::kernel;
|
||||
|
||||
template <typename T, bool PersistenceOption = true>
|
||||
struct MlaSm100 {
|
||||
using Element = T;
|
||||
using ElementAcc = float;
|
||||
using ElementOut = T;
|
||||
|
||||
using TileShape = Shape<_128, _128, Shape<_512, _64>>;
|
||||
using TileShapeH = cute::tuple_element_t<0, TileShape>;
|
||||
using TileShapeD = cute::tuple_element_t<2, TileShape>;
|
||||
|
||||
// H K (D_latent D_rope) B
|
||||
using ProblemShape = cute::tuple<TileShapeH, int, TileShapeD, int>;
|
||||
|
||||
using StrideQ = cute::tuple<int64_t, _1, int64_t>; // H D B
|
||||
using StrideK = cute::tuple<int64_t, _1, int64_t>; // K D B
|
||||
using StrideO = StrideK; // H D B
|
||||
using StrideLSE = cute::tuple<_1, int>; // H B
|
||||
|
||||
using TileScheduler =
|
||||
std::conditional_t<PersistenceOption, Sm100MlaPersistentTileScheduler,
|
||||
Sm100MlaIndividualTileScheduler>;
|
||||
|
||||
using FmhaKernel =
|
||||
cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized<
|
||||
TileShape, Element, ElementAcc, ElementOut, ElementAcc, TileScheduler,
|
||||
/*kIsCpAsync=*/true>;
|
||||
using Fmha = cutlass::fmha::device::MLA<FmhaKernel>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
typename T::Fmha::Arguments args_from_options(
|
||||
at::Tensor const& out, at::Tensor const& q_nope, at::Tensor const& q_pe,
|
||||
at::Tensor const& kv_c_and_k_pe_cache, at::Tensor const& seq_lens,
|
||||
at::Tensor const& page_table, double scale) {
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = q_nope.device().index();
|
||||
hw_info.sm_count =
|
||||
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
|
||||
hw_info.device_id);
|
||||
|
||||
int batches = q_nope.sizes()[0];
|
||||
int page_count_per_seq = page_table.sizes()[1];
|
||||
int page_count_total = kv_c_and_k_pe_cache.sizes()[0];
|
||||
int page_size = kv_c_and_k_pe_cache.sizes()[1];
|
||||
int max_seq_len = page_size * page_count_per_seq;
|
||||
using TileShapeH = typename T::TileShapeH;
|
||||
using TileShapeD = typename T::TileShapeD;
|
||||
auto problem_shape =
|
||||
cute::make_tuple(TileShapeH{}, max_seq_len, TileShapeD{}, batches);
|
||||
|
||||
auto [H, K, D, B] = problem_shape;
|
||||
auto [D_latent, D_rope] = D;
|
||||
|
||||
using StrideQ = typename T::StrideQ;
|
||||
using StrideK = typename T::StrideK;
|
||||
using StrideO = typename T::StrideO;
|
||||
using StrideLSE = typename T::StrideLSE;
|
||||
|
||||
StrideQ stride_Q_latent = cute::make_tuple(
|
||||
static_cast<int64_t>(D_latent), _1{}, static_cast<int64_t>(H * D_latent));
|
||||
StrideQ stride_Q_rope = cute::make_tuple(static_cast<int64_t>(D_rope), _1{},
|
||||
static_cast<int64_t>(H * D_rope));
|
||||
StrideK stride_C =
|
||||
cute::make_tuple(static_cast<int64_t>(D_latent + D_rope), _1{},
|
||||
static_cast<int64_t>(page_size * (D_latent + D_rope)));
|
||||
StrideLSE stride_PT = cute::make_stride(_1{}, page_count_per_seq);
|
||||
StrideLSE stride_LSE = cute::make_tuple(_1{}, static_cast<int>(H));
|
||||
StrideO stride_O = cute::make_tuple(static_cast<int64_t>(D_latent), _1{},
|
||||
static_cast<int64_t>(H * D_latent));
|
||||
|
||||
using Element = typename T::Element;
|
||||
using ElementOut = typename T::ElementOut;
|
||||
using ElementAcc = typename T::ElementAcc;
|
||||
auto Q_latent_ptr = static_cast<Element*>(q_nope.data_ptr());
|
||||
auto Q_rope_ptr = static_cast<Element*>(q_pe.data_ptr());
|
||||
auto C_ptr = static_cast<Element*>(kv_c_and_k_pe_cache.data_ptr());
|
||||
auto scale_f = static_cast<float>(scale);
|
||||
typename T::Fmha::Arguments arguments{
|
||||
problem_shape,
|
||||
{scale_f, Q_latent_ptr, stride_Q_latent, Q_rope_ptr, stride_Q_rope, C_ptr,
|
||||
stride_C, C_ptr + D_latent, stride_C,
|
||||
static_cast<int*>(seq_lens.data_ptr()),
|
||||
static_cast<int*>(page_table.data_ptr()), stride_PT, page_count_total,
|
||||
page_size},
|
||||
{static_cast<ElementOut*>(out.data_ptr()), stride_O,
|
||||
static_cast<ElementAcc*>(nullptr), stride_LSE},
|
||||
hw_info,
|
||||
-1, // split_kv
|
||||
nullptr, // is_var_split_kv
|
||||
};
|
||||
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
|
||||
// split_kv automatically based on batch size and sequence length to balance
|
||||
// workload across available SMs. Consider using var_split_kv for manual
|
||||
// control if needed.
|
||||
T::Fmha::set_split_kv(arguments);
|
||||
return arguments;
|
||||
}
|
||||
|
||||
template <typename Element>
|
||||
void runMla(at::Tensor const& out, at::Tensor const& q_nope,
|
||||
at::Tensor const& q_pe, at::Tensor const& kv_c_and_k_pe_cache,
|
||||
at::Tensor const& seq_lens, at::Tensor const& page_table,
|
||||
float scale, cudaStream_t stream) {
|
||||
using MlaSm100Type = MlaSm100<Element>;
|
||||
typename MlaSm100Type::Fmha fmha;
|
||||
auto arguments = args_from_options<MlaSm100Type>(
|
||||
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, scale);
|
||||
size_t workspace_size = MlaSm100Type::Fmha::get_workspace_size(arguments);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(q_nope.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
CUTLASS_CHECK(fmha.can_implement(arguments));
|
||||
|
||||
CUTLASS_CHECK(fmha.initialize(arguments, workspace.data_ptr(), stream));
|
||||
|
||||
CUTLASS_CHECK(fmha.run(arguments, workspace.data_ptr(), stream));
|
||||
}
|
||||
|
||||
void cutlass_mla_decode_sm100a(torch::Tensor const& out,
|
||||
torch::Tensor const& q_nope,
|
||||
torch::Tensor const& q_pe,
|
||||
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||
torch::Tensor const& seq_lens,
|
||||
torch::Tensor const& page_table, double scale) {
|
||||
TORCH_CHECK(q_nope.device().is_cuda(), "q_nope must be on CUDA");
|
||||
TORCH_CHECK(q_nope.dim() == 3, "q_nope must be a 3D tensor");
|
||||
TORCH_CHECK(q_pe.dim() == 3, "q_pe must be a 3D tensor");
|
||||
TORCH_CHECK(kv_c_and_k_pe_cache.dim() == 3,
|
||||
"kv_c_and_k_pe_cache must be a 3D tensor");
|
||||
TORCH_CHECK(seq_lens.dim() == 1, "seq_lens must be a 1D tensor");
|
||||
TORCH_CHECK(page_table.dim() == 2, "page_table must be a 2D tensor");
|
||||
TORCH_CHECK(out.dim() == 3, "out must be a 3D tensor");
|
||||
|
||||
auto B_q_nope = q_nope.size(0);
|
||||
auto H_q_nope = q_nope.size(1);
|
||||
auto D_q_nope = q_nope.size(2);
|
||||
auto B_q_pe = q_pe.size(0);
|
||||
auto H_q_pe = q_pe.size(1);
|
||||
auto D_q_pe = q_pe.size(2);
|
||||
auto B_pt = page_table.size(0);
|
||||
auto PAGE_NUM = page_table.size(1);
|
||||
auto PAGE_SIZE = kv_c_and_k_pe_cache.size(1);
|
||||
auto D_ckv = kv_c_and_k_pe_cache.size(2);
|
||||
auto B_o = out.size(0);
|
||||
auto H_o = out.size(1);
|
||||
auto D_o = out.size(2);
|
||||
|
||||
TORCH_CHECK(D_q_nope == 512, "D_q_nope must be equal to 512");
|
||||
TORCH_CHECK(D_q_pe == 64, "D_q_pe must be equal to 64");
|
||||
TORCH_CHECK(D_ckv == 576, "D_ckv must be equal to 576");
|
||||
TORCH_CHECK(H_q_nope == H_q_pe && H_q_nope == H_o && H_o == 128,
|
||||
"H_q_nope, H_q_pe, and H_o must be equal to 128");
|
||||
TORCH_CHECK(PAGE_SIZE > 0 && (PAGE_SIZE & (PAGE_SIZE - 1)) == 0,
|
||||
"PAGE_SIZE must be a power of 2");
|
||||
TORCH_CHECK(
|
||||
B_q_nope == B_q_pe && B_q_nope == B_pt && B_q_nope == B_o,
|
||||
"Batch dims must be same for page_table, q_nope and q_pe, and out");
|
||||
TORCH_CHECK(PAGE_NUM % (128 / PAGE_SIZE) == 0,
|
||||
"PAGE_NUM must be divisible by 128 / PAGE_SIZE");
|
||||
TORCH_CHECK(D_o == 512, "D_o must be equal to 512");
|
||||
|
||||
TORCH_CHECK(q_nope.dtype() == at::ScalarType::Half ||
|
||||
q_nope.dtype() == at::ScalarType::BFloat16 ||
|
||||
q_nope.dtype() == at::ScalarType::Float8_e4m3fn,
|
||||
"q_nope must be a half, bfloat16, or float8_e4m3fn tensor");
|
||||
TORCH_CHECK(kv_c_and_k_pe_cache.dtype() == q_nope.dtype() &&
|
||||
q_nope.dtype() == q_pe.dtype(),
|
||||
"kv_c_and_k_pe_cache, q_nope, and q_pe must be the same type");
|
||||
TORCH_CHECK(seq_lens.dtype() == torch::kInt32,
|
||||
"seq_lens must be a 32-bit integer tensor");
|
||||
TORCH_CHECK(page_table.dtype() == torch::kInt32,
|
||||
"page_table must be a 32-bit integer tensor");
|
||||
|
||||
auto in_dtype = q_nope.dtype();
|
||||
at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()};
|
||||
const cudaStream_t stream =
|
||||
at::cuda::getCurrentCUDAStream(q_nope.get_device());
|
||||
if (in_dtype == at::ScalarType::Half) {
|
||||
runMla<cutlass::half_t>(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens,
|
||||
page_table, scale, stream);
|
||||
} else if (in_dtype == at::ScalarType::BFloat16) {
|
||||
runMla<cutlass::bfloat16_t>(out, q_nope, q_pe, kv_c_and_k_pe_cache,
|
||||
seq_lens, page_table, scale, stream);
|
||||
} else if (in_dtype == at::ScalarType::Float8_e4m3fn) {
|
||||
runMla<cutlass::float_e4m3_t>(out, q_nope, q_pe, kv_c_and_k_pe_cache,
|
||||
seq_lens, page_table, scale, stream);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported input data type of MLA");
|
||||
}
|
||||
}
|
||||
@ -270,9 +270,10 @@ __global__ void reshape_and_cache_flash_kernel(
|
||||
cache_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads,
|
||||
// head_size]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int block_stride, const int key_stride, const int value_stride,
|
||||
const int num_heads, const int head_size, const int block_size,
|
||||
const float* k_scale, const float* v_scale) {
|
||||
const int64_t block_stride, const int64_t page_stride,
|
||||
const int64_t head_stride, const int64_t key_stride,
|
||||
const int64_t value_stride, const int num_heads, const int head_size,
|
||||
const int block_size, const float* k_scale, const float* v_scale) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
// NOTE: slot_idx can be -1 if the token is padded
|
||||
@ -288,8 +289,8 @@ __global__ void reshape_and_cache_flash_kernel(
|
||||
const int head_idx = i / head_size;
|
||||
const int head_offset = i % head_size;
|
||||
const int64_t tgt_key_value_idx = block_idx * block_stride +
|
||||
block_offset * num_heads * head_size +
|
||||
head_idx * head_size + head_offset;
|
||||
block_offset * page_stride +
|
||||
head_idx * head_stride + head_offset;
|
||||
scalar_t tgt_key = key[src_key_idx];
|
||||
scalar_t tgt_value = value[src_value_idx];
|
||||
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
|
||||
@ -396,16 +397,16 @@ void reshape_and_cache(
|
||||
// KV_T is the data type of key and value tensors.
|
||||
// CACHE_T is the stored data type of kv-cache.
|
||||
// KV_DTYPE is the real data type of kv-cache.
|
||||
#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
||||
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
||||
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride, \
|
||||
value_stride, num_heads, head_size, block_size, \
|
||||
reinterpret_cast<const float*>(k_scale.data_ptr()), \
|
||||
#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
||||
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
||||
slot_mapping.data_ptr<int64_t>(), block_stride, page_stride, \
|
||||
head_stride, key_stride, value_stride, num_heads, head_size, \
|
||||
block_size, reinterpret_cast<const float*>(k_scale.data_ptr()), \
|
||||
reinterpret_cast<const float*>(v_scale.data_ptr()));
|
||||
|
||||
void reshape_and_cache_flash(
|
||||
@ -432,9 +433,11 @@ void reshape_and_cache_flash(
|
||||
int head_size = key.size(2);
|
||||
int block_size = key_cache.size(1);
|
||||
|
||||
int key_stride = key.stride(0);
|
||||
int value_stride = value.stride(0);
|
||||
int block_stride = key_cache.stride(0);
|
||||
int64_t key_stride = key.stride(0);
|
||||
int64_t value_stride = value.stride(0);
|
||||
int64_t block_stride = key_cache.stride(0);
|
||||
int64_t page_stride = key_cache.stride(1);
|
||||
int64_t head_stride = key_cache.stride(2);
|
||||
TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0));
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
|
||||
@ -7,3 +7,22 @@ inline constexpr uint32_t next_pow_2(uint32_t const num) {
|
||||
if (num <= 1) return num;
|
||||
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
|
||||
}
|
||||
|
||||
template <typename A, typename B>
|
||||
static inline constexpr auto div_ceil(A a, B b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
// Round a down to the next multiple of b. The caller is responsible for making
|
||||
// sure that b is non-zero
|
||||
template <typename T>
|
||||
inline constexpr T round_to_previous_multiple_of(T a, T b) {
|
||||
return a % b == 0 ? a : (a / b) * b;
|
||||
}
|
||||
|
||||
// Round a up to the next multiple of b. The caller is responsible for making
|
||||
// sure that b is non-zero
|
||||
template <typename T>
|
||||
inline constexpr T round_to_next_multiple_of(T a, T b) {
|
||||
return a % b == 0 ? a : ((a / b) + 1) * b;
|
||||
}
|
||||
|
||||
@ -4,6 +4,11 @@
|
||||
#include <string>
|
||||
#include <sched.h>
|
||||
#endif
|
||||
#if __GLIBC__ == 2 && __GLIBC_MINOR__ < 30
|
||||
#include <unistd.h>
|
||||
#include <sys/syscall.h>
|
||||
#define gettid() syscall(SYS_gettid)
|
||||
#endif
|
||||
|
||||
#include "cpu_types.hpp"
|
||||
|
||||
|
||||
@ -375,7 +375,7 @@ class CustomAllreduce {
|
||||
bool fully_connected_;
|
||||
|
||||
RankSignals sg_;
|
||||
// Stores an map from a pointer to its peer pointers from all ranks.
|
||||
// Stores a map from a pointer to its peer pointers from all ranks.
|
||||
std::unordered_map<void*, RankData*> buffers_;
|
||||
Signal* self_sg_;
|
||||
|
||||
|
||||
@ -422,7 +422,7 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
|
||||
int final_state_position = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize);
|
||||
// in case the final state is separated between the last "smem_exchange" and
|
||||
// and the one before it (chunk = n_chunks - 1 and chunk = n_chunks - 2),
|
||||
// (which occurs when `final_state_position` is a non-positivie index)
|
||||
// (which occurs when `final_state_position` is a non-positive index)
|
||||
// we load the correct data from smem_exchange from both chunks, the last chunk iteration and the one before it
|
||||
if (conv_states != nullptr && final_state_position < 0 && seqlen > kWidth){
|
||||
input_t vals_load[kNElts] = {0};
|
||||
|
||||
@ -138,8 +138,8 @@ __device__ inline FragB dequant<vllm::kU4B8.id()>(int q) {
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
|
||||
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
|
||||
// directly into `SUB` and `ADD`.
|
||||
const int SUB = 0x64086408;
|
||||
@ -182,8 +182,8 @@ __device__ inline FragB dequant<vllm::kU4.id()>(int q) {
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
|
||||
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
|
||||
const int SUB = 0x64006400;
|
||||
const int MUL = 0x2c002c00;
|
||||
|
||||
103
csrc/moe/marlin_moe_wna16/generate_kernels.py
Normal file
103
csrc/moe/marlin_moe_wna16/generate_kernels.py
Normal file
@ -0,0 +1,103 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import glob
|
||||
import itertools
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
import jinja2
|
||||
|
||||
FILE_HEAD = """
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
""".strip()
|
||||
|
||||
TEMPLATE = ("template __global__ void Marlin<"
|
||||
"{{scalar_t}}, "
|
||||
"{{w_type_id}}, "
|
||||
"{{threads}}, "
|
||||
"{{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, "
|
||||
"{{thread_k_blocks}}, "
|
||||
"{{'true' if m_block_size_8 else 'false'}}, "
|
||||
"{{stages}}, "
|
||||
"{{'true' if has_act_order else 'false'}}, "
|
||||
"{{'true' if has_zp else 'false'}}, "
|
||||
"{{group_blocks}}, "
|
||||
"{{'true' if is_zp_float else 'false'}}>"
|
||||
"( MARLIN_KERNEL_PARAMS );")
|
||||
|
||||
# int8 with zero point case (vllm::kU8) is also supported,
|
||||
# we don't add it to reduce wheel size.
|
||||
SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128"]
|
||||
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
|
||||
|
||||
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
|
||||
# group_blocks:
|
||||
# = 0 : act order case
|
||||
# = -1 : channelwise quantization
|
||||
# > 0 : group_size=16*group_blocks
|
||||
GROUP_BLOCKS = [0, -1, 2, 4, 8]
|
||||
DTYPES = ["fp16", "bf16"]
|
||||
|
||||
|
||||
def remove_old_kernels():
|
||||
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"):
|
||||
subprocess.call(["rm", "-f", filename])
|
||||
|
||||
|
||||
def generate_new_kernels():
|
||||
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
|
||||
has_zp = "B" not in scalar_type
|
||||
all_template_str_list = []
|
||||
|
||||
for group_blocks, m_blocks, thread_configs in itertools.product(
|
||||
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS):
|
||||
|
||||
has_act_order = group_blocks == 0
|
||||
if has_zp and has_act_order:
|
||||
continue
|
||||
if thread_configs[2] == 256:
|
||||
if m_blocks <= 1 and thread_configs[0] != 128:
|
||||
continue
|
||||
if m_blocks > 1 and thread_configs[0] != 64:
|
||||
continue
|
||||
|
||||
k_blocks = thread_configs[0] // 16
|
||||
n_blocks = thread_configs[1] // 16
|
||||
threads = thread_configs[2]
|
||||
|
||||
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
|
||||
|
||||
template_str = jinja2.Template(TEMPLATE).render(
|
||||
scalar_t=c_dtype,
|
||||
w_type_id=scalar_type + ".id()",
|
||||
threads=threads,
|
||||
thread_m_blocks=max(m_blocks, 1),
|
||||
thread_n_blocks=n_blocks,
|
||||
thread_k_blocks=k_blocks,
|
||||
m_block_size_8=m_blocks == 0.5,
|
||||
stages="pipe_stages",
|
||||
has_act_order=has_act_order,
|
||||
has_zp=has_zp,
|
||||
group_blocks=group_blocks,
|
||||
is_zp_float=False,
|
||||
)
|
||||
|
||||
all_template_str_list.append(template_str)
|
||||
|
||||
file_content = FILE_HEAD + "\n\n"
|
||||
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
|
||||
filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu"
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
|
||||
f.write(file_content)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
remove_old_kernels()
|
||||
generate_new_kernels()
|
||||
44
csrc/moe/marlin_moe_wna16/kernel.h
Normal file
44
csrc/moe/marlin_moe_wna16/kernel.h
Normal file
@ -0,0 +1,44 @@
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#endif
|
||||
|
||||
#include "quantization/gptq_marlin/marlin.cuh"
|
||||
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
#define MARLIN_KERNEL_PARAMS \
|
||||
const int4 *__restrict__ A, const int4 *__restrict__ B, \
|
||||
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
|
||||
const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \
|
||||
const int *__restrict__ g_idx, \
|
||||
const int32_t *__restrict__ sorted_token_ids_ptr, \
|
||||
const int32_t *__restrict__ expert_ids_ptr, \
|
||||
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
|
||||
const float *__restrict__ topk_weights_ptr, int top_k, \
|
||||
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
|
||||
int prob_n, int prob_k, int *locks, bool use_atomic_add, \
|
||||
bool use_fp32_reduce
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
||||
const int threads, // number of threads in a threadblock
|
||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||
// dimension (batchsize) of the
|
||||
// threadblock
|
||||
const int thread_n_blocks, // same for n dimension (output)
|
||||
const int thread_k_blocks, // same for k dimension (reduction)
|
||||
const bool m_block_size_8, // whether m_block_size == 8
|
||||
// only works when thread_m_blocks == 1
|
||||
const int stages, // number of stages for the async global->shared
|
||||
// fetch pipeline
|
||||
const bool has_act_order, // whether act_order is enabled
|
||||
const bool has_zp, // whether zero-points are enabled
|
||||
const int group_blocks, // number of consecutive 16x16 blocks
|
||||
// with a separate quantization scale
|
||||
const bool is_zp_float // is zero point of float16 type?
|
||||
>
|
||||
__global__ void Marlin(MARLIN_KERNEL_PARAMS);
|
||||
|
||||
}
|
||||
1917
csrc/moe/marlin_moe_wna16/marlin_template.h
Normal file
1917
csrc/moe/marlin_moe_wna16/marlin_template.h
Normal file
File diff suppressed because it is too large
Load Diff
927
csrc/moe/marlin_moe_wna16/ops.cu
Normal file
927
csrc/moe/marlin_moe_wna16/ops.cu
Normal file
@ -0,0 +1,927 @@
|
||||
/*
|
||||
* Modified by Neural Magic
|
||||
* Copyright (C) Marlin.2024 Elias Frantar
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/*
|
||||
* Adapted from https://github.com/IST-DASLab/marlin
|
||||
*/
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#endif
|
||||
|
||||
#include "kernel.h"
|
||||
#include "core/registration.h"
|
||||
|
||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||
static_assert(std::is_same<scalar_t, half>::value || \
|
||||
std::is_same<scalar_t, nv_bfloat16>::value, \
|
||||
"only float16 and bfloat16 is supported");
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
__global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){};
|
||||
|
||||
using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS);
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
|
||||
template <int moe_block_size>
|
||||
__global__ void permute_cols_kernel(
|
||||
int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr,
|
||||
int4* __restrict__ out_int4_ptr,
|
||||
const int32_t* __restrict__ sorted_token_ids_ptr,
|
||||
const int32_t* __restrict__ expert_ids_ptr,
|
||||
const int32_t* __restrict__ num_tokens_past_padded_ptr, int size_m,
|
||||
int size_k, int top_k) {};
|
||||
|
||||
} // namespace marlin
|
||||
|
||||
torch::Tensor moe_wna16_marlin_gemm(
|
||||
torch::Tensor& a, std::optional<torch::Tensor> const& c_or_none,
|
||||
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
|
||||
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
|
||||
torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids,
|
||||
torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights,
|
||||
int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep,
|
||||
vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n,
|
||||
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
|
||||
bool is_zp_float) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||
"marlin_gemm(..) requires CUDA_ARCH >= 8.0");
|
||||
return torch::empty({1, 1});
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
// For a given "a" of size [M,K] performs a permutation of the K columns based
|
||||
// on the given "perm" indices.
|
||||
template <int moe_block_size>
|
||||
__global__ void permute_cols_kernel(
|
||||
int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr,
|
||||
int4* __restrict__ out_int4_ptr,
|
||||
const int32_t* __restrict__ sorted_token_ids_ptr,
|
||||
const int32_t* __restrict__ expert_ids_ptr,
|
||||
const int32_t* __restrict__ num_tokens_past_padded_ptr, int size_m,
|
||||
int size_k, int top_k) {
|
||||
int num_tokens_past_padded = num_tokens_past_padded_ptr[0];
|
||||
int num_moe_blocks = div_ceil(num_tokens_past_padded, moe_block_size);
|
||||
int32_t block_sorted_ids[moe_block_size];
|
||||
int block_num_valid_tokens = 0;
|
||||
int64_t old_expert_id = 0;
|
||||
int64_t expert_id = 0;
|
||||
int row_stride = size_k * sizeof(half) / 16;
|
||||
|
||||
auto read_moe_block_data = [&](int block_id) {
|
||||
block_num_valid_tokens = moe_block_size;
|
||||
int4* tmp_block_sorted_ids = reinterpret_cast<int4*>(block_sorted_ids);
|
||||
for (int i = 0; i < moe_block_size / 4; i++) {
|
||||
tmp_block_sorted_ids[i] =
|
||||
((int4*)sorted_token_ids_ptr)[block_id * moe_block_size / 4 + i];
|
||||
}
|
||||
for (int i = 0; i < moe_block_size; i++) {
|
||||
if (block_sorted_ids[i] >= size_m * top_k) {
|
||||
block_num_valid_tokens = i;
|
||||
break;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
auto permute_row = [&](int row) {
|
||||
int iters = size_k / default_threads;
|
||||
int rest = size_k % default_threads;
|
||||
|
||||
int in_offset = (row / top_k) * row_stride;
|
||||
int out_offset = row * row_stride;
|
||||
|
||||
half const* a_row_half =
|
||||
reinterpret_cast<half const*>(a_int4_ptr + in_offset);
|
||||
half* out_half = reinterpret_cast<half*>(out_int4_ptr + out_offset);
|
||||
|
||||
int base_k = 0;
|
||||
|
||||
for (int i = 0; i < iters; i++) {
|
||||
int cur_k = base_k + threadIdx.x;
|
||||
int src_pos = perm_int_ptr[cur_k];
|
||||
|
||||
out_half[cur_k] = a_row_half[src_pos];
|
||||
|
||||
base_k += default_threads;
|
||||
}
|
||||
|
||||
if (rest) {
|
||||
if (threadIdx.x < rest) {
|
||||
int cur_k = base_k + threadIdx.x;
|
||||
int src_pos = perm_int_ptr[cur_k];
|
||||
|
||||
out_half[cur_k] = a_row_half[src_pos];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
for (int index = blockIdx.x; index < num_moe_blocks; index += gridDim.x) {
|
||||
old_expert_id = expert_id;
|
||||
int tmp_expert_id = expert_ids_ptr[index];
|
||||
if (tmp_expert_id == -1) continue;
|
||||
expert_id = tmp_expert_id;
|
||||
perm_int_ptr += (expert_id - old_expert_id) * size_k;
|
||||
read_moe_block_data(index);
|
||||
|
||||
for (int i = 0; i < block_num_valid_tokens; i++)
|
||||
permute_row(block_sorted_ids[i]);
|
||||
}
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
int thread_k;
|
||||
int thread_n;
|
||||
int num_threads;
|
||||
} thread_config_t;
|
||||
|
||||
thread_config_t small_batch_thread_configs[] = {
|
||||
// Ordered by priority
|
||||
|
||||
// thread_k, thread_n, num_threads
|
||||
{128, 128, 256},
|
||||
{64, 128, 128}};
|
||||
|
||||
thread_config_t large_batch_thread_configs[] = {
|
||||
// Ordered by priority
|
||||
|
||||
// thread_k, thread_n, num_threads
|
||||
{64, 256, 256},
|
||||
{64, 128, 128}};
|
||||
|
||||
typedef struct {
|
||||
int blocks_per_sm;
|
||||
thread_config_t tb_cfg;
|
||||
} exec_config_t;
|
||||
|
||||
int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
|
||||
int prob_n, int prob_k, int num_bits, int group_size,
|
||||
bool has_act_order, bool is_k_full) {
|
||||
bool cache_scales_chunk = has_act_order && !is_k_full;
|
||||
|
||||
int tb_n = th_config.thread_n;
|
||||
int tb_k = th_config.thread_k;
|
||||
|
||||
// Get max scale groups per thread-block
|
||||
int tb_groups;
|
||||
if (group_size == -1) {
|
||||
tb_groups = 1;
|
||||
} else if (group_size == 0) {
|
||||
tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size
|
||||
} else {
|
||||
tb_groups = div_ceil(tb_k, group_size);
|
||||
}
|
||||
|
||||
if (cache_scales_chunk) {
|
||||
int load_groups =
|
||||
tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K
|
||||
load_groups = max(load_groups, 32); // We load at least 32 scale groups
|
||||
return load_groups * tb_n * 2;
|
||||
|
||||
} else {
|
||||
int tb_scales = tb_groups * tb_n * 2;
|
||||
|
||||
return tb_scales * pipe_stages;
|
||||
}
|
||||
}
|
||||
|
||||
int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
|
||||
int prob_m, int prob_n, int prob_k, int num_bits,
|
||||
int group_size, bool has_act_order, bool is_k_full,
|
||||
int has_zp, int is_zp_float) {
|
||||
int pack_factor = 32 / num_bits;
|
||||
|
||||
// Get B size
|
||||
int tb_k = th_config.thread_k;
|
||||
int tb_n = th_config.thread_n;
|
||||
int tb_m = thread_m_blocks * 16;
|
||||
|
||||
// shm size for block_sorted_ids/block_topk_weights
|
||||
// both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)
|
||||
int sh_block_meta_size = tb_m * 4 * 2;
|
||||
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2;
|
||||
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
|
||||
int sh_s_size =
|
||||
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
|
||||
group_size, has_act_order, is_k_full);
|
||||
int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0;
|
||||
int sh_zp_size = 0;
|
||||
if (has_zp) {
|
||||
if (is_zp_float)
|
||||
sh_zp_size = sh_s_size;
|
||||
else if (num_bits == 4)
|
||||
sh_zp_size = sh_s_size / 4;
|
||||
else if (num_bits == 8)
|
||||
sh_zp_size = sh_s_size / 2;
|
||||
}
|
||||
|
||||
int total_size = sh_a_size + sh_b_size + sh_s_size + sh_zp_size +
|
||||
sh_g_idx_size + sh_block_meta_size;
|
||||
|
||||
return total_size;
|
||||
}
|
||||
|
||||
bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
|
||||
int prob_m, int prob_n, int prob_k, int num_bits,
|
||||
int group_size, bool has_act_order, bool is_k_full,
|
||||
int has_zp, int is_zp_float, int max_shared_mem) {
|
||||
// Sanity
|
||||
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
|
||||
th_config.num_threads == -1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Verify K/N are divisible by thread K/N
|
||||
if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Verify min for thread K/N
|
||||
if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// num_threads must be at least 128 (= 4 warps)
|
||||
if (th_config.num_threads < 128) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check that pipeline fits into cache
|
||||
int cache_size = get_kernel_cache_size(
|
||||
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size,
|
||||
has_act_order, is_k_full, has_zp, is_zp_float);
|
||||
return cache_size <= max_shared_mem;
|
||||
}
|
||||
|
||||
#define __GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
||||
M_BLOCK_SIZE_8, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
|
||||
NUM_THREADS, IS_ZP_FLOAT) \
|
||||
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
|
||||
thread_n_blocks == THREAD_N_BLOCKS && \
|
||||
thread_k_blocks == THREAD_K_BLOCKS && \
|
||||
m_block_size_8 == M_BLOCK_SIZE_8 && \
|
||||
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
|
||||
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
|
||||
is_zp_float == IS_ZP_FLOAT) { \
|
||||
kernel = Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
|
||||
THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, \
|
||||
pipe_stages, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
|
||||
IS_ZP_FLOAT>; \
|
||||
}
|
||||
|
||||
#define GPTQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, true, false, 0, NUM_THREADS, \
|
||||
false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
|
||||
NUM_THREADS, false) \
|
||||
\
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 8, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
|
||||
NUM_THREADS, false)
|
||||
|
||||
#define GPTQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
|
||||
NUM_THREADS, false) \
|
||||
\
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
|
||||
NUM_THREADS, false) \
|
||||
\
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
|
||||
NUM_THREADS, false) \
|
||||
\
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
|
||||
NUM_THREADS, false)
|
||||
|
||||
#define AWQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 2, NUM_THREADS, \
|
||||
false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \
|
||||
false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 8, NUM_THREADS, \
|
||||
false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
|
||||
NUM_THREADS, false)
|
||||
|
||||
#define AWQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
|
||||
NUM_THREADS, false) \
|
||||
\
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
|
||||
NUM_THREADS, false) \
|
||||
\
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
|
||||
NUM_THREADS, false)
|
||||
|
||||
// We currently have 4-bit models only with group_blocks == 4
|
||||
#define HQQ_GET_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \
|
||||
true) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||
NUM_THREADS, true) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||
NUM_THREADS, true) \
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||
NUM_THREADS, true) \
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||
NUM_THREADS, true)
|
||||
|
||||
template <typename scalar_t>
|
||||
MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
|
||||
int thread_m_blocks, int thread_n_blocks,
|
||||
int thread_k_blocks, bool m_block_size_8,
|
||||
bool has_act_order, bool has_zp,
|
||||
int group_blocks, int num_threads,
|
||||
bool is_zp_float) {
|
||||
int num_bits = q_type.size_bits();
|
||||
auto kernel = MarlinDefault;
|
||||
if (false) {
|
||||
}
|
||||
GPTQ_GET_IF_M1(vllm::kU4B8, 8, 8, 256)
|
||||
GPTQ_GET_IF_M1(vllm::kU4B8, 8, 4, 128)
|
||||
|
||||
GPTQ_GET_IF_M234(vllm::kU4B8, 16, 4, 256)
|
||||
GPTQ_GET_IF_M234(vllm::kU4B8, 8, 4, 128)
|
||||
|
||||
GPTQ_GET_IF_M1(vllm::kU8B128, 8, 8, 256)
|
||||
GPTQ_GET_IF_M1(vllm::kU8B128, 8, 4, 128)
|
||||
|
||||
GPTQ_GET_IF_M234(vllm::kU8B128, 16, 4, 256)
|
||||
GPTQ_GET_IF_M234(vllm::kU8B128, 8, 4, 128)
|
||||
|
||||
AWQ_GET_IF_M1(vllm::kU4, 8, 8, 256)
|
||||
AWQ_GET_IF_M1(vllm::kU4, 8, 4, 128)
|
||||
|
||||
AWQ_GET_IF_M234(vllm::kU4, 16, 4, 256)
|
||||
AWQ_GET_IF_M234(vllm::kU4, 8, 4, 128)
|
||||
|
||||
return kernel;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
|
||||
int prob_n, int prob_k, int thread_m_blocks,
|
||||
bool m_block_size_8, int num_bits,
|
||||
int group_size, bool has_act_order,
|
||||
bool is_k_full, bool has_zp,
|
||||
bool is_zp_float, int max_shared_mem) {
|
||||
exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}};
|
||||
thread_config_t* thread_configs = thread_m_blocks > 1
|
||||
? large_batch_thread_configs
|
||||
: small_batch_thread_configs;
|
||||
int thread_configs_size =
|
||||
thread_m_blocks > 1
|
||||
? sizeof(large_batch_thread_configs) / sizeof(thread_config_t)
|
||||
: sizeof(small_batch_thread_configs) / sizeof(thread_config_t);
|
||||
|
||||
int count = 0;
|
||||
constexpr int device_max_reg_size = 255 * 1024;
|
||||
for (int i = 0; i < thread_configs_size; i++) {
|
||||
thread_config_t th_config = thread_configs[i];
|
||||
|
||||
if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k,
|
||||
num_bits, group_size, has_act_order, is_k_full, has_zp,
|
||||
is_zp_float, max_shared_mem)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int cache_size = get_kernel_cache_size(
|
||||
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits,
|
||||
group_size, has_act_order, is_k_full, has_zp, is_zp_float);
|
||||
|
||||
int group_blocks = 0;
|
||||
if (!has_act_order) {
|
||||
group_blocks = group_size == -1 ? -1 : group_size / 16;
|
||||
}
|
||||
|
||||
auto kernel = get_marlin_kernel<scalar_t>(
|
||||
q_type, thread_m_blocks, th_config.thread_n / 16,
|
||||
th_config.thread_k / 16, m_block_size_8, has_act_order, has_zp,
|
||||
group_blocks, th_config.num_threads, is_zp_float);
|
||||
|
||||
if (kernel == MarlinDefault) continue;
|
||||
|
||||
if (thread_m_blocks > 1) {
|
||||
exec_cfg = {1, th_config};
|
||||
break;
|
||||
} else {
|
||||
cudaFuncAttributes attr;
|
||||
cudaFuncGetAttributes(&attr, kernel);
|
||||
int reg_size = max(attr.numRegs, 1) * th_config.num_threads * 4;
|
||||
int allow_count = min(device_max_reg_size / reg_size,
|
||||
max_shared_mem / (cache_size + 1024));
|
||||
allow_count = max(min(allow_count, 4), 1);
|
||||
if (allow_count > count) {
|
||||
count = allow_count;
|
||||
exec_cfg = {count, th_config};
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
return exec_cfg;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
void* zp, void* g_idx, void* perm, void* a_tmp,
|
||||
void* sorted_token_ids, void* expert_ids,
|
||||
void* num_tokens_past_padded, void* topk_weights,
|
||||
int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep,
|
||||
int prob_m, int prob_n, int prob_k, void* workspace,
|
||||
vllm::ScalarType const& q_type, bool has_act_order,
|
||||
bool is_k_full, bool has_zp, int num_groups, int group_size,
|
||||
int dev, cudaStream_t stream, int thread_k, int thread_n,
|
||||
int sms, bool use_atomic_add, bool use_fp32_reduce,
|
||||
bool is_zp_float) {
|
||||
int thread_m_blocks = div_ceil(moe_block_size, 16);
|
||||
bool m_block_size_8 = moe_block_size == 8;
|
||||
|
||||
if (has_zp) {
|
||||
TORCH_CHECK(
|
||||
q_type == vllm::kU4 || q_type == vllm::kU8,
|
||||
"q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
q_type == vllm::kU4B8 || q_type == vllm::kU8B128,
|
||||
"q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
|
||||
q_type.str());
|
||||
}
|
||||
|
||||
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
||||
", ", prob_n, ", ", prob_k, "]");
|
||||
|
||||
int group_blocks = 0;
|
||||
if (has_act_order) {
|
||||
if (is_k_full) {
|
||||
TORCH_CHECK(group_size != -1);
|
||||
group_blocks = group_size / 16;
|
||||
TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
|
||||
" is not divisible by group_blocks = ", group_blocks);
|
||||
} else {
|
||||
TORCH_CHECK(group_size == 0);
|
||||
group_blocks = 0;
|
||||
}
|
||||
} else {
|
||||
if (group_size == -1) {
|
||||
group_blocks = -1;
|
||||
} else {
|
||||
group_blocks = group_size / 16;
|
||||
TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
|
||||
" is not divisible by group_blocks = ", group_blocks);
|
||||
}
|
||||
}
|
||||
|
||||
int num_bits = q_type.size_bits();
|
||||
const int4* A_ptr = (const int4*)A;
|
||||
const int4* B_ptr = (const int4*)B;
|
||||
int4* C_ptr = (int4*)C;
|
||||
int4* C_tmp_ptr = (int4*)C_tmp;
|
||||
const int4* s_ptr = (const int4*)s;
|
||||
const int4* zp_ptr = (const int4*)zp;
|
||||
const int* g_idx_ptr = (const int*)g_idx;
|
||||
const int* perm_ptr = (const int*)perm;
|
||||
int4* a_tmp_ptr = (int4*)a_tmp;
|
||||
const int32_t* sorted_token_ids_ptr = (const int32_t*)sorted_token_ids;
|
||||
const int32_t* expert_ids_ptr = (const int32_t*)expert_ids;
|
||||
const int32_t* num_tokens_past_padded_ptr =
|
||||
(const int32_t*)num_tokens_past_padded;
|
||||
const float* topk_weights_ptr = (const float*)topk_weights;
|
||||
int* locks = (int*)workspace;
|
||||
|
||||
if (has_act_order) {
|
||||
// Permute A columns
|
||||
auto kernel = permute_cols_kernel<8>;
|
||||
if (moe_block_size == 8) {
|
||||
} else if (moe_block_size == 16)
|
||||
kernel = permute_cols_kernel<16>;
|
||||
else if (moe_block_size == 32)
|
||||
kernel = permute_cols_kernel<32>;
|
||||
else if (moe_block_size == 48)
|
||||
kernel = permute_cols_kernel<48>;
|
||||
else if (moe_block_size == 64)
|
||||
kernel = permute_cols_kernel<64>;
|
||||
else
|
||||
TORCH_CHECK(false, "unsupported moe_block_size ", moe_block_size);
|
||||
|
||||
// avoid ">>>" being formatted to "> > >"
|
||||
// clang-format off
|
||||
kernel<<<sms, default_threads, 0, stream>>>(
|
||||
A_ptr, perm_ptr, a_tmp_ptr, sorted_token_ids_ptr, expert_ids_ptr,
|
||||
num_tokens_past_padded_ptr, prob_m, prob_k, top_k);
|
||||
// clang-format on
|
||||
A_ptr = a_tmp_ptr;
|
||||
prob_m = prob_m * top_k;
|
||||
top_k = 1;
|
||||
|
||||
// If we have a full K, then we can run the non-act-order version of Marlin
|
||||
// (since the weight rows are reordered by increasing group ids, and by
|
||||
// having a full K, we have full original groups)
|
||||
if (is_k_full) has_act_order = false;
|
||||
}
|
||||
|
||||
int max_shared_mem = 0;
|
||||
cudaDeviceGetAttribute(&max_shared_mem,
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
||||
TORCH_CHECK(max_shared_mem > 0);
|
||||
|
||||
// Set thread config
|
||||
exec_config_t exec_cfg;
|
||||
thread_config_t thread_tfg;
|
||||
if (thread_k != -1 && thread_n != -1) {
|
||||
thread_tfg = thread_config_t{thread_k, thread_n, default_threads};
|
||||
exec_cfg = exec_config_t{1, thread_tfg};
|
||||
TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
|
||||
" is not divisible by thread_n = ", thread_n);
|
||||
TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
|
||||
" is not divisible by thread_k = ", thread_k);
|
||||
} else {
|
||||
// Auto config
|
||||
exec_cfg = determine_exec_config<scalar_t>(
|
||||
q_type, prob_m, prob_n, prob_k, thread_m_blocks, m_block_size_8,
|
||||
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float,
|
||||
max_shared_mem);
|
||||
thread_tfg = exec_cfg.tb_cfg;
|
||||
}
|
||||
|
||||
int num_threads = thread_tfg.num_threads;
|
||||
thread_k = thread_tfg.thread_k;
|
||||
thread_n = thread_tfg.thread_n;
|
||||
int blocks = sms * exec_cfg.blocks_per_sm;
|
||||
if (exec_cfg.blocks_per_sm > 1)
|
||||
max_shared_mem = max_shared_mem / exec_cfg.blocks_per_sm - 1024;
|
||||
|
||||
int thread_k_blocks = thread_k / 16;
|
||||
int thread_n_blocks = thread_n / 16;
|
||||
|
||||
TORCH_CHECK(is_valid_config(thread_tfg, thread_m_blocks, prob_m, prob_n,
|
||||
prob_k, num_bits, group_size, has_act_order,
|
||||
is_k_full, has_zp, is_zp_float, max_shared_mem),
|
||||
"Invalid thread config: thread_m_blocks = ", thread_m_blocks,
|
||||
", thread_k = ", thread_tfg.thread_k,
|
||||
", thread_n = ", thread_tfg.thread_n,
|
||||
", num_threads = ", thread_tfg.num_threads, " for MKN = [",
|
||||
prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
|
||||
", group_size = ", group_size,
|
||||
", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
|
||||
", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float,
|
||||
", max_shared_mem = ", max_shared_mem);
|
||||
|
||||
auto kernel = get_marlin_kernel<scalar_t>(
|
||||
q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, m_block_size_8,
|
||||
has_act_order, has_zp, group_blocks, num_threads, is_zp_float);
|
||||
|
||||
if (kernel == MarlinDefault) {
|
||||
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
|
||||
", ", prob_k, "]", ", has_act_order = ", has_act_order,
|
||||
", num_groups = ", num_groups, ", group_size = ", group_size,
|
||||
", thread_m_blocks = ", thread_m_blocks,
|
||||
", thread_n_blocks = ", thread_n_blocks,
|
||||
", thread_k_blocks = ", thread_k_blocks,
|
||||
", num_bits = ", num_bits);
|
||||
}
|
||||
|
||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
max_shared_mem);
|
||||
// avoid ">>>" being formatted to "> > >"
|
||||
// clang-format off
|
||||
kernel<<<blocks, num_threads, max_shared_mem, stream>>>(
|
||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr,
|
||||
sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,
|
||||
topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m,
|
||||
prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
|
||||
torch::Tensor moe_wna16_marlin_gemm(
|
||||
torch::Tensor& a, std::optional<torch::Tensor> const& c_or_none,
|
||||
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
|
||||
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
|
||||
torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids,
|
||||
torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights,
|
||||
int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep,
|
||||
vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n,
|
||||
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
|
||||
bool is_zp_float) {
|
||||
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
|
||||
int pack_factor = 32 / b_q_type.size_bits();
|
||||
|
||||
if (moe_block_size != 8) {
|
||||
TORCH_CHECK(moe_block_size % 16 == 0,
|
||||
"unsupported moe_block_size=", moe_block_size);
|
||||
TORCH_CHECK(moe_block_size >= 16 && moe_block_size <= 64,
|
||||
"unsupported moe_block_size=", moe_block_size);
|
||||
}
|
||||
|
||||
// Verify A
|
||||
TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0),
|
||||
", size_m = ", size_m);
|
||||
TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1),
|
||||
", size_k = ", size_k);
|
||||
|
||||
// Verify B
|
||||
TORCH_CHECK(
|
||||
size_k % MARLIN_NAMESPACE_NAME::tile_size == 0, "size_k = ", size_k,
|
||||
" is not divisible by tile_size = ", MARLIN_NAMESPACE_NAME::tile_size);
|
||||
TORCH_CHECK((size_k / MARLIN_NAMESPACE_NAME::tile_size) == b_q_weight.size(1),
|
||||
"Shape mismatch: b_q_weight.size(1) = ", b_q_weight.size(1),
|
||||
", size_k = ", size_k,
|
||||
", tile_size = ", MARLIN_NAMESPACE_NAME::tile_size);
|
||||
TORCH_CHECK(
|
||||
b_q_weight.size(2) % MARLIN_NAMESPACE_NAME::tile_size == 0,
|
||||
"b_q_weight.size(2) = ", b_q_weight.size(2),
|
||||
" is not divisible by tile_size = ", MARLIN_NAMESPACE_NAME::tile_size);
|
||||
int actual_size_n =
|
||||
(b_q_weight.size(2) / MARLIN_NAMESPACE_NAME::tile_size) * pack_factor;
|
||||
TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
|
||||
", actual_size_n = ", actual_size_n);
|
||||
|
||||
// Verify device and strides
|
||||
TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
|
||||
TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
|
||||
|
||||
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
|
||||
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
|
||||
|
||||
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
|
||||
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
|
||||
|
||||
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
|
||||
// auto -1)
|
||||
int thread_k = -1;
|
||||
// thread_n: `n` size of a thread_tile in `weights` (can usually be left as
|
||||
// auto -1)
|
||||
int thread_n = -1;
|
||||
// sms: number of SMs to use for the kernel
|
||||
int sms = -1;
|
||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device());
|
||||
|
||||
// Alloc buffers
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
||||
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
|
||||
torch::Tensor c;
|
||||
if (c_or_none.has_value()) {
|
||||
c = c_or_none.value();
|
||||
TORCH_CHECK(c.device().is_cuda(), "c is not on GPU");
|
||||
TORCH_CHECK(c.is_contiguous(), "c is not contiguous");
|
||||
TORCH_CHECK(c.size(0) == size_m * top_k,
|
||||
"Shape mismatch: c.size(0) = ", c.size(0),
|
||||
", size_m * topk = ", size_m * top_k);
|
||||
TORCH_CHECK(c.size(1) == size_n, "Shape mismatch: c.size(1) = ", c.size(1),
|
||||
", size_n = ", size_n);
|
||||
} else {
|
||||
c = torch::empty({size_m * top_k, size_n}, options);
|
||||
}
|
||||
|
||||
// Alloc C tmp buffer that is going to be used for the global reduce
|
||||
torch::Tensor c_tmp;
|
||||
auto options_fp32 =
|
||||
torch::TensorOptions().dtype(at::kFloat).device(a.device());
|
||||
if (use_fp32_reduce && !use_atomic_add) {
|
||||
// max num of threadblocks is sms * 4
|
||||
long max_c_tmp_size = min(
|
||||
(long)size_n * sorted_token_ids.size(0),
|
||||
(long)sms * 4 * moe_block_size * MARLIN_NAMESPACE_NAME::max_thread_n);
|
||||
if (moe_block_size == 8) max_c_tmp_size *= 2;
|
||||
c_tmp = torch::empty({max_c_tmp_size}, options_fp32);
|
||||
} else {
|
||||
c_tmp = torch::empty({0}, options_fp32);
|
||||
}
|
||||
|
||||
// Detect groupsize and act_order
|
||||
int num_groups = -1;
|
||||
int group_size = -1;
|
||||
|
||||
int rank = b_scales.sizes().size();
|
||||
TORCH_CHECK(rank == 3, "b_scales rank = ", rank, " is not 3");
|
||||
TORCH_CHECK(b_scales.size(2) == size_n, "b_scales dim 2 = ", b_scales.size(2),
|
||||
" is not size_n = ", size_n);
|
||||
num_groups = b_scales.size(1);
|
||||
|
||||
torch::Tensor g_idx, perm, a_tmp;
|
||||
;
|
||||
if (g_idx_or_none.has_value() && perm_or_none.has_value()) {
|
||||
g_idx = g_idx_or_none.value();
|
||||
perm = perm_or_none.value();
|
||||
|
||||
TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU");
|
||||
TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous");
|
||||
TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU");
|
||||
TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous");
|
||||
|
||||
// Verify g_idx and perm
|
||||
TORCH_CHECK((g_idx.size(-1) == 0 && perm.size(-1) == 0) ||
|
||||
(g_idx.size(-1) == size_k && perm.size(-1) == size_k),
|
||||
"Unexpected g_idx.size(-1) = ", g_idx.size(-1),
|
||||
" and perm.size(-1) = ", perm.size(-1),
|
||||
", where size_k = ", size_k);
|
||||
} else {
|
||||
g_idx = torch::empty({0}, options);
|
||||
perm = torch::empty({0}, options);
|
||||
a_tmp = torch::empty({0}, options);
|
||||
}
|
||||
bool has_act_order = g_idx.size(-1) > 0 && perm.size(-1) > 0;
|
||||
|
||||
if (has_act_order) {
|
||||
a_tmp = torch::empty({size_m * top_k, size_k}, options);
|
||||
if (is_k_full) {
|
||||
TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1");
|
||||
TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k,
|
||||
", is not divisible by num_groups = ", num_groups);
|
||||
group_size = size_k / num_groups;
|
||||
} else {
|
||||
group_size = 0;
|
||||
}
|
||||
|
||||
} else {
|
||||
a_tmp = torch::empty({0}, options);
|
||||
if (num_groups > 1) {
|
||||
TORCH_CHECK(
|
||||
size_k % num_groups == 0, "size_k = ", size_k,
|
||||
", is not divisible by b_scales.size(1) = ", b_scales.size(1));
|
||||
group_size = size_k / num_groups;
|
||||
} else {
|
||||
group_size = -1;
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor b_zeros;
|
||||
if (b_zeros_or_none.has_value()) {
|
||||
b_zeros = b_zeros_or_none.value();
|
||||
TORCH_CHECK(b_zeros.device().is_cuda(), "b_zeros is not on GPU");
|
||||
TORCH_CHECK(b_zeros.is_contiguous(), "b_zeros is not contiguous");
|
||||
} else {
|
||||
b_zeros = torch::empty({0}, options);
|
||||
}
|
||||
bool has_zp = b_zeros.size(-1) > 0;
|
||||
|
||||
if (has_zp) {
|
||||
TORCH_CHECK(
|
||||
b_q_type == vllm::kU4,
|
||||
"b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str());
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128,
|
||||
"b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
|
||||
b_q_type.str());
|
||||
}
|
||||
|
||||
if (has_zp && is_zp_float) {
|
||||
TORCH_CHECK(a.scalar_type() == at::ScalarType::Half,
|
||||
"Computation type must be float16 (half) when using float zero "
|
||||
"points.");
|
||||
}
|
||||
|
||||
// Verify b_zeros
|
||||
if (has_zp) {
|
||||
int rank = b_zeros.sizes().size();
|
||||
TORCH_CHECK(rank == 3, "b_zeros rank = ", rank, " is not 3");
|
||||
if (is_zp_float) {
|
||||
TORCH_CHECK(b_zeros.size(2) == size_n,
|
||||
"b_zeros dim 2 = ", b_zeros.size(2),
|
||||
" is not size_n = ", size_n);
|
||||
TORCH_CHECK(num_groups == b_zeros.size(1),
|
||||
"b_zeros dim 1 = ", b_zeros.size(1),
|
||||
" is not num_groups = ", num_groups);
|
||||
TORCH_CHECK(num_groups != -1, "num_groups must be != -1");
|
||||
} else {
|
||||
TORCH_CHECK(b_zeros.size(1) == num_groups,
|
||||
"b_zeros dim 1 = ", b_zeros.size(1),
|
||||
" is not num_groups = ", num_groups);
|
||||
TORCH_CHECK(b_zeros.size(2) == size_n / pack_factor,
|
||||
"b_zeros dim 2 = ", b_zeros.size(2),
|
||||
" is not size_n / pack_factor = ", size_n / pack_factor);
|
||||
}
|
||||
}
|
||||
|
||||
// Verify workspace size
|
||||
TORCH_CHECK(size_n % MARLIN_NAMESPACE_NAME::min_thread_n == 0,
|
||||
"size_n = ", size_n, ", is not divisible by min_thread_n = ",
|
||||
MARLIN_NAMESPACE_NAME::min_thread_n);
|
||||
|
||||
int max_n_tiles = size_n / MARLIN_NAMESPACE_NAME::min_thread_n;
|
||||
int min_workspace_size = min(
|
||||
max_n_tiles * (int)(sorted_token_ids.size(0) / moe_block_size), sms * 4);
|
||||
TORCH_CHECK(workspace.numel() >= min_workspace_size,
|
||||
"workspace.numel = ", workspace.numel(),
|
||||
" is below min_workspace_size = ", min_workspace_size);
|
||||
|
||||
int dev = a.get_device();
|
||||
if (a.scalar_type() == at::ScalarType::Half) {
|
||||
MARLIN_NAMESPACE_NAME::marlin_mm<half>(
|
||||
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
|
||||
c_tmp.data_ptr<float>(), b_scales.data_ptr<at::Half>(),
|
||||
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
|
||||
a_tmp.data_ptr<at::Half>(), sorted_token_ids.data_ptr(),
|
||||
expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(),
|
||||
topk_weights.data_ptr(), moe_block_size, top_k, mul_topk_weights, is_ep,
|
||||
size_m, size_n, size_k, workspace.data_ptr(), b_q_type, has_act_order,
|
||||
is_k_full, has_zp, num_groups, group_size, dev,
|
||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||
use_atomic_add, use_fp32_reduce, is_zp_float);
|
||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||
MARLIN_NAMESPACE_NAME::marlin_mm<nv_bfloat16>(
|
||||
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
||||
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
|
||||
b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(),
|
||||
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
|
||||
sorted_token_ids.data_ptr(), expert_ids.data_ptr(),
|
||||
num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(),
|
||||
moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k,
|
||||
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
|
||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float);
|
||||
} else {
|
||||
TORCH_CHECK(false,
|
||||
"moe_wna16_marlin_gemm only supports bfloat16 and float16");
|
||||
}
|
||||
|
||||
return c;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("moe_wna16_marlin_gemm", &moe_wna16_marlin_gemm);
|
||||
}
|
||||
133
csrc/moe/moe_permute_unpermute_op.cu
Normal file
133
csrc/moe/moe_permute_unpermute_op.cu
Normal file
@ -0,0 +1,133 @@
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <torch/all.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include "permute_unpermute_kernels/moe_permute_unpermute_kernel.h"
|
||||
#include "permute_unpermute_kernels/dispatch.h"
|
||||
#include "core/registration.h"
|
||||
|
||||
void moe_permute(
|
||||
const torch::Tensor& input, // [n_token, hidden]
|
||||
const torch::Tensor& topk_weights, //[n_token, topk]
|
||||
torch::Tensor& topk_ids, // [n_token, topk]
|
||||
const torch::Tensor& token_expert_indicies, // [n_token, topk]
|
||||
const std::optional<torch::Tensor>& expert_map, // [n_expert]
|
||||
int64_t n_expert, int64_t n_local_expert, int64_t topk,
|
||||
const std::optional<int64_t>& align_block_size,
|
||||
torch::Tensor&
|
||||
permuted_input, // [topk * n_token/align_block_size_m, hidden]
|
||||
torch::Tensor& expert_first_token_offset, // [n_local_expert + 1]
|
||||
torch::Tensor& src_row_id2dst_row_id_map, // [n_token, topk]
|
||||
torch::Tensor& m_indices) { // [align_expand_m]
|
||||
TORCH_CHECK(topk_weights.scalar_type() == at::ScalarType::Float,
|
||||
"topk_weights must be float32");
|
||||
TORCH_CHECK(expert_first_token_offset.scalar_type() == at::ScalarType::Long,
|
||||
"expert_first_token_offset must be int64");
|
||||
TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int,
|
||||
"topk_ids must be int32");
|
||||
TORCH_CHECK(token_expert_indicies.scalar_type() == at::ScalarType::Int,
|
||||
"token_expert_indicies must be int32");
|
||||
TORCH_CHECK(src_row_id2dst_row_id_map.scalar_type() == at::ScalarType::Int,
|
||||
"src_row_id2dst_row_id_map must be int32");
|
||||
TORCH_CHECK(expert_first_token_offset.size(0) == n_local_expert + 1,
|
||||
"expert_first_token_offset shape != n_local_expert+1")
|
||||
TORCH_CHECK(
|
||||
src_row_id2dst_row_id_map.sizes() == token_expert_indicies.sizes(),
|
||||
"token_expert_indicies shape must be same as src_row_id2dst_row_id_map");
|
||||
auto n_token = input.sizes()[0];
|
||||
auto n_hidden = input.sizes()[1];
|
||||
auto align_block_size_value =
|
||||
align_block_size.has_value() ? align_block_size.value() : -1;
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
const long sorter_size =
|
||||
CubKeyValueSorter::getWorkspaceSize(n_token * topk, n_expert);
|
||||
auto sort_workspace = torch::empty(
|
||||
{sorter_size},
|
||||
torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false));
|
||||
auto permuted_experts_id = torch::empty_like(topk_ids);
|
||||
auto dst_row_id2src_row_id_map = torch::empty_like(src_row_id2dst_row_id_map);
|
||||
auto align_expert_first_token_offset =
|
||||
torch::zeros_like(expert_first_token_offset);
|
||||
|
||||
CubKeyValueSorter sorter{};
|
||||
int64_t* valid_num_ptr = nullptr;
|
||||
// pre-process kernel for expert-parallelism:
|
||||
// no local expert id plus "n_expert" offset for priority to local expert
|
||||
// map local expert id [n, .., n+n_local_expert-1] to [0, n_local_expert -1]
|
||||
// For example, 4 expert with ep_size=2. ep_rank=1 owns global expert id
|
||||
// [2,3] with expert_map[-1, -1, 0, 1], preprocess_topk_id process topk_ids
|
||||
// and map global expert id [2, 3] to local_expert id [0, 1] and map global
|
||||
// expert id [0, 1] ( not in ep rank=1) to [4, 5] by plus n_expert. This map
|
||||
// operation is to make local expert high priority in following sort topk_ids
|
||||
// and scan local expert_first_token_offset for each ep rank for next group
|
||||
// gemm.
|
||||
if (expert_map.has_value()) {
|
||||
const int* expert_map_ptr = get_ptr<int>(expert_map.value());
|
||||
valid_num_ptr =
|
||||
get_ptr<int64_t>(expert_first_token_offset) + n_local_expert;
|
||||
preprocessTopkIdLauncher(get_ptr<int>(topk_ids), n_token * topk,
|
||||
expert_map_ptr, n_expert, stream);
|
||||
}
|
||||
// expert sort topk expert id and scan expert id get expert_first_token_offset
|
||||
sortAndScanExpert(get_ptr<int>(topk_ids), get_ptr<int>(token_expert_indicies),
|
||||
get_ptr<int>(permuted_experts_id),
|
||||
get_ptr<int>(dst_row_id2src_row_id_map),
|
||||
get_ptr<int64_t>(expert_first_token_offset), n_token,
|
||||
n_expert, n_local_expert, topk, sorter,
|
||||
get_ptr<int>(sort_workspace), stream);
|
||||
|
||||
// dispatch expandInputRowsKernelLauncher
|
||||
MOE_DISPATCH(input.scalar_type(), [&] {
|
||||
expandInputRowsKernelLauncher<scalar_t>(
|
||||
get_ptr<scalar_t>(input), get_ptr<scalar_t>(permuted_input),
|
||||
get_ptr<float>(topk_weights), get_ptr<int>(permuted_experts_id),
|
||||
get_ptr<int>(dst_row_id2src_row_id_map),
|
||||
get_ptr<int>(src_row_id2dst_row_id_map),
|
||||
get_ptr<int64_t>(expert_first_token_offset), n_token, valid_num_ptr,
|
||||
n_hidden, topk, n_local_expert, align_block_size_value, stream);
|
||||
});
|
||||
|
||||
// get m_indices and update expert_first_token_offset with align block
|
||||
getMIndices(get_ptr<int64_t>(expert_first_token_offset),
|
||||
get_ptr<int64_t>(align_expert_first_token_offset),
|
||||
get_ptr<int>(m_indices), n_local_expert, align_block_size_value,
|
||||
stream);
|
||||
if (align_block_size.has_value()) {
|
||||
// update align_expert_first_token_offset
|
||||
expert_first_token_offset.copy_(align_expert_first_token_offset);
|
||||
}
|
||||
}
|
||||
|
||||
void moe_unpermute(
|
||||
const torch::Tensor& permuted_hidden_states, // [n_token * topk, hidden]
|
||||
const torch::Tensor& topk_weights, //[n_token, topk]
|
||||
const torch::Tensor& topk_ids, // [n_token, topk]
|
||||
const torch::Tensor& src_row_id2dst_row_id_map, // [n_token, topk]
|
||||
const torch::Tensor& expert_first_token_offset, // [n_local_expert+1]
|
||||
int64_t n_expert, int64_t n_local_expert, int64_t topk,
|
||||
torch::Tensor& hidden_states // [n_token, hidden]
|
||||
) {
|
||||
TORCH_CHECK(src_row_id2dst_row_id_map.sizes() == topk_ids.sizes(),
|
||||
"topk_ids shape must be same as src_row_id2dst_row_id_map");
|
||||
TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int,
|
||||
"topk_ids must be int32");
|
||||
TORCH_CHECK(
|
||||
permuted_hidden_states.scalar_type() == hidden_states.scalar_type(),
|
||||
"topk_ids dtype must be same as src_row_id2dst_row_id_map");
|
||||
auto n_token = hidden_states.size(0);
|
||||
auto n_hidden = hidden_states.size(1);
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
const int64_t* valid_ptr =
|
||||
get_ptr<int64_t>(expert_first_token_offset) + n_local_expert;
|
||||
MOE_DISPATCH(hidden_states.scalar_type(), [&] {
|
||||
finalizeMoeRoutingKernelLauncher<scalar_t, scalar_t>(
|
||||
get_ptr<scalar_t>(permuted_hidden_states),
|
||||
get_ptr<scalar_t>(hidden_states), get_ptr<float>(topk_weights),
|
||||
get_ptr<int>(src_row_id2dst_row_id_map), get_ptr<int>(topk_ids),
|
||||
n_token, n_hidden, topk, valid_ptr, stream);
|
||||
});
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("moe_permute", &moe_permute);
|
||||
m.impl("moe_unpermute", &moe_unpermute);
|
||||
}
|
||||
@ -13,7 +13,6 @@
|
||||
template <typename scalar_t, int bit, int GROUPS>
|
||||
__global__ void moe_wna16_gemm_kernel(
|
||||
const scalar_t* __restrict__ input, scalar_t* __restrict__ output,
|
||||
|
||||
const uint32_t* __restrict__ qweight, const scalar_t* __restrict__ scales,
|
||||
const uint32_t* __restrict__ qzeros,
|
||||
|
||||
@ -54,8 +53,6 @@ __global__ void moe_wna16_gemm_kernel(
|
||||
if (token_index / top_k >= size_m) break;
|
||||
|
||||
num_valid_tokens = m + 1;
|
||||
if (blockIdx.z == 0 && offset_n < size_n)
|
||||
output[token_index * size_n + offset_n] = Dtype::int2num(0);
|
||||
|
||||
if (expert_id != -1) {
|
||||
int k_per_thread = DIVIDE(BLOCK_SIZE_K, BLOCK_SIZE_N);
|
||||
@ -284,8 +281,7 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
|
||||
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
|
||||
int64_t BLOCK_SIZE_K, int64_t bit) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
auto options =
|
||||
torch::TensorOptions().dtype(input.dtype()).device(input.device());
|
||||
output.zero_();
|
||||
|
||||
const int num_experts = b_qweight.size(0);
|
||||
const int size_m = input.size(0);
|
||||
@ -302,9 +298,9 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
|
||||
const uint32_t* b_qzeros_ptr;
|
||||
if (b_qzeros.has_value())
|
||||
b_qzeros_ptr = (const uint32_t*)b_qzeros.value().data_ptr<uint8_t>();
|
||||
const float* topk_weights_ptr;
|
||||
const float* topk_weights_ptr = nullptr;
|
||||
if (topk_weights.has_value())
|
||||
topk_weights_ptr = (const float*)topk_weights.value().data_ptr();
|
||||
topk_weights_ptr = (const float*)topk_weights.value().data_ptr<float>();
|
||||
|
||||
int groups_per_block_row = BLOCK_SIZE_K / group_size;
|
||||
TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8");
|
||||
|
||||
@ -108,11 +108,11 @@ __device__ inline void dequant<half2, 4>(int q, half2* res) {
|
||||
const int MUL = 0x2c002c00;
|
||||
const int ADD = 0xd400d400;
|
||||
|
||||
int lo0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
|
||||
int hi0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
|
||||
int lo0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
int hi0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
q >>= 8;
|
||||
int lo1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
|
||||
int hi1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
|
||||
int lo1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
int hi1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
|
||||
res[0] = __hsub2(*reinterpret_cast<half2*>(&lo0),
|
||||
*reinterpret_cast<const half2*>(&SUB));
|
||||
@ -149,13 +149,13 @@ __device__ inline void dequant<nv_bfloat162, 4>(int q, nv_bfloat162* res) {
|
||||
static constexpr uint32_t MASK = 0x000f000f;
|
||||
static constexpr uint32_t EX = 0x43004300;
|
||||
|
||||
int lo0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
|
||||
int lo0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
q >>= 4;
|
||||
int hi0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
|
||||
int hi0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
q >>= 4;
|
||||
int lo1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
|
||||
int lo1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
q >>= 4;
|
||||
int hi1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
|
||||
int hi1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
|
||||
static constexpr uint32_t MUL = 0x3F803F80;
|
||||
static constexpr uint32_t ADD = 0xC300C300;
|
||||
|
||||
53
csrc/moe/permute_unpermute_kernels/dispatch.h
Normal file
53
csrc/moe/permute_unpermute_kernels/dispatch.h
Normal file
@ -0,0 +1,53 @@
|
||||
#pragma once
|
||||
#include <cuda_fp8.h>
|
||||
#define MOE_SWITCH(TYPE, ...) \
|
||||
at::ScalarType _st = ::detail::scalar_type(TYPE); \
|
||||
switch (_st) { \
|
||||
__VA_ARGS__ \
|
||||
default: \
|
||||
TORCH_CHECK(false, "[moe permute]data type dispatch fail!") \
|
||||
}
|
||||
|
||||
#define MOE_DISPATCH_CASE(enum_type, ...) \
|
||||
case enum_type: { \
|
||||
using scalar_t = ScalarType2CudaType<enum_type>::type; \
|
||||
__VA_ARGS__(); \
|
||||
break; \
|
||||
}
|
||||
#define MOE_DISPATCH_FLOAT_CASE(...) \
|
||||
MOE_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
MOE_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||
MOE_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||
MOE_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__) \
|
||||
MOE_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__)
|
||||
|
||||
#define MOE_DISPATCH(TYPE, ...) \
|
||||
MOE_SWITCH(TYPE, MOE_DISPATCH_FLOAT_CASE(__VA_ARGS__))
|
||||
|
||||
template <at::ScalarType type>
|
||||
struct ScalarType2CudaType;
|
||||
|
||||
template <>
|
||||
struct ScalarType2CudaType<at::ScalarType::Float> {
|
||||
using type = float;
|
||||
};
|
||||
template <>
|
||||
struct ScalarType2CudaType<at::ScalarType::Half> {
|
||||
using type = half;
|
||||
};
|
||||
template <>
|
||||
struct ScalarType2CudaType<at::ScalarType::BFloat16> {
|
||||
using type = __nv_bfloat16;
|
||||
};
|
||||
|
||||
// #if __CUDA_ARCH__ >= 890
|
||||
// fp8
|
||||
template <>
|
||||
struct ScalarType2CudaType<at::ScalarType::Float8_e5m2> {
|
||||
using type = __nv_fp8_e5m2;
|
||||
};
|
||||
template <>
|
||||
struct ScalarType2CudaType<at::ScalarType::Float8_e4m3fn> {
|
||||
using type = __nv_fp8_e4m3;
|
||||
};
|
||||
// #endif
|
||||
@ -0,0 +1,229 @@
|
||||
|
||||
#include "moe_permute_unpermute_kernel.h"
|
||||
|
||||
// CubKeyValueSorter definition begin
|
||||
CubKeyValueSorter::CubKeyValueSorter()
|
||||
: num_experts_(0), num_bits_(sizeof(int) * 8) {}
|
||||
|
||||
int CubKeyValueSorter::expertsToBits(int num_experts) {
|
||||
// Max value we represent is V = num_experts + (num_experts - 1) = 2 *
|
||||
// num_experts - 1 The maximum number of bits is therefore floor(log2(V)) + 1
|
||||
return static_cast<int>(log2(2 * num_experts - 1)) + 1;
|
||||
}
|
||||
|
||||
CubKeyValueSorter::CubKeyValueSorter(int const num_experts)
|
||||
: num_experts_(num_experts), num_bits_(expertsToBits(num_experts)) {}
|
||||
|
||||
void CubKeyValueSorter::updateNumExperts(int const num_experts) {
|
||||
num_experts_ = num_experts;
|
||||
num_bits_ = expertsToBits(num_experts);
|
||||
}
|
||||
|
||||
size_t CubKeyValueSorter::getWorkspaceSize(size_t const num_key_value_pairs,
|
||||
int const num_experts) {
|
||||
int num_bits = expertsToBits(num_experts);
|
||||
size_t required_storage = 0;
|
||||
int* null_int = nullptr;
|
||||
cub::DeviceRadixSort::SortPairs(nullptr, required_storage, null_int, null_int,
|
||||
null_int, null_int, num_key_value_pairs, 0,
|
||||
num_bits);
|
||||
|
||||
// when num_key_value_pairs, num_experts, num_bits, required_storage = 64,
|
||||
// 4, 3, 0 The required_storage seems to vary between 0 and 1 for the same
|
||||
// inputs
|
||||
if (required_storage == 0) {
|
||||
required_storage = 1;
|
||||
}
|
||||
return required_storage;
|
||||
}
|
||||
|
||||
void CubKeyValueSorter::run(void* workspace, size_t const workspace_size,
|
||||
int const* keys_in, int* keys_out,
|
||||
int const* values_in, int* values_out,
|
||||
size_t const num_key_value_pairs,
|
||||
cudaStream_t stream) {
|
||||
size_t expected_ws_size = getWorkspaceSize(num_key_value_pairs, num_experts_);
|
||||
size_t actual_ws_size = workspace_size;
|
||||
|
||||
TORCH_CHECK(expected_ws_size <= workspace_size,
|
||||
"[CubKeyValueSorter::run] The allocated workspace is too small "
|
||||
"to run this problem.");
|
||||
cub::DeviceRadixSort::SortPairs(workspace, actual_ws_size, keys_in, keys_out,
|
||||
values_in, values_out, num_key_value_pairs, 0,
|
||||
num_bits_, stream);
|
||||
}
|
||||
// CubKeyValueSorter definition end
|
||||
|
||||
static inline size_t pad_to_multiple_of_16(size_t const& input) {
|
||||
static constexpr int ALIGNMENT = 16;
|
||||
return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT);
|
||||
}
|
||||
template <class T>
|
||||
__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices,
|
||||
int64_t const arr_length,
|
||||
T const target) {
|
||||
int64_t low = 0, high = arr_length - 1, target_location = -1;
|
||||
while (low <= high) {
|
||||
int64_t mid = (low + high) / 2;
|
||||
|
||||
if (sorted_indices[mid] >= target) {
|
||||
high = mid - 1;
|
||||
} else {
|
||||
low = mid + 1;
|
||||
target_location = mid;
|
||||
}
|
||||
}
|
||||
return target_location + 1;
|
||||
}
|
||||
|
||||
// Calculates the start offset of the tokens for a given expert. The last
|
||||
// element is the total number of valid tokens
|
||||
__global__ void computeExpertFirstTokenOffsetKernel(
|
||||
int const* sorted_experts, int64_t const sorted_experts_len,
|
||||
int const num_experts, int64_t* expert_first_token_offset) {
|
||||
// First, compute the global tid. We only need 1 thread per expert.
|
||||
int const expert = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
// Note that expert goes [0, num_experts] (inclusive) because we want a count
|
||||
// for the total number of active tokens at the end of the scan.
|
||||
if (expert >= num_experts + 1) {
|
||||
return;
|
||||
}
|
||||
expert_first_token_offset[expert] =
|
||||
findTotalEltsLessThanTarget(sorted_experts, sorted_experts_len, expert);
|
||||
}
|
||||
|
||||
void computeExpertFirstTokenOffset(int const* sorted_indices,
|
||||
int const total_indices,
|
||||
int const num_experts,
|
||||
int64_t* expert_first_token_offset,
|
||||
cudaStream_t stream) {
|
||||
int const num_entries = num_experts + 1;
|
||||
int const threads = std::min(1024, num_entries);
|
||||
int const blocks = (num_entries + threads - 1) / threads;
|
||||
|
||||
computeExpertFirstTokenOffsetKernel<<<blocks, threads, 0, stream>>>(
|
||||
sorted_indices, total_indices, num_experts, expert_first_token_offset);
|
||||
}
|
||||
|
||||
void sortAndScanExpert(int* expert_for_source_row, const int* source_rows,
|
||||
int* permuted_experts, int* permuted_rows,
|
||||
int64_t* expert_first_token_offset, int num_rows,
|
||||
int num_experts, int num_experts_per_node, int k,
|
||||
CubKeyValueSorter& sorter, void* sorter_ws,
|
||||
cudaStream_t stream) {
|
||||
int64_t const expanded_num_rows = static_cast<int64_t>(k) * num_rows;
|
||||
// We need to use the full num_experts because that is the sentinel value used
|
||||
// by topk for disabled experts
|
||||
sorter.updateNumExperts(num_experts);
|
||||
size_t const sorter_ws_size_bytes = pad_to_multiple_of_16(
|
||||
sorter.getWorkspaceSize(expanded_num_rows, num_experts));
|
||||
sorter.run((void*)sorter_ws, sorter_ws_size_bytes, expert_for_source_row,
|
||||
permuted_experts, source_rows, permuted_rows, expanded_num_rows,
|
||||
stream);
|
||||
computeExpertFirstTokenOffset(permuted_experts, expanded_num_rows,
|
||||
num_experts_per_node, expert_first_token_offset,
|
||||
stream);
|
||||
}
|
||||
|
||||
__global__ void preprocessTopkIdKernel(int* topk_id_ptr, int size,
|
||||
const int* expert_map_ptr,
|
||||
int num_experts) {
|
||||
auto tidx = threadIdx.x;
|
||||
auto bidx = blockIdx.x;
|
||||
auto lidx = tidx & 31;
|
||||
auto widx = tidx >> 5;
|
||||
auto warp_count = (blockDim.x + 31) >> 5;
|
||||
auto offset = bidx * blockDim.x;
|
||||
auto bound = min(offset + blockDim.x, size);
|
||||
extern __shared__ int smem_expert_map[];
|
||||
// store expert_map in smem
|
||||
for (int i = tidx; i < num_experts; i += blockDim.x) {
|
||||
smem_expert_map[i] = expert_map_ptr[i];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// query global expert id in expert map.
|
||||
// if global expert id = -1 in exert map, plus n_expert
|
||||
// else set global expert id = exert map[global expert id]
|
||||
if (offset + tidx < bound) {
|
||||
auto topk_id = topk_id_ptr[offset + tidx];
|
||||
auto local_expert_idx = smem_expert_map[topk_id];
|
||||
if (local_expert_idx == -1) {
|
||||
topk_id += num_experts;
|
||||
} else {
|
||||
topk_id = local_expert_idx;
|
||||
}
|
||||
__syncwarp();
|
||||
topk_id_ptr[offset + tidx] = topk_id;
|
||||
}
|
||||
}
|
||||
void preprocessTopkIdLauncher(int* topk_id_ptr, int size,
|
||||
const int* expert_map_ptr, int num_experts,
|
||||
cudaStream_t stream) {
|
||||
int block = std::min(size, 1024);
|
||||
int grid = (size + block - 1) / block;
|
||||
int smem_size = (num_experts) * sizeof(int);
|
||||
preprocessTopkIdKernel<<<grid, block, smem_size, stream>>>(
|
||||
topk_id_ptr, size, expert_map_ptr, num_experts);
|
||||
}
|
||||
|
||||
template <bool ALIGN_BLOCK_SIZE>
|
||||
__global__ void getMIndicesKernel(int64_t* expert_first_token_offset,
|
||||
int64_t* align_expert_first_token_offset,
|
||||
int* m_indices, const int num_local_expert,
|
||||
const int align_block_size) {
|
||||
int eidx = blockIdx.x;
|
||||
int tidx = threadIdx.x;
|
||||
extern __shared__ int64_t smem_expert_first_token_offset[];
|
||||
for (int i = tidx; i <= num_local_expert; i += blockDim.x) {
|
||||
smem_expert_first_token_offset[tidx] = __ldg(expert_first_token_offset + i);
|
||||
}
|
||||
__syncthreads();
|
||||
auto last_token_offset = smem_expert_first_token_offset[eidx + 1];
|
||||
auto first_token_offset = smem_expert_first_token_offset[eidx];
|
||||
int n_token_in_expert = last_token_offset - first_token_offset;
|
||||
|
||||
if constexpr (ALIGN_BLOCK_SIZE) {
|
||||
n_token_in_expert = (n_token_in_expert + align_block_size - 1) /
|
||||
align_block_size * align_block_size;
|
||||
// round up to ALIGN_BLOCK_SIZE
|
||||
int64_t accumulate_align_offset = 0;
|
||||
for (int i = 1; i <= eidx + 1; i++) {
|
||||
int n_token = smem_expert_first_token_offset[i] -
|
||||
smem_expert_first_token_offset[i - 1];
|
||||
accumulate_align_offset =
|
||||
accumulate_align_offset + (n_token + align_block_size - 1) /
|
||||
align_block_size * align_block_size;
|
||||
if (i == eidx) {
|
||||
first_token_offset = accumulate_align_offset;
|
||||
}
|
||||
// last block store align_expert_first_token_offset
|
||||
if (eidx == num_local_expert - 1 && threadIdx.x == 0) {
|
||||
align_expert_first_token_offset[i] = accumulate_align_offset;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int idx = tidx; idx < n_token_in_expert; idx += blockDim.x) {
|
||||
// update m_indice with expert id
|
||||
m_indices[first_token_offset + idx] = eidx;
|
||||
}
|
||||
}
|
||||
|
||||
void getMIndices(int64_t* expert_first_token_offset,
|
||||
int64_t* align_expert_first_token_offset, int* m_indices,
|
||||
int num_local_expert, const int align_block_size,
|
||||
cudaStream_t stream) {
|
||||
int block = 256;
|
||||
int grid = num_local_expert;
|
||||
int smem_size = sizeof(int64_t) * (num_local_expert + 1);
|
||||
if (align_block_size == -1) {
|
||||
getMIndicesKernel<false><<<grid, block, smem_size, stream>>>(
|
||||
expert_first_token_offset, align_expert_first_token_offset, m_indices,
|
||||
num_local_expert, align_block_size);
|
||||
} else {
|
||||
getMIndicesKernel<true><<<grid, block, smem_size, stream>>>(
|
||||
expert_first_token_offset, align_expert_first_token_offset, m_indices,
|
||||
num_local_expert, align_block_size);
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,95 @@
|
||||
#pragma once
|
||||
// reference from tensorrt_llm moe kernel implementation archive in
|
||||
// https://github.com/BBuf/tensorrt-llm-moe/tree/master
|
||||
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <torch/all.h>
|
||||
#include "dispatch.h"
|
||||
#include <cub/cub.cuh>
|
||||
#include <cub/device/device_radix_sort.cuh>
|
||||
#include <cub/util_type.cuh>
|
||||
#include "cutlass/numeric_size.h"
|
||||
#include "cutlass/array.h"
|
||||
|
||||
template <typename T>
|
||||
inline T* get_ptr(torch::Tensor& t) {
|
||||
return reinterpret_cast<T*>(t.data_ptr());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline const T* get_ptr(const torch::Tensor& t) {
|
||||
return reinterpret_cast<const T*>(t.data_ptr());
|
||||
}
|
||||
|
||||
class CubKeyValueSorter {
|
||||
public:
|
||||
CubKeyValueSorter();
|
||||
|
||||
CubKeyValueSorter(int const num_experts);
|
||||
|
||||
void updateNumExperts(int const num_experts);
|
||||
|
||||
static size_t getWorkspaceSize(size_t const num_key_value_pairs,
|
||||
int const num_experts);
|
||||
|
||||
void run(void* workspace, size_t const workspace_size, int const* keys_in,
|
||||
int* keys_out, int const* values_in, int* values_out,
|
||||
size_t const num_key_value_pairs, cudaStream_t stream);
|
||||
|
||||
private:
|
||||
static int expertsToBits(int experts);
|
||||
int num_experts_;
|
||||
int num_bits_;
|
||||
};
|
||||
|
||||
void computeExpertFirstTokenOffset(int const* sorted_indices,
|
||||
int const total_indices,
|
||||
int const num_experts,
|
||||
int64_t* expert_first_token_offset,
|
||||
cudaStream_t stream);
|
||||
|
||||
void sortAndScanExpert(int* expert_for_source_row, const int* source_rows,
|
||||
int* permuted_experts, int* permuted_rows,
|
||||
int64_t* expert_first_token_offset, int num_rows,
|
||||
int num_experts, int num_experts_per_node, int k,
|
||||
CubKeyValueSorter& sorter, void* sorter_ws,
|
||||
cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void expandInputRowsKernelLauncher(
|
||||
T const* unpermuted_input, T* permuted_output,
|
||||
const float* unpermuted_scales, int* sorted_experts,
|
||||
int const* expanded_dest_row_to_expanded_source_row,
|
||||
int* expanded_source_row_to_expanded_dest_row,
|
||||
int64_t* expert_first_token_offset, int64_t const num_rows,
|
||||
int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
|
||||
int num_local_experts, const int& align_block_size, cudaStream_t stream);
|
||||
|
||||
// Final kernel to unpermute and scale
|
||||
// This kernel unpermutes the original data, does the k-way reduction and
|
||||
// performs the final skip connection.
|
||||
template <typename T, typename OutputType, bool CHECK_SKIPPED>
|
||||
__global__ void finalizeMoeRoutingKernel(
|
||||
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
|
||||
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
|
||||
int const* expert_for_source_row, int64_t const orig_cols, int64_t const k,
|
||||
int64_t const* num_valid_ptr);
|
||||
|
||||
template <class T, class OutputType>
|
||||
void finalizeMoeRoutingKernelLauncher(
|
||||
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
|
||||
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
|
||||
int const* expert_for_source_row, int64_t const num_rows,
|
||||
int64_t const cols, int64_t const k, int64_t const* num_valid_ptr,
|
||||
cudaStream_t stream);
|
||||
|
||||
void preprocessTopkIdLauncher(int* topk_id_ptr, int size,
|
||||
const int* expert_map_ptr, int num_experts,
|
||||
cudaStream_t stream);
|
||||
|
||||
void getMIndices(int64_t* expert_first_token_offset,
|
||||
int64_t* align_expert_first_token_offset, int* m_indices,
|
||||
int num_local_expert, const int align_block_size,
|
||||
cudaStream_t stream);
|
||||
|
||||
#include "moe_permute_unpermute_kernel.inl"
|
||||
@ -0,0 +1,211 @@
|
||||
#pragma once
|
||||
|
||||
template <typename T, bool CHECK_SKIPPED, bool ALIGN_BLOCK_SIZE>
|
||||
__global__ void expandInputRowsKernel(
|
||||
T const* unpermuted_input, T* permuted_output,
|
||||
const float* unpermuted_scales, int* sorted_experts,
|
||||
int const* expanded_dest_row_to_expanded_source_row,
|
||||
int* expanded_source_row_to_expanded_dest_row,
|
||||
int64_t* expert_first_token_offset, int64_t const num_rows,
|
||||
int64_t const* num_dest_rows, int64_t const cols, int64_t k,
|
||||
int num_local_experts, int align_block_size) {
|
||||
// Reverse permutation map.
|
||||
// I do this so that later, we can use the source -> dest map to do the k-way
|
||||
// reduction and unpermuting. I need the reverse map for that reduction to
|
||||
// allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1
|
||||
// thread block will be responsible for all k summations.
|
||||
int64_t expanded_dest_row = blockIdx.x;
|
||||
int64_t const expanded_source_row =
|
||||
expanded_dest_row_to_expanded_source_row[expanded_dest_row];
|
||||
int expert_id = sorted_experts[expanded_dest_row];
|
||||
|
||||
extern __shared__ int64_t smem_expert_first_token_offset[];
|
||||
int64_t align_expanded_row_accumulate = 0;
|
||||
if constexpr (ALIGN_BLOCK_SIZE) {
|
||||
// load g2s
|
||||
for (int idx = threadIdx.x; idx < num_local_experts + 1;
|
||||
idx += blockDim.x) {
|
||||
smem_expert_first_token_offset[idx] =
|
||||
__ldg(expert_first_token_offset + idx);
|
||||
}
|
||||
__syncthreads();
|
||||
int lane_idx = threadIdx.x & 31;
|
||||
|
||||
if (lane_idx == 0) {
|
||||
// set token_offset_in_expert = 0 if this expert is not local expert
|
||||
int token_offset_in_expert =
|
||||
expert_id >= num_local_experts
|
||||
? 0
|
||||
: expanded_dest_row - smem_expert_first_token_offset[expert_id];
|
||||
int64_t accumulate_align_offset = 0;
|
||||
#pragma unroll 1
|
||||
for (int eidx = 1; eidx <= min(expert_id, num_local_experts); eidx++) {
|
||||
auto n_token_in_expert = smem_expert_first_token_offset[eidx] -
|
||||
smem_expert_first_token_offset[eidx - 1];
|
||||
accumulate_align_offset += (n_token_in_expert + align_block_size - 1) /
|
||||
align_block_size * align_block_size;
|
||||
}
|
||||
expanded_dest_row = accumulate_align_offset + token_offset_in_expert;
|
||||
}
|
||||
// lane0 shuffle broadcast align_expanded_dest_row
|
||||
expanded_dest_row = __shfl_sync(0xffffffff, expanded_dest_row, 0);
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
assert(expanded_dest_row <= INT32_MAX);
|
||||
expanded_source_row_to_expanded_dest_row[expanded_source_row] =
|
||||
static_cast<int>(expanded_dest_row);
|
||||
}
|
||||
|
||||
if (!CHECK_SKIPPED || blockIdx.x < *num_dest_rows) {
|
||||
// Load 128-bits per thread
|
||||
constexpr int64_t ELEM_PER_THREAD = 128 / cutlass::sizeof_bits<T>::value;
|
||||
using DataElem = cutlass::Array<T, ELEM_PER_THREAD>;
|
||||
|
||||
// Duplicate and permute rows
|
||||
int64_t const source_k_rank = expanded_source_row / num_rows;
|
||||
int64_t const source_row = expanded_source_row % num_rows;
|
||||
|
||||
auto const* source_row_ptr =
|
||||
reinterpret_cast<DataElem const*>(unpermuted_input + source_row * cols);
|
||||
auto* dest_row_ptr =
|
||||
reinterpret_cast<DataElem*>(permuted_output + expanded_dest_row * cols);
|
||||
|
||||
int64_t const start_offset = threadIdx.x;
|
||||
int64_t const stride = blockDim.x;
|
||||
int64_t const num_elems_in_col = cols / ELEM_PER_THREAD;
|
||||
|
||||
for (int elem_index = start_offset; elem_index < num_elems_in_col;
|
||||
elem_index += stride) {
|
||||
dest_row_ptr[elem_index] = source_row_ptr[elem_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void expandInputRowsKernelLauncher(
|
||||
T const* unpermuted_input, T* permuted_output,
|
||||
const float* unpermuted_scales, int* sorted_experts,
|
||||
int const* expanded_dest_row_to_expanded_source_row,
|
||||
int* expanded_source_row_to_expanded_dest_row,
|
||||
int64_t* expert_first_token_offset, int64_t const num_rows,
|
||||
int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
|
||||
int num_local_experts, const int& align_block_size, cudaStream_t stream) {
|
||||
int64_t const blocks = num_rows * k;
|
||||
int64_t const threads = 256;
|
||||
using FuncPtr = decltype(&expandInputRowsKernel<T, true, true>);
|
||||
FuncPtr func_map[2][2] = {
|
||||
{&expandInputRowsKernel<T, false, false>,
|
||||
&expandInputRowsKernel<T, false, true>},
|
||||
{&expandInputRowsKernel<T, true, false>,
|
||||
&expandInputRowsKernel<T, true, true>},
|
||||
};
|
||||
bool is_check_skip = num_valid_tokens_ptr != nullptr;
|
||||
bool is_align_block_size = align_block_size != -1;
|
||||
auto func = func_map[is_check_skip][is_align_block_size];
|
||||
|
||||
int64_t smem_size = sizeof(int64_t) * (num_local_experts + 1);
|
||||
|
||||
func<<<blocks, threads, smem_size, stream>>>(
|
||||
unpermuted_input, permuted_output, unpermuted_scales, sorted_experts,
|
||||
expanded_dest_row_to_expanded_source_row,
|
||||
expanded_source_row_to_expanded_dest_row, expert_first_token_offset,
|
||||
num_rows, num_valid_tokens_ptr, cols, k, num_local_experts,
|
||||
align_block_size);
|
||||
}
|
||||
|
||||
template <class T, class U>
|
||||
__host__ __device__ constexpr static U arrayConvert(T const& input) {
|
||||
using Type = typename U::Element;
|
||||
static_assert(T::kElements == U::kElements);
|
||||
U u;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < U::kElements; i++) {
|
||||
u[i] = static_cast<Type>(input[i]);
|
||||
}
|
||||
return u;
|
||||
}
|
||||
|
||||
template <typename T, typename OutputType, bool CHECK_SKIPPED>
|
||||
__global__ void finalizeMoeRoutingKernel(
|
||||
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
|
||||
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
|
||||
int const* expert_for_source_row, int64_t const orig_cols, int64_t const k,
|
||||
int64_t const* num_valid_ptr) {
|
||||
assert(orig_cols % 4 == 0);
|
||||
int64_t const original_row = blockIdx.x;
|
||||
int64_t const num_rows = gridDim.x;
|
||||
auto const offset = original_row * orig_cols;
|
||||
OutputType* reduced_row_ptr = reduced_unpermuted_output + offset;
|
||||
int64_t const num_valid = *num_valid_ptr;
|
||||
|
||||
// Load 128-bits per thread, according to the smallest data type we read/write
|
||||
constexpr int64_t FINALIZE_ELEM_PER_THREAD =
|
||||
128 / std::min(cutlass::sizeof_bits<OutputType>::value,
|
||||
cutlass::sizeof_bits<T>::value);
|
||||
|
||||
int64_t const start_offset = threadIdx.x;
|
||||
int64_t const stride = blockDim.x;
|
||||
int64_t const num_elems_in_col = orig_cols / FINALIZE_ELEM_PER_THREAD;
|
||||
|
||||
using InputElem = cutlass::Array<T, FINALIZE_ELEM_PER_THREAD>;
|
||||
using OutputElem = cutlass::Array<OutputType, FINALIZE_ELEM_PER_THREAD>;
|
||||
using ComputeElem = cutlass::Array<float, FINALIZE_ELEM_PER_THREAD>;
|
||||
auto const* expanded_permuted_rows_v =
|
||||
reinterpret_cast<InputElem const*>(expanded_permuted_rows);
|
||||
auto* reduced_row_ptr_v = reinterpret_cast<OutputElem*>(reduced_row_ptr);
|
||||
|
||||
#pragma unroll
|
||||
for (int elem_index = start_offset; elem_index < num_elems_in_col;
|
||||
elem_index += stride) {
|
||||
ComputeElem thread_output;
|
||||
thread_output.fill(0);
|
||||
float row_rescale{0.f};
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||
int64_t const expanded_original_row = original_row + k_idx * num_rows;
|
||||
int64_t const expanded_permuted_row =
|
||||
expanded_source_row_to_expanded_dest_row[expanded_original_row];
|
||||
|
||||
int64_t const k_offset = original_row * k + k_idx;
|
||||
float const row_scale = scales[k_offset];
|
||||
|
||||
// Check after row_rescale has accumulated
|
||||
if (CHECK_SKIPPED && expanded_permuted_row >= num_valid) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto const* expanded_permuted_rows_row_ptr =
|
||||
expanded_permuted_rows_v + expanded_permuted_row * num_elems_in_col;
|
||||
|
||||
int64_t const expert_idx = expert_for_source_row[k_offset];
|
||||
|
||||
ComputeElem expert_result = arrayConvert<InputElem, ComputeElem>(
|
||||
expanded_permuted_rows_row_ptr[elem_index]);
|
||||
thread_output = thread_output + row_scale * (expert_result);
|
||||
}
|
||||
|
||||
OutputElem output_elem =
|
||||
arrayConvert<ComputeElem, OutputElem>(thread_output);
|
||||
reduced_row_ptr_v[elem_index] = output_elem;
|
||||
}
|
||||
}
|
||||
|
||||
template <class T, class OutputType>
|
||||
void finalizeMoeRoutingKernelLauncher(
|
||||
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
|
||||
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
|
||||
int const* expert_for_source_row, int64_t const num_rows,
|
||||
int64_t const cols, int64_t const k, int64_t const* num_valid_ptr,
|
||||
cudaStream_t stream) {
|
||||
int64_t const blocks = num_rows;
|
||||
int64_t const threads = 256;
|
||||
bool const check_finished = num_valid_ptr != nullptr;
|
||||
using FuncPtr = decltype(&finalizeMoeRoutingKernel<T, OutputType, false>);
|
||||
FuncPtr func_map[2] = {&finalizeMoeRoutingKernel<T, OutputType, false>,
|
||||
&finalizeMoeRoutingKernel<T, OutputType, true>};
|
||||
auto* const kernel = func_map[check_finished];
|
||||
kernel<<<blocks, threads, 0, stream>>>(
|
||||
expanded_permuted_rows, reduced_unpermuted_output, scales,
|
||||
expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k,
|
||||
num_valid_ptr);
|
||||
}
|
||||
@ -42,6 +42,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
|
||||
m.impl("moe_wna16_gemm", torch::kCUDA, &moe_wna16_gemm);
|
||||
|
||||
m.def(
|
||||
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
|
||||
"Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none,"
|
||||
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
|
||||
"Tensor sorted_token_ids,"
|
||||
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
|
||||
"Tensor! topk_weights, int moe_block_size, int top_k, "
|
||||
"bool mul_topk_weights, bool is_ep, int b_q_type_id,"
|
||||
"int size_m, int size_n, int size_k,"
|
||||
"bool is_full_k, bool use_atomic_add,"
|
||||
"bool use_fp32_reduce, bool is_zp_float) -> Tensor");
|
||||
m.def(
|
||||
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
|
||||
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
|
||||
@ -51,6 +62,20 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
"topk, "
|
||||
"int moe_block_size, bool replicate_input, bool apply_weights)"
|
||||
" -> Tensor");
|
||||
|
||||
m.def(
|
||||
"moe_permute(Tensor input, Tensor topk_weight, Tensor! topk_ids,"
|
||||
"Tensor token_expert_indicies, Tensor? expert_map, int n_expert,"
|
||||
"int n_local_expert,"
|
||||
"int topk, int? align_block_size,Tensor! permuted_input, Tensor! "
|
||||
"expert_first_token_offset, Tensor! src_row_id2dst_row_id_map, Tensor! "
|
||||
"m_indices)->()");
|
||||
|
||||
m.def(
|
||||
"moe_unpermute(Tensor permuted_hidden_states, Tensor topk_weights,"
|
||||
"Tensor topk_ids,Tensor src_row_id2dst_row_id_map, Tensor "
|
||||
"expert_first_token_offset, int n_expert, int n_local_expert,int "
|
||||
"topk, Tensor! hidden_states)->()");
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
#endif
|
||||
|
||||
18
csrc/ops.h
18
csrc/ops.h
@ -52,6 +52,15 @@ void paged_attention_v2(
|
||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||
const int64_t blocksparse_head_sliding_step);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
void merge_attn_states(torch::Tensor& output,
|
||||
std::optional<torch::Tensor> output_lse,
|
||||
const torch::Tensor& prefix_output,
|
||||
const torch::Tensor& prefix_lse,
|
||||
const torch::Tensor& suffix_output,
|
||||
const torch::Tensor& suffix_lse);
|
||||
#endif
|
||||
|
||||
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
||||
double epsilon);
|
||||
|
||||
@ -88,6 +97,9 @@ void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
||||
|
||||
void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input,
|
||||
torch::Tensor& scale);
|
||||
|
||||
void mul_and_silu(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);
|
||||
@ -119,6 +131,12 @@ void advance_step_flashinfer(
|
||||
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
|
||||
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds);
|
||||
|
||||
void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
|
||||
torch::Tensor const& q_pe,
|
||||
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||
torch::Tensor const& seq_lens,
|
||||
torch::Tensor const& page_table, double scale);
|
||||
|
||||
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
|
||||
120
csrc/quantization/activation_kernels.cu
Normal file
120
csrc/quantization/activation_kernels.cu
Normal file
@ -0,0 +1,120 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/all.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <cmath>
|
||||
#include "core/math.hpp"
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include "quantization/fp8/common.cuh"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T silu_kernel(const T& x) {
|
||||
// x * sigmoid(x)
|
||||
return (T)(((float)x) / (1.0f + expf((float)-x)));
|
||||
}
|
||||
|
||||
// Activation and gating kernel template.
|
||||
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
|
||||
typename fp8_type>
|
||||
__global__ void act_and_mul_quant_kernel(
|
||||
fp8_type* __restrict__ out, // [..., d]
|
||||
const scalar_t* __restrict__ input, // [..., 2, d]
|
||||
const float* scale, const int d) {
|
||||
const int32_t blocks_per_token = gridDim.y;
|
||||
|
||||
const int32_t elems_per_128bit_load = (128 / 8) / sizeof(scalar_t);
|
||||
|
||||
// We don't expect the hidden dimension to exceed 32 bits so int32 should
|
||||
// be safe here.
|
||||
const int32_t tgt_elems_per_block = div_ceil(d, blocks_per_token);
|
||||
const int32_t elems_per_block =
|
||||
round_to_next_multiple_of(tgt_elems_per_block, elems_per_128bit_load);
|
||||
const int32_t block_start = blockIdx.y * elems_per_block;
|
||||
int32_t block_end = block_start + elems_per_block;
|
||||
block_end = block_end > d ? d : block_end;
|
||||
|
||||
// token_idx is 64 bit to prevent 32 bit overflow when the number of tokens
|
||||
// is very large
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const scalar_t* __restrict__ x_ptr = input + token_idx * 2 * d;
|
||||
const scalar_t* __restrict__ y_ptr = input + token_idx * 2 * d + d;
|
||||
fp8_type* __restrict__ out_ptr = out + token_idx * d;
|
||||
|
||||
// 128-bit vectorized code
|
||||
const int32_t vec_loop_end =
|
||||
round_to_previous_multiple_of(elems_per_128bit_load, block_end);
|
||||
const int32_t vec_end_idx = vec_loop_end / elems_per_128bit_load;
|
||||
const int32_t vec_start_idx = block_start / elems_per_128bit_load;
|
||||
|
||||
const int4* __restrict__ x_128bit_ptr = reinterpret_cast<const int4*>(x_ptr);
|
||||
const int4* __restrict__ y_128bit_ptr = reinterpret_cast<const int4*>(y_ptr);
|
||||
int2* __restrict__ out_128bit_ptr = reinterpret_cast<int2*>(out_ptr);
|
||||
|
||||
float inverted_scale = 1 / *scale;
|
||||
#pragma unroll
|
||||
for (int32_t vec_idx = vec_start_idx + threadIdx.x; vec_idx < vec_end_idx;
|
||||
vec_idx += blockDim.x) {
|
||||
const int4 x_128bit = VLLM_LDG(&x_128bit_ptr[vec_idx]);
|
||||
const int4 y_128bit = VLLM_LDG(&y_128bit_ptr[vec_idx]);
|
||||
using scalar_128bit_vec_t = std::array<scalar_t, elems_per_128bit_load>;
|
||||
using scalar_64bit_vec_t = std::array<fp8_type, elems_per_128bit_load>;
|
||||
|
||||
scalar_64bit_vec_t out_vec;
|
||||
const auto x_vec = reinterpret_cast<scalar_128bit_vec_t const&>(x_128bit);
|
||||
const auto y_vec = reinterpret_cast<scalar_128bit_vec_t const&>(y_128bit);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < elems_per_128bit_load; i++) {
|
||||
out_vec[i] = scaled_fp8_conversion<true, fp8_type>(
|
||||
ACT_FN(x_vec[i]) * y_vec[i], inverted_scale);
|
||||
}
|
||||
|
||||
out_128bit_ptr[vec_idx] = reinterpret_cast<const int2&>(out_vec);
|
||||
}
|
||||
|
||||
// Scalar cleanup code
|
||||
if (block_end > vec_loop_end) {
|
||||
for (int64_t idx = vec_loop_end + threadIdx.x; idx < block_end;
|
||||
idx += blockDim.x) {
|
||||
const scalar_t x = VLLM_LDG(&x_ptr[idx]);
|
||||
const scalar_t y = VLLM_LDG(&y_ptr[idx]);
|
||||
out_ptr[idx] =
|
||||
scaled_fp8_conversion<true, fp8_type>(ACT_FN(x) * y, inverted_scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace vllm
|
||||
|
||||
// Launch activation, gating, and quantize kernel.
|
||||
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
|
||||
int d = input.size(-1) / 2; \
|
||||
int64_t num_tokens = input.numel() / input.size(-1); \
|
||||
dim3 grid(num_tokens, num_tokens > 16 ? num_tokens > 32 ? 1 : 2 : 4); \
|
||||
dim3 block(std::min(d, 512)); \
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||
input.scalar_type(), "act_and_mul_kernel", [&] { \
|
||||
VLLM_DISPATCH_FP8_TYPES( \
|
||||
out.scalar_type(), "fused_add_rms_norm_kernel_fp8_type", [&] { \
|
||||
vllm::act_and_mul_quant_kernel<scalar_t, KERNEL<scalar_t>, \
|
||||
fp8_t> \
|
||||
<<<grid, block, 0, stream>>>(out.data_ptr<fp8_t>(), \
|
||||
input.data_ptr<scalar_t>(), \
|
||||
scale.data_ptr<float>(), d); \
|
||||
}); \
|
||||
});
|
||||
|
||||
void silu_and_mul_quant(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input, // [..., 2 * d]
|
||||
torch::Tensor& scale) {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(input.dtype() == torch::kFloat16 ||
|
||||
input.dtype() == torch::kBFloat16);
|
||||
TORCH_CHECK(input.size(-1) % 2 == 0);
|
||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
|
||||
}
|
||||
@ -46,14 +46,26 @@ __global__ void compute_expert_offsets(
|
||||
}
|
||||
|
||||
__global__ void compute_arg_sorts(const int* __restrict__ topk_ids,
|
||||
const int32_t* __restrict__ expert_offsets,
|
||||
int32_t* input_permutation,
|
||||
int32_t* output_permutation,
|
||||
int32_t* atomic_buffer, const int topk_length,
|
||||
const int topk) {
|
||||
int expert_id = blockIdx.x;
|
||||
int const blk_expert_id = blockIdx.x;
|
||||
int const num_experts = gridDim.x;
|
||||
int32_t const num_tokens = expert_offsets[num_experts];
|
||||
|
||||
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
|
||||
if (topk_ids[i] == expert_id) {
|
||||
int const expert_id = topk_ids[i];
|
||||
if (expert_id == -1 && blockIdx.x == 0) {
|
||||
// output_permutation is used to re-order the moe outputs. It is
|
||||
// used as c2 = c2[c_map], where c2 is a torch.tensor that is the
|
||||
// output of the cutlass kernels and c_map is the output_permutation.
|
||||
// c2 is initialized to zeros, therefore by setting the output_permutation
|
||||
// to num_tokens, we are guaranteed to fill the moe outputs to zero
|
||||
// for "invalid" topk_ids.
|
||||
output_permutation[i] = num_tokens;
|
||||
} else if (expert_id == blk_expert_id) {
|
||||
int start = atomicAdd(&atomic_buffer[expert_id], 1);
|
||||
input_permutation[start] = i / topk;
|
||||
output_permutation[i] = start;
|
||||
@ -83,6 +95,7 @@ void get_cutlass_moe_mm_data_caller(
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
|
||||
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
|
||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<const int32_t*>(expert_offsets.data_ptr()),
|
||||
static_cast<int32_t*>(input_permutation.data_ptr()),
|
||||
static_cast<int32_t*>(output_permutation.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(),
|
||||
|
||||
@ -336,7 +336,7 @@ inline void cutlass_gemm_sm89_fp8_dispatch(torch::Tensor& out,
|
||||
|
||||
uint32_t const m = a.size(0);
|
||||
uint32_t const mp2 =
|
||||
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
|
||||
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
|
||||
|
||||
if (mp2 <= 16) {
|
||||
// M in [1, 16]
|
||||
|
||||
@ -321,7 +321,7 @@ inline void cutlass_gemm_sm89_int8_dispatch(torch::Tensor& out,
|
||||
|
||||
uint32_t const m = a.size(0);
|
||||
uint32_t const mp2 =
|
||||
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
|
||||
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
|
||||
|
||||
if (mp2 <= 16) {
|
||||
// M in [1, 16]
|
||||
|
||||
@ -134,7 +134,7 @@ typename T::Gemm::Arguments args_from_options(
|
||||
using StrideB = typename T::StrideB;
|
||||
using StrideD = typename T::StrideD;
|
||||
using Sm100BlkScaledConfig =
|
||||
typename T::Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig;
|
||||
typename T::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
|
||||
int m = static_cast<int>(M);
|
||||
int n = static_cast<int>(N);
|
||||
|
||||
@ -96,7 +96,7 @@ void rms_norm_dynamic_per_token_quant_dispatch(
|
||||
std::optional<at::Tensor> const& scale_ub,
|
||||
std::optional<at::Tensor>& residual) {
|
||||
int32_t hidden_size = input.size(-1);
|
||||
int32_t num_tokens = input.numel() / hidden_size;
|
||||
auto num_tokens = input.numel() / hidden_size;
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(hidden_size, 1024));
|
||||
|
||||
@ -129,7 +129,7 @@ static __device__ __forceinline__ void moe_q(
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MOE_X_Q4_0 64
|
||||
#define MOE_X_Q4_0 8
|
||||
#define MOE_Y_Q4_0 128
|
||||
#define NWARPS_Q4_0 8
|
||||
#else
|
||||
@ -190,7 +190,7 @@ static void ggml_moe_q4_0_q8_1_cuda(
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MOE_X_Q4_1 64
|
||||
#define MOE_X_Q4_1 8
|
||||
#define MOE_Y_Q4_1 128
|
||||
#define NWARPS_Q4_1 8
|
||||
#else
|
||||
@ -251,7 +251,7 @@ static void ggml_moe_q4_1_q8_1_cuda(
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MOE_X_Q5_0 64
|
||||
#define MOE_X_Q5_0 8
|
||||
#define MOE_Y_Q5_0 128
|
||||
#define NWARPS_Q5_0 8
|
||||
#else
|
||||
@ -312,7 +312,7 @@ static void ggml_moe_q5_0_q8_1_cuda(
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MOE_X_Q5_1 64
|
||||
#define MOE_X_Q5_1 8
|
||||
#define MOE_Y_Q5_1 128
|
||||
#define NWARPS_Q5_1 8
|
||||
#else
|
||||
@ -373,7 +373,7 @@ static void ggml_moe_q5_1_q8_1_cuda(
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MOE_X_Q8_0 64
|
||||
#define MOE_X_Q8_0 8
|
||||
#define MOE_Y_Q8_0 128
|
||||
#define NWARPS_Q8_0 8
|
||||
#else
|
||||
@ -434,7 +434,7 @@ static void ggml_moe_q8_0_q8_1_cuda(
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MOE_X_Q2_K 64
|
||||
#define MOE_X_Q2_K 8
|
||||
#define MOE_Y_Q2_K 128
|
||||
#define NWARPS_Q2_K 8
|
||||
#else
|
||||
@ -495,7 +495,7 @@ static void ggml_moe_q2_K_q8_1_cuda(
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MOE_X_Q3_K 64
|
||||
#define MOE_X_Q3_K 8
|
||||
#define MOE_Y_Q3_K 128
|
||||
#define NWARPS_Q3_K 8
|
||||
#else
|
||||
@ -556,7 +556,7 @@ static void ggml_moe_q3_K_q8_1_cuda(
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MOE_X_Q4_K 64
|
||||
#define MOE_X_Q4_K 8
|
||||
#define MOE_Y_Q4_K 128
|
||||
#define NWARPS_Q4_K 8
|
||||
#else
|
||||
@ -617,7 +617,7 @@ static void ggml_moe_q4_K_q8_1_cuda(
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MOE_X_Q5_K 64
|
||||
#define MOE_X_Q5_K 8
|
||||
#define MOE_Y_Q5_K 128
|
||||
#define NWARPS_Q5_K 8
|
||||
#else
|
||||
@ -678,7 +678,7 @@ static void ggml_moe_q5_K_q8_1_cuda(
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MOE_X_Q6_K 64
|
||||
#define MOE_X_Q6_K 8
|
||||
#define MOE_Y_Q6_K 128
|
||||
#define NWARPS_Q6_K 8
|
||||
#else
|
||||
|
||||
@ -347,7 +347,7 @@ struct ComputeTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
|
||||
for (int n_idx = 0; n_idx < WARP_NITER; ++n_idx) {
|
||||
hmma16816_f32<FType>(
|
||||
C_frag[m_idx][n_idx], A_frag[reg_buf_idx][m_idx],
|
||||
reinterpret_cast<uint32_t(&)[2]>(BF_frag[reg_buf_idx][n_idx]));
|
||||
reinterpret_cast<uint32_t (&)[2]>(BF_frag[reg_buf_idx][n_idx]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -173,8 +173,8 @@ dequant<half, vllm::kU4B8.id()>(int q) {
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
|
||||
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
|
||||
// directly into `SUB` and `ADD`.
|
||||
const int SUB = 0x64086408;
|
||||
@ -197,9 +197,9 @@ dequant<nv_bfloat16, vllm::kU4B8.id()>(int q) {
|
||||
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
|
||||
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
q >>= 4;
|
||||
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
|
||||
typename ScalarType<nv_bfloat16>::FragB frag_b;
|
||||
static constexpr uint32_t MUL = 0x3F803F80;
|
||||
@ -221,8 +221,8 @@ dequant<half, vllm::kU4.id()>(int q) {
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
|
||||
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
|
||||
const int SUB = 0x64006400;
|
||||
const int MUL = 0x2c002c00;
|
||||
@ -244,9 +244,9 @@ dequant<nv_bfloat16, vllm::kU4.id()>(int q) {
|
||||
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
|
||||
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
q >>= 4;
|
||||
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
|
||||
typename ScalarType<nv_bfloat16>::FragB frag_b;
|
||||
static constexpr uint32_t MUL = 0x3F803F80;
|
||||
|
||||
@ -9,7 +9,11 @@
|
||||
#include <cuda_runtime.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace marlin {
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin
|
||||
#endif
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
// Marlin params
|
||||
|
||||
@ -23,6 +27,7 @@ static constexpr int pipe_stages =
|
||||
|
||||
static constexpr int min_thread_n = 64;
|
||||
static constexpr int min_thread_k = 64;
|
||||
static constexpr int max_thread_n = 256;
|
||||
|
||||
static constexpr int tile_size = 16;
|
||||
static constexpr int max_par = 16;
|
||||
@ -84,4 +89,4 @@ __device__ inline void cp_async_wait() {
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace marlin
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
|
||||
@ -5,7 +5,11 @@
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
namespace marlin {
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin
|
||||
#endif
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
template <typename scalar_t>
|
||||
class ScalarType {};
|
||||
@ -54,7 +58,7 @@ class ScalarType<nv_bfloat16> {
|
||||
using FragS = Vec<nv_bfloat162, 1>;
|
||||
using FragZP = Vec<nv_bfloat162, 4>;
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
|
||||
static __device__ float inline num2float(const nv_bfloat16 x) {
|
||||
return __bfloat162float(x);
|
||||
}
|
||||
@ -74,6 +78,6 @@ class ScalarType<nv_bfloat16> {
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace marlin
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
|
||||
#endif
|
||||
|
||||
@ -96,8 +96,8 @@ __device__ inline FragB dequant(int q) {
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
|
||||
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
|
||||
// directly into `SUB` and `ADD`.
|
||||
const int SUB = 0x64086408;
|
||||
|
||||
@ -141,8 +141,8 @@ __device__ inline FragB dequant_per_group(int q, FragS_GROUP& frag_s, int i) {
|
||||
static constexpr uint32_t HI = 0x00f000f0;
|
||||
static constexpr uint32_t EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
uint32_t t0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
|
||||
uint32_t t1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
|
||||
uint32_t t0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
uint32_t t1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
|
||||
// directly into `SUB` and `ADD`.
|
||||
static constexpr uint32_t SUB = 0x64086408;
|
||||
|
||||
@ -127,8 +127,8 @@ __device__ inline FragB dequant_4bit(int q) {
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
|
||||
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
|
||||
// directly into `SUB` and `ADD`.
|
||||
const int SUB = 0x64086408;
|
||||
|
||||
@ -25,8 +25,9 @@
|
||||
#include "../attention/dtype_fp8.cuh"
|
||||
#include "../quantization/fp8/amd/quant_utils.cuh"
|
||||
|
||||
#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx942__))
|
||||
#define __HIP__MI300_MI250__
|
||||
#if defined(__HIPCC__) && \
|
||||
(defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__))
|
||||
#define __HIP__GFX9__
|
||||
#endif
|
||||
|
||||
#if defined(NDEBUG)
|
||||
@ -42,7 +43,7 @@
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
|
||||
|
||||
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
|
||||
#if defined(__HIP__GFX9__) // TODO: Add NAVI support
|
||||
|
||||
#define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32
|
||||
#define GCN_MFMA_INSTR __builtin_amdgcn_mfma_f32_4x4x4f16
|
||||
@ -1479,7 +1480,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
|
||||
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
|
||||
|
||||
// clang-format off
|
||||
template <typename scalar_t, typename cache_t,
|
||||
@ -1552,7 +1553,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
}
|
||||
// clang-format on
|
||||
|
||||
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
|
||||
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
|
||||
|
||||
#define LAUNCH_CUSTOM_ATTENTION_MFMA16(GQA_RATIO) \
|
||||
paged_attention_ll4mi_QKV_mfma16_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \
|
||||
|
||||
@ -2,6 +2,15 @@
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
|
||||
const int64_t rows_per_block);
|
||||
|
||||
torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
|
||||
const int64_t CuCount);
|
||||
|
||||
void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
|
||||
at::Tensor& scale_a, at::Tensor& scale_b, const int64_t CuCount);
|
||||
|
||||
void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums,
|
||||
torch::Tensor& max_logits, torch::Tensor& tmp_out,
|
||||
torch::Tensor& query, torch::Tensor& key_cache,
|
||||
|
||||
1600
csrc/rocm/skinny_gemms.cu
Normal file
1600
csrc/rocm/skinny_gemms.cu
Normal file
File diff suppressed because it is too large
Load Diff
@ -14,6 +14,24 @@
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
|
||||
// vLLM custom ops for rocm
|
||||
|
||||
// Custom gemm op for matrix-vector multiplication
|
||||
rocm_ops.def(
|
||||
"LLMM1(Tensor in_a, Tensor in_b, int rows_per_block) -> "
|
||||
"Tensor");
|
||||
rocm_ops.impl("LLMM1", torch::kCUDA, &LLMM1);
|
||||
|
||||
// Custom gemm op for skinny matrix-matrix multiplication
|
||||
rocm_ops.def(
|
||||
"wvSplitK(Tensor in_a, Tensor in_b, int CuCount) -> "
|
||||
"Tensor");
|
||||
rocm_ops.impl("wvSplitK", torch::kCUDA, &wvSplitK);
|
||||
|
||||
// wvSplitK for fp8
|
||||
rocm_ops.def(
|
||||
"wvSplitKQ(Tensor in_a, Tensor in_b, Tensor! out_c, Tensor scale_a, "
|
||||
" Tensor scale_b, int CuCount) -> ()");
|
||||
rocm_ops.impl("wvSplitKQ", torch::kCUDA, &wvSplitKQ);
|
||||
|
||||
// Custom attention op
|
||||
// Compute the attention between an input query and the cached
|
||||
// keys/values using PagedAttention.
|
||||
|
||||
@ -64,11 +64,30 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" int blocksparse_head_sliding_step) -> ()");
|
||||
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// Merge attn states
|
||||
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
|
||||
// can be used to combine partial attention results (in the split-KV case)
|
||||
ops.def(
|
||||
"merge_attn_states("
|
||||
" Tensor! output,"
|
||||
" Tensor!? output_lse,"
|
||||
" Tensor prefix_output,"
|
||||
" Tensor prefix_lse,"
|
||||
" Tensor suffix_output,"
|
||||
" Tensor suffix_lse) -> ()");
|
||||
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
|
||||
#endif
|
||||
|
||||
// Activation ops
|
||||
// Activation function used in SwiGLU.
|
||||
ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
|
||||
ops.def("silu_and_mul(Tensor! result, Tensor input) -> ()");
|
||||
ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
|
||||
|
||||
ops.def(
|
||||
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
|
||||
ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);
|
||||
|
||||
ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()");
|
||||
ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu);
|
||||
|
||||
@ -428,6 +447,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def("cutlass_sparse_compress(Tensor a) -> Tensor[]");
|
||||
ops.impl("cutlass_sparse_compress", &cutlass_sparse_compress);
|
||||
|
||||
// CUTLASS MLA decode
|
||||
ops.def(
|
||||
"cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe,"
|
||||
" Tensor kv_c_and_k_pe_cache, Tensor seq_lens,"
|
||||
" Tensor page_table, float scale) -> ()");
|
||||
ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
|
||||
|
||||
// Mamba selective scan kernel
|
||||
ops.def(
|
||||
"selective_scan_fwd(Tensor! u, Tensor! delta,"
|
||||
|
||||
@ -5,11 +5,11 @@
|
||||
# docs/source/contributing/dockerfile/dockerfile.md and
|
||||
# docs/source/assets/contributing/dockerfile-stages-dependency.png
|
||||
|
||||
ARG CUDA_VERSION=12.4.1
|
||||
ARG CUDA_VERSION=12.8.1
|
||||
#################### BASE BUILD IMAGE ####################
|
||||
# prepare basic build environment
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 AS base
|
||||
ARG CUDA_VERSION=12.4.1
|
||||
ARG CUDA_VERSION=12.8.1
|
||||
ARG PYTHON_VERSION=3.12
|
||||
ARG TARGETPLATFORM
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
@ -19,7 +19,10 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
|
||||
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
|
||||
&& apt-get update -y \
|
||||
&& apt-get install -y ccache software-properties-common git curl sudo \
|
||||
&& add-apt-repository ppa:deadsnakes/ppa \
|
||||
&& for i in 1 2 3; do \
|
||||
add-apt-repository -y ppa:deadsnakes/ppa && break || \
|
||||
{ echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \
|
||||
done \
|
||||
&& apt-get update -y \
|
||||
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
|
||||
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
|
||||
@ -34,6 +37,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
ENV UV_INDEX_STRATEGY="unsafe-best-match"
|
||||
|
||||
# Upgrade to GCC 10 to avoid https://gcc.gnu.org/bugzilla/show_bug.cgi?id=92519
|
||||
# as it was causing spam when compiling the CUTLASS kernels
|
||||
@ -66,7 +70,8 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
COPY requirements/common.txt requirements/common.txt
|
||||
COPY requirements/cuda.txt requirements/cuda.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements/cuda.txt
|
||||
uv pip install --system -r requirements/cuda.txt \
|
||||
--extra-index-url https://download.pytorch.org/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
|
||||
# cuda arch list used by torch
|
||||
# can be useful for both `dev` and `test`
|
||||
@ -89,9 +94,11 @@ COPY requirements/build.txt requirements/build.txt
|
||||
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
ENV UV_INDEX_STRATEGY="unsafe-best-match"
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements/build.txt
|
||||
uv pip install --system -r requirements/build.txt \
|
||||
--extra-index-url https://download.pytorch.org/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
|
||||
COPY . .
|
||||
ARG GIT_REPO_CHECK=0
|
||||
@ -158,19 +165,25 @@ FROM base as dev
|
||||
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
ENV UV_INDEX_STRATEGY="unsafe-best-match"
|
||||
|
||||
# Workaround for #17068
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system --no-build-isolation "git+https://github.com/state-spaces/mamba@v2.2.4"
|
||||
|
||||
COPY requirements/lint.txt requirements/lint.txt
|
||||
COPY requirements/test.txt requirements/test.txt
|
||||
COPY requirements/dev.txt requirements/dev.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements/dev.txt
|
||||
uv pip install --system -r requirements/dev.txt \
|
||||
--extra-index-url https://download.pytorch.org/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
#################### DEV IMAGE ####################
|
||||
|
||||
#################### vLLM installation IMAGE ####################
|
||||
# image with vLLM installed
|
||||
# TODO: Restore to base image after FlashInfer AOT wheel fixed
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 AS vllm-base
|
||||
ARG CUDA_VERSION=12.4.1
|
||||
ARG CUDA_VERSION=12.8.1
|
||||
ARG PYTHON_VERSION=3.12
|
||||
WORKDIR /vllm-workspace
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
@ -185,7 +198,10 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
|
||||
&& apt-get update -y \
|
||||
&& apt-get install -y ccache software-properties-common git curl wget sudo vim python3-pip \
|
||||
&& apt-get install -y ffmpeg libsm6 libxext6 libgl1 \
|
||||
&& add-apt-repository ppa:deadsnakes/ppa \
|
||||
&& for i in 1 2 3; do \
|
||||
add-apt-repository -y ppa:deadsnakes/ppa && break || \
|
||||
{ echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \
|
||||
done \
|
||||
&& apt-get update -y \
|
||||
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv libibverbs-dev \
|
||||
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
|
||||
@ -200,6 +216,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
ENV UV_INDEX_STRATEGY="unsafe-best-match"
|
||||
|
||||
# Workaround for https://github.com/openai/triton/issues/2507 and
|
||||
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
|
||||
@ -220,7 +237,8 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
# Install vllm wheel first, so that torch etc will be installed.
|
||||
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \
|
||||
--mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system dist/*.whl --verbose
|
||||
uv pip install --system dist/*.whl --verbose \
|
||||
--extra-index-url https://download.pytorch.org/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
|
||||
# If we need to build FlashInfer wheel before its release:
|
||||
# $ export FLASHINFER_ENABLE_AOT=1
|
||||
@ -237,9 +255,17 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
. /etc/environment && \
|
||||
if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
|
||||
uv pip install --system https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.1.post2/flashinfer_python-0.2.1.post2+cu124torch2.6-cp38-abi3-linux_x86_64.whl ; \
|
||||
# TESTING: install FlashInfer from source to test 2.7.0 final RC
|
||||
FLASHINFER_ENABLE_AOT=1 TORCH_CUDA_ARCH_LIST='7.5 8.0 8.6 8.9 9.0+PTX' \
|
||||
uv pip install --system --no-build-isolation "git+https://github.com/flashinfer-ai/flashinfer@v0.2.2.post1" ; \
|
||||
fi
|
||||
COPY examples examples
|
||||
COPY benchmarks benchmarks
|
||||
COPY ./vllm/collect_env.py .
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
. /etc/environment && \
|
||||
uv pip list
|
||||
|
||||
# Although we build Flashinfer with AOT mode, there's still
|
||||
# some issues w.r.t. JIT compilation. Therefore we need to
|
||||
@ -247,7 +273,8 @@ COPY examples examples
|
||||
# TODO: Remove this once FlashInfer AOT wheel is fixed
|
||||
COPY requirements/build.txt requirements/build.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements/build.txt
|
||||
uv pip install --system -r requirements/build.txt \
|
||||
--extra-index-url https://download.pytorch.org/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
|
||||
#################### vLLM installation IMAGE ####################
|
||||
|
||||
@ -261,6 +288,11 @@ ADD . /vllm-workspace/
|
||||
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
ENV UV_INDEX_STRATEGY="unsafe-best-match"
|
||||
|
||||
# Workaround for #17068
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system --no-build-isolation "git+https://github.com/state-spaces/mamba@v2.2.4"
|
||||
|
||||
# install development dependencies (for testing)
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
@ -289,6 +321,7 @@ RUN mv vllm test_docs/
|
||||
#################### OPENAI API SERVER ####################
|
||||
# base openai image with additional requirements, for any subsequent openai-style images
|
||||
FROM vllm-base AS vllm-openai-base
|
||||
ARG TARGETPLATFORM
|
||||
|
||||
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user