Compare commits
597 Commits
bench-late
...
disable-sd
| Author | SHA1 | Date | |
|---|---|---|---|
| b73fdb927a | |||
| a0304dc504 | |||
| c7941cca18 | |||
| b6dd32aa07 | |||
| f94886946e | |||
| 72dfe4c74f | |||
| 8b464d9660 | |||
| 889ebb2638 | |||
| 3ad986c28b | |||
| 344e193b7d | |||
| fb1c933ade | |||
| 72c5b97231 | |||
| fa93cd9f60 | |||
| aec9674dbe | |||
| 7fcc4223dc | |||
| 8262a3e23b | |||
| f211331c48 | |||
| 9053d0b134 | |||
| cb3f2d8d10 | |||
| c12df53b60 | |||
| d1aeea7553 | |||
| d8bccde686 | |||
| 20e489eaa1 | |||
| 4213475ec7 | |||
| d92879baf6 | |||
| 690fe019f0 | |||
| ed7a29d9f8 | |||
| 756848e79e | |||
| 18445edd0f | |||
| 30215ca61f | |||
| 838cedade7 | |||
| 4283a28c2f | |||
| 93a126fbc7 | |||
| 8e4b351a0c | |||
| 9869453c42 | |||
| 3642c59aa8 | |||
| 43eea2953b | |||
| de7eb10ce4 | |||
| fd11a325b8 | |||
| 4d17e20310 | |||
| 10fd1d7380 | |||
| 52b4f4a8d7 | |||
| e782e0a170 | |||
| dc2ceca5c5 | |||
| f8acd01ff7 | |||
| c48334d405 | |||
| 909fdaf152 | |||
| 8c1c926d00 | |||
| df6f3ce883 | |||
| 513f074766 | |||
| b07bf83c7d | |||
| 53e8cf53a4 | |||
| 54271bb766 | |||
| 9e96f56efb | |||
| b278911229 | |||
| 7bd0c7745c | |||
| 1cf0719ebd | |||
| 537d5ee025 | |||
| c8e5be35f7 | |||
| a6e72e1e4f | |||
| 5e83a7277f | |||
| 68af5f6c5c | |||
| 8de2901fea | |||
| c53e0730cb | |||
| a0e619e62a | |||
| 70116459c3 | |||
| 65e262b93b | |||
| 43faa0461a | |||
| 48cb2109b6 | |||
| a5450f11c9 | |||
| 9d98ab5ec6 | |||
| df5c879527 | |||
| 423e9f1cbe | |||
| 0bd7f8fca5 | |||
| d5615af9ae | |||
| 19dcc02a72 | |||
| 7feae92c1f | |||
| f851b84266 | |||
| fc966e9cc6 | |||
| ef19e67d2c | |||
| a41351f363 | |||
| 6aae216b4e | |||
| b22980a1dc | |||
| 881f735827 | |||
| 2f54045508 | |||
| 5aa6efb9a5 | |||
| 6ca0234478 | |||
| 649818995f | |||
| 7a0a9da72b | |||
| 69bff9bc89 | |||
| 41ca7eb491 | |||
| eef364723c | |||
| 0d6e187e88 | |||
| 9420a1fc30 | |||
| 583e900996 | |||
| 05e1fbfc52 | |||
| fe92176321 | |||
| 6d0df0ebeb | |||
| 0fa939e2d1 | |||
| 0422ce109f | |||
| 47bdee409c | |||
| 49f189439d | |||
| 5adf6f6b7f | |||
| 4115f19958 | |||
| 340d7b1b21 | |||
| 1bcbcbf574 | |||
| 82e43b2d7e | |||
| 67309a1cb5 | |||
| b724afe343 | |||
| 21f4f1c9a4 | |||
| b0c1f6202d | |||
| c0dfd97519 | |||
| a9138e85b1 | |||
| 0a05ed57e6 | |||
| 14288d1332 | |||
| b411418ff0 | |||
| 2bc0f72ae5 | |||
| 9c1244de57 | |||
| db2f8d915c | |||
| 6167c0e5d2 | |||
| ed2e464653 | |||
| 2c8ed8ee48 | |||
| ed50f46641 | |||
| 46e678bcff | |||
| 6b2427f995 | |||
| b07d741661 | |||
| 41fb013d29 | |||
| 32d4b669d0 | |||
| 3cde34a4a4 | |||
| bdb3660312 | |||
| f3a21e9c68 | |||
| 8e630d680e | |||
| af869f6dff | |||
| 53c0fa1e25 | |||
| f7912cba3d | |||
| 6317a5174a | |||
| aa72d9a4ea | |||
| ce17db8085 | |||
| 8c87a9ad46 | |||
| ec69124eb4 | |||
| d0da99fb70 | |||
| b2f195c429 | |||
| 047797ef90 | |||
| eb8ef4224d | |||
| 56a735261c | |||
| e1cf90e099 | |||
| 6bc1e30ef9 | |||
| 7e081ba7ca | |||
| 1e013fa388 | |||
| bc7c4d206b | |||
| f67e9e9f22 | |||
| 36fe78769f | |||
| 83d933718c | |||
| 5175b884f7 | |||
| 5536b30a4c | |||
| 7f58fb9718 | |||
| 30bc3e0f66 | |||
| f34410715f | |||
| 68d4c33202 | |||
| f961d7f6ef | |||
| d059110498 | |||
| 571e8dd65e | |||
| 4b91c927f6 | |||
| 0e237f0035 | |||
| 8f7bace7c3 | |||
| e4d6144232 | |||
| 8d32dc603d | |||
| c4ab9f3e71 | |||
| 2689d5c027 | |||
| acba33a0f1 | |||
| a114bf20a3 | |||
| 3097ce3a32 | |||
| d6da9322c8 | |||
| 71ce44047f | |||
| 188b7f9b8c | |||
| b9b4746950 | |||
| 7b8a2ab76f | |||
| c9acbf1141 | |||
| 5b794cae8d | |||
| 0e4254492f | |||
| 1311913f55 | |||
| 29f395c97c | |||
| fa3bba2a53 | |||
| 986537f1c3 | |||
| 210207525e | |||
| 71eda0bb76 | |||
| 471fe65630 | |||
| 3a0fba5cf4 | |||
| 299ebb62b2 | |||
| f728ab8e35 | |||
| 63e26fff78 | |||
| fe3462c774 | |||
| 3b34fd5273 | |||
| 55d6d3fdb8 | |||
| 7272bfae77 | |||
| d9ac9e3dc5 | |||
| d41faaf9df | |||
| b34f33438a | |||
| 26c0406555 | |||
| 4c41278b77 | |||
| bb3605db85 | |||
| fe742aef5a | |||
| 4b07d36891 | |||
| 87aaadef73 | |||
| 682e0b6d2f | |||
| d6195a748b | |||
| 205d84aaa9 | |||
| 5124f5bf51 | |||
| 83f3c3bd91 | |||
| d9737ca1c6 | |||
| 9d4ca19d50 | |||
| 2ef0dc53b8 | |||
| 1d4680fad2 | |||
| 2c1bd848a6 | |||
| 5c9121203c | |||
| 490b1698a5 | |||
| 5a5e29de88 | |||
| 3d3ab3689f | |||
| 686623c5e7 | |||
| aadb656562 | |||
| 87e067de41 | |||
| 26507f8973 | |||
| 9c1d5b456d | |||
| e31045f95c | |||
| aaec845f8e | |||
| 7bdfd29a35 | |||
| e78587a64c | |||
| 7eb4255628 | |||
| 6a0f547561 | |||
| 30ed81b7ca | |||
| 7a4a5de729 | |||
| c16fb5dae8 | |||
| e37073efd7 | |||
| 183dad7a85 | |||
| 3408e47159 | |||
| 0377b8310b | |||
| e4755f7fac | |||
| 92edf35826 | |||
| eb5819b2d9 | |||
| 5989f4684d | |||
| 5125d72f02 | |||
| a018e555fd | |||
| 6211b92273 | |||
| 05fcd1b430 | |||
| 7c02d6a137 | |||
| 11c3b98491 | |||
| dbe7f07001 | |||
| c69bf4ee06 | |||
| d27ea94034 | |||
| 99ed526101 | |||
| 207da28186 | |||
| 5b1aca2ae3 | |||
| d8e557b5e5 | |||
| 61a44a0b22 | |||
| a6481525b8 | |||
| 8cac35ba43 | |||
| 9dbf7a2dc1 | |||
| 607029e515 | |||
| cb072ce93b | |||
| 95aca283b4 | |||
| 2b05b8ce69 | |||
| 3c776dcefb | |||
| 2cbd4d2999 | |||
| 3092375e27 | |||
| 3cd91dc955 | |||
| 8a7368e069 | |||
| 93e561ec4d | |||
| e1b004839a | |||
| ee378f3d49 | |||
| e82ee40de3 | |||
| facbe2a114 | |||
| 7168920491 | |||
| 21378a2323 | |||
| 976711d9db | |||
| 44fa4d556c | |||
| 3ac98edcb1 | |||
| 966c742ed2 | |||
| 0d7d05f4b6 | |||
| 96bb8aa68b | |||
| 3badb0213b | |||
| fdcb850f14 | |||
| 54a66e5fee | |||
| 280d62b8a2 | |||
| 1666e66443 | |||
| 1575c1701a | |||
| 6ae996a873 | |||
| b590adfdc1 | |||
| b4fe16c75b | |||
| bc5dd4f669 | |||
| dbb036cf61 | |||
| 70e7ed841d | |||
| d06ba4ed3f | |||
| 6b40996ae8 | |||
| d2020acac7 | |||
| 1eb3c2ed48 | |||
| c64ee87267 | |||
| b1308b84a3 | |||
| 7b5ecf79bd | |||
| 9883a18859 | |||
| b3f2fddd17 | |||
| aa29841ede | |||
| 6bf27affb6 | |||
| 1dd23386ec | |||
| 7cbfc10943 | |||
| ce4ddd2d1a | |||
| e51929ebca | |||
| dc1b4a6f13 | |||
| 63d2705edb | |||
| d085a44082 | |||
| f49e5aff11 | |||
| 6c11ecf8d3 | |||
| 93e5f3c5fb | |||
| 70363bccfa | |||
| 3cdc57669f | |||
| 68bb122eb4 | |||
| d9fc8cd9da | |||
| f069f3ea74 | |||
| c5bc0e7fcc | |||
| 4a3a518722 | |||
| fbf722c6e6 | |||
| e92d7085bf | |||
| bd6028d6b0 | |||
| 802329dee9 | |||
| 41cc883c29 | |||
| 57504a4bcf | |||
| ed4792c990 | |||
| 87b836ba77 | |||
| 56c76c2e0e | |||
| c09632a66c | |||
| a3bf8d4a2b | |||
| 16eda8c43a | |||
| cd77382ac1 | |||
| 71b9cde010 | |||
| 5285589f37 | |||
| f41647ee6b | |||
| 4d022cbc75 | |||
| 70de35a881 | |||
| 34b2cf3b33 | |||
| 9e90c9f73f | |||
| e9528f6dc6 | |||
| 51baa9c333 | |||
| 35e076b3a8 | |||
| a26f59ccbc | |||
| aa3b3d76e0 | |||
| f7030df3be | |||
| 905e91e9ac | |||
| f8f9c0ba62 | |||
| dda811021a | |||
| 93195146ea | |||
| ed37599544 | |||
| 99ef59cf7f | |||
| d544d141ec | |||
| 3e397a9484 | |||
| 268c325078 | |||
| 3cc9af88ff | |||
| 7cd0bd7212 | |||
| 56d4aefa33 | |||
| dd143ef541 | |||
| daefed052c | |||
| 5fbab20e02 | |||
| e8224f3dca | |||
| 9665313c39 | |||
| 0c54fc7273 | |||
| c1b57855ec | |||
| 83b824c8b4 | |||
| 7678fcd5b6 | |||
| 8661c0241d | |||
| ce8d6b75fc | |||
| 61de3ef74b | |||
| ec1f9c8c91 | |||
| 65e09094c4 | |||
| c70cf0fe06 | |||
| a5d11a54dc | |||
| 3d4c87758e | |||
| a9bd832fc5 | |||
| 417bcefbae | |||
| baada0e737 | |||
| 82eb61dd4c | |||
| 0d4d06fe2f | |||
| 4aed0ca6a2 | |||
| 1621b25288 | |||
| a564797151 | |||
| 1da6a09274 | |||
| 1e44ffc3ff | |||
| a454748544 | |||
| 1bff42c4b7 | |||
| cb391d85dc | |||
| fee5b8d37f | |||
| b2ce859bd2 | |||
| 566f10a929 | |||
| c3b5189137 | |||
| a25866ac8d | |||
| 098900d7c2 | |||
| 98d01d3ce2 | |||
| d55244df31 | |||
| 04149cce27 | |||
| 24834f4894 | |||
| ec7da6fcf3 | |||
| 819d548e8a | |||
| 477d2a8aa2 | |||
| e484e02857 | |||
| 24f6b9a713 | |||
| 9cdde47289 | |||
| b1eb4ca152 | |||
| 87b4ac56c2 | |||
| cb84e45ac7 | |||
| 4716377fbc | |||
| 4e9cf8c1dd | |||
| 2976dc27e9 | |||
| 102bf967f0 | |||
| 1f4b09b525 | |||
| 86c3369eb8 | |||
| 2755c34a8f | |||
| db10422184 | |||
| e1a2c699dd | |||
| 0115ccd5c0 | |||
| 40b4284fe3 | |||
| 4ebc0b9640 | |||
| dc96fd54c6 | |||
| 1f5d13ab9f | |||
| 90cb44eb02 | |||
| e11880deea | |||
| 9351f91be9 | |||
| 5a1e1c8353 | |||
| 69ecaa7c79 | |||
| 7f00899ff7 | |||
| 995e3d1f41 | |||
| b4ac449a83 | |||
| 8e5314a468 | |||
| 87918e40c4 | |||
| f6b32efb7f | |||
| b99733d092 | |||
| 05a015d6a5 | |||
| ad971af8c7 | |||
| f2ebb6f541 | |||
| 1d01211264 | |||
| f94ab12f79 | |||
| a865bc1ca6 | |||
| 21802c4b6d | |||
| 652907b354 | |||
| 24f1c01e0f | |||
| fad6e2538e | |||
| 7f6d47c1a2 | |||
| 3147586ebd | |||
| ed636d99ca | |||
| 090c856d76 | |||
| ad434d4cfe | |||
| 66d433b94f | |||
| 027b204ff1 | |||
| 55dcce91df | |||
| 8017c8db7f | |||
| dc3529dbf6 | |||
| 7699258ef0 | |||
| e9ba99f296 | |||
| 7c80368710 | |||
| 95d63f38c0 | |||
| bb8dab821e | |||
| fc0f87768a | |||
| 0a57386721 | |||
| 3749e28774 | |||
| 86fc2321ff | |||
| 2549c0dfef | |||
| b10e519895 | |||
| 9bde5ba127 | |||
| 72c8f1ad04 | |||
| da224daaa9 | |||
| 3a100b9278 | |||
| 242a637aea | |||
| c2a9671510 | |||
| d5ae4f7f42 | |||
| b6c502a150 | |||
| 9ca710e525 | |||
| eb07c8cb5b | |||
| ba10801961 | |||
| 620fc2d09e | |||
| 29283eaa7e | |||
| 2fa66ef713 | |||
| 13affc432d | |||
| d8f094a92a | |||
| 97ae6d777f | |||
| 6baeee70d1 | |||
| d2517a4939 | |||
| 6342adc438 | |||
| 0adba91547 | |||
| 4285e423a6 | |||
| 63375f0cdb | |||
| 70ad3f9e98 | |||
| d6fc629f4d | |||
| af51d80fa1 | |||
| f5722a5052 | |||
| 651cf0fec1 | |||
| 4dc52e1c53 | |||
| 4708f13a9c | |||
| a6d042df0a | |||
| 40a36ccfeb | |||
| ef608c37a7 | |||
| 2386803f2a | |||
| 95862f7b4d | |||
| 230b131b54 | |||
| 0812d8dd41 | |||
| bf7e3c51ae | |||
| a35a8a8392 | |||
| 4ef0bb1fcf | |||
| fadc59c0e6 | |||
| 86cbd2eee9 | |||
| 092475f738 | |||
| dcc56d62da | |||
| f15e70d906 | |||
| b6be6f8d1e | |||
| 03a70eacaf | |||
| 45b1ff7a25 | |||
| 15ba07ef25 | |||
| d2b58ca203 | |||
| 82e7e19a6e | |||
| 421c462948 | |||
| 84884cd9ac | |||
| a43aa183dc | |||
| 463bbb1835 | |||
| 5e125e74d1 | |||
| 06f21ce7a5 | |||
| 57a810db9c | |||
| 8b664706aa | |||
| 37bfee92bf | |||
| e73ff24e31 | |||
| bd7599d34a | |||
| 01b6113659 | |||
| 1b84eff03a | |||
| 55acf86bf8 | |||
| f021b97993 | |||
| 1cab43c2d2 | |||
| 8bd651b318 | |||
| 58e234a754 | |||
| e86c414d6a | |||
| 550b2801ad | |||
| cefb9e5a28 | |||
| 98d7367b61 | |||
| 594a8b9030 | |||
| 44f990515b | |||
| 252937806c | |||
| 51826d51fa | |||
| 14e53ed11f | |||
| ddb94c2605 | |||
| 90969fb39a | |||
| 101f1481f9 | |||
| 2edc87b161 | |||
| 4203926f10 | |||
| cdb57015a7 | |||
| aa557e6422 | |||
| 0e00d40e4f | |||
| c920e01242 | |||
| 274d8e8818 | |||
| 2039c6305b | |||
| 6efb195a6e | |||
| 24b7fb455a | |||
| 58f5a59769 | |||
| db9dfcfa6a | |||
| 9ef98d527e | |||
| 93491aefc7 | |||
| 7acd539cd7 | |||
| e75a6301bd | |||
| a79cc68b3a | |||
| 7e3f7a4ee7 | |||
| 9ec8257914 | |||
| 38327cf454 | |||
| dfa82e2a3d | |||
| e59ca942f5 | |||
| a57a3044aa | |||
| 4e5a0f6ae2 | |||
| b63bd14999 | |||
| 2041c0e360 | |||
| 085cbc4f9f | |||
| 2b93162fb0 | |||
| 2e45bd29fe | |||
| 51d7c6a2b2 | |||
| f3aca1ee30 | |||
| 8dd41d6bcc | |||
| 0a298ea418 | |||
| d330558bab | |||
| 656fd72976 | |||
| 79455cf421 | |||
| 30d6a015e0 | |||
| 8af5a5c4e5 | |||
| 3a5f0afcd2 | |||
| c7e63aa4d8 | |||
| 4a9ce1784c | |||
| 7e4e709b43 | |||
| 63d8eabed0 | |||
| e830b01383 | |||
| ff6473980d | |||
| a164aea35d | |||
| a76f547e11 | |||
| b7b7676d67 | |||
| e6e3c55ef2 | |||
| f98a4920f9 | |||
| d4bfc23ef0 | |||
| 9a2160fa55 | |||
| 2de4118243 |
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m deepseek-ai/DeepSeek-V2-Lite-Chat -b "auto" -l 1000 -f 5 -t 2
|
||||
model_name: "deepseek-ai/DeepSeek-V2-Lite-Chat"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For hf script, without -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform -b auto -l 1000 -f 5
|
||||
model_name: "nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For hf script, without -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m meta-llama/Meta-Llama-3-70B-Instruct -b 32 -l 250 -f 5
|
||||
model_name: "meta-llama/Meta-Llama-3-70B-Instruct"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8A8-FP8-Channelwise-compressed-tensors -b auto -l 1000 -f 5 -t 1
|
||||
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-FP8-Channelwise-compressed-tensors"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform -b auto -l 1000 -f 5 -t 1
|
||||
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test -b 32 -l 1000 -f 5 -t 1
|
||||
model_name: "nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Meta-Llama-3-8B-Instruct-FP8 -b 32 -l 250 -f 5 -t 1
|
||||
model_name: "neuralmagic/Meta-Llama-3-8B-Instruct-FP8"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Asym-Per-Token-Test -b "auto" -l 250 -f 5 -t 1
|
||||
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Asym-Per-Token-Test"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Per-Token-Test -b "auto" -l 250 -f 5 -t 1
|
||||
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Per-Token-Test"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test -b auto -l 1000 -f 5 -t 1
|
||||
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test"
|
||||
tasks:
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m meta-llama/Meta-Llama-3-8B-Instruct -b 32 -l 250 -f 5 -t 1
|
||||
# For hf script, without -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m meta-llama/Meta-Llama-3-8B-Instruct -b 32 -l 250 -f 5
|
||||
model_name: "meta-llama/Meta-Llama-3-8B-Instruct"
|
||||
tasks:
|
||||
- name: "gsm8k"
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m HandH1998/QQQ-Llama-3-8b-g128 -b 32 -l 1000 -f 5 -t 1
|
||||
model_name: "HandH1998/QQQ-Llama-3-8b-g128"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8 -b "auto" -l 1000 -f 5 -t 1
|
||||
model_name: "neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m mgoin/Minitron-4B-Base-FP8 -b auto -l 1000 -f 5 -t 1
|
||||
model_name: "mgoin/Minitron-4B-Base-FP8"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Mixtral-8x22B-Instruct-v0.1-FP8-dynamic -b "auto" -l 250 -f 5 -t 8
|
||||
model_name: "neuralmagic/Mixtral-8x22B-Instruct-v0.1-FP8-dynamic"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8 -b "auto" -l 250 -f 5 -t 4
|
||||
model_name: "neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8"
|
||||
tasks:
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m neuralmagic/Mixtral-8x7B-Instruct-v0.1 -b 32 -l 250 -f 5 -t 4
|
||||
# For hf script, without -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m neuralmagic/Mixtral-8x7B-Instruct-v0.1 -b 32 -l 250 -f 5
|
||||
model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
tasks:
|
||||
- name: "gsm8k"
|
||||
|
||||
@ -0,0 +1,12 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16 -b auto -l 1319 -f 5 -t 1
|
||||
model_name: "nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16"
|
||||
tasks:
|
||||
- name: "gsm8k"
|
||||
metrics:
|
||||
- name: "exact_match,strict-match"
|
||||
value: 0.30
|
||||
- name: "exact_match,flexible-extract"
|
||||
value: 0.465
|
||||
limit: 1319
|
||||
num_fewshot: 5
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Qwen2-1.5B-Instruct-FP8W8 -b auto -l 1000 -f 5 -t 1
|
||||
model_name: "nm-testing/Qwen2-1.5B-Instruct-FP8W8"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Qwen2-1.5B-Instruct-quantized.w8a8 -b "auto" -l 1000 -f 5 -t 1
|
||||
model_name: "neuralmagic/Qwen2-1.5B-Instruct-quantized.w8a8"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Qwen2-1.5B-Instruct-W8A16-Channelwise -b "auto" -l 1000 -f 5 -t 1
|
||||
model_name: "nm-testing/Qwen2-1.5B-Instruct-W8A16-Channelwise"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m Qwen/Qwen2-57B-A14B-Instruct -b "auto" -l 250 -f 5 -t 4
|
||||
model_name: "Qwen/Qwen2-57B-A14B-Instruct"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM -b "auto" -t 2
|
||||
model_name: "nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM"
|
||||
tasks:
|
||||
|
||||
@ -4,7 +4,7 @@ Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml
|
||||
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml
|
||||
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
|
||||
Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
|
||||
Minitron-4B-Base-FP8.yaml
|
||||
Qwen1.5-MoE-W4A16-compressed-tensors.yaml
|
||||
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
|
||||
Qwen2-1.5B-Instruct-FP8W8.yaml
|
||||
Meta-Llama-3-8B-QQQ.yaml
|
||||
|
||||
@ -16,7 +16,7 @@ import numpy
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
RTOL = 0.05
|
||||
RTOL = 0.08
|
||||
TEST_DATA_FILE = os.environ.get(
|
||||
"LM_EVAL_TEST_DATA_FILE",
|
||||
".buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct.yaml")
|
||||
|
||||
@ -10,15 +10,24 @@ set -x
|
||||
set -o pipefail
|
||||
|
||||
check_gpus() {
|
||||
# check the number of GPUs and GPU type.
|
||||
declare -g gpu_count=$(nvidia-smi --list-gpus | wc -l)
|
||||
if command -v nvidia-smi; then
|
||||
# check the number of GPUs and GPU type.
|
||||
declare -g gpu_count=$(nvidia-smi --list-gpus | wc -l)
|
||||
elif command -v amd-smi; then
|
||||
declare -g gpu_count=$(amd-smi list | grep 'GPU' | wc -l)
|
||||
fi
|
||||
|
||||
if [[ $gpu_count -gt 0 ]]; then
|
||||
echo "GPU found."
|
||||
else
|
||||
echo "Need at least 1 GPU to run benchmarking."
|
||||
exit 1
|
||||
fi
|
||||
declare -g gpu_type=$(nvidia-smi --query-gpu=name --format=csv,noheader | awk '{print $2}')
|
||||
if command -v nvidia-smi; then
|
||||
declare -g gpu_type=$(nvidia-smi --query-gpu=name --format=csv,noheader | awk '{print $2}')
|
||||
elif command -v amd-smi; then
|
||||
declare -g gpu_type=$(amd-smi static -g 0 -a | grep 'MARKET_NAME' | awk '{print $2}')
|
||||
fi
|
||||
echo "GPU type is $gpu_type"
|
||||
}
|
||||
|
||||
@ -90,9 +99,15 @@ kill_gpu_processes() {
|
||||
|
||||
|
||||
# wait until GPU memory usage smaller than 1GB
|
||||
while [ "$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits | head -n 1)" -ge 1000 ]; do
|
||||
sleep 1
|
||||
done
|
||||
if command -v nvidia-smi; then
|
||||
while [ "$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits | head -n 1)" -ge 1000 ]; do
|
||||
sleep 1
|
||||
done
|
||||
elif command -v amd-smi; then
|
||||
while [ "$(amd-smi metric -g 0 | grep 'USED_VRAM' | awk '{print $2}')" -ge 1000 ]; do
|
||||
sleep 1
|
||||
done
|
||||
fi
|
||||
|
||||
# remove vllm config file
|
||||
rm -rf ~/.config/vllm
|
||||
|
||||
@ -3,10 +3,10 @@ steps:
|
||||
agents:
|
||||
queue: cpu_queue_postmerge
|
||||
commands:
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.4.0 --tag vllm-ci:build-image --target build --progress plain ."
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.4.0 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
|
||||
- "mkdir artifacts"
|
||||
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||
- "bash .buildkite/upload-wheels.sh"
|
||||
- "bash .buildkite/scripts/upload-wheels.sh"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
@ -14,10 +14,10 @@ steps:
|
||||
agents:
|
||||
queue: cpu_queue_postmerge
|
||||
commands:
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.1.0 --tag vllm-ci:build-image --target build --progress plain ."
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.1.0 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
|
||||
- "mkdir artifacts"
|
||||
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||
- "bash .buildkite/upload-wheels.sh"
|
||||
- "bash .buildkite/scripts/upload-wheels.sh"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
@ -31,10 +31,10 @@ steps:
|
||||
agents:
|
||||
queue: cpu_queue_postmerge
|
||||
commands:
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=11.8.0 --tag vllm-ci:build-image --target build --progress plain ."
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=11.8.0 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
|
||||
- "mkdir artifacts"
|
||||
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||
- "bash .buildkite/upload-wheels.sh"
|
||||
- "bash .buildkite/scripts/upload-wheels.sh"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
@ -48,7 +48,7 @@ steps:
|
||||
queue: cpu_queue_postmerge
|
||||
commands:
|
||||
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.4.0 --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain ."
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.4.0 --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ."
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT"
|
||||
|
||||
- label: "Build and publish TPU release image"
|
||||
@ -57,7 +57,7 @@ steps:
|
||||
agents:
|
||||
queue: tpu_queue_postmerge
|
||||
commands:
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --tag vllm/vllm-tpu:nightly --tag vllm/vllm-tpu:$BUILDKITE_COMMIT --progress plain -f Dockerfile.tpu ."
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --tag vllm/vllm-tpu:nightly --tag vllm/vllm-tpu:$BUILDKITE_COMMIT --progress plain -f docker/Dockerfile.tpu ."
|
||||
- "docker push vllm/vllm-tpu:nightly"
|
||||
- "docker push vllm/vllm-tpu:$BUILDKITE_COMMIT"
|
||||
plugins:
|
||||
@ -82,7 +82,22 @@ steps:
|
||||
queue: cpu_queue_postmerge
|
||||
commands:
|
||||
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest --progress plain --target vllm-openai -f Dockerfile.cpu ."
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest --progress plain --target vllm-openai -f docker/Dockerfile.cpu ."
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version)"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
- block: "Build Neuron release image"
|
||||
key: block-neuron-release-image-build
|
||||
depends_on: ~
|
||||
|
||||
- label: "Build and publish Neuron release image"
|
||||
depends_on: block-neuron-release-image-build
|
||||
agents:
|
||||
queue: neuron-postmerge
|
||||
commands:
|
||||
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:latest --progress plain -f docker/Dockerfile.neuron ."
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:$(buildkite-agent meta-data get release-version)"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
@ -98,6 +98,13 @@ if [[ $commands == *" kernels "* ]]; then
|
||||
--ignore=kernels/test_machete_mm.py \
|
||||
--ignore=kernels/test_mha_attn.py \
|
||||
--ignore=kernels/test_block_fp8.py \
|
||||
--ignore=kernels/test_cutlass_moe.py \
|
||||
--ignore=kernels/test_mamba_ssm_ssd.py \
|
||||
--ignore=kernels/test_attention.py \
|
||||
--ignore=kernels/test_block_int8.py \
|
||||
--ignore=kernels/test_fused_quant_layernorm.py \
|
||||
--ignore=kernels/test_int8_kernel.py \
|
||||
--ignore=kernels/test_triton_moe_ptpc_fp8.py \
|
||||
--ignore=kernels/test_permute_cols.py"
|
||||
fi
|
||||
|
||||
@ -105,19 +112,33 @@ fi
|
||||
if [[ $commands == *" entrypoints/openai "* ]]; then
|
||||
commands=${commands//" entrypoints/openai "/" entrypoints/openai \
|
||||
--ignore=entrypoints/openai/test_audio.py \
|
||||
--ignore=entrypoints/openai/test_chat.py \
|
||||
--ignore=entrypoints/openai/test_shutdown.py \
|
||||
--ignore=entrypoints/openai/test_completion.py \
|
||||
--ignore=entrypoints/openai/test_sleep.py \
|
||||
--ignore=entrypoints/openai/test_models.py \
|
||||
--ignore=entrypoints/openai/test_lora_adapters.py \
|
||||
--ignore=entrypoints/openai/test_return_tokens_as_ids.py \
|
||||
--ignore=entrypoints/openai/test_root_path.py \
|
||||
--ignore=entrypoints/openai/test_tokenization.py \
|
||||
--ignore=entrypoints/openai/test_prompt_validation.py "}
|
||||
fi
|
||||
|
||||
#ignore certain Entrypoints/llm tests
|
||||
if [[ $commands == *" && pytest -v -s entrypoints/llm/test_guided_generate.py"* ]]; then
|
||||
commands=${commands//" && pytest -v -s entrypoints/llm/test_guided_generate.py"/" "}
|
||||
if [[ $commands == *" entrypoints/llm "* ]]; then
|
||||
commands=${commands//" entrypoints/llm "/" entrypoints/llm \
|
||||
--ignore=entrypoints/llm/test_chat.py \
|
||||
--ignore=entrypoints/llm/test_accuracy.py \
|
||||
--ignore=entrypoints/llm/test_init.py \
|
||||
--ignore=entrypoints/llm/test_generate_multiple_loras.py \
|
||||
--ignore=entrypoints/llm/test_prompt_validation.py "}
|
||||
fi
|
||||
|
||||
#Obsolete currently
|
||||
##ignore certain Entrypoints/llm tests
|
||||
#if [[ $commands == *" && pytest -v -s entrypoints/llm/test_guided_generate.py"* ]]; then
|
||||
# commands=${commands//" && pytest -v -s entrypoints/llm/test_guided_generate.py"/" "}
|
||||
#fi
|
||||
|
||||
# --ignore=entrypoints/openai/test_encoder_decoder.py \
|
||||
# --ignore=entrypoints/openai/test_embedding.py \
|
||||
# --ignore=entrypoints/openai/test_oot_registration.py
|
||||
45
.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh
Executable file
45
.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh
Executable file
@ -0,0 +1,45 @@
|
||||
#!/bin/bash
|
||||
|
||||
# This script build the CPU docker image and run the offline inference inside the container.
|
||||
# It serves a sanity check for compilation and basic model usage.
|
||||
set -ex
|
||||
|
||||
# Setup cleanup
|
||||
remove_docker_container() {
|
||||
if [[ -n "$container_id" ]]; then
|
||||
podman rm -f "$container_id" || true
|
||||
fi
|
||||
podman system prune -f
|
||||
}
|
||||
trap remove_docker_container EXIT
|
||||
remove_docker_container
|
||||
|
||||
# Try building the docker image
|
||||
podman build -t cpu-test-ubi9-ppc -f docker/Dockerfile.ppc64le .
|
||||
|
||||
# Run the image
|
||||
container_id=$(podman run -itd --entrypoint /bin/bash -v /tmp/:/root/.cache/huggingface --privileged=true --network host -e HF_TOKEN cpu-test-ubi9-ppc)
|
||||
|
||||
function cpu_tests() {
|
||||
|
||||
# offline inference
|
||||
podman exec -it "$container_id" bash -c "
|
||||
set -e
|
||||
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m"
|
||||
|
||||
# Run basic model test
|
||||
podman exec -it "$container_id" bash -c "
|
||||
set -e
|
||||
pip install pytest pytest-asyncio einops peft Pillow soundfile transformers_stream_generator matplotlib
|
||||
pip install sentence-transformers datamodel_code_generator
|
||||
pytest -v -s tests/models/embedding/language/test_cls_models.py::test_classification_models[float-jason9693/Qwen2.5-1.5B-apeach]
|
||||
pytest -v -s tests/models/embedding/language/test_embedding.py::test_models[half-BAAI/bge-base-en-v1.5]
|
||||
pytest -v -s tests/models/encoder_decoder/language -m cpu_model"
|
||||
}
|
||||
|
||||
# All of CPU tests are expected to be finished less than 40 mins.
|
||||
|
||||
export container_id
|
||||
export -f cpu_tests
|
||||
timeout 40m bash -c cpu_tests
|
||||
|
||||
@ -10,5 +10,4 @@ trap remove_docker_container EXIT
|
||||
remove_docker_container
|
||||
|
||||
# Try building the docker image
|
||||
docker build -t cpu-test -f Dockerfile.ppc64le .
|
||||
|
||||
docker build -t cpu-test -f docker/Dockerfile.s390x .
|
||||
@ -18,8 +18,8 @@ trap remove_docker_container EXIT
|
||||
remove_docker_container
|
||||
|
||||
# Try building the docker image
|
||||
numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --tag cpu-test-"$BUILDKITE_BUILD_NUMBER" --target vllm-test -f Dockerfile.cpu .
|
||||
numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" --tag cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2 --target vllm-test -f Dockerfile.cpu .
|
||||
numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --tag cpu-test-"$BUILDKITE_BUILD_NUMBER" --target vllm-test -f docker/Dockerfile.cpu .
|
||||
numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" --tag cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2 --target vllm-test -f docker/Dockerfile.cpu .
|
||||
|
||||
# Run the image, setting --shm-size=4g for tensor parallel.
|
||||
docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus="$CORE_RANGE" \
|
||||
@ -9,6 +9,7 @@ python3 use_existing_torch.py
|
||||
|
||||
# Try building the docker image
|
||||
DOCKER_BUILDKIT=1 docker build . \
|
||||
--file docker/Dockerfile \
|
||||
--target vllm-openai \
|
||||
--platform "linux/arm64" \
|
||||
-t gh200-test \
|
||||
@ -5,7 +5,7 @@
|
||||
set -ex
|
||||
|
||||
# Try building the docker image
|
||||
docker build -t hpu-test-env -f Dockerfile.hpu .
|
||||
docker build -t hpu-test-env -f docker/Dockerfile.hpu .
|
||||
|
||||
# Setup cleanup
|
||||
# certain versions of HPU software stack have a bug that can
|
||||
@ -35,7 +35,7 @@ else
|
||||
date "+%s" > /tmp/neuron-docker-build-timestamp
|
||||
fi
|
||||
|
||||
docker build -t "${image_name}" -f Dockerfile.neuron .
|
||||
docker build -t "${image_name}" -f docker/Dockerfile.neuron .
|
||||
|
||||
# Setup cleanup
|
||||
remove_docker_container() {
|
||||
@ -1,9 +1,9 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
set -xue
|
||||
|
||||
# Build the docker image.
|
||||
docker build -f Dockerfile.tpu -t vllm-tpu .
|
||||
docker build -f docker/Dockerfile.tpu -t vllm-tpu .
|
||||
|
||||
# Set up cleanup.
|
||||
remove_docker_container() { docker rm -f tpu-test || true; }
|
||||
@ -17,10 +17,15 @@ source /etc/environment
|
||||
docker run --privileged --net host --shm-size=16G -it \
|
||||
-e "HF_TOKEN=$HF_TOKEN" --name tpu-test \
|
||||
vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git \
|
||||
&& python3 -m pip install pytest \
|
||||
&& python3 -m pip install pytest pytest-asyncio tpu-info \
|
||||
&& python3 -m pip install lm_eval[api]==0.4.4 \
|
||||
&& export VLLM_XLA_CACHE_PATH= \
|
||||
&& export VLLM_USE_V1=1 \
|
||||
&& export VLLM_XLA_CHECK_RECOMPILATION=1 \
|
||||
&& echo HARDWARE \
|
||||
&& tpu-info \
|
||||
&& echo TEST_0 \
|
||||
&& pytest -v -s /workspace/vllm/tests/v1/tpu/test_perf.py \
|
||||
&& echo TEST_1 \
|
||||
&& pytest -v -s /workspace/vllm/tests/tpu/test_compilation.py \
|
||||
&& echo TEST_2 \
|
||||
@ -34,7 +39,15 @@ docker run --privileged --net host --shm-size=16G -it \
|
||||
&& echo TEST_6 \
|
||||
&& pytest -s -v /workspace/vllm/tests/v1/tpu/worker/test_tpu_model_runner.py \
|
||||
&& echo TEST_7 \
|
||||
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py" \
|
||||
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py \
|
||||
&& echo TEST_8 \
|
||||
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py \
|
||||
&& echo TEST_9 \
|
||||
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_multimodal.py \
|
||||
&& echo TEST_10 \
|
||||
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py \
|
||||
&& echo TEST_11 \
|
||||
&& pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py" \
|
||||
|
||||
|
||||
# TODO: This test fails because it uses RANDOM_SEED sampling
|
||||
@ -8,7 +8,7 @@ image_name="xpu/vllm-ci:${BUILDKITE_COMMIT}"
|
||||
container_name="xpu_${BUILDKITE_COMMIT}_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)"
|
||||
|
||||
# Try building the docker image
|
||||
docker build -t ${image_name} -f Dockerfile.xpu .
|
||||
docker build -t ${image_name} -f docker/Dockerfile.xpu .
|
||||
|
||||
# Setup cleanup
|
||||
remove_docker_container() {
|
||||
@ -5,8 +5,8 @@
|
||||
set -ex
|
||||
set -o pipefail
|
||||
|
||||
# cd into parent directory of this file
|
||||
cd "$(dirname "${BASH_SOURCE[0]}")/.."
|
||||
# cd 2 levels into the working directory
|
||||
cd "$(dirname "${BASH_SOURCE[0]}")/../.."
|
||||
|
||||
(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
set -euox pipefail
|
||||
|
||||
if [[ $# -lt 4 ]]; then
|
||||
echo "Usage: .buildkite/run-multi-node-test.sh WORKING_DIR NUM_NODES NUM_GPUS DOCKER_IMAGE COMMAND1 COMMAND2 ... COMMANDN"
|
||||
echo "Usage: .buildkite/scripts/run-multi-node-test.sh WORKING_DIR NUM_NODES NUM_GPUS DOCKER_IMAGE COMMAND1 COMMAND2 ... COMMANDN"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
# Documentation
|
||||
# label(str): the name of the test. emoji allowed.
|
||||
# fast_check(bool): whether to run this on each commit on fastcheck pipeline.
|
||||
# torch_nightly(bool): whether to run this on vllm against torch nightly pipeline.
|
||||
# fast_check_only(bool): run this test on fastcheck pipeline only
|
||||
# optional(bool): never run this test by default (i.e. need to unblock manually) unless it's scheduled nightly run.
|
||||
# command(str): the single command to run for tests. incompatible with commands.
|
||||
@ -70,6 +71,7 @@ steps:
|
||||
- label: Basic Correctness Test # 30min
|
||||
#mirror_hardwares: [amd]
|
||||
fast_check: true
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/basic_correctness/test_basic_correctness
|
||||
@ -104,7 +106,8 @@ steps:
|
||||
- label: Entrypoints Test # 40min
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
fast_check: true
|
||||
mirror_hardwares: [amd]
|
||||
torch_nightly: true
|
||||
#mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/entrypoints/llm
|
||||
@ -118,7 +121,7 @@ steps:
|
||||
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
|
||||
- VLLM_USE_V1=0 pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/correctness/
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_openai_schema.py
|
||||
- pytest -v -s entrypoints/test_chat_utils.py
|
||||
- VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
|
||||
|
||||
@ -155,6 +158,7 @@ steps:
|
||||
- popd
|
||||
|
||||
- label: Metrics, Tracing Test # 10min
|
||||
mirror_hardwares: [amd]
|
||||
num_gpus: 2
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -162,18 +166,13 @@ steps:
|
||||
- tests/tracing
|
||||
commands:
|
||||
- pytest -v -s metrics
|
||||
- "pip install \
|
||||
'opentelemetry-sdk>=1.26.0,<1.27.0' \
|
||||
'opentelemetry-api>=1.26.0,<1.27.0' \
|
||||
'opentelemetry-exporter-otlp>=1.26.0,<1.27.0' \
|
||||
'opentelemetry-semantic-conventions-ai>=0.4.1,<0.5.0'"
|
||||
- pytest -v -s tracing
|
||||
|
||||
##### fast check tests #####
|
||||
##### 1 GPU test #####
|
||||
|
||||
- label: Regression Test # 5min
|
||||
mirror_hardwares: [amd]
|
||||
#mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/test_regression
|
||||
@ -204,12 +203,13 @@ steps:
|
||||
commands:
|
||||
# split the test to avoid interference
|
||||
- pytest -v -s v1/core
|
||||
- pytest -v -s v1/entrypoints
|
||||
- pytest -v -s v1/engine
|
||||
- pytest -v -s v1/entrypoints
|
||||
- pytest -v -s v1/sample
|
||||
- pytest -v -s v1/worker
|
||||
- pytest -v -s v1/structured_output
|
||||
- pytest -v -s v1/spec_decode
|
||||
- pytest -v -s v1/test_serial_utils.py
|
||||
- pytest -v -s v1/test_stats.py
|
||||
- pytest -v -s v1/test_utils.py
|
||||
- pytest -v -s v1/test_oracle.py
|
||||
@ -285,13 +285,22 @@ steps:
|
||||
- pytest -v -s spec_decode/e2e/test_eagle_correctness.py
|
||||
|
||||
- label: LoRA Test %N # 15min each
|
||||
mirror_hardwares: [amd]
|
||||
#mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- vllm/lora
|
||||
- tests/lora
|
||||
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_minicpmv_tp.py --ignore=lora/test_transfomers_model.py
|
||||
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py
|
||||
parallelism: 4
|
||||
|
||||
- label: PyTorch Compilation Unit Tests
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/compile
|
||||
commands:
|
||||
- pytest -v -s compile/test_pass_manager.py
|
||||
- pytest -v -s compile/test_fusion.py
|
||||
- pytest -v -s compile/test_sequence_parallelism.py
|
||||
|
||||
- label: PyTorch Fullgraph Smoke Test # 9min
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -301,7 +310,6 @@ steps:
|
||||
# these tests need to be separated, cannot combine
|
||||
- pytest -v -s compile/piecewise/test_simple.py
|
||||
- pytest -v -s compile/piecewise/test_toy_llama.py
|
||||
- pytest -v -s compile/test_pass_manager.py
|
||||
|
||||
- label: PyTorch Fullgraph Test # 18min
|
||||
source_file_dependencies:
|
||||
@ -310,18 +318,49 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s compile/test_full_graph.py
|
||||
|
||||
- label: Kernels Test %N # 1h each
|
||||
mirror_hardwares: [amd]
|
||||
- label: Kernels Core Operation Test
|
||||
source_file_dependencies:
|
||||
- csrc/
|
||||
- vllm/attention
|
||||
- tests/kernels
|
||||
- tests/kernels/core
|
||||
commands:
|
||||
- pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
||||
parallelism: 4
|
||||
- pytest -v -s kernels/core
|
||||
|
||||
- label: Kernels Attention Test %N
|
||||
source_file_dependencies:
|
||||
- csrc/attention/
|
||||
- vllm/attention
|
||||
- vllm/v1/attention
|
||||
- tests/kernels/attention
|
||||
commands:
|
||||
- pytest -v -s kernels/attention --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
||||
parallelism: 2
|
||||
|
||||
- label: Kernels Quantization Test %N
|
||||
source_file_dependencies:
|
||||
- csrc/quantization/
|
||||
- vllm/model_executor/layers/quantization
|
||||
- tests/kernels/quantization
|
||||
commands:
|
||||
- pytest -v -s kernels/quantization --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
||||
parallelism: 2
|
||||
|
||||
- label: Kernels MoE Test
|
||||
source_file_dependencies:
|
||||
- csrc/moe/
|
||||
- tests/kernels/moe
|
||||
- vllm/model_executor/layers/fused_moe/
|
||||
commands:
|
||||
- pytest -v -s kernels/moe
|
||||
|
||||
- label: Kernels Mamba Test
|
||||
source_file_dependencies:
|
||||
- csrc/mamba/
|
||||
- tests/kernels/mamba
|
||||
commands:
|
||||
- pytest -v -s kernels/mamba
|
||||
|
||||
- label: Tensorizer Test # 11min
|
||||
mirror_hardwares: [amd]
|
||||
# mirror_hardwares: [amd]
|
||||
soft_fail: true
|
||||
source_file_dependencies:
|
||||
- vllm/model_executor/model_loader
|
||||
@ -337,7 +376,14 @@ steps:
|
||||
source_file_dependencies:
|
||||
- benchmarks/
|
||||
commands:
|
||||
- bash run-benchmarks.sh
|
||||
- bash scripts/run-benchmarks.sh
|
||||
|
||||
- label: Benchmarks CLI Test # 10min
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/benchmarks/
|
||||
commands:
|
||||
- pytest -v -s benchmarks/
|
||||
|
||||
- label: Quantization Test # 33min
|
||||
source_file_dependencies:
|
||||
@ -372,12 +418,14 @@ steps:
|
||||
|
||||
- label: OpenAI-Compatible Tool Use # 20 min
|
||||
fast_check: false
|
||||
mirror_hardwares: [ amd ]
|
||||
#mirror_hardwares: [ amd ]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/tool_use
|
||||
- tests/mistral_tool_use
|
||||
commands:
|
||||
- pytest -v -s tool_use
|
||||
- pytest -v -s mistral_tool_use
|
||||
|
||||
##### models test #####
|
||||
|
||||
@ -389,7 +437,9 @@ steps:
|
||||
- pytest -v -s models/test_transformers.py
|
||||
- pytest -v -s models/test_registry.py
|
||||
# V1 Test: https://github.com/vllm-project/vllm/issues/14531
|
||||
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py
|
||||
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2'
|
||||
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'llama4'
|
||||
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'plamo2'
|
||||
|
||||
- label: Language Models Test (Standard) # 32min
|
||||
#mirror_hardwares: [amd]
|
||||
@ -399,6 +449,8 @@ steps:
|
||||
- tests/models/embedding/language
|
||||
- tests/models/encoder_decoder/language
|
||||
commands:
|
||||
# Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
|
||||
- pip install causal-conv1d
|
||||
- pytest -v -s models/decoder_only/language -m 'core_model or quant_model'
|
||||
- pytest -v -s models/embedding/language -m core_model
|
||||
|
||||
@ -410,6 +462,8 @@ steps:
|
||||
- tests/models/embedding/language
|
||||
- tests/models/encoder_decoder/language
|
||||
commands:
|
||||
# Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
|
||||
- pip install causal-conv1d
|
||||
- pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model'
|
||||
- pytest -v -s models/embedding/language -m 'not core_model'
|
||||
|
||||
@ -426,7 +480,7 @@ steps:
|
||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||
- pytest -v -s models/multimodal
|
||||
- pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model'
|
||||
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model'
|
||||
- pytest -v -s models/decoder_only/vision_language -m 'core_model or quant_model'
|
||||
- pytest -v -s models/embedding/vision_language -m core_model
|
||||
- pytest -v -s models/encoder_decoder/audio_language -m core_model
|
||||
- pytest -v -s models/encoder_decoder/language -m core_model
|
||||
@ -445,10 +499,7 @@ steps:
|
||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||
- pytest -v -s models/decoder_only/audio_language -m 'not core_model and not quant_model'
|
||||
- pytest -v -s models/decoder_only/vision_language/test_models.py -m 'split(group=0) and not core_model and not quant_model'
|
||||
# HACK - run phi3v tests separately to sidestep this transformers bug
|
||||
# https://github.com/huggingface/transformers/issues/34307
|
||||
- pytest -v -s models/decoder_only/vision_language/test_phi3v.py
|
||||
- pytest -v -s --ignore models/decoder_only/vision_language/test_models.py --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model and not quant_model'
|
||||
- pytest -v -s --ignore models/decoder_only/vision_language/test_models.py models/decoder_only/vision_language -m 'not core_model and not quant_model'
|
||||
- pytest -v -s models/embedding/vision_language -m 'not core_model'
|
||||
- pytest -v -s models/encoder_decoder/language -m 'not core_model'
|
||||
- pytest -v -s models/encoder_decoder/vision_language -m 'not core_model'
|
||||
@ -464,6 +515,7 @@ steps:
|
||||
|
||||
# This test is used only in PR development phase to test individual models and should never run on main
|
||||
- label: Custom Models Test
|
||||
mirror_hardwares: [amd]
|
||||
optional: true
|
||||
commands:
|
||||
- echo 'Testing custom models...'
|
||||
@ -475,6 +527,7 @@ steps:
|
||||
##### multi gpus test #####
|
||||
|
||||
- label: Distributed Comm Ops Test # 7min
|
||||
mirror_hardwares: [amd]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 2
|
||||
source_file_dependencies:
|
||||
@ -531,11 +584,14 @@ steps:
|
||||
- pytest models/encoder_decoder/language/test_bart.py -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest models/decoder_only/vision_language/test_models.py -v -s -m 'distributed(num_gpus=2)'
|
||||
# test sequence parallel
|
||||
- pytest -v -s distributed/test_sequence_parallel.py
|
||||
# this test fails consistently.
|
||||
# TODO: investigate and fix
|
||||
# - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
|
||||
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
||||
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/test_disagg.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown
|
||||
|
||||
- label: Plugin Tests (2 GPUs) # 40min
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
@ -602,8 +658,6 @@ steps:
|
||||
# requires multi-GPU testing for validation.
|
||||
- pytest -v -s -x lora/test_chatglm3_tp.py
|
||||
- pytest -v -s -x lora/test_llama_tp.py
|
||||
- pytest -v -s -x lora/test_minicpmv_tp.py
|
||||
- pytest -v -s -x lora/test_transfomers_model.py
|
||||
|
||||
|
||||
- label: Weight Loading Multiple GPU Test # 33min
|
||||
|
||||
1
.github/CODEOWNERS
vendored
1
.github/CODEOWNERS
vendored
@ -12,6 +12,7 @@
|
||||
/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth
|
||||
/vllm/model_executor/guided_decoding @mgoin @russellb
|
||||
/vllm/multimodal @DarkLight1337 @ywang96
|
||||
/vllm/vllm_flash_attn @LucasWilkinson
|
||||
CMakeLists.txt @tlrmchlsmth
|
||||
|
||||
# vLLM V1
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/200-installation.yml
vendored
2
.github/ISSUE_TEMPLATE/200-installation.yml
vendored
@ -14,7 +14,7 @@ body:
|
||||
description: |
|
||||
Please run the following and paste the output below.
|
||||
```sh
|
||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/collect_env.py
|
||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/vllm/collect_env.py
|
||||
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
||||
python collect_env.py
|
||||
```
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/300-usage.yml
vendored
2
.github/ISSUE_TEMPLATE/300-usage.yml
vendored
@ -14,7 +14,7 @@ body:
|
||||
description: |
|
||||
Please run the following and paste the output below.
|
||||
```sh
|
||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/collect_env.py
|
||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/vllm/collect_env.py
|
||||
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
||||
python collect_env.py
|
||||
```
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/400-bug-report.yml
vendored
2
.github/ISSUE_TEMPLATE/400-bug-report.yml
vendored
@ -14,7 +14,7 @@ body:
|
||||
description: |
|
||||
Please run the following and paste the output below.
|
||||
```sh
|
||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/collect_env.py
|
||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/vllm/collect_env.py
|
||||
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
||||
python collect_env.py
|
||||
```
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/600-new-model.yml
vendored
2
.github/ISSUE_TEMPLATE/600-new-model.yml
vendored
@ -9,7 +9,7 @@ body:
|
||||
value: >
|
||||
#### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue+sort%3Acreated-desc+).
|
||||
|
||||
#### We also highly recommend you read https://docs.vllm.ai/en/latest/contributing/model/adding_model.html first to understand how to add a new model.
|
||||
#### We also highly recommend you read https://docs.vllm.ai/en/latest/contributing/model/index.html first to understand how to add a new model.
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: The model to consider.
|
||||
|
||||
@ -35,7 +35,7 @@ body:
|
||||
description: |
|
||||
Please run the following and paste the output below.
|
||||
```sh
|
||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/collect_env.py
|
||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/vllm/collect_env.py
|
||||
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
||||
python collect_env.py
|
||||
```
|
||||
|
||||
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
@ -3,4 +3,4 @@ FILL IN THE PR DESCRIPTION HERE
|
||||
FIX #xxxx (*link existing issues this PR will resolve*)
|
||||
|
||||
<!--- pyml disable-next-line no-emphasis-as-heading -->
|
||||
**BEFORE SUBMITTING, PLEASE READ <https://docs.vllm.ai/en/latest/contributing/overview.html>**
|
||||
**BEFORE SUBMITTING, PLEASE READ <https://docs.vllm.ai/en/latest/contributing/overview.html>** (anything written below this line will be removed by GitHub Actions)
|
||||
|
||||
36
.github/mergify.yml
vendored
36
.github/mergify.yml
vendored
@ -19,7 +19,7 @@ pull_request_rules:
|
||||
- files~=\.buildkite/
|
||||
- files~=^cmake/
|
||||
- files=CMakeLists.txt
|
||||
- files~=^Dockerfile
|
||||
- files~=^docker/Dockerfile
|
||||
- files~=^requirements.*\.txt
|
||||
- files=setup.py
|
||||
actions:
|
||||
@ -55,11 +55,19 @@ pull_request_rules:
|
||||
description: Automatically apply structured-output label
|
||||
conditions:
|
||||
- or:
|
||||
- files~=^benchmarks/structured_schemas/
|
||||
- files=benchmarks/benchmark_serving_structured_output.py
|
||||
- files=benchmarks/run_structured_output_benchmark.sh
|
||||
- files=docs/source/features/structured_outputs.md
|
||||
- files=examples/offline_inference/structured_outputs.py
|
||||
- files=examples/online_serving/openai_chat_completion_structured_outputs.py
|
||||
- files=examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py
|
||||
- files~=^vllm/model_executor/guided_decoding/
|
||||
- files=tests/model_executor/test_guided_processors.py
|
||||
- files=tests/entrypoints/llm/test_guided_generate.py
|
||||
- files=benchmarks/benchmark_serving_guided.py
|
||||
- files=benchmarks/benchmark_guided.py
|
||||
- files~=^tests/v1/structured_output/
|
||||
- files=tests/v1/entrypoints/llm/test_guided_generate.py
|
||||
- files~=^vllm/v1/structured_output/
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
@ -118,6 +126,28 @@ pull_request_rules:
|
||||
remove:
|
||||
- tpu
|
||||
|
||||
- name: label-tool-calling
|
||||
description: Automatically add tool-calling label
|
||||
conditions:
|
||||
- or:
|
||||
- files~=^tests/tool_use/
|
||||
- files~=^tests/mistral_tool_use/
|
||||
- files~=^tests/entrypoints/openai/tool_parsers/
|
||||
- files=tests/entrypoints/openai/test_chat_with_tool_reasoning.py
|
||||
- files~=^vllm/entrypoints/openai/tool_parsers/
|
||||
- files=docs/source/features/tool_calling.md
|
||||
- files=docs/source/getting_started/examples/openai_chat_completion_client_with_tools.md
|
||||
- files=docs/source/getting_started/examples/chat_with_tools.md
|
||||
- files~=^examples/tool_chat_*
|
||||
- files=examples/offline_inference/chat_with_tools.py
|
||||
- files=examples/online_serving/openai_chat_completion_client_with_tools_required.py
|
||||
- files=examples/online_serving/openai_chat_completion_tool_calls_with_reasoning.py
|
||||
- files=examples/online_serving/openai_chat_completion_client_with_tools.py
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
- tool-calling
|
||||
|
||||
- name: ping author on conflicts and add 'needs-rebase' label
|
||||
conditions:
|
||||
- conflict
|
||||
|
||||
2
.github/workflows/lint-and-deploy.yaml
vendored
2
.github/workflows/lint-and-deploy.yaml
vendored
@ -50,7 +50,7 @@ jobs:
|
||||
uses: helm/kind-action@a1b0e391336a6ee6713a0583f8c6240d70863de3 # v1.12.0
|
||||
|
||||
- name: Build the Docker image vllm cpu
|
||||
run: docker buildx build -f Dockerfile.cpu -t vllm-cpu-env .
|
||||
run: docker buildx build -f docker/Dockerfile.cpu -t vllm-cpu-env .
|
||||
|
||||
- name: Configuration of docker images, network and namespace for the kind cluster
|
||||
run: |
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@ -3,7 +3,6 @@
|
||||
|
||||
# vllm-flash-attn built from source
|
||||
vllm/vllm_flash_attn/*
|
||||
!vllm/vllm_flash_attn/fa_utils.py
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
@ -203,3 +202,6 @@ benchmarks/**/*.json
|
||||
# Linting
|
||||
actionlint
|
||||
shellcheck*/
|
||||
|
||||
# Ingore moe/marlin_moe gen code
|
||||
csrc/moe/marlin_moe_wna16/kernel_*
|
||||
|
||||
@ -1,3 +1,6 @@
|
||||
default_install_hook_types:
|
||||
- pre-commit
|
||||
- commit-msg
|
||||
default_stages:
|
||||
- pre-commit # Run locally
|
||||
- manual # Run in CI
|
||||
@ -8,7 +11,6 @@ repos:
|
||||
hooks:
|
||||
- id: yapf
|
||||
args: [--in-place, --verbose]
|
||||
additional_dependencies: [toml] # TODO: Remove when yapf is upgraded
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.9.3
|
||||
hooks:
|
||||
@ -119,6 +121,12 @@ repos:
|
||||
language: system
|
||||
always_run: true
|
||||
pass_filenames: false
|
||||
- id: update-dockerfile-graph
|
||||
name: Update Dockerfile dependency graph
|
||||
entry: tools/update-dockerfile-graph.sh
|
||||
language: script
|
||||
files: ^docker/Dockerfile$
|
||||
pass_filenames: false
|
||||
# Keep `suggestion` last
|
||||
- id: suggestion
|
||||
name: Suggestion
|
||||
|
||||
@ -44,7 +44,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1
|
||||
#
|
||||
# Note: the CUDA torch version is derived from pyproject.toml and various
|
||||
# requirements.txt files and should be kept consistent. The ROCm torch
|
||||
# versions are derived from Dockerfile.rocm
|
||||
# versions are derived from docker/Dockerfile.rocm
|
||||
#
|
||||
set(TORCH_SUPPORTED_VERSION_CUDA "2.6.0")
|
||||
set(TORCH_SUPPORTED_VERSION_ROCM "2.6.0")
|
||||
@ -230,6 +230,7 @@ set(VLLM_EXT_SRC
|
||||
"csrc/cache_kernels.cu"
|
||||
"csrc/attention/paged_attention_v1.cu"
|
||||
"csrc/attention/paged_attention_v2.cu"
|
||||
"csrc/attention/merge_attn_states.cu"
|
||||
"csrc/pos_encoding_kernels.cu"
|
||||
"csrc/activation_kernels.cu"
|
||||
"csrc/layernorm_kernels.cu"
|
||||
@ -242,6 +243,7 @@ set(VLLM_EXT_SRC
|
||||
"csrc/quantization/gguf/gguf_kernel.cu"
|
||||
"csrc/cuda_utils_kernels.cu"
|
||||
"csrc/prepare_inputs/advance_step.cu"
|
||||
"csrc/custom_all_reduce.cu"
|
||||
"csrc/torch_bindings.cpp")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
@ -249,7 +251,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
# Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case.
|
||||
# Please keep this in sync with FetchContent_Declare line below.
|
||||
set(CUTLASS_REVISION "v3.8.0" CACHE STRING "CUTLASS revision to use")
|
||||
set(CUTLASS_REVISION "v3.9.0" CACHE STRING "CUTLASS revision to use")
|
||||
|
||||
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
|
||||
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
|
||||
@ -267,7 +269,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
cutlass
|
||||
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
|
||||
# Please keep this in sync with CUTLASS_REVISION line above.
|
||||
GIT_TAG v3.8.0
|
||||
GIT_TAG v3.9.0
|
||||
GIT_PROGRESS TRUE
|
||||
|
||||
# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
|
||||
@ -283,13 +285,13 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
|
||||
"csrc/quantization/aqlm/gemm_kernels.cu"
|
||||
"csrc/quantization/awq/gemm_kernels.cu"
|
||||
"csrc/custom_all_reduce.cu"
|
||||
"csrc/permute_cols.cu"
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
|
||||
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
|
||||
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
|
||||
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
|
||||
"csrc/cutlass_extensions/common.cpp")
|
||||
"csrc/cutlass_extensions/common.cpp"
|
||||
"csrc/attention/mla/cutlass_mla_entry.cu")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${VLLM_EXT_SRC}"
|
||||
@ -462,7 +464,26 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
set(FP4_ARCHS)
|
||||
endif()
|
||||
|
||||
#
|
||||
# CUTLASS MLA Archs and flags
|
||||
cuda_archs_loose_intersection(MLA_ARCHS "10.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND MLA_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/attention/mla/cutlass_mla_kernels.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${MLA_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MLA=1")
|
||||
# Add MLA-specific include directories only to MLA source files
|
||||
set_source_files_properties(${SRCS}
|
||||
PROPERTIES INCLUDE_DIRECTORIES "${CUTLASS_DIR}/examples/77_blackwell_fmha;${CUTLASS_DIR}/examples/common")
|
||||
message(STATUS "Building CUTLASS MLA for archs: ${MLA_ARCHS}")
|
||||
else()
|
||||
message(STATUS "Not building CUTLASS MLA as no compatible archs were found.")
|
||||
# clear MLA_ARCHS
|
||||
set(MLA_ARCHS)
|
||||
endif()
|
||||
|
||||
# CUTLASS MoE kernels
|
||||
|
||||
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works
|
||||
@ -608,21 +629,51 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}")
|
||||
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
|
||||
if (MARLIN_MOE_ARCHS)
|
||||
set(MARLIN_MOE_SRC
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel.h"
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h"
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu"
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h"
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu"
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h"
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu"
|
||||
"csrc/moe/marlin_moe_ops.cu")
|
||||
|
||||
#
|
||||
# For the Marlin MOE kernels we automatically generate sources for various
|
||||
# preselected input type pairs and schedules.
|
||||
# Generate sources:
|
||||
set(MOE_MARLIN_GEN_SCRIPT
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/csrc/moe/marlin_moe_wna16/generate_kernels.py)
|
||||
file(MD5 ${MOE_MARLIN_GEN_SCRIPT} MOE_MARLIN_GEN_SCRIPT_HASH)
|
||||
|
||||
message(STATUS "Marlin MOE generation script hash: ${MOE_MARLIN_GEN_SCRIPT_HASH}")
|
||||
message(STATUS "Last run Marlin MOE generate script hash: $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH}")
|
||||
|
||||
if (NOT DEFINED CACHE{MOE_MARLIN_GEN_SCRIPT_HASH}
|
||||
OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH})
|
||||
execute_process(
|
||||
COMMAND ${CMAKE_COMMAND} -E env
|
||||
PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH
|
||||
${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT}
|
||||
RESULT_VARIABLE moe_marlin_generation_result
|
||||
OUTPUT_VARIABLE moe_marlin_generation_output
|
||||
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log
|
||||
ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log
|
||||
)
|
||||
|
||||
if (NOT moe_marlin_generation_result EQUAL 0)
|
||||
message(FATAL_ERROR "Marlin MOE generation failed."
|
||||
" Result: \"${moe_marlin_generation_result}\""
|
||||
"\nCheck the log for details: "
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log")
|
||||
else()
|
||||
set(MOE_MARLIN_GEN_SCRIPT_HASH ${MOE_MARLIN_GEN_SCRIPT_HASH}
|
||||
CACHE STRING "Last run Marlin MOE generate script hash" FORCE)
|
||||
message(STATUS "Marlin MOE generation completed successfully.")
|
||||
endif()
|
||||
else()
|
||||
message(STATUS "Marlin MOE generation script has not changed, skipping generation.")
|
||||
endif()
|
||||
|
||||
file(GLOB MOE_WNAA16_MARLIN_SRC "csrc/moe/marlin_moe_wna16/*.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_MOE_SRC}"
|
||||
SRCS "${MOE_WNAA16_MARLIN_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_MOE_ARCHS}")
|
||||
|
||||
list(APPEND VLLM_MOE_EXT_SRC "${MARLIN_MOE_SRC}")
|
||||
list(APPEND VLLM_MOE_EXT_SRC ${MOE_WNAA16_MARLIN_SRC})
|
||||
|
||||
message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}")
|
||||
else()
|
||||
message(STATUS "Not building Marlin MOE kernels as no compatible archs found"
|
||||
@ -647,6 +698,7 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
|
||||
#
|
||||
set(VLLM_ROCM_EXT_SRC
|
||||
"csrc/rocm/torch_bindings.cpp"
|
||||
"csrc/rocm/skinny_gemms.cu"
|
||||
"csrc/rocm/attention.cu")
|
||||
|
||||
define_gpu_extension_target(
|
||||
|
||||
14
README.md
14
README.md
@ -10,19 +10,14 @@ Easy, fast, and cheap LLM serving for everyone
|
||||
</h3>
|
||||
|
||||
<p align="center">
|
||||
| <a href="https://docs.vllm.ai"><b>Documentation</b></a> | <a href="https://vllm.ai"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://x.com/vllm_project"><b>Twitter/X</b></a> | <a href="https://discuss.vllm.ai"><b>User Forum</b></a> | <a href="https://slack.vllm.ai"><b>Developer Slack</b></a> |
|
||||
| <a href="https://docs.vllm.ai"><b>Documentation</b></a> | <a href="https://blog.vllm.ai/"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://x.com/vllm_project"><b>Twitter/X</b></a> | <a href="https://discuss.vllm.ai"><b>User Forum</b></a> | <a href="https://slack.vllm.ai"><b>Developer Slack</b></a> |
|
||||
</p>
|
||||
|
||||
---
|
||||
|
||||
[2025/03] We are collaborating with Ollama to host an [Inference Night](https://lu.ma/vllm-ollama) at Y Combinator in San Francisco on Thursday, March 27, at 6 PM. Discuss all things inference local or data center!
|
||||
|
||||
[2025/04] We're hosting our first-ever *vLLM Asia Developer Day* in Singapore on *April 3rd*! This is a full-day event (9 AM - 9 PM SGT) in partnership with SGInnovate, AMD, and Embedded LLM. Meet the vLLM team and learn about LLM inference for RL, MI300X, and more! [Register Now](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day)
|
||||
|
||||
---
|
||||
|
||||
*Latest News* 🔥
|
||||
|
||||
- [2025/04] We hosted [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing).
|
||||
- [2025/03] We hosted [vLLM x Ollama Inference Night](https://lu.ma/vllm-ollama)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/16T2PDD1YwRnZ4Tu8Q5r6n53c5Lr5c73UV9Vd2_eBo4U/edit?usp=sharing).
|
||||
- [2025/03] We hosted [the first vLLM China Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg)! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1REHvfQMKGnvz6p3Fd23HhSO4c8j5WPGZV0bKYLwnHyQ/edit?usp=sharing).
|
||||
- [2025/03] We hosted [the East Coast vLLM Meetup](https://lu.ma/7mu4k4xx)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1NHiv8EUFF1NLd3fEYODm56nDmL26lEeXCaDgyDlTsRs/edit#slide=id.g31441846c39_0_0).
|
||||
- [2025/02] We hosted [the ninth vLLM meetup](https://lu.ma/h7g3kuj9) with Meta! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1jzC_PZVXrVNSFVCW-V4cFXb6pn7zZ2CyP_Flwo05aqg/edit?usp=sharing) and AMD [here](https://drive.google.com/file/d/1Zk5qEJIkTmlQ2eQcXQZlljAx3m9s7nwn/view?usp=sharing). The slides from Meta will not be posted.
|
||||
@ -103,7 +98,7 @@ Visit our [documentation](https://docs.vllm.ai/en/latest/) to learn more.
|
||||
## Contributing
|
||||
|
||||
We welcome and value any contributions and collaborations.
|
||||
Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved.
|
||||
Please check out [Contributing to vLLM](https://docs.vllm.ai/en/stable/contributing/overview.html) for how to get involved.
|
||||
|
||||
## Sponsors
|
||||
|
||||
@ -126,6 +121,7 @@ Compute Resources:
|
||||
- Databricks
|
||||
- DeepInfra
|
||||
- Google Cloud
|
||||
- Intel
|
||||
- Lambda Lab
|
||||
- Nebius
|
||||
- Novita AI
|
||||
|
||||
@ -51,6 +51,12 @@ become available.
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td><code>likaixin/InstructCoder</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>HuggingFace-AIMO</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td><code>AI-MO/aimo-validation-aime</code> , <code>AI-MO/NuminaMath-1.5</code>, <code>AI-MO/NuminaMath-CoT</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>HuggingFace-Other</strong></td>
|
||||
@ -187,6 +193,35 @@ python3 vllm/benchmarks/benchmark_serving.py \
|
||||
--num-prompts 10
|
||||
```
|
||||
|
||||
**`AI-MO/aimo-validation-aime`**
|
||||
|
||||
``` bash
|
||||
python3 vllm/benchmarks/benchmark_serving.py \
|
||||
--model Qwen/QwQ-32B \
|
||||
--dataset-name hf \
|
||||
--dataset-path AI-MO/aimo-validation-aime \
|
||||
--num-prompts 10 \
|
||||
--seed 42
|
||||
```
|
||||
|
||||
### Running With Sampling Parameters
|
||||
|
||||
When using OpenAI-compatible backends such as `vllm`, optional sampling
|
||||
parameters can be specified. Example client command:
|
||||
|
||||
```bash
|
||||
python3 vllm/benchmarks/benchmark_serving.py \
|
||||
--backend vllm \
|
||||
--model NousResearch/Hermes-3-Llama-3.1-8B \
|
||||
--endpoint /v1/completions \
|
||||
--dataset-name sharegpt \
|
||||
--dataset-path <your data path>/ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||
--top-k 10 \
|
||||
--top-p 0.9 \
|
||||
--temperature 0.5 \
|
||||
--num-prompts 10
|
||||
```
|
||||
|
||||
---
|
||||
## Example - Offline Throughput Benchmark
|
||||
|
||||
@ -278,6 +313,18 @@ python3 vllm/benchmarks/benchmark_throughput.py \
|
||||
--num-prompts 10
|
||||
```
|
||||
|
||||
**`AI-MO/aimo-validation-aime`**
|
||||
|
||||
```bash
|
||||
python3 benchmarks/benchmark_throughput.py \
|
||||
--model Qwen/QwQ-32B \
|
||||
--backend vllm \
|
||||
--dataset-name hf \
|
||||
--dataset-path AI-MO/aimo-validation-aime \
|
||||
--hf-split train \
|
||||
--num-prompts 10
|
||||
```
|
||||
|
||||
### Benchmark with LoRA Adapters
|
||||
|
||||
``` bash
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
@ -32,6 +33,7 @@ class RequestFuncInput:
|
||||
extra_body: Optional[dict] = None
|
||||
multi_modal_content: Optional[dict] = None
|
||||
ignore_eos: bool = False
|
||||
language: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -219,7 +221,15 @@ async def async_request_deepspeed_mii(
|
||||
if response.status == 200:
|
||||
parsed_resp = await response.json()
|
||||
output.latency = time.perf_counter() - st
|
||||
output.generated_text = parsed_resp["text"][0]
|
||||
if "choices" in parsed_resp:
|
||||
output.generated_text = parsed_resp["choices"][0][
|
||||
"text"]
|
||||
elif "text" in parsed_resp:
|
||||
output.generated_text = parsed_resp["text"][0]
|
||||
else:
|
||||
output.error = ("Unexpected response format: "
|
||||
"neither 'choices' nor 'text' found")
|
||||
output.success = False
|
||||
output.success = True
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
@ -428,6 +438,110 @@ async def async_request_openai_chat_completions(
|
||||
return output
|
||||
|
||||
|
||||
async def async_request_openai_audio(
|
||||
request_func_input: RequestFuncInput,
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
# Lazy import without PlaceholderModule to avoid vllm dep.
|
||||
import soundfile
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith(
|
||||
("transcriptions", "translations"
|
||||
)), "OpenAI Chat Completions API URL must end with 'transcriptions' "
|
||||
"or `translations`."
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
content = [{"type": "text", "text": request_func_input.prompt}]
|
||||
payload = {
|
||||
"model": request_func_input.model_name \
|
||||
if request_func_input.model_name else request_func_input.model,
|
||||
"temperature": 0.0,
|
||||
"max_completion_tokens": request_func_input.output_len,
|
||||
"stream": True,
|
||||
"language": "en",
|
||||
# Flattened due to multipart/form-data
|
||||
"stream_include_usage": True,
|
||||
"stream_continuous_usage_stats": True
|
||||
}
|
||||
if request_func_input.extra_body:
|
||||
payload.update(request_func_input.extra_body)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
}
|
||||
|
||||
# Send audio file
|
||||
def to_bytes(y, sr):
|
||||
buffer = io.BytesIO()
|
||||
soundfile.write(buffer, y, sr, format="WAV")
|
||||
buffer.seek(0)
|
||||
return buffer
|
||||
|
||||
with to_bytes(*request_func_input.multi_modal_content['audio']) as f:
|
||||
form = aiohttp.FormData()
|
||||
form.add_field('file', f, content_type='audio/wav')
|
||||
for key, value in payload.items():
|
||||
form.add_field(key, str(value))
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
generated_text = ""
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url,
|
||||
data=form,
|
||||
headers=headers) as response:
|
||||
if response.status == 200:
|
||||
async for chunk_bytes in response.content:
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
||||
"data: ")
|
||||
if chunk != "[DONE]":
|
||||
timestamp = time.perf_counter()
|
||||
data = json.loads(chunk)
|
||||
|
||||
if choices := data.get("choices"):
|
||||
content = choices[0]["delta"].get(
|
||||
"content")
|
||||
# First token
|
||||
if ttft == 0.0:
|
||||
ttft = timestamp - st
|
||||
output.ttft = ttft
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(
|
||||
timestamp - most_recent_timestamp)
|
||||
|
||||
generated_text += content or ""
|
||||
elif usage := data.get("usage"):
|
||||
output.output_tokens = usage.get(
|
||||
"completion_tokens")
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
|
||||
output.generated_text = generated_text
|
||||
output.success = True
|
||||
output.latency = most_recent_timestamp - st
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
except Exception:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
def get_model(pretrained_model_name_or_path: str) -> str:
|
||||
if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true':
|
||||
from modelscope import snapshot_download
|
||||
@ -485,7 +599,14 @@ ASYNC_REQUEST_FUNCS = {
|
||||
"deepspeed-mii": async_request_deepspeed_mii,
|
||||
"openai": async_request_openai_completions,
|
||||
"openai-chat": async_request_openai_chat_completions,
|
||||
"openai-audio": async_request_openai_audio,
|
||||
"tensorrt-llm": async_request_trt_llm,
|
||||
"scalellm": async_request_openai_completions,
|
||||
"sglang": async_request_openai_completions,
|
||||
}
|
||||
|
||||
OPENAI_COMPATIBLE_BACKENDS = [
|
||||
k for k, v in ASYNC_REQUEST_FUNCS.items()
|
||||
if v in (async_request_openai_completions,
|
||||
async_request_openai_chat_completions)
|
||||
]
|
||||
|
||||
@ -64,6 +64,7 @@ class SampleRequest:
|
||||
|
||||
class BenchmarkDataset(ABC):
|
||||
DEFAULT_SEED = 0
|
||||
IS_MULTIMODAL = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -288,7 +289,7 @@ def process_image(image: Any) -> Mapping[str, Any]:
|
||||
class RandomDataset(BenchmarkDataset):
|
||||
# Default values copied from benchmark_serving.py for the random dataset.
|
||||
DEFAULT_PREFIX_LEN = 0
|
||||
DEFAULT_RANGE_RATIO = 1.0
|
||||
DEFAULT_RANGE_RATIO = 0.0
|
||||
DEFAULT_INPUT_LEN = 1024
|
||||
DEFAULT_OUTPUT_LEN = 128
|
||||
|
||||
@ -308,19 +309,32 @@ class RandomDataset(BenchmarkDataset):
|
||||
output_len: int = DEFAULT_OUTPUT_LEN,
|
||||
**kwargs,
|
||||
) -> list[SampleRequest]:
|
||||
# Enforce range_ratio < 1
|
||||
assert range_ratio < 1.0, (
|
||||
"random_range_ratio must be < 1.0 to ensure a valid sampling range"
|
||||
)
|
||||
|
||||
vocab_size = tokenizer.vocab_size
|
||||
|
||||
prefix_token_ids = (np.random.randint(
|
||||
0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else [])
|
||||
|
||||
input_low = int(input_len * range_ratio)
|
||||
output_low = int(output_len * range_ratio)
|
||||
# New sampling logic: [X * (1 - b), X * (1 + b)]
|
||||
input_low = int(input_len * (1 - range_ratio))
|
||||
input_high = int(input_len * (1 + range_ratio))
|
||||
output_low = int(output_len * (1 - range_ratio))
|
||||
output_high = int(output_len * (1 + range_ratio))
|
||||
|
||||
# Add logging for debugging
|
||||
logger.info("Sampling input_len from [%s, %s]", input_low, input_high)
|
||||
logger.info("Sampling output_len from [%s, %s]", output_low,
|
||||
output_high)
|
||||
|
||||
input_lens = np.random.randint(input_low,
|
||||
input_len + 1,
|
||||
input_high + 1,
|
||||
size=num_requests)
|
||||
output_lens = np.random.randint(output_low,
|
||||
output_len + 1,
|
||||
output_high + 1,
|
||||
size=num_requests)
|
||||
offsets = np.random.randint(0, vocab_size, size=num_requests)
|
||||
|
||||
@ -472,11 +486,11 @@ class SonnetDataset(BenchmarkDataset):
|
||||
|
||||
# Determine how many poem lines to use.
|
||||
num_input_lines = round((input_len - base_offset) / avg_len)
|
||||
num_prefix_lines = round((prefix_len - base_offset) / avg_len)
|
||||
num_prefix_lines = max(round((prefix_len - base_offset) / avg_len), 0)
|
||||
prefix_lines = self.data[:num_prefix_lines]
|
||||
|
||||
samples = []
|
||||
for _ in range(num_requests):
|
||||
while len(samples) < num_requests:
|
||||
extra_lines = random.choices(self.data,
|
||||
k=num_input_lines - num_prefix_lines)
|
||||
prompt = f"{base_prompt}{''.join(prefix_lines + extra_lines)}"
|
||||
@ -484,13 +498,14 @@ class SonnetDataset(BenchmarkDataset):
|
||||
prompt_formatted = tokenizer.apply_chat_template(
|
||||
msg, add_generation_prompt=True, tokenize=False)
|
||||
prompt_len = len(tokenizer(prompt_formatted).input_ids)
|
||||
samples.append(
|
||||
SampleRequest(
|
||||
prompt=prompt_formatted
|
||||
if return_prompt_formatted else prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
))
|
||||
if prompt_len <= input_len:
|
||||
samples.append(
|
||||
SampleRequest(
|
||||
prompt=prompt_formatted
|
||||
if return_prompt_formatted else prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
))
|
||||
return samples
|
||||
|
||||
|
||||
@ -582,15 +597,6 @@ class HuggingFaceDataset(BenchmarkDataset):
|
||||
) -> None:
|
||||
super().__init__(dataset_path=dataset_path, **kwargs)
|
||||
|
||||
# Validate dataset path
|
||||
if self.SUPPORTED_DATASET_PATHS and \
|
||||
self.dataset_path not in self.SUPPORTED_DATASET_PATHS:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} "
|
||||
f"only supports: {', '.join(self.SUPPORTED_DATASET_PATHS)}. "
|
||||
"Please consider contributing if you would "
|
||||
"like to add support for additional dataset formats.")
|
||||
|
||||
self.dataset_split = dataset_split
|
||||
self.dataset_subset = dataset_subset
|
||||
self.load_data()
|
||||
@ -616,6 +622,7 @@ class ConversationDataset(HuggingFaceDataset):
|
||||
SUPPORTED_DATASET_PATHS = {
|
||||
'lmms-lab/LLaVA-OneVision-Data', 'Aeala/ShareGPT_Vicuna_unfiltered'
|
||||
}
|
||||
IS_MULTIMODAL = True
|
||||
|
||||
def sample(self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
@ -680,6 +687,7 @@ class VisionArenaDataset(HuggingFaceDataset):
|
||||
"lmarena-ai/vision-arena-bench-v0.1":
|
||||
lambda x: x["turns"][0][0]["content"]
|
||||
}
|
||||
IS_MULTIMODAL = True
|
||||
|
||||
def sample(
|
||||
self,
|
||||
@ -761,3 +769,129 @@ class InstructCoderDataset(HuggingFaceDataset):
|
||||
))
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
return sampled_requests
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# AIMO Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class AIMODataset(HuggingFaceDataset):
|
||||
"""
|
||||
Dataset class for processing a AIMO dataset with reasoning questions.
|
||||
"""
|
||||
SUPPORTED_DATASET_PATHS = {
|
||||
"AI-MO/aimo-validation-aime", "AI-MO/NuminaMath-1.5",
|
||||
"AI-MO/NuminaMath-CoT"
|
||||
}
|
||||
|
||||
def sample(self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
**kwargs) -> list:
|
||||
sampled_requests = []
|
||||
dynamic_output = output_len is None
|
||||
|
||||
for item in self.data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
prompt, completion = item['problem'], item["solution"]
|
||||
|
||||
prompt_ids = tokenizer(prompt).input_ids
|
||||
completion_ids = tokenizer(completion).input_ids
|
||||
prompt_len = len(prompt_ids)
|
||||
completion_len = len(completion_ids)
|
||||
output_len = completion_len if dynamic_output else output_len
|
||||
assert isinstance(output_len, int) and output_len > 0
|
||||
if dynamic_output and not is_valid_sequence(prompt_len,
|
||||
completion_len,
|
||||
max_prompt_len=2048,
|
||||
max_total_len=32000):
|
||||
continue
|
||||
sampled_requests.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
multi_modal_data=None,
|
||||
))
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
return sampled_requests
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# ASR Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ASRDataset(HuggingFaceDataset):
|
||||
"""
|
||||
Dataset class for processing a ASR dataset for transcription.
|
||||
Tested on the following set:
|
||||
|
||||
+----------------+----------------------------------------+--------------------------+-----------------------------+
|
||||
| Dataset | Domain | Speaking Style | hf-subset |
|
||||
+----------------+----------------------------------------+--------------------------+-----------------------------+
|
||||
| TED-LIUM | TED talks | Oratory | release1, release2, release3|
|
||||
| | | | release3-speaker-adaptation |
|
||||
| VoxPopuli | European Parliament | Oratory | en, de, it, fr, ... |
|
||||
| LibriSpeech | Audiobook | Narrated | "LIUM/tedlium" |
|
||||
| GigaSpeech | Audiobook, podcast, YouTube | Narrated, spontaneous | xs, s, m, l, xl, dev, test |
|
||||
| SPGISpeech | Financial meetings | Oratory, spontaneous | S, M, L, dev, test |
|
||||
| AMI | Meetings | Spontaneous | ihm, sdm |
|
||||
+----------------+----------------------------------------+--------------------------+-----------------------------+
|
||||
|
||||
""" # noqa: E501
|
||||
SUPPORTED_DATASET_PATHS = {
|
||||
"openslr/librispeech_asr", "facebook/voxpopuli", "LIUM/tedlium",
|
||||
"edinburghcstr/ami", "speechcolab/gigaspeech", "kensho/spgispeech"
|
||||
}
|
||||
|
||||
DEFAULT_OUTPUT_LEN = 128
|
||||
IS_MULTIMODAL = True
|
||||
|
||||
# TODO Whisper-specific. Abstract interface when more models are supported.
|
||||
TRANSCRIPTION_PREAMBLE = "<|startoftranscript|><|en|><|transcribe|>"\
|
||||
"<|notimestamps|>"
|
||||
skip_long_audios: bool = True
|
||||
|
||||
def sample(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> list:
|
||||
import librosa
|
||||
output_len = (output_len
|
||||
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
|
||||
prompt = ASRDataset.TRANSCRIPTION_PREAMBLE
|
||||
prompt_len = len(tokenizer(prompt).input_ids)
|
||||
sampled_requests = []
|
||||
skipped = 0
|
||||
for item in self.data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
audio = item["audio"]
|
||||
y, sr = audio["array"], audio["sampling_rate"]
|
||||
duration_s = librosa.get_duration(y=y, sr=sr)
|
||||
# Whisper max supported duration
|
||||
if self.skip_long_audios and duration_s > 30:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
mm_content = {"audio": (y, sr)}
|
||||
sampled_requests.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
multi_modal_data=mm_content,
|
||||
))
|
||||
if skipped:
|
||||
logger.warning("%d samples discarded from dataset due to" \
|
||||
" their length being greater than" \
|
||||
" what Whisper supports.", skipped)
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
return sampled_requests
|
||||
|
||||
@ -5,21 +5,18 @@ import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from benchmark_utils import (convert_to_pytorch_benchmark_format, get_requests,
|
||||
validate_dataset, write_to_json)
|
||||
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.inputs import TextPrompt, TokensPrompt
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
@ -51,34 +48,28 @@ def main(args: argparse.Namespace):
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
n=args.n,
|
||||
temperature=0,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=args.output_len,
|
||||
detokenize=not args.disable_detokenize,
|
||||
)
|
||||
print(sampling_params)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer, trust_remote_code=args.trust_remote_code)
|
||||
requests = get_requests(args.batch_size, args, tokenizer)
|
||||
prompts: list[Union[TextPrompt, TokensPrompt]] = []
|
||||
for request in requests:
|
||||
prompts.append(
|
||||
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"],
|
||||
multi_modal_data=request.multi_modal_data)
|
||||
if "prompt_token_ids" in request.prompt else \
|
||||
TextPrompt(prompt=request.prompt,
|
||||
multi_modal_data=request.multi_modal_data))
|
||||
dummy_prompt_token_ids = np.random.randint(10000,
|
||||
size=(args.batch_size,
|
||||
args.input_len))
|
||||
dummy_prompts: list[PromptType] = [{
|
||||
"prompt_token_ids": batch
|
||||
} for batch in dummy_prompt_token_ids.tolist()]
|
||||
|
||||
def llm_generate():
|
||||
if not args.use_beam_search:
|
||||
llm.generate(prompts,
|
||||
llm.generate(dummy_prompts,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=False)
|
||||
else:
|
||||
llm.beam_search(
|
||||
prompts,
|
||||
dummy_prompts,
|
||||
BeamSearchParams(
|
||||
beam_width=args.n,
|
||||
max_tokens=args.output_len,
|
||||
@ -189,44 +180,7 @@ if __name__ == "__main__":
|
||||
help=("Do not detokenize responses (i.e. do not include "
|
||||
"detokenization time in the latency measurement)"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"],
|
||||
help="Name of the dataset to benchmark on.",
|
||||
default="sharegpt")
|
||||
# random dataset
|
||||
parser.add_argument(
|
||||
"--random-range-ratio",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Range of sampled ratio of input/output length, "
|
||||
"used only for RandomDataSet.",
|
||||
)
|
||||
parser.add_argument("--dataset-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the dataset")
|
||||
|
||||
# LoRA
|
||||
parser.add_argument(
|
||||
"--lora-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the lora adapters to use. This can be an absolute path, "
|
||||
"a relative path, or a Hugging Face model identifier.")
|
||||
|
||||
parser.add_argument("--prefix-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of prefix tokens per request."
|
||||
"This is for the RandomDataset and SonnetDataset")
|
||||
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
if args.tokenizer is None:
|
||||
args.tokenizer = args.model
|
||||
args.backend = "vllm"
|
||||
validate_dataset(args)
|
||||
random.seed(0)
|
||||
main(args)
|
||||
|
||||
@ -63,14 +63,16 @@ class Request:
|
||||
output_len: int
|
||||
|
||||
|
||||
def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> str:
|
||||
def sample_tokens(tokenizer: PreTrainedTokenizerBase,
|
||||
length: int) -> list[int]:
|
||||
vocab = tokenizer.get_vocab()
|
||||
all_special_ids = set(tokenizer.all_special_ids)
|
||||
|
||||
# Remove the special tokens.
|
||||
vocab = {
|
||||
k: v
|
||||
for k, v in vocab.items() if k not in tokenizer.all_special_ids
|
||||
}
|
||||
return random.choices(list(vocab.values()), k=length)
|
||||
return random.choices(
|
||||
[v for k, v in vocab.items() if k not in all_special_ids],
|
||||
k=length,
|
||||
)
|
||||
|
||||
|
||||
def sample_requests_from_dataset(
|
||||
|
||||
@ -34,7 +34,8 @@ from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
|
||||
from backend_request_func import (ASYNC_REQUEST_FUNCS,
|
||||
OPENAI_COMPATIBLE_BACKENDS, RequestFuncInput,
|
||||
RequestFuncOutput)
|
||||
from tqdm.asyncio import tqdm
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
@ -49,7 +50,8 @@ try:
|
||||
except ImportError:
|
||||
from argparse import ArgumentParser as FlexibleArgumentParser
|
||||
|
||||
from benchmark_dataset import (BurstGPTDataset, ConversationDataset,
|
||||
from benchmark_dataset import (AIMODataset, ASRDataset, BurstGPTDataset,
|
||||
ConversationDataset, HuggingFaceDataset,
|
||||
InstructCoderDataset, RandomDataset,
|
||||
SampleRequest, ShareGPTDataset, SonnetDataset,
|
||||
VisionArenaDataset)
|
||||
@ -154,7 +156,7 @@ def calculate_metrics(
|
||||
if outputs[i].success:
|
||||
output_len = outputs[i].output_tokens
|
||||
|
||||
if output_len is None:
|
||||
if not output_len:
|
||||
# We use the tokenizer to count the number of output tokens
|
||||
# for some serving backends instead of looking at
|
||||
# len(outputs[i].itl) since multiple output tokens may be
|
||||
@ -259,6 +261,7 @@ async def benchmark(
|
||||
goodput_config_dict: dict[str, float],
|
||||
max_concurrency: Optional[int],
|
||||
lora_modules: Optional[Iterable[str]],
|
||||
extra_body: Optional[dict],
|
||||
):
|
||||
if backend in ASYNC_REQUEST_FUNCS:
|
||||
request_func = ASYNC_REQUEST_FUNCS[backend]
|
||||
@ -271,10 +274,6 @@ async def benchmark(
|
||||
input_requests[0].expected_output_len, \
|
||||
input_requests[0].multi_modal_data
|
||||
|
||||
if backend != "openai-chat" and test_mm_content is not None:
|
||||
# multi-modal benchmark is only available on OpenAI Chat backend.
|
||||
raise ValueError(
|
||||
"Multi-modal content is only supported on 'openai-chat' backend.")
|
||||
assert test_mm_content is None or isinstance(test_mm_content, dict)
|
||||
test_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
@ -286,6 +285,7 @@ async def benchmark(
|
||||
logprobs=logprobs,
|
||||
multi_modal_content=test_mm_content,
|
||||
ignore_eos=ignore_eos,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
|
||||
test_output = await request_func(request_func_input=test_input)
|
||||
@ -312,7 +312,8 @@ async def benchmark(
|
||||
output_len=test_output_len,
|
||||
logprobs=logprobs,
|
||||
multi_modal_content=test_mm_content,
|
||||
ignore_eos=ignore_eos)
|
||||
ignore_eos=ignore_eos,
|
||||
extra_body=extra_body)
|
||||
profile_output = await request_func(request_func_input=profile_input)
|
||||
if profile_output.success:
|
||||
print("Profiler started")
|
||||
@ -362,7 +363,8 @@ async def benchmark(
|
||||
output_len=output_len,
|
||||
logprobs=logprobs,
|
||||
multi_modal_content=mm_content,
|
||||
ignore_eos=ignore_eos)
|
||||
ignore_eos=ignore_eos,
|
||||
extra_body=extra_body)
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
limited_request_func(request_func_input=request_func_input,
|
||||
@ -595,14 +597,38 @@ def main(args: argparse.Namespace):
|
||||
args.hf_split = "train"
|
||||
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_class = ConversationDataset
|
||||
elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_class = AIMODataset
|
||||
args.hf_split = "train"
|
||||
elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_class = ASRDataset
|
||||
args.hf_split = "train"
|
||||
else:
|
||||
supported_datasets = set([
|
||||
dataset_name for cls in HuggingFaceDataset.__subclasses__()
|
||||
for dataset_name in cls.SUPPORTED_DATASET_PATHS
|
||||
])
|
||||
raise ValueError(
|
||||
f"Unsupported dataset path: {args.dataset_path}. "
|
||||
"Huggingface dataset only supports dataset_path"
|
||||
f" from one of following: {supported_datasets}. "
|
||||
"Please consider contributing if you would "
|
||||
"like to add support for additional dataset formats.")
|
||||
|
||||
if (dataset_class.IS_MULTIMODAL and backend not in \
|
||||
["openai-chat", "openai-audio"]):
|
||||
# multi-modal benchmark is only available on OpenAI Chat backend.
|
||||
raise ValueError(
|
||||
"Multi-modal content is only supported on 'openai-chat' and " \
|
||||
"'openai-audio' backend.")
|
||||
input_requests = dataset_class(
|
||||
dataset_path=args.dataset_path,
|
||||
dataset_subset=args.hf_subset,
|
||||
dataset_split=args.hf_split,
|
||||
random_seed=args.seed,
|
||||
).sample(
|
||||
num_requests=args.num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
random_seed=args.seed,
|
||||
output_len=args.hf_output_len,
|
||||
)
|
||||
|
||||
@ -637,6 +663,26 @@ def main(args: argparse.Namespace):
|
||||
raise ValueError(f"Unknown dataset: {args.dataset_name}") from err
|
||||
goodput_config_dict = check_goodput_args(args)
|
||||
|
||||
# Collect the sampling parameters.
|
||||
sampling_params = {
|
||||
k: v
|
||||
for k, v in {
|
||||
"top_p": args.top_p,
|
||||
"top_k": args.top_k,
|
||||
"min_p": args.min_p,
|
||||
"temperature": args.temperature
|
||||
}.items() if v is not None
|
||||
}
|
||||
|
||||
# Sampling parameters are only supported by openai-compatible backend.
|
||||
if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS:
|
||||
raise ValueError(
|
||||
"Sampling parameters are only supported by openai-compatible "
|
||||
"backends.")
|
||||
|
||||
if "temperature" not in sampling_params:
|
||||
sampling_params["temperature"] = 0.0 # Default to greedy decoding.
|
||||
|
||||
# Avoid GC processing "static" data - reduce pause times.
|
||||
gc.collect()
|
||||
gc.freeze()
|
||||
@ -663,10 +709,11 @@ def main(args: argparse.Namespace):
|
||||
goodput_config_dict=goodput_config_dict,
|
||||
max_concurrency=args.max_concurrency,
|
||||
lora_modules=args.lora_modules,
|
||||
extra_body=sampling_params,
|
||||
))
|
||||
|
||||
# Save config and results to json
|
||||
if args.save_result:
|
||||
if args.save_result or args.append_result:
|
||||
result_json: dict[str, Any] = {}
|
||||
|
||||
# Setup
|
||||
@ -687,6 +734,14 @@ def main(args: argparse.Namespace):
|
||||
raise ValueError(
|
||||
"Invalid metadata format. Please use KEY=VALUE format."
|
||||
)
|
||||
# Traffic
|
||||
result_json["request_rate"] = (args.request_rate if args.request_rate
|
||||
< float("inf") else "inf")
|
||||
result_json["burstiness"] = args.burstiness
|
||||
result_json["max_concurrency"] = args.max_concurrency
|
||||
|
||||
# Merge with benchmark result
|
||||
result_json = {**result_json, **benchmark_result}
|
||||
|
||||
if not args.save_detailed:
|
||||
# Remove fields with too many data points
|
||||
@ -697,15 +752,6 @@ def main(args: argparse.Namespace):
|
||||
if field in result_json:
|
||||
del result_json[field]
|
||||
|
||||
# Traffic
|
||||
result_json["request_rate"] = (args.request_rate if args.request_rate
|
||||
< float("inf") else "inf")
|
||||
result_json["burstiness"] = args.burstiness
|
||||
result_json["max_concurrency"] = args.max_concurrency
|
||||
|
||||
# Merge with benchmark result
|
||||
result_json = {**result_json, **benchmark_result}
|
||||
|
||||
# Save to file
|
||||
base_model_id = model_id.split("/")[-1]
|
||||
max_concurrency_str = (f"-concurrency{args.max_concurrency}"
|
||||
@ -715,7 +761,12 @@ def main(args: argparse.Namespace):
|
||||
file_name = args.result_filename
|
||||
if args.result_dir:
|
||||
file_name = os.path.join(args.result_dir, file_name)
|
||||
with open(file_name, "w", encoding='utf-8') as outfile:
|
||||
with open(file_name,
|
||||
mode="a+" if args.append_result else "w",
|
||||
encoding='utf-8') as outfile:
|
||||
# Append a newline.
|
||||
if args.append_result and outfile.tell() != 0:
|
||||
outfile.write("\n")
|
||||
json.dump(result_json, outfile)
|
||||
save_to_pytorch_benchmark_format(args, result_json, file_name)
|
||||
|
||||
@ -847,6 +898,11 @@ if __name__ == "__main__":
|
||||
help="When saving the results, whether to include per request "
|
||||
"information such as response, error, ttfs, tpots, etc.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--append-result",
|
||||
action="store_true",
|
||||
help="Append the benchmark result to the existing json file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata",
|
||||
metavar="KEY=VALUE",
|
||||
@ -880,7 +936,7 @@ if __name__ == "__main__":
|
||||
"--percentile-metrics",
|
||||
type=str,
|
||||
default="ttft,tpot,itl",
|
||||
help="Comma-seperated list of selected metrics to report percentils. "
|
||||
help="Comma-separated list of selected metrics to report percentils. "
|
||||
"This argument specifies the metrics to report percentiles. "
|
||||
"Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". "
|
||||
"Default value is \"ttft,tpot,itl\".")
|
||||
@ -888,7 +944,7 @@ if __name__ == "__main__":
|
||||
"--metric-percentiles",
|
||||
type=str,
|
||||
default="99",
|
||||
help="Comma-seperated list of percentiles for selected metrics. "
|
||||
help="Comma-separated list of percentiles for selected metrics. "
|
||||
"To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". "
|
||||
"Default value is \"99\". "
|
||||
"Use \"--percentile-metrics\" to select metrics.",
|
||||
@ -955,18 +1011,23 @@ if __name__ == "__main__":
|
||||
random_group.add_argument(
|
||||
"--random-range-ratio",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Range of sampled ratio of input/output length, "
|
||||
"used only for random sampling.",
|
||||
default=0.0,
|
||||
help="Range ratio for sampling input/output length, "
|
||||
"used only for random sampling. Must be in the range [0, 1) to define "
|
||||
"a symmetric sampling range"
|
||||
"[length * (1 - range_ratio), length * (1 + range_ratio)].",
|
||||
)
|
||||
random_group.add_argument(
|
||||
"--random-prefix-len",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of fixed prefix tokens before random "
|
||||
" context. The length range of context in a random "
|
||||
" request is [random-prefix-len, "
|
||||
" random-prefix-len + random-prefix-len * random-range-ratio).")
|
||||
help=("Number of fixed prefix tokens before the random context "
|
||||
"in a request. "
|
||||
"The total input length is the sum of `random-prefix-len` and "
|
||||
"a random "
|
||||
"context length sampled from [input_len * (1 - range_ratio), "
|
||||
"input_len * (1 + range_ratio)]."),
|
||||
)
|
||||
|
||||
hf_group = parser.add_argument_group("hf dataset options")
|
||||
hf_group.add_argument("--hf-subset",
|
||||
@ -985,6 +1046,33 @@ if __name__ == "__main__":
|
||||
"from the sampled HF dataset.",
|
||||
)
|
||||
|
||||
sampling_group = parser.add_argument_group("sampling parameters")
|
||||
sampling_group.add_argument(
|
||||
"--top-p",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Top-p sampling parameter. Only has effect on openai-compatible "
|
||||
"backends.")
|
||||
sampling_group.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Top-k sampling parameter. Only has effect on openai-compatible "
|
||||
"backends.")
|
||||
sampling_group.add_argument(
|
||||
"--min-p",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Min-p sampling parameter. Only has effect on openai-compatible "
|
||||
"backends.")
|
||||
sampling_group.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Temperature sampling parameter. Only has effect on "
|
||||
"openai-compatible backends. If not specified, default to greedy "
|
||||
"decoding (i.e. temperature==0.0).")
|
||||
|
||||
parser.add_argument(
|
||||
'--tokenizer-mode',
|
||||
type=str,
|
||||
|
||||
@ -11,7 +11,7 @@ On the client side, run:
|
||||
--model <your_model> \
|
||||
--dataset json \
|
||||
--structured-output-ratio 1.0 \
|
||||
--structured-output-backend xgrammar \
|
||||
--structured-output-backend auto \
|
||||
--request-rate 10 \
|
||||
--num-prompts 1000
|
||||
|
||||
@ -51,7 +51,7 @@ try:
|
||||
except ImportError:
|
||||
from argparse import ArgumentParser as FlexibleArgumentParser
|
||||
|
||||
from vllm.v1.structured_output.utils import (
|
||||
from vllm.v1.structured_output.backend_xgrammar import (
|
||||
has_xgrammar_unsupported_json_features)
|
||||
|
||||
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
|
||||
@ -130,10 +130,11 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
||||
"description":
|
||||
"An unique optional field to avoid cached schemas"
|
||||
}
|
||||
else:
|
||||
json_schemas = [schema] * args.num_prompts
|
||||
|
||||
def gen_prompt(index: int):
|
||||
schema = json_schemas[index % len(json_schemas)]
|
||||
return f"Generate an example of a user profile given the following schema: {json.dumps(schema)}" # noqa: E501
|
||||
return f"Generate an example of a user profile given the following schema: {json.dumps(get_schema(index))}" # noqa: E501
|
||||
|
||||
def get_schema(index: int):
|
||||
return json_schemas[index % len(json_schemas)]
|
||||
@ -149,17 +150,17 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
||||
|
||||
elif args.dataset == "grammar":
|
||||
schema = """
|
||||
?start: select_statement
|
||||
root ::= select_statement
|
||||
|
||||
?select_statement: "SELECT " column_list " FROM " table_name
|
||||
select_statement ::= "SELECT " column " from " table " where " condition
|
||||
|
||||
?column_list: column_name ("," column_name)*
|
||||
column ::= "col_1 " | "col_2 "
|
||||
|
||||
?table_name: identifier
|
||||
table ::= "table_1 " | "table_2 "
|
||||
|
||||
?column_name: identifier
|
||||
condition ::= column "= " number
|
||||
|
||||
?identifier: /[a-zA-Z_][a-zA-Z0-9_]*/
|
||||
number ::= "1 " | "2 "
|
||||
"""
|
||||
prompt = "Generate an SQL query to show the 'username' \
|
||||
and 'email' from the 'users' table."
|
||||
@ -963,7 +964,7 @@ if __name__ == "__main__":
|
||||
"--percentile-metrics",
|
||||
type=str,
|
||||
default="ttft,tpot,itl",
|
||||
help="Comma-seperated list of selected metrics to report percentils. "
|
||||
help="Comma-separated list of selected metrics to report percentils. "
|
||||
"This argument specifies the metrics to report percentiles. "
|
||||
"Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". "
|
||||
"Default value is \"ttft,tpot,itl\".")
|
||||
@ -971,7 +972,7 @@ if __name__ == "__main__":
|
||||
"--metric-percentiles",
|
||||
type=str,
|
||||
default="99",
|
||||
help="Comma-seperated list of percentiles for selected metrics. "
|
||||
help="Comma-separated list of percentiles for selected metrics. "
|
||||
"To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". "
|
||||
"Default value is \"99\". "
|
||||
"Use \"--percentile-metrics\" to select metrics.",
|
||||
@ -996,12 +997,14 @@ if __name__ == "__main__":
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Ratio of Structured Outputs requests")
|
||||
parser.add_argument(
|
||||
"--structured-output-backend",
|
||||
type=str,
|
||||
choices=["outlines", "lm-format-enforcer", "xgrammar", "guidance"],
|
||||
default="xgrammar",
|
||||
help="Backend to use for structured outputs")
|
||||
parser.add_argument("--structured-output-backend",
|
||||
type=str,
|
||||
choices=[
|
||||
"outlines", "lm-format-enforcer", "xgrammar",
|
||||
"guidance", "auto"
|
||||
],
|
||||
default="auto",
|
||||
help="Backend to use for structured outputs")
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
@ -11,9 +11,11 @@ from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import uvloop
|
||||
from benchmark_dataset import SampleRequest
|
||||
from benchmark_utils import (convert_to_pytorch_benchmark_format, get_requests,
|
||||
validate_dataset, write_to_json)
|
||||
from benchmark_dataset import (AIMODataset, BurstGPTDataset,
|
||||
ConversationDataset, InstructCoderDataset,
|
||||
RandomDataset, SampleRequest, ShareGPTDataset,
|
||||
SonnetDataset, VisionArenaDataset)
|
||||
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||
from tqdm import tqdm
|
||||
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
||||
PreTrainedTokenizerBase)
|
||||
@ -211,14 +213,17 @@ def run_hf(
|
||||
max_prompt_len = 0
|
||||
max_output_len = 0
|
||||
for i in range(len(requests)):
|
||||
prompt, prompt_len, output_len = requests[i]
|
||||
prompt = requests[i].prompt
|
||||
prompt_len = requests[i].prompt_len
|
||||
output_len = requests[i].expected_output_len
|
||||
# Add the prompt to the batch.
|
||||
batch.append(prompt)
|
||||
max_prompt_len = max(max_prompt_len, prompt_len)
|
||||
max_output_len = max(max_output_len, output_len)
|
||||
if len(batch) < max_batch_size and i != len(requests) - 1:
|
||||
# Check if we can add more requests to the batch.
|
||||
_, next_prompt_len, next_output_len = requests[i + 1]
|
||||
next_prompt_len = requests[i + 1].prompt_len
|
||||
next_output_len = requests[i + 1].expected_output_len
|
||||
if (max(max_prompt_len, next_prompt_len) +
|
||||
max(max_output_len, next_output_len)) <= 2048:
|
||||
# We can add more requests to the batch.
|
||||
@ -285,6 +290,62 @@ def save_to_pytorch_benchmark_format(args: argparse.Namespace,
|
||||
write_to_json(pt_file, pt_records)
|
||||
|
||||
|
||||
def get_requests(args, tokenizer):
|
||||
# Common parameters for all dataset types.
|
||||
common_kwargs = {
|
||||
"dataset_path": args.dataset_path,
|
||||
"random_seed": args.seed,
|
||||
}
|
||||
sample_kwargs = {
|
||||
"tokenizer": tokenizer,
|
||||
"lora_path": args.lora_path,
|
||||
"max_loras": args.max_loras,
|
||||
"num_requests": args.num_prompts,
|
||||
"input_len": args.input_len,
|
||||
"output_len": args.output_len,
|
||||
}
|
||||
|
||||
if args.dataset_path is None or args.dataset_name == "random":
|
||||
sample_kwargs["range_ratio"] = args.random_range_ratio
|
||||
sample_kwargs["prefix_len"] = args.prefix_len
|
||||
dataset_cls = RandomDataset
|
||||
elif args.dataset_name == "sharegpt":
|
||||
dataset_cls = ShareGPTDataset
|
||||
if args.backend == "vllm-chat":
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
elif args.dataset_name == "sonnet":
|
||||
assert tokenizer.chat_template or tokenizer.default_chat_template, (
|
||||
"Tokenizer/model must have chat template for sonnet dataset.")
|
||||
dataset_cls = SonnetDataset
|
||||
sample_kwargs["prefix_len"] = args.prefix_len
|
||||
sample_kwargs["return_prompt_formatted"] = True
|
||||
elif args.dataset_name == "burstgpt":
|
||||
dataset_cls = BurstGPTDataset
|
||||
elif args.dataset_name == "hf":
|
||||
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = VisionArenaDataset
|
||||
common_kwargs['dataset_subset'] = None
|
||||
common_kwargs['dataset_split'] = "train"
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = InstructCoderDataset
|
||||
common_kwargs['dataset_split'] = "train"
|
||||
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = ConversationDataset
|
||||
common_kwargs['dataset_subset'] = args.hf_subset
|
||||
common_kwargs['dataset_split'] = args.hf_split
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = AIMODataset
|
||||
common_kwargs['dataset_subset'] = None
|
||||
common_kwargs['dataset_split'] = "train"
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset name: {args.dataset_name}")
|
||||
# Remove None values
|
||||
sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None}
|
||||
return dataset_cls(**common_kwargs).sample(**sample_kwargs)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
if args.seed is None:
|
||||
args.seed = 0
|
||||
@ -293,7 +354,7 @@ def main(args: argparse.Namespace):
|
||||
# Sample the requests.
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer, trust_remote_code=args.trust_remote_code)
|
||||
requests = get_requests(args.num_prompts, args, tokenizer)
|
||||
requests = get_requests(args, tokenizer)
|
||||
is_multi_modal = any(request.multi_modal_data is not None
|
||||
for request in requests)
|
||||
request_outputs: Optional[list[RequestOutput]] = None
|
||||
@ -394,8 +455,48 @@ def validate_args(args):
|
||||
if args.backend not in valid_backends:
|
||||
raise ValueError(f"Unsupported backend: {args.backend}")
|
||||
|
||||
# === Dataset Validation ===
|
||||
validate_dataset(args)
|
||||
# === Dataset Configuration ===
|
||||
if not args.dataset and not args.dataset_path:
|
||||
print(
|
||||
"When dataset path is not set, it will default to random dataset")
|
||||
args.dataset_name = 'random'
|
||||
if args.input_len is None:
|
||||
raise ValueError("input_len must be provided for a random dataset")
|
||||
|
||||
# === Dataset Name Specific Checks ===
|
||||
# --hf-subset and --hf-split: only used
|
||||
# when dataset_name is 'hf'
|
||||
if args.dataset_name != "hf" and (
|
||||
getattr(args, "hf_subset", None) is not None
|
||||
or getattr(args, "hf_split", None) is not None):
|
||||
warnings.warn("--hf-subset and --hf-split will be ignored \
|
||||
since --dataset-name is not 'hf'.",
|
||||
stacklevel=2)
|
||||
elif args.dataset_name == "hf":
|
||||
if args.dataset_path in (
|
||||
VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys()
|
||||
| ConversationDataset.SUPPORTED_DATASET_PATHS):
|
||||
assert args.backend == "vllm-chat", f"{args.dataset_path} needs to use vllm-chat as the backend." #noqa: E501
|
||||
elif args.dataset_path in (InstructCoderDataset.SUPPORTED_DATASET_PATHS
|
||||
| AIMODataset.SUPPORTED_DATASET_PATHS):
|
||||
assert args.backend == "vllm", f"{args.dataset_path} needs to use vllm as the backend." #noqa: E501
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{args.dataset_path} is not supported by hf dataset.")
|
||||
|
||||
# --random-range-ratio: only used when dataset_name is 'random'
|
||||
if args.dataset_name != 'random' and args.random_range_ratio is not None:
|
||||
warnings.warn("--random-range-ratio will be ignored since \
|
||||
--dataset-name is not 'random'.",
|
||||
stacklevel=2)
|
||||
|
||||
# --prefix-len: only used when dataset_name is 'random', 'sonnet', or not
|
||||
# set.
|
||||
if args.dataset_name not in {"random", "sonnet", None
|
||||
} and args.prefix_len is not None:
|
||||
warnings.warn("--prefix-len will be ignored since --dataset-name\
|
||||
is not 'random', 'sonnet', or not set.",
|
||||
stacklevel=2)
|
||||
|
||||
# === LoRA Settings ===
|
||||
if getattr(args, "enable_lora", False) and args.backend != "vllm":
|
||||
@ -422,6 +523,13 @@ def validate_args(args):
|
||||
raise ValueError(
|
||||
"Tokenizer must be the same as the model for MII backend.")
|
||||
|
||||
# --data-parallel is not supported currently.
|
||||
# https://github.com/vllm-project/vllm/issues/16222
|
||||
if args.data_parallel_size > 1:
|
||||
raise ValueError(
|
||||
"Data parallel is not supported in offline benchmark, \
|
||||
please use benchmark serving instead")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(description="Benchmark the throughput.")
|
||||
@ -435,6 +543,14 @@ if __name__ == "__main__":
|
||||
choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"],
|
||||
help="Name of the dataset to benchmark on.",
|
||||
default="sharegpt")
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the ShareGPT dataset, will be deprecated in\
|
||||
the next release. The dataset is expected to "
|
||||
"be a json in form of list[dict[..., conversations: "
|
||||
"list[dict[..., value: <prompt_or_response>]]]]")
|
||||
parser.add_argument("--dataset-path",
|
||||
type=str,
|
||||
default=None,
|
||||
@ -485,18 +601,30 @@ if __name__ == "__main__":
|
||||
default=None,
|
||||
help="Path to the lora adapters to use. This can be an absolute path, "
|
||||
"a relative path, or a Hugging Face model identifier.")
|
||||
parser.add_argument("--prefix-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of prefix tokens per request."
|
||||
"This is for the RandomDataset and SonnetDataset")
|
||||
parser.add_argument(
|
||||
"--prefix-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help=f"Number of prefix tokens to be used in RandomDataset "
|
||||
"and SonnetDataset. For RandomDataset, the total input "
|
||||
"length is the sum of prefix-len (default: "
|
||||
f"{RandomDataset.DEFAULT_PREFIX_LEN}) and a random context length "
|
||||
"sampled from [input_len * (1 - range_ratio), "
|
||||
"input_len * (1 + range_ratio)]. For SonnetDataset, "
|
||||
f"prefix_len (default: {SonnetDataset.DEFAULT_PREFIX_LEN}) "
|
||||
"controls how much of the input is fixed lines versus "
|
||||
"random lines, but the total input length remains approximately "
|
||||
"input_len tokens.")
|
||||
# random dataset
|
||||
parser.add_argument(
|
||||
"--random-range-ratio",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Range of sampled ratio of input/output length, "
|
||||
"used only for RandomDataSet.",
|
||||
help=f"Range ratio (default : {RandomDataset.DEFAULT_RANGE_RATIO}) "
|
||||
"for sampling input/output length, "
|
||||
"used only for RandomDataset. Must be in the range [0, 1) to "
|
||||
"define a symmetric sampling range "
|
||||
"[length * (1 - range_ratio), length * (1 + range_ratio)].",
|
||||
)
|
||||
|
||||
# hf dtaset
|
||||
|
||||
@ -4,14 +4,8 @@ import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any
|
||||
|
||||
from benchmark_dataset import (BurstGPTDataset, ConversationDataset,
|
||||
InstructCoderDataset, RandomDataset,
|
||||
SampleRequest, ShareGPTDataset, SonnetDataset,
|
||||
VisionArenaDataset)
|
||||
|
||||
|
||||
def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
|
||||
metrics: dict[str, list],
|
||||
@ -73,113 +67,3 @@ class InfEncoder(json.JSONEncoder):
|
||||
def write_to_json(filename: str, records: list) -> None:
|
||||
with open(filename, "w") as f:
|
||||
json.dump(records, f, cls=InfEncoder)
|
||||
|
||||
|
||||
def get_requests(num_requests: int, args: argparse.Namespace,
|
||||
tokenizer: Any) -> list[SampleRequest]:
|
||||
"""
|
||||
Sample the requests for the benchmark.
|
||||
"""
|
||||
# Common parameters for all dataset types.
|
||||
common_kwargs = {
|
||||
"dataset_path": args.dataset_path,
|
||||
"random_seed": args.seed,
|
||||
}
|
||||
sample_kwargs = {
|
||||
"tokenizer": tokenizer,
|
||||
"lora_path": args.lora_path,
|
||||
"max_loras": args.max_loras,
|
||||
"num_requests": num_requests,
|
||||
"input_len": args.input_len,
|
||||
"output_len": args.output_len,
|
||||
}
|
||||
|
||||
if args.dataset_path is None or args.dataset_name == "random":
|
||||
sample_kwargs["range_ratio"] = args.random_range_ratio
|
||||
sample_kwargs["prefix_len"] = args.prefix_len
|
||||
dataset_cls = RandomDataset
|
||||
elif args.dataset_name == "sharegpt":
|
||||
dataset_cls = ShareGPTDataset
|
||||
if getattr(args, "backend", False) and args.backend == "vllm-chat":
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
elif args.dataset_name == "sonnet":
|
||||
assert tokenizer.chat_template or tokenizer.default_chat_template, (
|
||||
"Tokenizer/model must have chat template for sonnet dataset.")
|
||||
dataset_cls = SonnetDataset
|
||||
sample_kwargs["prefix_len"] = args.prefix_len
|
||||
sample_kwargs["return_prompt_formatted"] = True
|
||||
elif args.dataset_name == "burstgpt":
|
||||
dataset_cls = BurstGPTDataset
|
||||
elif args.dataset_name == "hf":
|
||||
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = VisionArenaDataset
|
||||
common_kwargs['dataset_subset'] = None
|
||||
common_kwargs['dataset_split'] = "train"
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = InstructCoderDataset
|
||||
common_kwargs['dataset_split'] = "train"
|
||||
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = ConversationDataset
|
||||
common_kwargs['dataset_subset'] = args.hf_subset
|
||||
common_kwargs['dataset_split'] = args.hf_split
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset name: {args.dataset_name}")
|
||||
# Remove None values
|
||||
sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None}
|
||||
return dataset_cls(**common_kwargs).sample(**sample_kwargs)
|
||||
|
||||
|
||||
def validate_dataset(args: argparse.Namespace, ):
|
||||
"""
|
||||
Validate the dataset arguments.
|
||||
"""
|
||||
# === Dataset Configuration ===
|
||||
if not args.dataset_path:
|
||||
print(
|
||||
"When dataset path is not set, it will default to random dataset")
|
||||
args.dataset_name = 'random'
|
||||
if args.input_len is None:
|
||||
raise ValueError("input_len must be provided for a random dataset")
|
||||
|
||||
# === Dataset Name Specific Checks ===
|
||||
# --hf-subset and --hf-split: only used
|
||||
# when dataset_name is 'hf'
|
||||
if args.dataset_name != "hf" and (
|
||||
getattr(args, "hf_subset", None) is not None
|
||||
or getattr(args, "hf_split", None) is not None):
|
||||
warnings.warn("--hf-subset and --hf-split will be ignored \
|
||||
since --dataset-name is not 'hf'.",
|
||||
stacklevel=2)
|
||||
elif args.dataset_name == "hf":
|
||||
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
|
||||
assert getattr(
|
||||
args, 'backend', None
|
||||
) and args.backend == "vllm-chat", "VisionArenaDataset needs to use vllm-chat as the backend." #noqa: E501
|
||||
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
|
||||
assert getattr(
|
||||
args, 'backend', None
|
||||
) and args.backend == "vllm", "InstructCoder dataset needs to use vllm as the backend." #noqa: E501
|
||||
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
|
||||
assert getattr(
|
||||
args, 'backend', None
|
||||
) and args.backend == "vllm-chat", "ConversationDataset needs to use vllm-chat as the backend." #noqa: E501
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{args.dataset_path} is not supported by hf dataset.")
|
||||
|
||||
# --random-range-ratio: only used when dataset_name is 'random'
|
||||
if args.dataset_name != 'random' and args.random_range_ratio is not None:
|
||||
warnings.warn("--random-range-ratio will be ignored since \
|
||||
--dataset-name is not 'random'.",
|
||||
stacklevel=2)
|
||||
|
||||
# --prefix-len: only used when dataset_name is 'random', 'sonnet', or not
|
||||
# set.
|
||||
if args.dataset_name not in {"random", "sonnet", None
|
||||
} and args.prefix_len is not None:
|
||||
warnings.warn("--prefix-len will be ignored since --dataset-name\
|
||||
is not 'random', 'sonnet', or not set.",
|
||||
stacklevel=2)
|
||||
|
||||
236
benchmarks/kernels/benchmark_bitblas.py
Normal file
236
benchmarks/kernels/benchmark_bitblas.py
Normal file
@ -0,0 +1,236 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
|
||||
MINIMUM_BITBLAS_VERSION)
|
||||
|
||||
try:
|
||||
import bitblas
|
||||
if bitblas.__version__ < MINIMUM_BITBLAS_VERSION:
|
||||
raise ImportError("bitblas version is wrong. Please "
|
||||
f"install bitblas>={MINIMUM_BITBLAS_VERSION}")
|
||||
except ImportError as e:
|
||||
bitblas_import_exception = e
|
||||
raise ValueError("Trying to use the bitblas backend, but could not import"
|
||||
f"with the following error: {bitblas_import_exception}. "
|
||||
"Please install bitblas through the following command: "
|
||||
f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`"
|
||||
) from bitblas_import_exception
|
||||
|
||||
from bitblas import Matmul, MatmulConfig, auto_detect_nvidia_target
|
||||
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark BitBLAS int4 on a specific target.")
|
||||
|
||||
# Add arguments to the parser
|
||||
parser.add_argument(
|
||||
"--target",
|
||||
type=str,
|
||||
default=auto_detect_nvidia_target(),
|
||||
help="Specify the target device for benchmarking.",
|
||||
)
|
||||
parser.add_argument("--group_size",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Group size for grouped quantization.")
|
||||
parser.add_argument(
|
||||
"--A_dtype",
|
||||
type=str,
|
||||
default="float16",
|
||||
choices=["float16", "float32", "float64", "int32", "int8"],
|
||||
help="Data type of activation A.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--W_dtype",
|
||||
type=str,
|
||||
default="int4",
|
||||
choices=[
|
||||
"float16",
|
||||
"float32",
|
||||
"float64",
|
||||
"int32",
|
||||
"int8",
|
||||
"int4",
|
||||
"int2",
|
||||
"int1",
|
||||
"nf4",
|
||||
"fp4_e2m1",
|
||||
],
|
||||
help="Data type of weight W.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--accum_dtype",
|
||||
type=str,
|
||||
default="float16",
|
||||
choices=["float16", "int32"],
|
||||
help="Data type for accumulation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--out_dtype",
|
||||
type=str,
|
||||
default="float16",
|
||||
choices=["float16", "float32", "int32", "int8"],
|
||||
help="Data type for output.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--layout",
|
||||
type=str,
|
||||
default="nt",
|
||||
choices=["nt", "nn"],
|
||||
help="Matrix layout, 'nt' for non-transpose A and transpose W.",
|
||||
)
|
||||
parser.add_argument("--with_bias",
|
||||
action="store_true",
|
||||
help="Include bias in the benchmark.")
|
||||
parser.add_argument(
|
||||
"--with_scaling",
|
||||
action="store_true",
|
||||
help="Include scaling factor in the quantization.",
|
||||
)
|
||||
parser.add_argument("--with_zeros",
|
||||
action="store_true",
|
||||
help="Include zeros in the quantization.")
|
||||
parser.add_argument(
|
||||
"--zeros_mode",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["original", "rescale", "quantized"],
|
||||
help="Specify the mode for calculating zeros.",
|
||||
)
|
||||
|
||||
# Parse the arguments
|
||||
args = parser.parse_args()
|
||||
|
||||
# Assign arguments to variables
|
||||
target = args.target
|
||||
A_dtype = args.A_dtype
|
||||
W_dtype = args.W_dtype
|
||||
accum_dtype = args.accum_dtype
|
||||
out_dtype = args.out_dtype
|
||||
layout = args.layout
|
||||
with_bias = args.with_bias
|
||||
group_size = args.group_size
|
||||
with_scaling = args.with_scaling
|
||||
with_zeros = args.with_zeros
|
||||
zeros_mode = args.zeros_mode
|
||||
|
||||
# Define a list of shared arguments that repeat in every config
|
||||
shared_args = [
|
||||
A_dtype,
|
||||
W_dtype,
|
||||
out_dtype,
|
||||
accum_dtype,
|
||||
layout,
|
||||
with_bias,
|
||||
group_size,
|
||||
with_scaling,
|
||||
with_zeros,
|
||||
zeros_mode,
|
||||
]
|
||||
|
||||
# Define just the (M, K, N) shapes in a more compact list
|
||||
shapes = [
|
||||
# square test
|
||||
(1, 16384, 16384),
|
||||
# BLOOM-176B
|
||||
(1, 43008, 14336),
|
||||
(1, 14336, 14336),
|
||||
(1, 57344, 14336),
|
||||
(1, 14336, 57344),
|
||||
# OPT-65B
|
||||
(1, 9216, 9216),
|
||||
(1, 36864, 9216),
|
||||
(1, 9216, 36864),
|
||||
(1, 22016, 8192),
|
||||
# LLAMA-70B/65B
|
||||
(1, 8192, 22016),
|
||||
(1, 8192, 8192),
|
||||
(1, 28672, 8192),
|
||||
(1, 8192, 28672),
|
||||
# square test
|
||||
(16384, 16384, 16384),
|
||||
# BLOOM-176B
|
||||
(8192, 43008, 14336),
|
||||
(8192, 14336, 14336),
|
||||
(8192, 57344, 14336),
|
||||
(8192, 14336, 57344),
|
||||
# OPT-65B
|
||||
(8192, 9216, 9216),
|
||||
(8192, 36864, 9216),
|
||||
(8192, 9216, 36864),
|
||||
(8192, 22016, 8192),
|
||||
# LLAMA-70B/65B
|
||||
(8192, 8192, 22016),
|
||||
(8192, 8192, 8192),
|
||||
(8192, 28672, 8192),
|
||||
(8192, 8192, 28672),
|
||||
]
|
||||
|
||||
# Build test shapes with all the shared arguments
|
||||
test_shapes = [(MatmulConfig, Matmul, (*shape, *shared_args))
|
||||
for shape in shapes]
|
||||
|
||||
benchmark_sets = []
|
||||
benchmark_sets.extend(test_shapes)
|
||||
|
||||
benchmark_results = {}
|
||||
for config_class, operator, input_args in benchmark_sets:
|
||||
config = config_class(*input_args)
|
||||
matmul = operator(config, target=target, enable_tuning=True)
|
||||
kernel_latency = matmul.profile_latency()
|
||||
|
||||
print("Time cost is: {:.3f} ms".format(kernel_latency))
|
||||
|
||||
profile_config = {
|
||||
f"{operator.__name__}-{'-'.join([str(i) for i in input_args])}": {
|
||||
"BitBLAS_top20_latency": kernel_latency,
|
||||
}
|
||||
}
|
||||
|
||||
benchmark_results.update(profile_config)
|
||||
|
||||
# Define headers for the table
|
||||
headers = [
|
||||
"PrimFunc",
|
||||
"Input Arguments",
|
||||
"BitBLAS Top20 Latency",
|
||||
]
|
||||
|
||||
# Calculate column widths for pretty printing
|
||||
col_widths = [0, 0, 0]
|
||||
for config_key, values in benchmark_results.items():
|
||||
args_split = config_key.split("-")
|
||||
func_name = args_split[0]
|
||||
input_args_str = "-".join(args_split[1:])
|
||||
col_widths[0] = max(col_widths[0], len(func_name) + 2, len(headers[0]) + 2)
|
||||
col_widths[1] = max(col_widths[1],
|
||||
len(input_args_str) + 2,
|
||||
len(headers[1]) + 2)
|
||||
col_widths[2] = max(col_widths[2],
|
||||
len(f"{values['BitBLAS_top20_latency']:.3f} ms") + 2,
|
||||
len(headers[2]) + 2)
|
||||
# break only if you want to measure widths from a single example;
|
||||
# otherwise, let it loop over all items.
|
||||
|
||||
# Print header
|
||||
for i, header in enumerate(headers):
|
||||
headers[i] = header.ljust(col_widths[i])
|
||||
print("".join(headers))
|
||||
print("-" * sum(col_widths))
|
||||
|
||||
# Print rows
|
||||
for config_key, values in benchmark_results.items():
|
||||
args_split = config_key.split("-")
|
||||
func_name = args_split[0]
|
||||
input_args_str = "-".join(args_split[1:])
|
||||
row = [
|
||||
func_name,
|
||||
input_args_str,
|
||||
f"{values['BitBLAS_top20_latency']:.3f} ms",
|
||||
]
|
||||
row_str = "".join(
|
||||
[str(cell).ljust(col_widths[idx]) for idx, cell in enumerate(row)])
|
||||
print(row_str)
|
||||
@ -17,8 +17,14 @@ from torch.utils.benchmark import Measurement as TMeasurement
|
||||
from utils import ArgPool, Bench, CudaGraphBenchParams
|
||||
from weight_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm.lora.ops.triton_ops import LoRAKernelMeta, lora_expand, lora_shrink
|
||||
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
if HAS_TRITON:
|
||||
from vllm.lora.ops.triton_ops import (LoRAKernelMeta, lora_expand,
|
||||
lora_shrink)
|
||||
from vllm.lora.ops.triton_ops.utils import (_LORA_A_PTR_DICT,
|
||||
_LORA_B_PTR_DICT)
|
||||
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
|
||||
|
||||
@ -30,19 +30,18 @@ class BenchmarkConfig(TypedDict):
|
||||
num_stages: int
|
||||
|
||||
|
||||
def benchmark_config(
|
||||
config: BenchmarkConfig,
|
||||
num_tokens: int,
|
||||
num_experts: int,
|
||||
shard_intermediate_size: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
num_iters: int = 100,
|
||||
block_quant_shape: List[int] = None,
|
||||
) -> float:
|
||||
def benchmark_config(config: BenchmarkConfig,
|
||||
num_tokens: int,
|
||||
num_experts: int,
|
||||
shard_intermediate_size: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
num_iters: int = 100,
|
||||
block_quant_shape: List[int] = None,
|
||||
use_deep_gemm: bool = False) -> float:
|
||||
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
if use_int8_w8a16:
|
||||
@ -115,22 +114,41 @@ def benchmark_config(
|
||||
def run():
|
||||
from vllm.model_executor.layers.fused_moe import override_config
|
||||
with override_config(config):
|
||||
fused_moe(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
renormalize=True,
|
||||
inplace=True,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_quant_shape,
|
||||
)
|
||||
if use_deep_gemm:
|
||||
topk_weights, topk_ids = fused_topk(x, input_gating, topk,
|
||||
False)
|
||||
return fused_experts(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=True,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_quant_shape,
|
||||
allow_deep_gemm=True,
|
||||
)
|
||||
else:
|
||||
fused_moe(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
renormalize=True,
|
||||
inplace=True,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_quant_shape,
|
||||
)
|
||||
|
||||
# JIT compilation & warmup
|
||||
run()
|
||||
@ -366,6 +384,7 @@ class BenchmarkWorker:
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
block_quant_shape: List[int] = None,
|
||||
use_deep_gemm: bool = False,
|
||||
) -> tuple[dict[str, int], float]:
|
||||
current_platform.seed_everything(self.seed)
|
||||
dtype_str = get_config_dtype_str(dtype,
|
||||
@ -396,7 +415,8 @@ class BenchmarkWorker:
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
num_iters=100,
|
||||
block_quant_shape=block_quant_shape)
|
||||
block_quant_shape=block_quant_shape,
|
||||
use_deep_gemm=use_deep_gemm)
|
||||
return config, kernel_time
|
||||
|
||||
def tune(
|
||||
@ -411,6 +431,7 @@ class BenchmarkWorker:
|
||||
use_int8_w8a16: bool,
|
||||
search_space: list[dict[str, int]],
|
||||
block_quant_shape: list[int],
|
||||
use_deep_gemm: bool,
|
||||
) -> dict[str, int]:
|
||||
best_config = None
|
||||
best_time = float("inf")
|
||||
@ -436,7 +457,8 @@ class BenchmarkWorker:
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
num_iters=20,
|
||||
block_quant_shape=block_quant_shape)
|
||||
block_quant_shape=block_quant_shape,
|
||||
use_deep_gemm=use_deep_gemm)
|
||||
except triton.runtime.autotuner.OutOfResources:
|
||||
# Some configurations may be invalid and fail to compile.
|
||||
continue
|
||||
@ -531,6 +553,8 @@ def main(args: argparse.Namespace):
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
else:
|
||||
# Support for llama4
|
||||
config = config.get_text_config()
|
||||
# Default: Mixtral.
|
||||
E = config.num_local_experts
|
||||
topk = config.num_experts_per_tok
|
||||
@ -550,6 +574,8 @@ def main(args: argparse.Namespace):
|
||||
else:
|
||||
batch_sizes = [args.batch_size]
|
||||
|
||||
use_deep_gemm = bool(args.use_deep_gemm)
|
||||
|
||||
ray.init()
|
||||
num_gpus = int(ray.available_resources()["GPU"])
|
||||
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
|
||||
@ -572,10 +598,10 @@ def main(args: argparse.Namespace):
|
||||
|
||||
start = time.time()
|
||||
configs = _distribute(
|
||||
"tune",
|
||||
[(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype,
|
||||
use_fp8_w8a8, use_int8_w8a16, search_space, block_quant_shape)
|
||||
for batch_size in batch_sizes])
|
||||
"tune", [(batch_size, E, shard_intermediate_size, hidden_size,
|
||||
topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space,
|
||||
block_quant_shape, use_deep_gemm)
|
||||
for batch_size in batch_sizes])
|
||||
best_configs = {
|
||||
M: sort_config(config)
|
||||
for M, config in zip(batch_sizes, configs)
|
||||
@ -589,7 +615,7 @@ def main(args: argparse.Namespace):
|
||||
outputs = _distribute(
|
||||
"benchmark",
|
||||
[(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype,
|
||||
use_fp8_w8a8, use_int8_w8a16, block_quant_shape)
|
||||
use_fp8_w8a8, use_int8_w8a16, block_quant_shape, use_deep_gemm)
|
||||
for batch_size in batch_sizes])
|
||||
|
||||
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
|
||||
@ -611,6 +637,7 @@ if __name__ == "__main__":
|
||||
type=str,
|
||||
choices=["auto", "fp8_w8a8", "int8_w8a16"],
|
||||
default="auto")
|
||||
parser.add_argument("--use-deep-gemm", action="store_true")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--batch-size", type=int, required=False)
|
||||
parser.add_argument("--tune", action="store_true")
|
||||
|
||||
@ -33,8 +33,6 @@ endif()
|
||||
|
||||
if(MACOSX_FOUND)
|
||||
list(APPEND CXX_COMPILE_FLAGS
|
||||
"-Xpreprocessor"
|
||||
"-fopenmp"
|
||||
"-DVLLM_CPU_EXTENSION")
|
||||
else()
|
||||
list(APPEND CXX_COMPILE_FLAGS
|
||||
@ -197,6 +195,7 @@ set(VLLM_EXT_SRC
|
||||
if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cpu/quant.cpp"
|
||||
"csrc/cpu/shm.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
endif()
|
||||
|
||||
|
||||
@ -38,7 +38,7 @@ else()
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn
|
||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||
GIT_TAG dc9d410b3e2d6534a4c70724c2515f4def670a22
|
||||
GIT_TAG 8798f27777fb57f447070301bf33a9f9c607f491
|
||||
GIT_PROGRESS TRUE
|
||||
# Don't share the vllm-flash-attn build between build types
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
|
||||
178
csrc/attention/merge_attn_states.cu
Normal file
178
csrc/attention/merge_attn_states.cu
Normal file
@ -0,0 +1,178 @@
|
||||
#include <optional>
|
||||
#include <torch/all.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <algorithm>
|
||||
|
||||
#include "attention_dtypes.h"
|
||||
#include "attention_utils.cuh"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
|
||||
// can be used to combine partial attention results (in the split-KV case)
|
||||
template <typename scalar_t, const uint NUM_THREADS>
|
||||
__global__ void merge_attn_states_kernel(
|
||||
scalar_t* output, float* output_lse, const scalar_t* prefix_output,
|
||||
const float* prefix_lse, const scalar_t* suffix_output,
|
||||
const float* suffix_lse, const uint num_tokens, const uint num_heads,
|
||||
const uint head_size) {
|
||||
using pack_128b_t = uint4;
|
||||
const uint pack_size = 16 / sizeof(scalar_t);
|
||||
const uint threads_per_head = head_size / pack_size;
|
||||
|
||||
const uint global_idx = blockIdx.x * NUM_THREADS + threadIdx.x;
|
||||
const uint token_head_threads = num_tokens * num_heads * threads_per_head;
|
||||
|
||||
if (global_idx >= token_head_threads) return;
|
||||
|
||||
// global_idx -> token_idx + head_idx + pack_idx
|
||||
const uint token_head_idx = global_idx / threads_per_head;
|
||||
const uint pack_idx = global_idx % threads_per_head;
|
||||
|
||||
const uint token_idx = token_head_idx / num_heads;
|
||||
const uint head_idx = token_head_idx % num_heads;
|
||||
|
||||
const uint pack_offset = pack_idx * pack_size; // (0~15)*8, etc.
|
||||
const uint head_offset =
|
||||
token_idx * num_heads * head_size + head_idx * head_size;
|
||||
const scalar_t* prefix_head_ptr = prefix_output + head_offset;
|
||||
const scalar_t* suffix_head_ptr = suffix_output + head_offset;
|
||||
scalar_t* output_head_ptr = output + head_offset;
|
||||
|
||||
float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
|
||||
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
|
||||
p_lse = std::isinf(p_lse) ? -std::numeric_limits<float>::infinity() : p_lse;
|
||||
s_lse = std::isinf(s_lse) ? -std::numeric_limits<float>::infinity() : s_lse;
|
||||
|
||||
const float max_lse = fmaxf(p_lse, s_lse);
|
||||
p_lse = p_lse - max_lse;
|
||||
s_lse = s_lse - max_lse;
|
||||
const float p_se = expf(p_lse);
|
||||
const float s_se = expf(s_lse);
|
||||
const float out_se = p_se + s_se;
|
||||
const float p_scale = p_se / out_se;
|
||||
const float s_scale = s_se / out_se;
|
||||
|
||||
if (pack_offset < head_size) {
|
||||
// Pack 128b load
|
||||
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(
|
||||
prefix_head_ptr)[pack_offset / pack_size];
|
||||
pack_128b_t s_out_pack = reinterpret_cast<const pack_128b_t*>(
|
||||
suffix_head_ptr)[pack_offset / pack_size];
|
||||
pack_128b_t o_out_pack;
|
||||
|
||||
#pragma unroll
|
||||
for (uint i = 0; i < pack_size; ++i) {
|
||||
// Always use float for FMA to keep high precision.
|
||||
// half(uint16_t), bfloat16, float -> float.
|
||||
const float p_out_f =
|
||||
vllm::to_float(reinterpret_cast<const scalar_t*>(&p_out_pack)[i]);
|
||||
const float s_out_f =
|
||||
vllm::to_float(reinterpret_cast<const scalar_t*>(&s_out_pack)[i]);
|
||||
// fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale)
|
||||
const float o_out_f = p_out_f * p_scale + (s_out_f * s_scale);
|
||||
// float -> half(uint16_t), bfloat16, float.
|
||||
vllm::from_float(reinterpret_cast<scalar_t*>(&o_out_pack)[i], o_out_f);
|
||||
}
|
||||
|
||||
// Pack 128b storage
|
||||
reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] =
|
||||
o_out_pack;
|
||||
}
|
||||
// We only need to write to output_lse once per head.
|
||||
if (output_lse != nullptr && pack_idx == 0) {
|
||||
float out_lse = logf(out_se) + max_lse;
|
||||
output_lse[head_idx * num_tokens + token_idx] = out_lse;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
// The following macro is used to dispatch the conversion function based on
|
||||
// the output data type. The FN is a macro that calls a function with
|
||||
// template<typename scalar_t>.
|
||||
#define DISPATCH_BY_SCALAR_DTYPE(scalar_dtype, fn) \
|
||||
{ \
|
||||
if (scalar_dtype == at::ScalarType::Float) { \
|
||||
fn(float); \
|
||||
} else if (scalar_dtype == at::ScalarType::Half) { \
|
||||
fn(uint16_t); \
|
||||
} else if (scalar_dtype == at::ScalarType::BFloat16) { \
|
||||
fn(__nv_bfloat16); \
|
||||
} else { \
|
||||
TORCH_CHECK(false, "Unsupported data type of O: ", scalar_dtype); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \
|
||||
{ \
|
||||
vllm::merge_attn_states_kernel<scalar_t, NUM_THREADS> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<scalar_t*>(output.data_ptr()), output_lse_ptr, \
|
||||
reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \
|
||||
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
|
||||
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
|
||||
reinterpret_cast<float*>(suffix_lse.data_ptr()), num_tokens, \
|
||||
num_heads, head_size); \
|
||||
}
|
||||
|
||||
/*@brief Merges the attention states from prefix and suffix
|
||||
* into the output tensor. NUM_TOKENS: n, NUM_HEADS: h, HEAD_SIZE: d
|
||||
*
|
||||
* @param output [n,h,d] The output tensor to store the merged attention states.
|
||||
* @param output_lse [h,d] Optional tensor to store the log-sum-exp values.
|
||||
* @param prefix_output [n,h,d] The prefix attention states.
|
||||
* @param prefix_lse [h,n] The log-sum-exp values for the prefix attention
|
||||
* states.
|
||||
* @param suffix_output [n,h,d] The suffix attention states.
|
||||
* @param suffix_lse [h,n] The log-sum-exp values for the suffix attention
|
||||
* states.
|
||||
*/
|
||||
template <typename scalar_t>
|
||||
void merge_attn_states_launcher(torch::Tensor& output,
|
||||
std::optional<torch::Tensor> output_lse,
|
||||
const torch::Tensor& prefix_output,
|
||||
const torch::Tensor& prefix_lse,
|
||||
const torch::Tensor& suffix_output,
|
||||
const torch::Tensor& suffix_lse) {
|
||||
constexpr uint NUM_THREADS = 128;
|
||||
const uint num_tokens = output.size(0);
|
||||
const uint num_heads = output.size(1);
|
||||
const uint head_size = output.size(2);
|
||||
const uint pack_size = 16 / sizeof(scalar_t);
|
||||
TORCH_CHECK(head_size % pack_size == 0,
|
||||
"headsize must be multiple of pack_size:", pack_size);
|
||||
float* output_lse_ptr = nullptr;
|
||||
if (output_lse.has_value()) {
|
||||
output_lse_ptr = output_lse.value().data_ptr<float>();
|
||||
}
|
||||
// Process one pack elements per thread. for float, the
|
||||
// pack_size is 4 for half/bf16, the pack_size is 8.
|
||||
const uint threads_per_head = head_size / pack_size;
|
||||
const uint total_threads = num_tokens * num_heads * threads_per_head;
|
||||
|
||||
dim3 block(NUM_THREADS);
|
||||
dim3 grid((total_threads + NUM_THREADS - 1) / NUM_THREADS);
|
||||
|
||||
const c10::cuda::OptionalCUDAGuard device_guard(prefix_output.device());
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS);
|
||||
}
|
||||
|
||||
#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \
|
||||
{ \
|
||||
merge_attn_states_launcher<scalar_t>(output, output_lse, prefix_output, \
|
||||
prefix_lse, suffix_output, \
|
||||
suffix_lse); \
|
||||
}
|
||||
|
||||
void merge_attn_states(torch::Tensor& output,
|
||||
std::optional<torch::Tensor> output_lse,
|
||||
const torch::Tensor& prefix_output,
|
||||
const torch::Tensor& prefix_lse,
|
||||
const torch::Tensor& suffix_output,
|
||||
const torch::Tensor& suffix_lse) {
|
||||
DISPATCH_BY_SCALAR_DTYPE(output.dtype(), CALL_MERGE_ATTN_STATES_LAUNCHER);
|
||||
}
|
||||
38
csrc/attention/mla/cutlass_mla_entry.cu
Normal file
38
csrc/attention/mla/cutlass_mla_entry.cu
Normal file
@ -0,0 +1,38 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#if defined ENABLE_CUTLASS_MLA && ENABLE_CUTLASS_MLA
|
||||
void cutlass_mla_decode_sm100a(torch::Tensor const& out,
|
||||
torch::Tensor const& q_nope,
|
||||
torch::Tensor const& q_pe,
|
||||
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||
torch::Tensor const& seq_lens,
|
||||
torch::Tensor const& page_table, double scale);
|
||||
#endif
|
||||
|
||||
void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
|
||||
torch::Tensor const& q_pe,
|
||||
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||
torch::Tensor const& seq_lens,
|
||||
torch::Tensor const& page_table, double scale) {
|
||||
#if defined ENABLE_CUTLASS_MLA && ENABLE_CUTLASS_MLA
|
||||
return cutlass_mla_decode_sm100a(out, q_nope, q_pe, kv_c_and_k_pe_cache,
|
||||
seq_lens, page_table, scale);
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled cutlass MLA");
|
||||
}
|
||||
225
csrc/attention/mla/cutlass_mla_kernels.cu
Normal file
225
csrc/attention/mla/cutlass_mla_kernels.cu
Normal file
@ -0,0 +1,225 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/kernel_hardware_info.h"
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
|
||||
#include "device/sm100_mla.hpp"
|
||||
#include "kernel/sm100_mla_tile_scheduler.hpp"
|
||||
|
||||
using namespace cute;
|
||||
using namespace cutlass::fmha::kernel;
|
||||
|
||||
template <typename T, bool PersistenceOption = true>
|
||||
struct MlaSm100 {
|
||||
using Element = T;
|
||||
using ElementAcc = float;
|
||||
using ElementOut = T;
|
||||
|
||||
using TileShape = Shape<_128, _128, Shape<_512, _64>>;
|
||||
using TileShapeH = cute::tuple_element_t<0, TileShape>;
|
||||
using TileShapeD = cute::tuple_element_t<2, TileShape>;
|
||||
|
||||
// H K (D_latent D_rope) B
|
||||
using ProblemShape = cute::tuple<TileShapeH, int, TileShapeD, int>;
|
||||
|
||||
using StrideQ = cute::tuple<int64_t, _1, int64_t>; // H D B
|
||||
using StrideK = cute::tuple<int64_t, _1, int64_t>; // K D B
|
||||
using StrideO = StrideK; // H D B
|
||||
using StrideLSE = cute::tuple<_1, int>; // H B
|
||||
|
||||
using TileScheduler =
|
||||
std::conditional_t<PersistenceOption, Sm100MlaPersistentTileScheduler,
|
||||
Sm100MlaIndividualTileScheduler>;
|
||||
|
||||
using FmhaKernel =
|
||||
cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized<
|
||||
TileShape, Element, ElementAcc, ElementOut, ElementAcc, TileScheduler,
|
||||
/*kIsCpAsync=*/true>;
|
||||
using Fmha = cutlass::fmha::device::MLA<FmhaKernel>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
typename T::Fmha::Arguments args_from_options(
|
||||
at::Tensor const& out, at::Tensor const& q_nope, at::Tensor const& q_pe,
|
||||
at::Tensor const& kv_c_and_k_pe_cache, at::Tensor const& seq_lens,
|
||||
at::Tensor const& page_table, double scale) {
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = q_nope.device().index();
|
||||
hw_info.sm_count =
|
||||
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
|
||||
hw_info.device_id);
|
||||
|
||||
int batches = q_nope.sizes()[0];
|
||||
int page_count_per_seq = page_table.sizes()[1];
|
||||
int page_count_total = kv_c_and_k_pe_cache.sizes()[0];
|
||||
int page_size = kv_c_and_k_pe_cache.sizes()[1];
|
||||
int max_seq_len = page_size * page_count_per_seq;
|
||||
using TileShapeH = typename T::TileShapeH;
|
||||
using TileShapeD = typename T::TileShapeD;
|
||||
auto problem_shape =
|
||||
cute::make_tuple(TileShapeH{}, max_seq_len, TileShapeD{}, batches);
|
||||
|
||||
auto [H, K, D, B] = problem_shape;
|
||||
auto [D_latent, D_rope] = D;
|
||||
|
||||
using StrideQ = typename T::StrideQ;
|
||||
using StrideK = typename T::StrideK;
|
||||
using StrideO = typename T::StrideO;
|
||||
using StrideLSE = typename T::StrideLSE;
|
||||
|
||||
StrideQ stride_Q_latent = cute::make_tuple(
|
||||
static_cast<int64_t>(D_latent), _1{}, static_cast<int64_t>(H * D_latent));
|
||||
StrideQ stride_Q_rope = cute::make_tuple(static_cast<int64_t>(D_rope), _1{},
|
||||
static_cast<int64_t>(H * D_rope));
|
||||
StrideK stride_C =
|
||||
cute::make_tuple(static_cast<int64_t>(D_latent + D_rope), _1{},
|
||||
static_cast<int64_t>(page_size * (D_latent + D_rope)));
|
||||
StrideLSE stride_PT = cute::make_stride(_1{}, page_count_per_seq);
|
||||
StrideLSE stride_LSE = cute::make_tuple(_1{}, static_cast<int>(H));
|
||||
StrideO stride_O = cute::make_tuple(static_cast<int64_t>(D_latent), _1{},
|
||||
static_cast<int64_t>(H * D_latent));
|
||||
|
||||
using Element = typename T::Element;
|
||||
using ElementOut = typename T::ElementOut;
|
||||
using ElementAcc = typename T::ElementAcc;
|
||||
auto Q_latent_ptr = static_cast<Element*>(q_nope.data_ptr());
|
||||
auto Q_rope_ptr = static_cast<Element*>(q_pe.data_ptr());
|
||||
auto C_ptr = static_cast<Element*>(kv_c_and_k_pe_cache.data_ptr());
|
||||
auto scale_f = static_cast<float>(scale);
|
||||
typename T::Fmha::Arguments arguments{
|
||||
problem_shape,
|
||||
{scale_f, Q_latent_ptr, stride_Q_latent, Q_rope_ptr, stride_Q_rope, C_ptr,
|
||||
stride_C, C_ptr + D_latent, stride_C,
|
||||
static_cast<int*>(seq_lens.data_ptr()),
|
||||
static_cast<int*>(page_table.data_ptr()), stride_PT, page_count_total,
|
||||
page_size},
|
||||
{static_cast<ElementOut*>(out.data_ptr()), stride_O,
|
||||
static_cast<ElementAcc*>(nullptr), stride_LSE},
|
||||
hw_info,
|
||||
-1, // split_kv
|
||||
nullptr, // is_var_split_kv
|
||||
};
|
||||
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
|
||||
// split_kv automatically based on batch size and sequence length to balance
|
||||
// workload across available SMs. Consider using var_split_kv for manual
|
||||
// control if needed.
|
||||
T::Fmha::set_split_kv(arguments);
|
||||
return arguments;
|
||||
}
|
||||
|
||||
template <typename Element>
|
||||
void runMla(at::Tensor const& out, at::Tensor const& q_nope,
|
||||
at::Tensor const& q_pe, at::Tensor const& kv_c_and_k_pe_cache,
|
||||
at::Tensor const& seq_lens, at::Tensor const& page_table,
|
||||
float scale, cudaStream_t stream) {
|
||||
using MlaSm100Type = MlaSm100<Element>;
|
||||
typename MlaSm100Type::Fmha fmha;
|
||||
auto arguments = args_from_options<MlaSm100Type>(
|
||||
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, scale);
|
||||
size_t workspace_size = MlaSm100Type::Fmha::get_workspace_size(arguments);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(q_nope.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
CUTLASS_CHECK(fmha.can_implement(arguments));
|
||||
|
||||
CUTLASS_CHECK(fmha.initialize(arguments, workspace.data_ptr(), stream));
|
||||
|
||||
CUTLASS_CHECK(fmha.run(arguments, workspace.data_ptr(), stream));
|
||||
}
|
||||
|
||||
void cutlass_mla_decode_sm100a(torch::Tensor const& out,
|
||||
torch::Tensor const& q_nope,
|
||||
torch::Tensor const& q_pe,
|
||||
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||
torch::Tensor const& seq_lens,
|
||||
torch::Tensor const& page_table, double scale) {
|
||||
TORCH_CHECK(q_nope.device().is_cuda(), "q_nope must be on CUDA");
|
||||
TORCH_CHECK(q_nope.dim() == 3, "q_nope must be a 3D tensor");
|
||||
TORCH_CHECK(q_pe.dim() == 3, "q_pe must be a 3D tensor");
|
||||
TORCH_CHECK(kv_c_and_k_pe_cache.dim() == 3,
|
||||
"kv_c_and_k_pe_cache must be a 3D tensor");
|
||||
TORCH_CHECK(seq_lens.dim() == 1, "seq_lens must be a 1D tensor");
|
||||
TORCH_CHECK(page_table.dim() == 2, "page_table must be a 2D tensor");
|
||||
TORCH_CHECK(out.dim() == 3, "out must be a 3D tensor");
|
||||
|
||||
auto B_q_nope = q_nope.size(0);
|
||||
auto H_q_nope = q_nope.size(1);
|
||||
auto D_q_nope = q_nope.size(2);
|
||||
auto B_q_pe = q_pe.size(0);
|
||||
auto H_q_pe = q_pe.size(1);
|
||||
auto D_q_pe = q_pe.size(2);
|
||||
auto B_pt = page_table.size(0);
|
||||
auto PAGE_NUM = page_table.size(1);
|
||||
auto PAGE_SIZE = kv_c_and_k_pe_cache.size(1);
|
||||
auto D_ckv = kv_c_and_k_pe_cache.size(2);
|
||||
auto B_o = out.size(0);
|
||||
auto H_o = out.size(1);
|
||||
auto D_o = out.size(2);
|
||||
|
||||
TORCH_CHECK(D_q_nope == 512, "D_q_nope must be equal to 512");
|
||||
TORCH_CHECK(D_q_pe == 64, "D_q_pe must be equal to 64");
|
||||
TORCH_CHECK(D_ckv == 576, "D_ckv must be equal to 576");
|
||||
TORCH_CHECK(H_q_nope == H_q_pe && H_q_nope == H_o && H_o == 128,
|
||||
"H_q_nope, H_q_pe, and H_o must be equal to 128");
|
||||
TORCH_CHECK(PAGE_SIZE > 0 && (PAGE_SIZE & (PAGE_SIZE - 1)) == 0,
|
||||
"PAGE_SIZE must be a power of 2");
|
||||
TORCH_CHECK(
|
||||
B_q_nope == B_q_pe && B_q_nope == B_pt && B_q_nope == B_o,
|
||||
"Batch dims must be same for page_table, q_nope and q_pe, and out");
|
||||
TORCH_CHECK(PAGE_NUM % (128 / PAGE_SIZE) == 0,
|
||||
"PAGE_NUM must be divisible by 128 / PAGE_SIZE");
|
||||
TORCH_CHECK(D_o == 512, "D_o must be equal to 512");
|
||||
|
||||
TORCH_CHECK(q_nope.dtype() == at::ScalarType::Half ||
|
||||
q_nope.dtype() == at::ScalarType::BFloat16 ||
|
||||
q_nope.dtype() == at::ScalarType::Float8_e4m3fn,
|
||||
"q_nope must be a half, bfloat16, or float8_e4m3fn tensor");
|
||||
TORCH_CHECK(kv_c_and_k_pe_cache.dtype() == q_nope.dtype() &&
|
||||
q_nope.dtype() == q_pe.dtype(),
|
||||
"kv_c_and_k_pe_cache, q_nope, and q_pe must be the same type");
|
||||
TORCH_CHECK(seq_lens.dtype() == torch::kInt32,
|
||||
"seq_lens must be a 32-bit integer tensor");
|
||||
TORCH_CHECK(page_table.dtype() == torch::kInt32,
|
||||
"page_table must be a 32-bit integer tensor");
|
||||
|
||||
auto in_dtype = q_nope.dtype();
|
||||
at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()};
|
||||
const cudaStream_t stream =
|
||||
at::cuda::getCurrentCUDAStream(q_nope.get_device());
|
||||
if (in_dtype == at::ScalarType::Half) {
|
||||
runMla<cutlass::half_t>(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens,
|
||||
page_table, scale, stream);
|
||||
} else if (in_dtype == at::ScalarType::BFloat16) {
|
||||
runMla<cutlass::bfloat16_t>(out, q_nope, q_pe, kv_c_and_k_pe_cache,
|
||||
seq_lens, page_table, scale, stream);
|
||||
} else if (in_dtype == at::ScalarType::Float8_e4m3fn) {
|
||||
runMla<cutlass::float_e4m3_t>(out, q_nope, q_pe, kv_c_and_k_pe_cache,
|
||||
seq_lens, page_table, scale, stream);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported input data type of MLA");
|
||||
}
|
||||
}
|
||||
@ -270,9 +270,10 @@ __global__ void reshape_and_cache_flash_kernel(
|
||||
cache_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads,
|
||||
// head_size]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int block_stride, const int key_stride, const int value_stride,
|
||||
const int num_heads, const int head_size, const int block_size,
|
||||
const float* k_scale, const float* v_scale) {
|
||||
const int64_t block_stride, const int64_t page_stride,
|
||||
const int64_t head_stride, const int64_t key_stride,
|
||||
const int64_t value_stride, const int num_heads, const int head_size,
|
||||
const int block_size, const float* k_scale, const float* v_scale) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
// NOTE: slot_idx can be -1 if the token is padded
|
||||
@ -288,8 +289,8 @@ __global__ void reshape_and_cache_flash_kernel(
|
||||
const int head_idx = i / head_size;
|
||||
const int head_offset = i % head_size;
|
||||
const int64_t tgt_key_value_idx = block_idx * block_stride +
|
||||
block_offset * num_heads * head_size +
|
||||
head_idx * head_size + head_offset;
|
||||
block_offset * page_stride +
|
||||
head_idx * head_stride + head_offset;
|
||||
scalar_t tgt_key = key[src_key_idx];
|
||||
scalar_t tgt_value = value[src_value_idx];
|
||||
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
|
||||
@ -396,16 +397,16 @@ void reshape_and_cache(
|
||||
// KV_T is the data type of key and value tensors.
|
||||
// CACHE_T is the stored data type of kv-cache.
|
||||
// KV_DTYPE is the real data type of kv-cache.
|
||||
#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
||||
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
||||
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride, \
|
||||
value_stride, num_heads, head_size, block_size, \
|
||||
reinterpret_cast<const float*>(k_scale.data_ptr()), \
|
||||
#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
||||
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
||||
slot_mapping.data_ptr<int64_t>(), block_stride, page_stride, \
|
||||
head_stride, key_stride, value_stride, num_heads, head_size, \
|
||||
block_size, reinterpret_cast<const float*>(k_scale.data_ptr()), \
|
||||
reinterpret_cast<const float*>(v_scale.data_ptr()));
|
||||
|
||||
void reshape_and_cache_flash(
|
||||
@ -432,9 +433,11 @@ void reshape_and_cache_flash(
|
||||
int head_size = key.size(2);
|
||||
int block_size = key_cache.size(1);
|
||||
|
||||
int key_stride = key.stride(0);
|
||||
int value_stride = value.stride(0);
|
||||
int block_stride = key_cache.stride(0);
|
||||
int64_t key_stride = key.stride(0);
|
||||
int64_t value_stride = value.stride(0);
|
||||
int64_t block_stride = key_cache.stride(0);
|
||||
int64_t page_stride = key_cache.stride(1);
|
||||
int64_t head_stride = key_cache.stride(2);
|
||||
TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0));
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
|
||||
@ -78,9 +78,14 @@ struct FP16Vec16 : public Vec<FP16Vec16> {
|
||||
|
||||
__m256i reg;
|
||||
|
||||
// normal load
|
||||
explicit FP16Vec16(const void* ptr)
|
||||
: reg((__m256i)_mm256_loadu_si256((__m256i*)ptr)) {}
|
||||
|
||||
// non-temproal load
|
||||
explicit FP16Vec16(bool, void* ptr)
|
||||
: reg(_mm256_stream_load_si256((__m256i*)ptr)) {}
|
||||
|
||||
explicit FP16Vec16(const FP32Vec16&);
|
||||
|
||||
void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; }
|
||||
@ -110,9 +115,14 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||
|
||||
__m256i reg;
|
||||
|
||||
// normal load
|
||||
explicit BF16Vec16(const void* ptr)
|
||||
: reg((__m256i)_mm256_loadu_si256((__m256i*)ptr)) {}
|
||||
|
||||
// non-temproal load
|
||||
explicit BF16Vec16(bool, void* ptr)
|
||||
: reg(_mm256_stream_load_si256((__m256i*)ptr)) {}
|
||||
|
||||
explicit BF16Vec16(const FP32Vec16&);
|
||||
|
||||
void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; }
|
||||
@ -313,8 +323,13 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
|
||||
explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {}
|
||||
|
||||
// normal load
|
||||
explicit FP32Vec16(const float* ptr) : reg(_mm512_loadu_ps(ptr)) {}
|
||||
|
||||
// non-temproal load
|
||||
explicit FP32Vec16(bool, void* ptr)
|
||||
: reg((__m512)_mm512_stream_load_si512(ptr)) {}
|
||||
|
||||
explicit FP32Vec16(__m512 data) : reg(data) {}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec4& data)
|
||||
@ -547,6 +562,33 @@ struct INT8Vec16 : public Vec<INT8Vec16> {
|
||||
_mm_mask_storeu_epi8(ptr, mask, reg);
|
||||
}
|
||||
};
|
||||
|
||||
struct INT8Vec64 : public Vec<INT8Vec64> {
|
||||
constexpr static int VEC_ELEM_NUM = 64;
|
||||
union AliasReg {
|
||||
__m512i reg;
|
||||
int8_t values[VEC_ELEM_NUM];
|
||||
};
|
||||
|
||||
__m512i reg;
|
||||
|
||||
// normal load
|
||||
explicit INT8Vec64(void* ptr) : reg(_mm512_loadu_epi8(ptr)) {}
|
||||
|
||||
// non-temproal load
|
||||
explicit INT8Vec64(bool, void* ptr) : reg(_mm512_stream_load_si512(ptr)) {}
|
||||
|
||||
void save(void* ptr) const { _mm512_storeu_epi8(ptr, reg); }
|
||||
|
||||
void save(int8_t* ptr, const int elem_num) const {
|
||||
constexpr uint64_t M = 0xFFFFFFFFFFFFFFFF;
|
||||
__mmask64 mask = _cvtu64_mask64(M >> (64 - elem_num));
|
||||
_mm512_mask_storeu_epi8(ptr, mask, reg);
|
||||
}
|
||||
|
||||
// non-temproal save
|
||||
void nt_save(int8_t* ptr) { _mm512_stream_si512((__m512i*)ptr, reg); }
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
@ -657,6 +699,22 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
|
||||
|
||||
inline void prefetch(const void* addr) { _mm_prefetch(addr, _MM_HINT_T1); }
|
||||
|
||||
#ifdef __AVX512F__
|
||||
inline void non_temporal_save(FP16Vec16& vec, void* ptr) {
|
||||
_mm256_stream_si256((__m256i*)ptr, vec.reg);
|
||||
}
|
||||
inline void non_temporal_save(BF16Vec32& vec, void* ptr) {
|
||||
_mm512_stream_si512((__m512i*)ptr, vec.reg);
|
||||
}
|
||||
inline void non_temporal_save(BF16Vec16& vec, void* ptr) {
|
||||
_mm256_stream_si256((__m256i*)ptr, vec.reg);
|
||||
}
|
||||
inline void non_temporal_save(FP32Vec16& vec, void* ptr) {
|
||||
_mm512_stream_ps((float*)ptr, vec.reg);
|
||||
}
|
||||
#endif
|
||||
|
||||
inline void mem_barrier() { _mm_mfence(); }
|
||||
}; // namespace vec_op
|
||||
|
||||
#endif
|
||||
|
||||
781
csrc/cpu/shm.cpp
Normal file
781
csrc/cpu/shm.cpp
Normal file
@ -0,0 +1,781 @@
|
||||
#include "cpu/cpu_types.hpp"
|
||||
|
||||
#include <fcntl.h>
|
||||
#include <sys/mman.h>
|
||||
#include <sys/stat.h>
|
||||
#include <unistd.h>
|
||||
|
||||
namespace {
|
||||
#define MAX_SHM_RANK_NUM 8
|
||||
#define MAX_THREAD_NUM 12
|
||||
#define PER_THREAD_SHM_BUFFER_BYTES (4 * 1024 * 1024)
|
||||
#define MIN_THREAD_PROCESS_SIZE (8 * 1024)
|
||||
#define MAX_P2P_SEND_TENSOR_NUM 8
|
||||
|
||||
template <typename scalar_t>
|
||||
struct KernelVecType {
|
||||
using scalar_vec_t = void;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KernelVecType<float> {
|
||||
using scalar_vec_t = vec_op::FP32Vec16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KernelVecType<c10::BFloat16> {
|
||||
using scalar_vec_t = vec_op::BF16Vec16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KernelVecType<c10::Half> {
|
||||
using scalar_vec_t = vec_op::FP16Vec16;
|
||||
};
|
||||
|
||||
enum class ThreadSHMStat : char { THREAD_READY = 0, SHM_DATA_READY, DONE };
|
||||
|
||||
struct ThreadSHMContext {
|
||||
volatile ThreadSHMStat thread_stats[MAX_SHM_RANK_NUM];
|
||||
int thread_id;
|
||||
int thread_num;
|
||||
int rank;
|
||||
int group_size;
|
||||
size_t _spinning_count;
|
||||
int swizzled_ranks[MAX_SHM_RANK_NUM];
|
||||
void* thread_shm_ptrs[MAX_SHM_RANK_NUM];
|
||||
ThreadSHMContext* shm_contexts[MAX_SHM_RANK_NUM];
|
||||
|
||||
ThreadSHMContext(const int thread_id, const int thread_num, const int rank,
|
||||
const int group_size, void* thread_shm_ptr)
|
||||
: thread_id(thread_id),
|
||||
thread_num(thread_num),
|
||||
rank(rank),
|
||||
group_size(group_size),
|
||||
_spinning_count(0) {
|
||||
static_assert(sizeof(ThreadSHMContext) % 64 == 0);
|
||||
TORCH_CHECK(group_size <= MAX_SHM_RANK_NUM);
|
||||
TORCH_CHECK((size_t)this % 64 == 0);
|
||||
TORCH_CHECK((size_t)thread_shm_ptr % 64 == 0);
|
||||
for (int i = 0; i < MAX_SHM_RANK_NUM; ++i) {
|
||||
shm_contexts[i] = nullptr;
|
||||
thread_shm_ptrs[i] = nullptr;
|
||||
swizzled_ranks[i] = (i + rank) % group_size;
|
||||
thread_stats[i] = ThreadSHMStat::DONE;
|
||||
}
|
||||
set_context(rank, this, thread_shm_ptr);
|
||||
}
|
||||
|
||||
void set_context(int rank, ThreadSHMContext* ptr, void* thread_shm_ptr) {
|
||||
TORCH_CHECK(rank < MAX_SHM_RANK_NUM);
|
||||
TORCH_CHECK(ptr);
|
||||
TORCH_CHECK(thread_shm_ptr);
|
||||
TORCH_CHECK_EQ(ptr->thread_num, thread_num);
|
||||
TORCH_CHECK_EQ(ptr->thread_id, thread_id);
|
||||
shm_contexts[rank] = ptr;
|
||||
thread_shm_ptrs[rank] = thread_shm_ptr;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T* get_thread_shm_ptr(int rank) {
|
||||
return reinterpret_cast<T*>(thread_shm_ptrs[rank]);
|
||||
}
|
||||
|
||||
int get_swizzled_rank(int idx) { return swizzled_ranks[idx]; }
|
||||
|
||||
void wait_for_all(ThreadSHMStat prev_stat) {
|
||||
for (int idx = 0; idx < group_size; ++idx) {
|
||||
int rank = get_swizzled_rank(idx);
|
||||
while (thread_stats[rank] == prev_stat) {
|
||||
++_spinning_count;
|
||||
_mm_pause();
|
||||
}
|
||||
}
|
||||
vec_op::mem_barrier();
|
||||
}
|
||||
|
||||
void wait_for_one(int rank, ThreadSHMStat prev_stat) {
|
||||
while (thread_stats[rank] == prev_stat) {
|
||||
++_spinning_count;
|
||||
_mm_pause();
|
||||
}
|
||||
vec_op::mem_barrier();
|
||||
}
|
||||
|
||||
void set_thread_stat(ThreadSHMStat stat) {
|
||||
for (int idx = 0; idx < group_size; ++idx) {
|
||||
int rank = get_swizzled_rank(idx);
|
||||
shm_contexts[rank]->thread_stats[this->rank] = stat;
|
||||
}
|
||||
}
|
||||
|
||||
void set_thread_stat(int target_rank, ThreadSHMStat stat) {
|
||||
for (int idx = 0; idx < group_size; ++idx) {
|
||||
int rank = get_swizzled_rank(idx);
|
||||
shm_contexts[rank]->thread_stats[target_rank] = stat;
|
||||
}
|
||||
}
|
||||
|
||||
// barrier for all ranks in the group, used for all2all ops
|
||||
// DONE -> THREAD_READY -> SHM_DATA_READY -> DONE -> ...
|
||||
void barrier(ThreadSHMStat next_stat) {
|
||||
if (next_stat == ThreadSHMStat::THREAD_READY) {
|
||||
set_thread_stat(ThreadSHMStat::THREAD_READY);
|
||||
wait_for_all(ThreadSHMStat::DONE);
|
||||
} else if (next_stat == ThreadSHMStat::SHM_DATA_READY) {
|
||||
set_thread_stat(ThreadSHMStat::SHM_DATA_READY);
|
||||
wait_for_all(ThreadSHMStat::THREAD_READY);
|
||||
} else if (next_stat == ThreadSHMStat::DONE) {
|
||||
set_thread_stat(ThreadSHMStat::DONE);
|
||||
wait_for_all(ThreadSHMStat::SHM_DATA_READY);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Invalid next_stat to barrier.");
|
||||
}
|
||||
}
|
||||
|
||||
std::string to_string() const {
|
||||
std::stringstream ss;
|
||||
ss << "SHMContext:";
|
||||
ss << "\nrank: " << rank;
|
||||
ss << "\ngroup_size: " << group_size;
|
||||
ss << "\nthread_num: " << thread_num;
|
||||
ss << "\nthread_id: " << thread_id;
|
||||
|
||||
ss << "\nshm_ctx_stat_loop_seq: [";
|
||||
for (int i = 0; i < group_size; ++i) {
|
||||
ss << swizzled_ranks[i] << ", ";
|
||||
}
|
||||
ss << "]";
|
||||
|
||||
ss << "\nshm_contexts: [";
|
||||
for (int i = 0; i < group_size; ++i) {
|
||||
if (shm_contexts[i]) {
|
||||
ss << shm_contexts[i]->rank << ", ";
|
||||
}
|
||||
}
|
||||
ss << "]";
|
||||
|
||||
return ss.str();
|
||||
}
|
||||
};
|
||||
|
||||
class SHMManager {
|
||||
public:
|
||||
explicit SHMManager(const std::string& name, const int rank,
|
||||
const int group_size)
|
||||
: _rank(rank),
|
||||
_group_size(group_size),
|
||||
_thread_num(std::min(torch::get_num_threads(), MAX_THREAD_NUM)),
|
||||
_shm_names({""}),
|
||||
_shared_mem_ptrs({nullptr}),
|
||||
_shm_ctx(nullptr) {
|
||||
_shm_names[rank] = get_shm_name(name, rank);
|
||||
_shared_mem_ptrs[rank] = init_shm(rank);
|
||||
_shm_ctx = reinterpret_cast<ThreadSHMContext*>(_shared_mem_ptrs[rank]);
|
||||
|
||||
for (int i = 0; i < _thread_num; ++i) {
|
||||
ThreadSHMContext* ctx = new (_shm_ctx + i)
|
||||
ThreadSHMContext(i, _thread_num, _rank, _group_size,
|
||||
compute_thread_shm_ptr(_shm_ctx, i));
|
||||
}
|
||||
}
|
||||
|
||||
void join(const std::string& name) {
|
||||
for (int rank_idx = 0; rank_idx < _group_size; ++rank_idx) {
|
||||
if (rank_idx != _rank) {
|
||||
TORCH_CHECK(_shm_names[rank_idx].empty());
|
||||
TORCH_CHECK(_shared_mem_ptrs[rank_idx] == nullptr);
|
||||
_shm_names[rank_idx] = get_shm_name(name, rank_idx);
|
||||
_shared_mem_ptrs[rank_idx] = init_shm(rank_idx);
|
||||
ThreadSHMContext* target_ctx =
|
||||
reinterpret_cast<ThreadSHMContext*>(_shared_mem_ptrs[rank_idx]);
|
||||
for (int thread_idx = 0; thread_idx < _thread_num; ++thread_idx) {
|
||||
_shm_ctx[thread_idx].set_context(
|
||||
rank_idx, target_ctx + thread_idx,
|
||||
compute_thread_shm_ptr(target_ctx, thread_idx));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
~SHMManager() { destroy_shm(); }
|
||||
|
||||
ThreadSHMContext* get_shm_ctx() const { return _shm_ctx; }
|
||||
|
||||
static std::string get_shm_name(const std::string& name, int rank) {
|
||||
return name + "_" + std::to_string(rank);
|
||||
}
|
||||
|
||||
static int64_t create_singleton_instance(const std::string& name,
|
||||
const int group_size,
|
||||
const int rank) {
|
||||
std::lock_guard<std::mutex> guard(SingletonInstancesLock);
|
||||
SingletonInstances.emplace_back(
|
||||
std::make_unique<SHMManager>(name, rank, group_size));
|
||||
return static_cast<int64_t>(SingletonInstances.size() - 1);
|
||||
}
|
||||
|
||||
static SHMManager* get_singleton_instance(int64_t handle) {
|
||||
return SingletonInstances[handle].get();
|
||||
}
|
||||
|
||||
protected:
|
||||
static std::vector<std::unique_ptr<SHMManager>> SingletonInstances;
|
||||
static std::mutex SingletonInstancesLock;
|
||||
|
||||
private:
|
||||
static size_t round_to_alignment(size_t num) {
|
||||
return ((num + 63) / 64) * 64;
|
||||
}
|
||||
|
||||
int8_t* compute_thread_shm_ptr(ThreadSHMContext* ctx, int thread_id) {
|
||||
int8_t* thread_shm_ptr =
|
||||
reinterpret_cast<int8_t*>(ctx) +
|
||||
round_to_alignment(_thread_num * sizeof(ThreadSHMContext));
|
||||
return thread_shm_ptr +
|
||||
thread_id * round_to_alignment(PER_THREAD_SHM_BUFFER_BYTES);
|
||||
}
|
||||
|
||||
size_t compute_shm_size() {
|
||||
const size_t rounded_rank_buffer_size =
|
||||
round_to_alignment(PER_THREAD_SHM_BUFFER_BYTES) * _thread_num;
|
||||
const size_t rounded_thread_shm_ctx_size =
|
||||
round_to_alignment(_thread_num * sizeof(ThreadSHMContext));
|
||||
const size_t shm_size =
|
||||
rounded_thread_shm_ctx_size + rounded_rank_buffer_size;
|
||||
return shm_size;
|
||||
}
|
||||
|
||||
void* init_shm(int target_rank) {
|
||||
const std::string& shm_name = _shm_names[target_rank];
|
||||
const int local_rank = _rank;
|
||||
const size_t shm_size = compute_shm_size();
|
||||
|
||||
int fd = -1;
|
||||
if (local_rank == target_rank) {
|
||||
fd = shm_open(shm_name.c_str(), O_CREAT | O_EXCL | O_RDWR,
|
||||
S_IRUSR | S_IWUSR);
|
||||
|
||||
if (fd == -1)
|
||||
TORCH_CHECK(false, "create shm in SHMManager failed. errno: " +
|
||||
std::to_string(errno));
|
||||
|
||||
if (ftruncate(fd, shm_size) == -1)
|
||||
TORCH_CHECK(false, "ftruncate in SHMManager failed. errno: " +
|
||||
std::to_string(errno));
|
||||
} else {
|
||||
fd = shm_open(shm_name.c_str(), O_RDWR, S_IRUSR | S_IWUSR);
|
||||
|
||||
if (fd == -1)
|
||||
TORCH_CHECK(false, "open shm in SHMManager failed. errno: " +
|
||||
std::to_string(errno));
|
||||
}
|
||||
|
||||
void* shm_ptr = mmap(nullptr, shm_size, PROT_READ | PROT_WRITE,
|
||||
MAP_SHARED | MAP_POPULATE, fd, 0);
|
||||
|
||||
if (shm_ptr == MAP_FAILED) {
|
||||
TORCH_CHECK(false,
|
||||
"mmap in SHMManager failed. errno: " + std::to_string(errno));
|
||||
}
|
||||
|
||||
if (close(fd) != 0) {
|
||||
TORCH_CHECK(
|
||||
false, "close in SHMManager failed. errno: " + std::to_string(errno));
|
||||
}
|
||||
|
||||
TORCH_CHECK((size_t)shm_ptr % 64 == 0);
|
||||
|
||||
return shm_ptr;
|
||||
}
|
||||
|
||||
void destroy_shm() {
|
||||
std::stringstream ss;
|
||||
ss << "local rank " << _rank << ": [";
|
||||
for (int thread_id = 0; thread_id < _thread_num; ++thread_id) {
|
||||
ss << _shm_ctx[thread_id]._spinning_count << ", ";
|
||||
}
|
||||
ss << "]\n";
|
||||
|
||||
for (int i = 0; i < MAX_SHM_RANK_NUM; ++i) {
|
||||
if (_shared_mem_ptrs[i] != nullptr) {
|
||||
munmap(_shared_mem_ptrs[i], compute_shm_size());
|
||||
}
|
||||
|
||||
if (!_shm_names[i].empty()) {
|
||||
shm_unlink(_shm_names[i].c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int _rank;
|
||||
int _group_size;
|
||||
int _thread_num;
|
||||
std::array<std::string, MAX_SHM_RANK_NUM> _shm_names;
|
||||
std::array<void*, MAX_SHM_RANK_NUM> _shared_mem_ptrs;
|
||||
ThreadSHMContext* _shm_ctx;
|
||||
};
|
||||
|
||||
namespace shm_cc_ops {
|
||||
template <typename scalar_t, typename F>
|
||||
void shm_cc_loop(ThreadSHMContext* ctx, int64_t elem_num, F&& inner_func) {
|
||||
int thread_num = ctx->thread_num;
|
||||
int64_t total_bytes = elem_num * sizeof(scalar_t);
|
||||
int64_t total_units_num =
|
||||
(total_bytes + MIN_THREAD_PROCESS_SIZE - 1) / MIN_THREAD_PROCESS_SIZE;
|
||||
int64_t per_thread_units_num =
|
||||
(total_units_num + thread_num - 1) / thread_num;
|
||||
int64_t per_unit_elem_num = MIN_THREAD_PROCESS_SIZE / sizeof(scalar_t);
|
||||
int64_t max_per_thread_iteration_elem_num =
|
||||
PER_THREAD_SHM_BUFFER_BYTES / sizeof(scalar_t);
|
||||
int64_t per_thread_elem_num = per_unit_elem_num * per_thread_units_num;
|
||||
|
||||
#pragma omp parallel for schedule(static, 1)
|
||||
for (int i = 0; i < thread_num; ++i) {
|
||||
int64_t offset = i * per_thread_elem_num;
|
||||
int64_t end = std::min(elem_num, offset + per_thread_elem_num);
|
||||
int64_t curr_elem_num =
|
||||
std::min(max_per_thread_iteration_elem_num, end - offset);
|
||||
ThreadSHMContext* thread_ctx = ctx + i;
|
||||
|
||||
while (curr_elem_num > 0) {
|
||||
inner_func(thread_ctx, offset, curr_elem_num);
|
||||
|
||||
offset += max_per_thread_iteration_elem_num;
|
||||
curr_elem_num = std::min(max_per_thread_iteration_elem_num, end - offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}; // namespace shm_cc_ops
|
||||
|
||||
namespace shm_cc_ops {
|
||||
|
||||
void memcpy_from_shm(void* dst, void* src, const int64_t bytes) {
|
||||
const int64_t aligned_bytes = ((bytes >> 6) << 6); // 64 bytes aligned
|
||||
int64_t i = 0;
|
||||
#pragma GCC unroll 4
|
||||
for (; i < aligned_bytes; i += 64) {
|
||||
vec_op::INT8Vec64 data(
|
||||
true, (int8_t*)src + i); // stream loading shm to avoid caching
|
||||
data.save((int8_t*)dst + i);
|
||||
}
|
||||
if (aligned_bytes < bytes) {
|
||||
vec_op::INT8Vec64 data(true, (int8_t*)src + aligned_bytes);
|
||||
data.save((int8_t*)dst + aligned_bytes, bytes - aligned_bytes);
|
||||
}
|
||||
}
|
||||
|
||||
void memcpy_to_shm(void* dst, void* src, const int64_t bytes) {
|
||||
#pragma GCC unroll 4
|
||||
for (int64_t i = 0; i < bytes; i += 64) {
|
||||
vec_op::INT8Vec64 data((int8_t*)src + i);
|
||||
data.nt_save((int8_t*)dst + i);
|
||||
}
|
||||
}
|
||||
|
||||
void memcpy(void* dst, void* src, const int64_t bytes) {
|
||||
const int64_t aligned_bytes = ((bytes >> 6) << 6); // 64 bytes aligned
|
||||
int64_t i = 0;
|
||||
#pragma GCC unroll 4
|
||||
for (; i < aligned_bytes; i += 64) {
|
||||
vec_op::INT8Vec64 data((int8_t*)src + i);
|
||||
data.save((int8_t*)dst + i);
|
||||
}
|
||||
if (aligned_bytes < bytes) {
|
||||
vec_op::INT8Vec64 data((int8_t*)src + aligned_bytes);
|
||||
data.save((int8_t*)dst + aligned_bytes, bytes - aligned_bytes);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, int RANKS>
|
||||
void all_reduce_sum_impl(ThreadSHMContext* ctx, scalar_t* data,
|
||||
size_t elem_num) {
|
||||
CPU_KERNEL_GUARD_IN(all_reduce_sum_impl)
|
||||
using vec_t = typename KernelVecType<scalar_t>::scalar_vec_t;
|
||||
constexpr int64_t vec_elem_num = vec_t::get_elem_num();
|
||||
const int worldsize = ctx->group_size;
|
||||
|
||||
shm_cc_ops::shm_cc_loop<scalar_t>(
|
||||
ctx, elem_num,
|
||||
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
|
||||
int64_t data_elem_num) {
|
||||
int rank = thread_ctx->rank;
|
||||
scalar_t* thread_shm_ptr =
|
||||
thread_ctx->get_thread_shm_ptr<scalar_t>(rank);
|
||||
scalar_t* thread_data_ptr = data + data_offset;
|
||||
int64_t thread_data_elem_num = data_elem_num * sizeof(scalar_t);
|
||||
|
||||
scalar_t* remote_data_ptrs[RANKS - 1];
|
||||
vec_op::unroll_loop<int, RANKS - 1>([&](int idx) {
|
||||
remote_data_ptrs[idx] = thread_ctx->get_thread_shm_ptr<scalar_t>(
|
||||
thread_ctx->get_swizzled_rank(idx + 1));
|
||||
});
|
||||
|
||||
thread_ctx->barrier(ThreadSHMStat::THREAD_READY);
|
||||
|
||||
shm_cc_ops::memcpy_to_shm(thread_shm_ptr, thread_data_ptr,
|
||||
thread_data_elem_num);
|
||||
|
||||
thread_ctx->barrier(ThreadSHMStat::SHM_DATA_READY);
|
||||
|
||||
int64_t aligned_data_elem_num =
|
||||
(data_elem_num / vec_elem_num) * vec_elem_num;
|
||||
int64_t i = 0;
|
||||
#pragma GCC unroll 4
|
||||
for (; i < aligned_data_elem_num; i += vec_elem_num) {
|
||||
vec_t local_data(thread_data_ptr + i); // load from cache
|
||||
vec_op::FP32Vec16 local_data_fp32(local_data);
|
||||
vec_op::unroll_loop<int, RANKS - 1>([&](int idx) {
|
||||
vec_t remote_data(
|
||||
true, remote_data_ptrs[idx] + i); // stream load from shm
|
||||
vec_op::FP32Vec16 remote_data_fp32(remote_data);
|
||||
local_data_fp32 = local_data_fp32 + remote_data_fp32; // sum reduce
|
||||
});
|
||||
vec_t reduced_data(local_data_fp32);
|
||||
reduced_data.save(thread_data_ptr + i);
|
||||
}
|
||||
|
||||
if (i < data_elem_num) {
|
||||
vec_t local_data(thread_data_ptr + i); // load from cache
|
||||
vec_op::FP32Vec16 local_data_fp32(local_data);
|
||||
vec_op::unroll_loop<int, RANKS - 1>([&](int idx) {
|
||||
vec_t remote_data(
|
||||
true, remote_data_ptrs[idx] + i); // stream load from shm
|
||||
vec_op::FP32Vec16 remote_data_fp32(remote_data);
|
||||
local_data_fp32 = local_data_fp32 + remote_data_fp32; // sum reduce
|
||||
});
|
||||
vec_t reduced_data(local_data_fp32);
|
||||
reduced_data.save(thread_data_ptr + i,
|
||||
data_elem_num - aligned_data_elem_num);
|
||||
}
|
||||
|
||||
thread_ctx->barrier(ThreadSHMStat::DONE);
|
||||
});
|
||||
|
||||
return;
|
||||
}
|
||||
}; // namespace shm_cc_ops
|
||||
|
||||
std::vector<std::unique_ptr<SHMManager>> SHMManager::SingletonInstances = {};
|
||||
std::mutex SHMManager::SingletonInstancesLock = {};
|
||||
|
||||
template <typename scalar_t>
|
||||
void shm_allreduce_sum(ThreadSHMContext* ctx, scalar_t* data, size_t elem_num) {
|
||||
switch (ctx->group_size) {
|
||||
case 2:
|
||||
shm_cc_ops::all_reduce_sum_impl<scalar_t, 2>(ctx, data, elem_num);
|
||||
break;
|
||||
case 3:
|
||||
shm_cc_ops::all_reduce_sum_impl<scalar_t, 3>(ctx, data, elem_num);
|
||||
break;
|
||||
case 4:
|
||||
shm_cc_ops::all_reduce_sum_impl<scalar_t, 4>(ctx, data, elem_num);
|
||||
break;
|
||||
case 8:
|
||||
shm_cc_ops::all_reduce_sum_impl<scalar_t, 8>(ctx, data, elem_num);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false,
|
||||
"Invalid world size: " + std::to_string(ctx->group_size));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void shm_gather_impl(ThreadSHMContext* ctx, scalar_t* data, size_t elem_num,
|
||||
scalar_t** outputs, const int dst) {
|
||||
CPU_KERNEL_GUARD_IN(shm_gather_impl)
|
||||
const int worldsize = ctx->group_size;
|
||||
TORCH_CHECK_LT(dst, worldsize);
|
||||
shm_cc_ops::shm_cc_loop<scalar_t>(
|
||||
ctx, elem_num,
|
||||
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
|
||||
int64_t data_elem_num) {
|
||||
int rank = thread_ctx->rank;
|
||||
scalar_t* thread_shm_ptr =
|
||||
thread_ctx->get_thread_shm_ptr<scalar_t>(rank);
|
||||
|
||||
thread_ctx->barrier(ThreadSHMStat::THREAD_READY);
|
||||
|
||||
shm_cc_ops::memcpy_to_shm(thread_shm_ptr, data + data_offset,
|
||||
data_elem_num * sizeof(scalar_t));
|
||||
|
||||
thread_ctx->barrier(ThreadSHMStat::SHM_DATA_READY);
|
||||
|
||||
if (rank == dst) {
|
||||
shm_cc_ops::memcpy(outputs[rank] + data_offset, data + data_offset,
|
||||
data_elem_num * sizeof(scalar_t));
|
||||
for (int i = 1; i < worldsize; ++i) {
|
||||
int src_rank = thread_ctx->get_swizzled_rank(i);
|
||||
scalar_t* src_ptr =
|
||||
thread_ctx->get_thread_shm_ptr<scalar_t>(src_rank); // shm
|
||||
scalar_t* dst_ptr = outputs[src_rank] + data_offset;
|
||||
shm_cc_ops::memcpy_from_shm(dst_ptr, src_ptr,
|
||||
data_elem_num * sizeof(scalar_t));
|
||||
}
|
||||
}
|
||||
|
||||
thread_ctx->barrier(ThreadSHMStat::DONE);
|
||||
});
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
struct MemPiece {
|
||||
void* ptr;
|
||||
int64_t size;
|
||||
|
||||
template <typename T>
|
||||
T* data_ptr() {
|
||||
return reinterpret_cast<T*>(ptr);
|
||||
}
|
||||
};
|
||||
|
||||
struct TensorListMeta {
|
||||
int64_t tensor_bytes[MAX_P2P_SEND_TENSOR_NUM];
|
||||
torch::ScalarType tensor_types[MAX_P2P_SEND_TENSOR_NUM];
|
||||
int64_t tensor_num;
|
||||
int64_t total_bytes;
|
||||
|
||||
TensorListMeta() : tensor_num(0), total_bytes(0) {
|
||||
static_assert(sizeof(TensorListMeta) % 64 == 0);
|
||||
static_assert(sizeof(TensorListMeta) <
|
||||
MIN_THREAD_PROCESS_SIZE); // To ensure the metadata always
|
||||
// hold by the thread 0
|
||||
for (int i = 0; i < MAX_P2P_SEND_TENSOR_NUM; ++i) {
|
||||
tensor_bytes[i] = 0;
|
||||
tensor_ptrs[i] = nullptr;
|
||||
tensor_types[i] = torch::ScalarType::Undefined;
|
||||
}
|
||||
}
|
||||
|
||||
// For send and recv
|
||||
void bind_tensor_list(std::vector<torch::Tensor>& tensor_list) {
|
||||
TORCH_CHECK(tensor_types[0] == torch::ScalarType::Undefined,
|
||||
"Re-bind TensorListMeta is not allowed.")
|
||||
TORCH_CHECK_LE(tensor_list.size(), MAX_P2P_SEND_TENSOR_NUM);
|
||||
tensor_num = tensor_list.size();
|
||||
int64_t bytes_sum = 0;
|
||||
for (int i = 0; i < tensor_list.size(); ++i) {
|
||||
torch::Tensor& t = tensor_list[i];
|
||||
TORCH_CHECK(t.is_contiguous());
|
||||
tensor_bytes[i] = t.nbytes();
|
||||
tensor_types[i] = t.scalar_type();
|
||||
tensor_ptrs[i] = t.data_ptr();
|
||||
bytes_sum += t.nbytes();
|
||||
}
|
||||
total_bytes = bytes_sum;
|
||||
}
|
||||
|
||||
// For recv
|
||||
std::vector<torch::Tensor> generate_tensor_list() {
|
||||
std::vector<torch::Tensor> tensor_list;
|
||||
tensor_list.reserve(tensor_num);
|
||||
|
||||
for (int i = 0; i < tensor_num; ++i) {
|
||||
int64_t bytes = tensor_bytes[i];
|
||||
auto type = tensor_types[i];
|
||||
int64_t elem_bytes = torch::elementSize(type);
|
||||
|
||||
TORCH_CHECK_EQ(bytes % elem_bytes, 0);
|
||||
int64_t elem_num = bytes / elem_bytes;
|
||||
auto options = torch::TensorOptions().dtype(type).device(torch::kCPU);
|
||||
tensor_list.emplace_back(torch::empty({elem_num}, options));
|
||||
}
|
||||
return tensor_list;
|
||||
}
|
||||
|
||||
MemPiece get_data(int64_t offset) {
|
||||
for (int i = 0; i < tensor_num; ++i) {
|
||||
if (offset < tensor_bytes[i]) {
|
||||
return {reinterpret_cast<int8_t*>(tensor_ptrs[i]) + offset,
|
||||
tensor_bytes[i] - offset};
|
||||
}
|
||||
offset -= tensor_bytes[i];
|
||||
}
|
||||
return {nullptr, 0};
|
||||
}
|
||||
|
||||
private:
|
||||
void* tensor_ptrs[MAX_P2P_SEND_TENSOR_NUM];
|
||||
int8_t _padding[40];
|
||||
};
|
||||
|
||||
void shm_send_tensor_list_impl(ThreadSHMContext* ctx,
|
||||
const std::vector<torch::Tensor>& tensor_list) {
|
||||
CPU_KERNEL_GUARD_IN(shm_send_tensor_list_impl)
|
||||
std::vector<torch::Tensor> tensor_list_with_metadata;
|
||||
tensor_list_with_metadata.reserve(1 + tensor_list.size());
|
||||
|
||||
auto options = torch::TensorOptions().dtype(torch::kInt8).device(torch::kCPU);
|
||||
tensor_list_with_metadata.emplace_back(
|
||||
torch::empty({sizeof(TensorListMeta)}, options));
|
||||
tensor_list_with_metadata.insert(tensor_list_with_metadata.end(),
|
||||
tensor_list.begin(), tensor_list.end());
|
||||
|
||||
torch::Tensor& metadata_tensor = tensor_list_with_metadata[0];
|
||||
TORCH_CHECK_EQ(metadata_tensor.nbytes(), sizeof(TensorListMeta));
|
||||
|
||||
TensorListMeta* metadata = new (metadata_tensor.data_ptr()) TensorListMeta();
|
||||
metadata->bind_tensor_list(tensor_list_with_metadata);
|
||||
|
||||
shm_cc_ops::shm_cc_loop<int8_t>(
|
||||
ctx, metadata->total_bytes,
|
||||
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
|
||||
int64_t data_elem_num) {
|
||||
int rank = thread_ctx->rank;
|
||||
// Wait until the receiver set the stat to DONE
|
||||
thread_ctx->wait_for_one(rank, ThreadSHMStat::SHM_DATA_READY);
|
||||
|
||||
int64_t curr_shm_offset = 0;
|
||||
while (curr_shm_offset < data_elem_num) {
|
||||
MemPiece frag = metadata->get_data(data_offset + curr_shm_offset);
|
||||
frag.size = std::min(frag.size, data_elem_num - curr_shm_offset);
|
||||
shm_cc_ops::memcpy(
|
||||
thread_ctx->get_thread_shm_ptr<int8_t>(rank) + curr_shm_offset,
|
||||
frag.ptr, frag.size);
|
||||
curr_shm_offset += frag.size;
|
||||
}
|
||||
|
||||
thread_ctx->set_thread_stat(rank, ThreadSHMStat::SHM_DATA_READY);
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> shm_recv_tensor_list_impl(ThreadSHMContext* ctx,
|
||||
int64_t src) {
|
||||
CPU_KERNEL_GUARD_IN(shm_recv_tensor_list_impl)
|
||||
auto options = torch::TensorOptions().dtype(torch::kInt8).device(torch::kCPU);
|
||||
torch::Tensor metadata_tensor =
|
||||
torch::empty({sizeof(TensorListMeta)}, options);
|
||||
|
||||
// Wait until the sender set the stat of the thread 0 to SHM_DATA_READY
|
||||
ctx->wait_for_one(src, ThreadSHMStat::DONE);
|
||||
shm_cc_ops::memcpy(metadata_tensor.data_ptr(),
|
||||
ctx->get_thread_shm_ptr<void>(src),
|
||||
sizeof(TensorListMeta));
|
||||
TensorListMeta* src_metadata =
|
||||
reinterpret_cast<TensorListMeta*>(metadata_tensor.data_ptr());
|
||||
std::vector<torch::Tensor> tensor_list_with_metadata =
|
||||
src_metadata->generate_tensor_list();
|
||||
|
||||
TensorListMeta metadata;
|
||||
metadata.bind_tensor_list(tensor_list_with_metadata);
|
||||
TORCH_CHECK_EQ(metadata.tensor_num, src_metadata->tensor_num);
|
||||
TORCH_CHECK_EQ(metadata.total_bytes, src_metadata->total_bytes);
|
||||
|
||||
shm_cc_ops::shm_cc_loop<int8_t>(
|
||||
ctx, metadata.total_bytes,
|
||||
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
|
||||
int64_t data_elem_num) {
|
||||
// Wait until the sender set the stat to SHM_DATA_READY
|
||||
thread_ctx->wait_for_one(src, ThreadSHMStat::DONE);
|
||||
int64_t curr_shm_offset = 0;
|
||||
while (curr_shm_offset < data_elem_num) {
|
||||
MemPiece frag = metadata.get_data(data_offset + curr_shm_offset);
|
||||
frag.size = std::min(frag.size, data_elem_num - curr_shm_offset);
|
||||
shm_cc_ops::memcpy(
|
||||
frag.ptr,
|
||||
thread_ctx->get_thread_shm_ptr<int8_t>(src) + curr_shm_offset,
|
||||
frag.size);
|
||||
curr_shm_offset += frag.size;
|
||||
}
|
||||
|
||||
thread_ctx->set_thread_stat(src, ThreadSHMStat::DONE);
|
||||
});
|
||||
|
||||
std::vector<torch::Tensor> tensor_list;
|
||||
tensor_list.reserve(metadata.tensor_num - 1);
|
||||
tensor_list.insert(tensor_list.begin(), tensor_list_with_metadata.begin() + 1,
|
||||
tensor_list_with_metadata.end());
|
||||
|
||||
return tensor_list;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void shm_gather(int64_t handle, torch::Tensor& data,
|
||||
const std::optional<std::vector<torch::Tensor>>& outputs,
|
||||
int64_t dst) {
|
||||
TORCH_CHECK(data.is_contiguous())
|
||||
VLLM_DISPATCH_FLOATING_TYPES(data.scalar_type(), "shm_gather_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(shm_gather_impl)
|
||||
|
||||
if (outputs.has_value()) {
|
||||
TORCH_CHECK_LE(outputs->size(), MAX_SHM_RANK_NUM);
|
||||
scalar_t* output_ptrs[MAX_SHM_RANK_NUM] = {nullptr};
|
||||
for (int i = 0; i < outputs->size(); ++i) {
|
||||
output_ptrs[i] = outputs->at(i).data_ptr<scalar_t>();
|
||||
}
|
||||
shm_gather_impl(SHMManager::get_singleton_instance(handle)->get_shm_ctx(),
|
||||
data.data_ptr<scalar_t>(), data.numel(), output_ptrs,
|
||||
dst);
|
||||
} else {
|
||||
shm_gather_impl(SHMManager::get_singleton_instance(handle)->get_shm_ctx(),
|
||||
data.data_ptr<scalar_t>(), data.numel(), (scalar_t**)(0),
|
||||
dst);
|
||||
}
|
||||
|
||||
CPU_KERNEL_GUARD_OUT(shm_gather_impl)
|
||||
});
|
||||
}
|
||||
|
||||
void shm_all_gather(int64_t handle, const torch::Tensor& data,
|
||||
torch::Tensor& output) {
|
||||
TORCH_CHECK(data.is_contiguous())
|
||||
TORCH_CHECK(output.is_contiguous())
|
||||
|
||||
const int64_t input_elem_num = data.numel();
|
||||
const int64_t output_elem_num = output.numel();
|
||||
TORCH_CHECK_EQ(output_elem_num % input_elem_num, 0);
|
||||
const int world_size = output_elem_num / input_elem_num;
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(data.scalar_type(), "shm_all_gather_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(shm_all_gather_impl)
|
||||
auto ctx = SHMManager::get_singleton_instance(handle)->get_shm_ctx();
|
||||
TORCH_CHECK_EQ(ctx->group_size, world_size);
|
||||
|
||||
scalar_t* output_ptrs[MAX_SHM_RANK_NUM] = {nullptr};
|
||||
for (int i = 0; i < world_size; ++i) {
|
||||
output_ptrs[i] = output.data_ptr<scalar_t>() + i * input_elem_num;
|
||||
}
|
||||
shm_gather_impl(ctx, data.data_ptr<scalar_t>(), data.numel(), output_ptrs,
|
||||
ctx->rank);
|
||||
CPU_KERNEL_GUARD_OUT(shm_all_gather_impl)
|
||||
});
|
||||
}
|
||||
|
||||
void shm_allreduce(int64_t handle, torch::Tensor& data) {
|
||||
TORCH_CHECK(data.is_contiguous())
|
||||
VLLM_DISPATCH_FLOATING_TYPES(data.scalar_type(), "shm_allreduce_sum", [&] {
|
||||
CPU_KERNEL_GUARD_IN(shm_allreduce_sum)
|
||||
shm_allreduce_sum(SHMManager::get_singleton_instance(handle)->get_shm_ctx(),
|
||||
data.data_ptr<scalar_t>(), data.numel());
|
||||
CPU_KERNEL_GUARD_OUT(shm_allreduce_sum)
|
||||
});
|
||||
}
|
||||
|
||||
void shm_send_tensor_list(int64_t handle,
|
||||
const std::vector<torch::Tensor>& tensor_list,
|
||||
int64_t dst) {
|
||||
CPU_KERNEL_GUARD_IN(shm_send_tensor_list)
|
||||
shm_send_tensor_list_impl(
|
||||
SHMManager::get_singleton_instance(handle)->get_shm_ctx(), tensor_list);
|
||||
CPU_KERNEL_GUARD_OUT(shm_send_tensor_list)
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> shm_recv_tensor_list(int64_t handle, int64_t src) {
|
||||
CPU_KERNEL_GUARD_IN(shm_recv_tensor_list)
|
||||
auto tensor_list = shm_recv_tensor_list_impl(
|
||||
SHMManager::get_singleton_instance(handle)->get_shm_ctx(), src);
|
||||
CPU_KERNEL_GUARD_OUT(shm_recv_tensor_list)
|
||||
return tensor_list;
|
||||
}
|
||||
|
||||
int64_t init_shm_manager(const std::string& name, const int64_t group_size,
|
||||
const int64_t rank) {
|
||||
return SHMManager::create_singleton_instance(name, group_size, rank);
|
||||
}
|
||||
|
||||
std::string join_shm_manager(int64_t handle, const std::string& name) {
|
||||
auto shm_manager = SHMManager::get_singleton_instance(handle);
|
||||
TORCH_CHECK(shm_manager);
|
||||
shm_manager->join(name);
|
||||
return shm_manager->get_shm_ctx()->to_string();
|
||||
}
|
||||
@ -22,6 +22,26 @@ void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,
|
||||
torch::Tensor& kv_cache, double scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens);
|
||||
|
||||
int64_t init_shm_manager(const std::string& name, const int64_t group_size,
|
||||
const int64_t rank);
|
||||
|
||||
std::string join_shm_manager(int64_t handle, const std::string& name);
|
||||
|
||||
void shm_allreduce(int64_t handle, torch::Tensor& data);
|
||||
|
||||
void shm_gather(int64_t handle, torch::Tensor& data,
|
||||
const std::optional<std::vector<torch::Tensor>>& outputs,
|
||||
int64_t dst);
|
||||
|
||||
void shm_all_gather(int64_t handle, const torch::Tensor& data,
|
||||
torch::Tensor& output);
|
||||
|
||||
void shm_send_tensor_list(int64_t handle,
|
||||
const std::vector<torch::Tensor>& tensor_list,
|
||||
int64_t dst);
|
||||
|
||||
std::vector<torch::Tensor> shm_recv_tensor_list(int64_t handle, int64_t src);
|
||||
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// vLLM custom ops
|
||||
|
||||
@ -131,6 +151,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor? azp, Tensor? bias) -> ()");
|
||||
ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp);
|
||||
#endif
|
||||
|
||||
// SHM CCL
|
||||
#ifdef __AVX512F__
|
||||
ops.def("init_shm_manager(str name, int group_size, int rank) -> int",
|
||||
&init_shm_manager);
|
||||
ops.def("join_shm_manager(int handle, str name) -> str", &join_shm_manager);
|
||||
ops.def("shm_allreduce(int handle, Tensor! data) -> ()");
|
||||
ops.impl("shm_allreduce", torch::kCPU, &shm_allreduce);
|
||||
ops.def(
|
||||
"shm_gather(int handle, Tensor data, Tensor[](a!)? outputs, int dst) -> "
|
||||
"()");
|
||||
ops.impl("shm_gather", torch::kCPU, &shm_gather);
|
||||
ops.def(
|
||||
"shm_all_gather(int handle, Tensor data, Tensor! output) -> "
|
||||
"()");
|
||||
ops.impl("shm_all_gather", torch::kCPU, &shm_all_gather);
|
||||
ops.def(
|
||||
"shm_send_tensor_list(int handle, Tensor[](a) tensor_list, int dst) -> "
|
||||
"()");
|
||||
ops.impl("shm_send_tensor_list", torch::kCPU, &shm_send_tensor_list);
|
||||
ops.def("shm_recv_tensor_list(int handle, int src) -> Tensor[](a)",
|
||||
&shm_recv_tensor_list);
|
||||
#endif
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
|
||||
@ -4,6 +4,11 @@
|
||||
#include <string>
|
||||
#include <sched.h>
|
||||
#endif
|
||||
#if __GLIBC__ == 2 && __GLIBC_MINOR__ < 30
|
||||
#include <unistd.h>
|
||||
#include <sys/syscall.h>
|
||||
#define gettid() syscall(SYS_gettid)
|
||||
#endif
|
||||
|
||||
#include "cpu_types.hpp"
|
||||
|
||||
@ -18,7 +23,7 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) {
|
||||
|
||||
#ifndef VLLM_NUMA_DISABLED
|
||||
std::string init_cpu_threads_env(const std::string& cpu_ids) {
|
||||
bitmask* omp_cpu_mask = numa_parse_cpustring(cpu_ids.c_str());
|
||||
bitmask* omp_cpu_mask = numa_parse_cpustring_all(cpu_ids.c_str());
|
||||
TORCH_CHECK(omp_cpu_mask->size > 0);
|
||||
std::vector<int> omp_cpu_ids;
|
||||
omp_cpu_ids.reserve(omp_cpu_mask->size);
|
||||
|
||||
@ -12,7 +12,7 @@ static_assert(sizeof(void*) == sizeof(fptr_t));
|
||||
|
||||
fptr_t init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs,
|
||||
torch::Tensor& rank_data, int64_t rank,
|
||||
bool full_nvlink) {
|
||||
bool fully_connected) {
|
||||
int world_size = fake_ipc_ptrs.size();
|
||||
if (world_size > 8)
|
||||
throw std::invalid_argument("world size > 8 is not supported");
|
||||
@ -27,7 +27,7 @@ fptr_t init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs,
|
||||
}
|
||||
return (fptr_t) new vllm::CustomAllreduce(ipc_ptrs, rank_data.data_ptr(),
|
||||
rank_data.numel(), rank, world_size,
|
||||
full_nvlink);
|
||||
fully_connected);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -142,3 +142,48 @@ void register_graph_buffers(fptr_t _fa,
|
||||
bytes.reserve(handles.size());
|
||||
fa->register_graph_buffers(bytes, offsets);
|
||||
}
|
||||
|
||||
std::tuple<fptr_t, torch::Tensor> allocate_shared_buffer_and_handle(
|
||||
int64_t size) {
|
||||
auto device_index = c10::cuda::current_device();
|
||||
at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index));
|
||||
void* buffer;
|
||||
cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed;
|
||||
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
||||
AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
|
||||
|
||||
// Allocate buffer
|
||||
#if defined(USE_ROCM)
|
||||
// data buffers need to be "uncached" for signal on MI200
|
||||
AT_CUDA_CHECK(
|
||||
hipExtMallocWithFlags((void**)&buffer, size, hipDeviceMallocUncached));
|
||||
#else
|
||||
AT_CUDA_CHECK(cudaMalloc((void**)&buffer, size));
|
||||
#endif
|
||||
AT_CUDA_CHECK(cudaMemsetAsync(buffer, 0, size, stream));
|
||||
AT_CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
|
||||
|
||||
// Create IPC memhandle for the allocated buffer.
|
||||
// Will use it in open_mem_handle.
|
||||
auto options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
|
||||
auto handle =
|
||||
torch::empty({static_cast<int64_t>(sizeof(cudaIpcMemHandle_t))}, options);
|
||||
AT_CUDA_CHECK(
|
||||
cudaIpcGetMemHandle((cudaIpcMemHandle_t*)handle.data_ptr(), buffer));
|
||||
|
||||
return std::make_tuple(reinterpret_cast<fptr_t>(buffer), handle);
|
||||
}
|
||||
|
||||
fptr_t open_mem_handle(torch::Tensor& mem_handle) {
|
||||
void* ipc_ptr;
|
||||
AT_CUDA_CHECK(cudaIpcOpenMemHandle(
|
||||
(void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)mem_handle.data_ptr()),
|
||||
cudaIpcMemLazyEnablePeerAccess));
|
||||
return reinterpret_cast<fptr_t>(ipc_ptr);
|
||||
}
|
||||
|
||||
void free_shared_buffer(fptr_t buffer) {
|
||||
AT_CUDA_CHECK(cudaFree(reinterpret_cast<void*>(buffer)));
|
||||
}
|
||||
|
||||
@ -5,6 +5,10 @@
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
typedef __hip_bfloat16 nv_bfloat16;
|
||||
#endif
|
||||
|
||||
#include <iostream>
|
||||
#include <array>
|
||||
#include <limits>
|
||||
@ -12,6 +16,7 @@
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace vllm {
|
||||
#define CUDACHECK(cmd) \
|
||||
do { \
|
||||
cudaError_t e = cmd; \
|
||||
@ -22,24 +27,37 @@
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// Maximal number of blocks in allreduce kernel.
|
||||
constexpr int kMaxBlocks = 36;
|
||||
|
||||
// Default number of blocks in allreduce kernel.
|
||||
#ifndef USE_ROCM
|
||||
const int defaultBlockLimit = 36;
|
||||
CUpointer_attribute rangeStartAddrAttr = CU_POINTER_ATTRIBUTE_RANGE_START_ADDR;
|
||||
#else
|
||||
const int defaultBlockLimit = 16;
|
||||
hipPointer_attribute rangeStartAddrAttr =
|
||||
HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR;
|
||||
#endif
|
||||
|
||||
// Counter may overflow, but it's fine since unsigned int overflow is
|
||||
// well-defined behavior.
|
||||
using FlagType = uint32_t;
|
||||
|
||||
// Two sets of peer counters are needed for two syncs: starting and ending an
|
||||
// operation. The reason is that it's possible for peer GPU block to arrive at
|
||||
// the second sync point while the current GPU block haven't passed the first
|
||||
// sync point. Thus, peer GPU may write counter+1 while current GPU is busy
|
||||
// waiting for counter. We use alternating counter array to avoid this
|
||||
// possibility.
|
||||
struct Signal {
|
||||
alignas(128) FlagType self_counter[kMaxBlocks][8];
|
||||
// Two sets of peer counters are needed for two syncs. The reason is that
|
||||
// it's possible for peer GPU block to arrive at the second sync point while
|
||||
// the current GPU block haven't passed the first sync point. Thus, peer GPU
|
||||
// may write counter+1 while current GPU is busy waiting for counter. We use
|
||||
// alternating counter array to avoid this possibility.
|
||||
alignas(128) FlagType peer_counter[2][kMaxBlocks][8];
|
||||
alignas(128) FlagType start[kMaxBlocks][8];
|
||||
alignas(128) FlagType end[kMaxBlocks][8];
|
||||
alignas(128) FlagType _flag[kMaxBlocks]; // incremental flags for each rank
|
||||
};
|
||||
|
||||
struct __align__(16) RankData {
|
||||
const void* __restrict__ ptrs[8];
|
||||
const void* ptrs[8];
|
||||
};
|
||||
|
||||
struct __align__(16) RankSignals {
|
||||
@ -134,27 +152,29 @@ DINLINE O downcast(array_t<float, O::size> val) {
|
||||
}
|
||||
}
|
||||
|
||||
#if !defined(USE_ROCM)
|
||||
|
||||
static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
|
||||
asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag),
|
||||
"l"(flag_addr));
|
||||
#else
|
||||
#else
|
||||
asm volatile("membar.sys; st.volatile.global.u32 [%1], %0;" ::"r"(flag),
|
||||
"l"(flag_addr));
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
static DINLINE FlagType ld_flag_acquire(FlagType* flag_addr) {
|
||||
FlagType flag;
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
|
||||
asm volatile("ld.acquire.sys.global.u32 %0, [%1];"
|
||||
: "=r"(flag)
|
||||
: "l"(flag_addr));
|
||||
#else
|
||||
#else
|
||||
asm volatile("ld.volatile.global.u32 %0, [%1]; membar.gl;"
|
||||
: "=r"(flag)
|
||||
: "l"(flag_addr));
|
||||
#endif
|
||||
#endif
|
||||
return flag;
|
||||
}
|
||||
|
||||
@ -170,37 +190,99 @@ static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) {
|
||||
return flag;
|
||||
}
|
||||
|
||||
// is_start: whether this is the very first synchronization barrier.
|
||||
// need_fence: whether a memory fence is needed. If true, a release-acquire
|
||||
// semantic is used to enforce memory access order before and after this
|
||||
// barrier.
|
||||
template <int ngpus, bool is_start, bool need_fence = false>
|
||||
DINLINE void multi_gpu_barrier(const RankSignals& sg, Signal* self_sg,
|
||||
int rank) {
|
||||
if constexpr (!is_start) __syncthreads();
|
||||
static_assert(
|
||||
!(is_start && need_fence)); // Start barrier shouldn't need fence.
|
||||
// This function is meant to be used as the first synchronization in the all
|
||||
// reduce kernel. Thus, it doesn't need to make any visibility guarantees for
|
||||
// prior memory accesses. Note: volatile writes will not be reordered against
|
||||
// other volatile writes.
|
||||
template <int ngpus>
|
||||
DINLINE void barrier_at_start(const RankSignals& sg, Signal* self_sg,
|
||||
int rank) {
|
||||
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
|
||||
if (threadIdx.x < ngpus) {
|
||||
// Increment the counter. Technically we only need one counter, but we use
|
||||
// multiple per block to eliminate the need to share the counter via smem.
|
||||
auto val = self_sg->self_counter[blockIdx.x][threadIdx.x] += 1;
|
||||
auto peer_counter_ptr = &sg.signals[threadIdx.x]->start[blockIdx.x][rank];
|
||||
auto self_counter_ptr = &self_sg->start[blockIdx.x][threadIdx.x];
|
||||
// Write the expected counter value to peer and wait for correct value
|
||||
// from peer.
|
||||
st_flag_volatile(peer_counter_ptr, flag);
|
||||
while (ld_flag_volatile(self_counter_ptr) != flag);
|
||||
}
|
||||
__syncthreads();
|
||||
// use one thread to update flag
|
||||
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
|
||||
}
|
||||
|
||||
// This function is meant to be used as the second or the final
|
||||
// synchronization barrier in the all reduce kernel. If it's the final
|
||||
// synchronization barrier, we don't need to make any visibility guarantees
|
||||
// for prior memory accesses.
|
||||
template <int ngpus, bool final_sync = false>
|
||||
DINLINE void barrier_at_end(const RankSignals& sg, Signal* self_sg, int rank) {
|
||||
__syncthreads();
|
||||
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
|
||||
if (threadIdx.x < ngpus) {
|
||||
auto peer_counter_ptr = &sg.signals[threadIdx.x]->end[blockIdx.x][rank];
|
||||
auto self_counter_ptr = &self_sg->end[blockIdx.x][threadIdx.x];
|
||||
// Write the expected counter value to peer and wait for correct value from
|
||||
// peer.
|
||||
auto peer_counter_ptr =
|
||||
&sg.signals[threadIdx.x]->peer_counter[val % 2][blockIdx.x][rank];
|
||||
auto self_counter_ptr =
|
||||
&self_sg->peer_counter[val % 2][blockIdx.x][threadIdx.x];
|
||||
if constexpr (need_fence) {
|
||||
st_flag_release(peer_counter_ptr, val);
|
||||
while (ld_flag_acquire(self_counter_ptr) != val);
|
||||
if constexpr (!final_sync) {
|
||||
st_flag_release(peer_counter_ptr, flag);
|
||||
while (ld_flag_acquire(self_counter_ptr) != flag);
|
||||
} else {
|
||||
st_flag_volatile(peer_counter_ptr, val);
|
||||
while (ld_flag_volatile(self_counter_ptr) != val);
|
||||
st_flag_volatile(peer_counter_ptr, flag);
|
||||
while (ld_flag_volatile(self_counter_ptr) != flag);
|
||||
}
|
||||
}
|
||||
if constexpr (is_start || need_fence) __syncthreads();
|
||||
if constexpr (!final_sync) __syncthreads();
|
||||
|
||||
// use one thread to update flag
|
||||
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
template <int ngpus>
|
||||
DINLINE void barrier_at_start(const RankSignals& sg, Signal* self_sg,
|
||||
int rank) {
|
||||
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
|
||||
if (threadIdx.x < ngpus) {
|
||||
// simultaneously write to the corresponding flag of all ranks.
|
||||
// Latency = 1 p2p write
|
||||
__scoped_atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank],
|
||||
flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
|
||||
// wait until we got true from all ranks
|
||||
while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x],
|
||||
__ATOMIC_RELAXED,
|
||||
__MEMORY_SCOPE_DEVICE) < flag);
|
||||
}
|
||||
__syncthreads();
|
||||
// use one thread to update flag
|
||||
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
|
||||
}
|
||||
|
||||
template <int ngpus, bool final_sync = false>
|
||||
DINLINE void barrier_at_end(const RankSignals& sg, Signal* self_sg, int rank) {
|
||||
__syncthreads();
|
||||
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
|
||||
if (threadIdx.x < ngpus) {
|
||||
// simultaneously write to the corresponding flag of all ranks.
|
||||
// Latency = 1 p2p write
|
||||
__scoped_atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank],
|
||||
flag,
|
||||
final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
|
||||
__MEMORY_SCOPE_SYSTEM);
|
||||
// wait until we got true from all ranks
|
||||
while (
|
||||
__scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
|
||||
final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
|
||||
__MEMORY_SCOPE_DEVICE) < flag);
|
||||
}
|
||||
if constexpr (!final_sync) __syncthreads();
|
||||
// use one thread to update flag
|
||||
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
template <typename P, int ngpus, typename A>
|
||||
DINLINE P packed_reduce(const P* ptrs[], int idx) {
|
||||
A tmp = upcast(ptrs[0][idx]);
|
||||
@ -220,13 +302,13 @@ __global__ void __launch_bounds__(512, 1)
|
||||
// note: we don't reorder the address so the accumulation order is the same
|
||||
// for all ranks, ensuring bitwise identical results
|
||||
auto dp = *_dp;
|
||||
multi_gpu_barrier<ngpus, true>(sg, self_sg, rank);
|
||||
barrier_at_start<ngpus>(sg, self_sg, rank);
|
||||
// do the actual reduction
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||
idx += gridDim.x * blockDim.x) {
|
||||
((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
|
||||
}
|
||||
multi_gpu_barrier<ngpus, false>(sg, self_sg, rank);
|
||||
barrier_at_end<ngpus, true>(sg, self_sg, rank);
|
||||
}
|
||||
|
||||
template <typename P>
|
||||
@ -255,18 +337,20 @@ __global__ void __launch_bounds__(512, 1)
|
||||
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
|
||||
}
|
||||
auto tmp_out = tmps[0];
|
||||
multi_gpu_barrier<ngpus, true>(sg, self_sg, rank);
|
||||
barrier_at_start<ngpus>(sg, self_sg, rank);
|
||||
|
||||
// stage 1: reduce scatter
|
||||
for (int idx = start + tid; idx < end; idx += stride) {
|
||||
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
|
||||
}
|
||||
multi_gpu_barrier<ngpus, false, true>(sg, self_sg, rank);
|
||||
barrier_at_end<ngpus>(sg, self_sg, rank);
|
||||
|
||||
// stage 2: allgather. Note: it's important to match the tid between
|
||||
// the two stages, because visibility across devices is only guaranteed
|
||||
// between threads that have the same tid. If thread i computes the sum of
|
||||
// start + i in the first stage, then thread i also gathers start + i from all
|
||||
// ranks.
|
||||
// start + i in the first stage, then thread i also gathers start + i from
|
||||
// all ranks.
|
||||
|
||||
for (int idx = tid; idx < largest_part; idx += stride) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ngpus; i++) {
|
||||
@ -287,21 +371,22 @@ class CustomAllreduce {
|
||||
public:
|
||||
int rank_;
|
||||
int world_size_;
|
||||
bool full_nvlink_;
|
||||
// Full NVLink or xGMI connection between GPUs.
|
||||
bool fully_connected_;
|
||||
|
||||
RankSignals sg_;
|
||||
// Stores an map from a pointer to its peer pointters from all ranks.
|
||||
// Stores a map from a pointer to its peer pointers from all ranks.
|
||||
std::unordered_map<void*, RankData*> buffers_;
|
||||
Signal* self_sg_;
|
||||
|
||||
// Stores rank data from all ranks. This is mainly for cuda graph purposes.
|
||||
// For cuda graph to work, all kernel arguments must be fixed during graph
|
||||
// capture time. However, the peer pointers are not known during graph capture
|
||||
// time. Therefore, during capture, we increment the rank data pointer and use
|
||||
// that as the argument to the kernel. The kernel arguments are stored in
|
||||
// graph_unreg_buffers_. The actual peer pointers will be filled in at the
|
||||
// memory pointed to by the pointers in graph_unreg_buffers_ when
|
||||
// the IPC handles are exchanged between ranks.
|
||||
// capture time. However, the peer pointers are not known during graph
|
||||
// capture time. Therefore, during capture, we increment the rank data
|
||||
// pointer and use that as the argument to the kernel. The kernel arguments
|
||||
// are stored in graph_unreg_buffers_. The actual peer pointers will be
|
||||
// filled in at the memory pointed to by the pointers in
|
||||
// graph_unreg_buffers_ when the IPC handles are exchanged between ranks.
|
||||
//
|
||||
// The overall process looks like this:
|
||||
// 1. Graph capture.
|
||||
@ -319,17 +404,18 @@ class CustomAllreduce {
|
||||
* Signals are an array of ipc-enabled buffers from all ranks.
|
||||
* For each of the buffer, the layout is as follows:
|
||||
* | -- sizeof(Signal) -- | ------ a few MB ----- |
|
||||
* The first section is for allreduce synchronization, and the second section
|
||||
* is for storing the intermediate results required by some allreduce algos.
|
||||
* The first section is for allreduce synchronization, and the second
|
||||
* section is for storing the intermediate results required by some
|
||||
* allreduce algos.
|
||||
*
|
||||
* Note: this class does not own any device memory. Any required buffers
|
||||
* are passed in from the constructor.
|
||||
*/
|
||||
CustomAllreduce(Signal** signals, void* rank_data, size_t rank_data_sz,
|
||||
int rank, int world_size, bool full_nvlink = true)
|
||||
int rank, int world_size, bool fully_connected = true)
|
||||
: rank_(rank),
|
||||
world_size_(world_size),
|
||||
full_nvlink_(full_nvlink),
|
||||
fully_connected_(fully_connected),
|
||||
self_sg_(signals[rank]),
|
||||
d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
|
||||
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
|
||||
@ -361,8 +447,7 @@ class CustomAllreduce {
|
||||
void* base_ptr;
|
||||
// note: must share the base address of each allocation, or we get wrong
|
||||
// address
|
||||
if (cuPointerGetAttribute(&base_ptr,
|
||||
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
|
||||
if (cuPointerGetAttribute(&base_ptr, rangeStartAddrAttr,
|
||||
(CUdeviceptr)ptr) != CUDA_SUCCESS)
|
||||
throw std::runtime_error("failed to get pointer attr");
|
||||
CUDACHECK(cudaIpcGetMemHandle(
|
||||
@ -396,11 +481,11 @@ class CustomAllreduce {
|
||||
|
||||
// Note: when registering graph buffers, we intentionally choose to not
|
||||
// deduplicate the addresses. That means if the allocator reuses some
|
||||
// addresses, they will be registered again. This is to account for the remote
|
||||
// possibility of different allocation patterns between ranks. For example,
|
||||
// rank 1 may get the same input address for the second allreduce, but rank 2
|
||||
// got a different address. IPC handles have internal reference counting
|
||||
// mechanism so overhead should be small.
|
||||
// addresses, they will be registered again. This is to account for the
|
||||
// remote possibility of different allocation patterns between ranks. For
|
||||
// example, rank 1 may get the same input address for the second allreduce,
|
||||
// but rank 2 got a different address. IPC handles have internal reference
|
||||
// counting mechanism so overhead should be small.
|
||||
void register_graph_buffers(
|
||||
const std::vector<std::string>& handles,
|
||||
const std::vector<std::vector<int64_t>>& offsets) {
|
||||
@ -431,15 +516,15 @@ class CustomAllreduce {
|
||||
/**
|
||||
* Performs allreduce, assuming input has already been registered.
|
||||
*
|
||||
* Block and grid default configs are results after careful grid search. Using
|
||||
* 36 blocks give the best or close to the best runtime on the devices I
|
||||
* tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also only
|
||||
* take a small amount of SMs. Not quite sure the underlying reason, but my
|
||||
* guess is that too many SMs will cause contention on NVLink bus.
|
||||
* Block and grid default configs are results after careful grid search.
|
||||
* Using 36 blocks give the best or close to the best runtime on the devices
|
||||
* I tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also
|
||||
* only take a small amount of SMs. Not quite sure the underlying reason,
|
||||
* but my guess is that too many SMs will cause contention on NVLink bus.
|
||||
*/
|
||||
template <typename T>
|
||||
void allreduce(cudaStream_t stream, T* input, T* output, int size,
|
||||
int threads = 512, int block_limit = 36) {
|
||||
int threads = 512, int block_limit = defaultBlockLimit) {
|
||||
auto d = packed_t<T>::P::size;
|
||||
if (size % d != 0)
|
||||
throw std::runtime_error(
|
||||
@ -473,13 +558,11 @@ class CustomAllreduce {
|
||||
#define KL(ngpus, name) \
|
||||
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
|
||||
rank_, size);
|
||||
// TODO(hanzhi713): Threshold is different for A100 and H100.
|
||||
// Add per device threshold.
|
||||
#define REDUCE_CASE(ngpus) \
|
||||
case ngpus: { \
|
||||
if (world_size_ == 2) { \
|
||||
KL(ngpus, cross_device_reduce_1stage); \
|
||||
} else if (full_nvlink_) { \
|
||||
} else if (fully_connected_) { \
|
||||
if ((world_size_ <= 4 && bytes < 512 * 1024) || \
|
||||
(world_size_ <= 8 && bytes < 256 * 1024)) { \
|
||||
KL(ngpus, cross_device_reduce_1stage); \
|
||||
@ -497,7 +580,8 @@ class CustomAllreduce {
|
||||
REDUCE_CASE(8)
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
|
||||
"custom allreduce only supports num gpus in (2,4,6,8). Actual "
|
||||
"num "
|
||||
"gpus = " +
|
||||
std::to_string(world_size_));
|
||||
}
|
||||
@ -511,10 +595,11 @@ class CustomAllreduce {
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
|
||||
a template instantiation:
|
||||
* To inspect PTX/SASS, copy paste this header file to compiler explorer and
|
||||
add a template instantiation:
|
||||
* template void vllm::CustomAllreduce::allreduce<half>(cudaStream_t, half *,
|
||||
half *, int, int, int);
|
||||
*/
|
||||
} // namespace vllm
|
||||
} // namespace vllm
|
||||
@ -1,9 +1,9 @@
|
||||
/**
|
||||
* This is a standalone test for custom allreduce.
|
||||
* To compile, make sure you have MPI and NCCL installed in your system.
|
||||
* export MPI_HOME=xxx
|
||||
* export MPI_HOME=XXX
|
||||
* nvcc -O2 -arch=native -std=c++17 custom_all_reduce_test.cu -o
|
||||
* custom_all_reduce_test -lnccl -I${MPI_HOME} -lmpi
|
||||
* custom_all_reduce_test -lnccl -I${MPI_HOME}/include -lmpi
|
||||
*
|
||||
* Warning: this C++ test is not designed to be very readable and was used
|
||||
* during the rapid prototyping process.
|
||||
@ -22,7 +22,15 @@
|
||||
#include "cuda_profiler_api.h"
|
||||
#include "custom_all_reduce.cuh"
|
||||
#include "mpi.h"
|
||||
#include "nccl.h"
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_bf16.h>
|
||||
typedef __hip_bfloat16 nv_bfloat16;
|
||||
#include "rccl/rccl.h"
|
||||
#include "custom_all_reduce_hip.cuh"
|
||||
#else
|
||||
#include "nccl.h"
|
||||
#include "custom_all_reduce.cuh"
|
||||
#endif
|
||||
|
||||
#define MPICHECK(cmd) \
|
||||
do { \
|
||||
@ -43,16 +51,29 @@
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#ifdef USE_ROCM
|
||||
__global__ void dummy_kernel() {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
|
||||
for (int i = 0; i < 100; i++) {
|
||||
uint64_t start = wall_clock64();
|
||||
uint64_t cycles_elapsed;
|
||||
do {
|
||||
cycles_elapsed = wall_clock64() - start;
|
||||
} while (cycles_elapsed < 100);
|
||||
}
|
||||
for (int i = 0; i < 100; i++) __nanosleep(1000000); // 100ms
|
||||
}
|
||||
#else
|
||||
__global__ void dummy_kernel() {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
|
||||
for (int i = 0; i < 100; i++) __nanosleep(1000000); // 100ms
|
||||
#else
|
||||
for (int i = 0; i < 100; i++) {
|
||||
long long int start = clock64();
|
||||
while (clock64() - start < 150000000); // approximately 98.4ms on P40
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
__global__ void set_data(T* data, int size, int myRank) {
|
||||
@ -121,8 +142,14 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
|
||||
* registration, they are allocated and registered together in the test for
|
||||
* convenience.
|
||||
*/
|
||||
#ifdef USE_ROCM
|
||||
CUDACHECK(hipExtMallocWithFlags(
|
||||
(void**)&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Signal),
|
||||
hipDeviceMallocUncached));
|
||||
#else
|
||||
CUDACHECK(
|
||||
cudaMalloc(&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Signal)));
|
||||
#endif
|
||||
CUDACHECK(
|
||||
cudaMemset(buffer, 0, 2 * data_size * sizeof(T) + sizeof(vllm::Signal)));
|
||||
CUDACHECK(cudaMalloc(&self_data_copy, data_size * sizeof(T)));
|
||||
@ -311,13 +338,18 @@ int main(int argc, char** argv) {
|
||||
|
||||
bool performance_test = true;
|
||||
cudaProfilerStart();
|
||||
// Uncomment to scan through different block size configs.
|
||||
// for (int threads : {256, 512, 1024}) {
|
||||
// for (int block_limit = 16; block_limit < 112; block_limit += 4) {
|
||||
// run<half>(myRank, nRanks, comm, threads, block_limit, 1024 * 1024,
|
||||
// performance_test);
|
||||
// }
|
||||
// }
|
||||
// Uncomment to scan through different block size configs.
|
||||
// for (int threads : {256, 512, 1024}) {
|
||||
// for (int block_limit = 16; block_limit < 112; block_limit += 4) {
|
||||
// run<half>(myRank, nRanks, comm, threads, block_limit, 1024 * 1024,
|
||||
// performance_test);
|
||||
// }
|
||||
// }
|
||||
#ifdef USE_ROCM
|
||||
const int block_limit = 16;
|
||||
#else
|
||||
const int block_limit = 36;
|
||||
#endif
|
||||
// Scan through different sizes to test performance.
|
||||
for (int sz = 512; sz <= (8 << 20); sz *= 2) {
|
||||
run<half>(myRank, nRanks, comm, 512, 36, sz + 8 * 47, performance_test);
|
||||
@ -326,4 +358,4 @@ int main(int argc, char** argv) {
|
||||
cudaProfilerStop();
|
||||
MPICHECK(MPI_Finalize());
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
}
|
||||
@ -422,7 +422,7 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
|
||||
int final_state_position = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize);
|
||||
// in case the final state is separated between the last "smem_exchange" and
|
||||
// and the one before it (chunk = n_chunks - 1 and chunk = n_chunks - 2),
|
||||
// (which occurs when `final_state_position` is a non-positivie index)
|
||||
// (which occurs when `final_state_position` is a non-positive index)
|
||||
// we load the correct data from smem_exchange from both chunks, the last chunk iteration and the one before it
|
||||
if (conv_states != nullptr && final_state_position < 0 && seqlen > kWidth){
|
||||
input_t vals_load[kNElts] = {0};
|
||||
|
||||
103
csrc/moe/marlin_moe_wna16/generate_kernels.py
Normal file
103
csrc/moe/marlin_moe_wna16/generate_kernels.py
Normal file
@ -0,0 +1,103 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import glob
|
||||
import itertools
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
import jinja2
|
||||
|
||||
FILE_HEAD = """
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
""".strip()
|
||||
|
||||
TEMPLATE = ("template __global__ void Marlin<"
|
||||
"{{scalar_t}}, "
|
||||
"{{w_type_id}}, "
|
||||
"{{threads}}, "
|
||||
"{{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, "
|
||||
"{{thread_k_blocks}}, "
|
||||
"{{'true' if m_block_size_8 else 'false'}}, "
|
||||
"{{stages}}, "
|
||||
"{{'true' if has_act_order else 'false'}}, "
|
||||
"{{'true' if has_zp else 'false'}}, "
|
||||
"{{group_blocks}}, "
|
||||
"{{'true' if is_zp_float else 'false'}}>"
|
||||
"( MARLIN_KERNEL_PARAMS );")
|
||||
|
||||
# int8 with zero point case (vllm::kU8) is also supported,
|
||||
# we don't add it to reduce wheel size.
|
||||
SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128"]
|
||||
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
|
||||
|
||||
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
|
||||
# group_blocks:
|
||||
# = 0 : act order case
|
||||
# = -1 : channelwise quantization
|
||||
# > 0 : group_size=16*group_blocks
|
||||
GROUP_BLOCKS = [0, -1, 2, 4, 8]
|
||||
DTYPES = ["fp16", "bf16"]
|
||||
|
||||
|
||||
def remove_old_kernels():
|
||||
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"):
|
||||
subprocess.call(["rm", "-f", filename])
|
||||
|
||||
|
||||
def generate_new_kernels():
|
||||
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
|
||||
has_zp = "B" not in scalar_type
|
||||
all_template_str_list = []
|
||||
|
||||
for group_blocks, m_blocks, thread_configs in itertools.product(
|
||||
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS):
|
||||
|
||||
has_act_order = group_blocks == 0
|
||||
if has_zp and has_act_order:
|
||||
continue
|
||||
if thread_configs[2] == 256:
|
||||
if m_blocks <= 1 and thread_configs[0] != 128:
|
||||
continue
|
||||
if m_blocks > 1 and thread_configs[0] != 64:
|
||||
continue
|
||||
|
||||
k_blocks = thread_configs[0] // 16
|
||||
n_blocks = thread_configs[1] // 16
|
||||
threads = thread_configs[2]
|
||||
|
||||
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
|
||||
|
||||
template_str = jinja2.Template(TEMPLATE).render(
|
||||
scalar_t=c_dtype,
|
||||
w_type_id=scalar_type + ".id()",
|
||||
threads=threads,
|
||||
thread_m_blocks=max(m_blocks, 1),
|
||||
thread_n_blocks=n_blocks,
|
||||
thread_k_blocks=k_blocks,
|
||||
m_block_size_8=m_blocks == 0.5,
|
||||
stages="pipe_stages",
|
||||
has_act_order=has_act_order,
|
||||
has_zp=has_zp,
|
||||
group_blocks=group_blocks,
|
||||
is_zp_float=False,
|
||||
)
|
||||
|
||||
all_template_str_list.append(template_str)
|
||||
|
||||
file_content = FILE_HEAD + "\n\n"
|
||||
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
|
||||
filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu"
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
|
||||
f.write(file_content)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
remove_old_kernels()
|
||||
generate_new_kernels()
|
||||
44
csrc/moe/marlin_moe_wna16/kernel.h
Normal file
44
csrc/moe/marlin_moe_wna16/kernel.h
Normal file
@ -0,0 +1,44 @@
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#endif
|
||||
|
||||
#include "quantization/gptq_marlin/marlin.cuh"
|
||||
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
#define MARLIN_KERNEL_PARAMS \
|
||||
const int4 *__restrict__ A, const int4 *__restrict__ B, \
|
||||
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
|
||||
const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \
|
||||
const int *__restrict__ g_idx, \
|
||||
const int32_t *__restrict__ sorted_token_ids_ptr, \
|
||||
const int32_t *__restrict__ expert_ids_ptr, \
|
||||
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
|
||||
const float *__restrict__ topk_weights_ptr, int top_k, \
|
||||
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
|
||||
int prob_n, int prob_k, int *locks, bool use_atomic_add, \
|
||||
bool use_fp32_reduce
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
||||
const int threads, // number of threads in a threadblock
|
||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||
// dimension (batchsize) of the
|
||||
// threadblock
|
||||
const int thread_n_blocks, // same for n dimension (output)
|
||||
const int thread_k_blocks, // same for k dimension (reduction)
|
||||
const bool m_block_size_8, // whether m_block_size == 8
|
||||
// only works when thread_m_blocks == 1
|
||||
const int stages, // number of stages for the async global->shared
|
||||
// fetch pipeline
|
||||
const bool has_act_order, // whether act_order is enabled
|
||||
const bool has_zp, // whether zero-points are enabled
|
||||
const int group_blocks, // number of consecutive 16x16 blocks
|
||||
// with a separate quantization scale
|
||||
const bool is_zp_float // is zero point of float16 type?
|
||||
>
|
||||
__global__ void Marlin(MARLIN_KERNEL_PARAMS);
|
||||
|
||||
}
|
||||
1917
csrc/moe/marlin_moe_wna16/marlin_template.h
Normal file
1917
csrc/moe/marlin_moe_wna16/marlin_template.h
Normal file
File diff suppressed because it is too large
Load Diff
927
csrc/moe/marlin_moe_wna16/ops.cu
Normal file
927
csrc/moe/marlin_moe_wna16/ops.cu
Normal file
@ -0,0 +1,927 @@
|
||||
/*
|
||||
* Modified by Neural Magic
|
||||
* Copyright (C) Marlin.2024 Elias Frantar
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/*
|
||||
* Adapted from https://github.com/IST-DASLab/marlin
|
||||
*/
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#endif
|
||||
|
||||
#include "kernel.h"
|
||||
#include "core/registration.h"
|
||||
|
||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||
static_assert(std::is_same<scalar_t, half>::value || \
|
||||
std::is_same<scalar_t, nv_bfloat16>::value, \
|
||||
"only float16 and bfloat16 is supported");
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
__global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){};
|
||||
|
||||
using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS);
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
|
||||
template <int moe_block_size>
|
||||
__global__ void permute_cols_kernel(
|
||||
int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr,
|
||||
int4* __restrict__ out_int4_ptr,
|
||||
const int32_t* __restrict__ sorted_token_ids_ptr,
|
||||
const int32_t* __restrict__ expert_ids_ptr,
|
||||
const int32_t* __restrict__ num_tokens_past_padded_ptr, int size_m,
|
||||
int size_k, int top_k) {};
|
||||
|
||||
} // namespace marlin
|
||||
|
||||
torch::Tensor moe_wna16_marlin_gemm(
|
||||
torch::Tensor& a, std::optional<torch::Tensor> const& c_or_none,
|
||||
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
|
||||
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
|
||||
torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids,
|
||||
torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights,
|
||||
int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep,
|
||||
vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n,
|
||||
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
|
||||
bool is_zp_float) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||
"marlin_gemm(..) requires CUDA_ARCH >= 8.0");
|
||||
return torch::empty({1, 1});
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
// For a given "a" of size [M,K] performs a permutation of the K columns based
|
||||
// on the given "perm" indices.
|
||||
template <int moe_block_size>
|
||||
__global__ void permute_cols_kernel(
|
||||
int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr,
|
||||
int4* __restrict__ out_int4_ptr,
|
||||
const int32_t* __restrict__ sorted_token_ids_ptr,
|
||||
const int32_t* __restrict__ expert_ids_ptr,
|
||||
const int32_t* __restrict__ num_tokens_past_padded_ptr, int size_m,
|
||||
int size_k, int top_k) {
|
||||
int num_tokens_past_padded = num_tokens_past_padded_ptr[0];
|
||||
int num_moe_blocks = div_ceil(num_tokens_past_padded, moe_block_size);
|
||||
int32_t block_sorted_ids[moe_block_size];
|
||||
int block_num_valid_tokens = 0;
|
||||
int64_t old_expert_id = 0;
|
||||
int64_t expert_id = 0;
|
||||
int row_stride = size_k * sizeof(half) / 16;
|
||||
|
||||
auto read_moe_block_data = [&](int block_id) {
|
||||
block_num_valid_tokens = moe_block_size;
|
||||
int4* tmp_block_sorted_ids = reinterpret_cast<int4*>(block_sorted_ids);
|
||||
for (int i = 0; i < moe_block_size / 4; i++) {
|
||||
tmp_block_sorted_ids[i] =
|
||||
((int4*)sorted_token_ids_ptr)[block_id * moe_block_size / 4 + i];
|
||||
}
|
||||
for (int i = 0; i < moe_block_size; i++) {
|
||||
if (block_sorted_ids[i] >= size_m * top_k) {
|
||||
block_num_valid_tokens = i;
|
||||
break;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
auto permute_row = [&](int row) {
|
||||
int iters = size_k / default_threads;
|
||||
int rest = size_k % default_threads;
|
||||
|
||||
int in_offset = (row / top_k) * row_stride;
|
||||
int out_offset = row * row_stride;
|
||||
|
||||
half const* a_row_half =
|
||||
reinterpret_cast<half const*>(a_int4_ptr + in_offset);
|
||||
half* out_half = reinterpret_cast<half*>(out_int4_ptr + out_offset);
|
||||
|
||||
int base_k = 0;
|
||||
|
||||
for (int i = 0; i < iters; i++) {
|
||||
int cur_k = base_k + threadIdx.x;
|
||||
int src_pos = perm_int_ptr[cur_k];
|
||||
|
||||
out_half[cur_k] = a_row_half[src_pos];
|
||||
|
||||
base_k += default_threads;
|
||||
}
|
||||
|
||||
if (rest) {
|
||||
if (threadIdx.x < rest) {
|
||||
int cur_k = base_k + threadIdx.x;
|
||||
int src_pos = perm_int_ptr[cur_k];
|
||||
|
||||
out_half[cur_k] = a_row_half[src_pos];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
for (int index = blockIdx.x; index < num_moe_blocks; index += gridDim.x) {
|
||||
old_expert_id = expert_id;
|
||||
int tmp_expert_id = expert_ids_ptr[index];
|
||||
if (tmp_expert_id == -1) continue;
|
||||
expert_id = tmp_expert_id;
|
||||
perm_int_ptr += (expert_id - old_expert_id) * size_k;
|
||||
read_moe_block_data(index);
|
||||
|
||||
for (int i = 0; i < block_num_valid_tokens; i++)
|
||||
permute_row(block_sorted_ids[i]);
|
||||
}
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
int thread_k;
|
||||
int thread_n;
|
||||
int num_threads;
|
||||
} thread_config_t;
|
||||
|
||||
thread_config_t small_batch_thread_configs[] = {
|
||||
// Ordered by priority
|
||||
|
||||
// thread_k, thread_n, num_threads
|
||||
{128, 128, 256},
|
||||
{64, 128, 128}};
|
||||
|
||||
thread_config_t large_batch_thread_configs[] = {
|
||||
// Ordered by priority
|
||||
|
||||
// thread_k, thread_n, num_threads
|
||||
{64, 256, 256},
|
||||
{64, 128, 128}};
|
||||
|
||||
typedef struct {
|
||||
int blocks_per_sm;
|
||||
thread_config_t tb_cfg;
|
||||
} exec_config_t;
|
||||
|
||||
int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
|
||||
int prob_n, int prob_k, int num_bits, int group_size,
|
||||
bool has_act_order, bool is_k_full) {
|
||||
bool cache_scales_chunk = has_act_order && !is_k_full;
|
||||
|
||||
int tb_n = th_config.thread_n;
|
||||
int tb_k = th_config.thread_k;
|
||||
|
||||
// Get max scale groups per thread-block
|
||||
int tb_groups;
|
||||
if (group_size == -1) {
|
||||
tb_groups = 1;
|
||||
} else if (group_size == 0) {
|
||||
tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size
|
||||
} else {
|
||||
tb_groups = div_ceil(tb_k, group_size);
|
||||
}
|
||||
|
||||
if (cache_scales_chunk) {
|
||||
int load_groups =
|
||||
tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K
|
||||
load_groups = max(load_groups, 32); // We load at least 32 scale groups
|
||||
return load_groups * tb_n * 2;
|
||||
|
||||
} else {
|
||||
int tb_scales = tb_groups * tb_n * 2;
|
||||
|
||||
return tb_scales * pipe_stages;
|
||||
}
|
||||
}
|
||||
|
||||
int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
|
||||
int prob_m, int prob_n, int prob_k, int num_bits,
|
||||
int group_size, bool has_act_order, bool is_k_full,
|
||||
int has_zp, int is_zp_float) {
|
||||
int pack_factor = 32 / num_bits;
|
||||
|
||||
// Get B size
|
||||
int tb_k = th_config.thread_k;
|
||||
int tb_n = th_config.thread_n;
|
||||
int tb_m = thread_m_blocks * 16;
|
||||
|
||||
// shm size for block_sorted_ids/block_topk_weights
|
||||
// both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)
|
||||
int sh_block_meta_size = tb_m * 4 * 2;
|
||||
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2;
|
||||
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
|
||||
int sh_s_size =
|
||||
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
|
||||
group_size, has_act_order, is_k_full);
|
||||
int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0;
|
||||
int sh_zp_size = 0;
|
||||
if (has_zp) {
|
||||
if (is_zp_float)
|
||||
sh_zp_size = sh_s_size;
|
||||
else if (num_bits == 4)
|
||||
sh_zp_size = sh_s_size / 4;
|
||||
else if (num_bits == 8)
|
||||
sh_zp_size = sh_s_size / 2;
|
||||
}
|
||||
|
||||
int total_size = sh_a_size + sh_b_size + sh_s_size + sh_zp_size +
|
||||
sh_g_idx_size + sh_block_meta_size;
|
||||
|
||||
return total_size;
|
||||
}
|
||||
|
||||
bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
|
||||
int prob_m, int prob_n, int prob_k, int num_bits,
|
||||
int group_size, bool has_act_order, bool is_k_full,
|
||||
int has_zp, int is_zp_float, int max_shared_mem) {
|
||||
// Sanity
|
||||
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
|
||||
th_config.num_threads == -1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Verify K/N are divisible by thread K/N
|
||||
if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Verify min for thread K/N
|
||||
if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// num_threads must be at least 128 (= 4 warps)
|
||||
if (th_config.num_threads < 128) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check that pipeline fits into cache
|
||||
int cache_size = get_kernel_cache_size(
|
||||
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size,
|
||||
has_act_order, is_k_full, has_zp, is_zp_float);
|
||||
return cache_size <= max_shared_mem;
|
||||
}
|
||||
|
||||
#define __GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
||||
M_BLOCK_SIZE_8, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
|
||||
NUM_THREADS, IS_ZP_FLOAT) \
|
||||
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
|
||||
thread_n_blocks == THREAD_N_BLOCKS && \
|
||||
thread_k_blocks == THREAD_K_BLOCKS && \
|
||||
m_block_size_8 == M_BLOCK_SIZE_8 && \
|
||||
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
|
||||
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
|
||||
is_zp_float == IS_ZP_FLOAT) { \
|
||||
kernel = Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
|
||||
THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, \
|
||||
pipe_stages, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
|
||||
IS_ZP_FLOAT>; \
|
||||
}
|
||||
|
||||
#define GPTQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, true, false, 0, NUM_THREADS, \
|
||||
false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
|
||||
NUM_THREADS, false) \
|
||||
\
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 8, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
|
||||
NUM_THREADS, false)
|
||||
|
||||
#define GPTQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
|
||||
NUM_THREADS, false) \
|
||||
\
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
|
||||
NUM_THREADS, false) \
|
||||
\
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
|
||||
NUM_THREADS, false) \
|
||||
\
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
|
||||
NUM_THREADS, false)
|
||||
|
||||
#define AWQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 2, NUM_THREADS, \
|
||||
false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \
|
||||
false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 8, NUM_THREADS, \
|
||||
false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
|
||||
NUM_THREADS, false)
|
||||
|
||||
#define AWQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
|
||||
NUM_THREADS, false) \
|
||||
\
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
|
||||
NUM_THREADS, false) \
|
||||
\
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
|
||||
NUM_THREADS, false)
|
||||
|
||||
// We currently have 4-bit models only with group_blocks == 4
|
||||
#define HQQ_GET_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \
|
||||
true) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||
NUM_THREADS, true) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||
NUM_THREADS, true) \
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||
NUM_THREADS, true) \
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||
NUM_THREADS, true)
|
||||
|
||||
template <typename scalar_t>
|
||||
MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
|
||||
int thread_m_blocks, int thread_n_blocks,
|
||||
int thread_k_blocks, bool m_block_size_8,
|
||||
bool has_act_order, bool has_zp,
|
||||
int group_blocks, int num_threads,
|
||||
bool is_zp_float) {
|
||||
int num_bits = q_type.size_bits();
|
||||
auto kernel = MarlinDefault;
|
||||
if (false) {
|
||||
}
|
||||
GPTQ_GET_IF_M1(vllm::kU4B8, 8, 8, 256)
|
||||
GPTQ_GET_IF_M1(vllm::kU4B8, 8, 4, 128)
|
||||
|
||||
GPTQ_GET_IF_M234(vllm::kU4B8, 16, 4, 256)
|
||||
GPTQ_GET_IF_M234(vllm::kU4B8, 8, 4, 128)
|
||||
|
||||
GPTQ_GET_IF_M1(vllm::kU8B128, 8, 8, 256)
|
||||
GPTQ_GET_IF_M1(vllm::kU8B128, 8, 4, 128)
|
||||
|
||||
GPTQ_GET_IF_M234(vllm::kU8B128, 16, 4, 256)
|
||||
GPTQ_GET_IF_M234(vllm::kU8B128, 8, 4, 128)
|
||||
|
||||
AWQ_GET_IF_M1(vllm::kU4, 8, 8, 256)
|
||||
AWQ_GET_IF_M1(vllm::kU4, 8, 4, 128)
|
||||
|
||||
AWQ_GET_IF_M234(vllm::kU4, 16, 4, 256)
|
||||
AWQ_GET_IF_M234(vllm::kU4, 8, 4, 128)
|
||||
|
||||
return kernel;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
|
||||
int prob_n, int prob_k, int thread_m_blocks,
|
||||
bool m_block_size_8, int num_bits,
|
||||
int group_size, bool has_act_order,
|
||||
bool is_k_full, bool has_zp,
|
||||
bool is_zp_float, int max_shared_mem) {
|
||||
exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}};
|
||||
thread_config_t* thread_configs = thread_m_blocks > 1
|
||||
? large_batch_thread_configs
|
||||
: small_batch_thread_configs;
|
||||
int thread_configs_size =
|
||||
thread_m_blocks > 1
|
||||
? sizeof(large_batch_thread_configs) / sizeof(thread_config_t)
|
||||
: sizeof(small_batch_thread_configs) / sizeof(thread_config_t);
|
||||
|
||||
int count = 0;
|
||||
constexpr int device_max_reg_size = 255 * 1024;
|
||||
for (int i = 0; i < thread_configs_size; i++) {
|
||||
thread_config_t th_config = thread_configs[i];
|
||||
|
||||
if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k,
|
||||
num_bits, group_size, has_act_order, is_k_full, has_zp,
|
||||
is_zp_float, max_shared_mem)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int cache_size = get_kernel_cache_size(
|
||||
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits,
|
||||
group_size, has_act_order, is_k_full, has_zp, is_zp_float);
|
||||
|
||||
int group_blocks = 0;
|
||||
if (!has_act_order) {
|
||||
group_blocks = group_size == -1 ? -1 : group_size / 16;
|
||||
}
|
||||
|
||||
auto kernel = get_marlin_kernel<scalar_t>(
|
||||
q_type, thread_m_blocks, th_config.thread_n / 16,
|
||||
th_config.thread_k / 16, m_block_size_8, has_act_order, has_zp,
|
||||
group_blocks, th_config.num_threads, is_zp_float);
|
||||
|
||||
if (kernel == MarlinDefault) continue;
|
||||
|
||||
if (thread_m_blocks > 1) {
|
||||
exec_cfg = {1, th_config};
|
||||
break;
|
||||
} else {
|
||||
cudaFuncAttributes attr;
|
||||
cudaFuncGetAttributes(&attr, kernel);
|
||||
int reg_size = max(attr.numRegs, 1) * th_config.num_threads * 4;
|
||||
int allow_count = min(device_max_reg_size / reg_size,
|
||||
max_shared_mem / (cache_size + 1024));
|
||||
allow_count = max(min(allow_count, 4), 1);
|
||||
if (allow_count > count) {
|
||||
count = allow_count;
|
||||
exec_cfg = {count, th_config};
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
return exec_cfg;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
void* zp, void* g_idx, void* perm, void* a_tmp,
|
||||
void* sorted_token_ids, void* expert_ids,
|
||||
void* num_tokens_past_padded, void* topk_weights,
|
||||
int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep,
|
||||
int prob_m, int prob_n, int prob_k, void* workspace,
|
||||
vllm::ScalarType const& q_type, bool has_act_order,
|
||||
bool is_k_full, bool has_zp, int num_groups, int group_size,
|
||||
int dev, cudaStream_t stream, int thread_k, int thread_n,
|
||||
int sms, bool use_atomic_add, bool use_fp32_reduce,
|
||||
bool is_zp_float) {
|
||||
int thread_m_blocks = div_ceil(moe_block_size, 16);
|
||||
bool m_block_size_8 = moe_block_size == 8;
|
||||
|
||||
if (has_zp) {
|
||||
TORCH_CHECK(
|
||||
q_type == vllm::kU4 || q_type == vllm::kU8,
|
||||
"q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
q_type == vllm::kU4B8 || q_type == vllm::kU8B128,
|
||||
"q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
|
||||
q_type.str());
|
||||
}
|
||||
|
||||
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
||||
", ", prob_n, ", ", prob_k, "]");
|
||||
|
||||
int group_blocks = 0;
|
||||
if (has_act_order) {
|
||||
if (is_k_full) {
|
||||
TORCH_CHECK(group_size != -1);
|
||||
group_blocks = group_size / 16;
|
||||
TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
|
||||
" is not divisible by group_blocks = ", group_blocks);
|
||||
} else {
|
||||
TORCH_CHECK(group_size == 0);
|
||||
group_blocks = 0;
|
||||
}
|
||||
} else {
|
||||
if (group_size == -1) {
|
||||
group_blocks = -1;
|
||||
} else {
|
||||
group_blocks = group_size / 16;
|
||||
TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
|
||||
" is not divisible by group_blocks = ", group_blocks);
|
||||
}
|
||||
}
|
||||
|
||||
int num_bits = q_type.size_bits();
|
||||
const int4* A_ptr = (const int4*)A;
|
||||
const int4* B_ptr = (const int4*)B;
|
||||
int4* C_ptr = (int4*)C;
|
||||
int4* C_tmp_ptr = (int4*)C_tmp;
|
||||
const int4* s_ptr = (const int4*)s;
|
||||
const int4* zp_ptr = (const int4*)zp;
|
||||
const int* g_idx_ptr = (const int*)g_idx;
|
||||
const int* perm_ptr = (const int*)perm;
|
||||
int4* a_tmp_ptr = (int4*)a_tmp;
|
||||
const int32_t* sorted_token_ids_ptr = (const int32_t*)sorted_token_ids;
|
||||
const int32_t* expert_ids_ptr = (const int32_t*)expert_ids;
|
||||
const int32_t* num_tokens_past_padded_ptr =
|
||||
(const int32_t*)num_tokens_past_padded;
|
||||
const float* topk_weights_ptr = (const float*)topk_weights;
|
||||
int* locks = (int*)workspace;
|
||||
|
||||
if (has_act_order) {
|
||||
// Permute A columns
|
||||
auto kernel = permute_cols_kernel<8>;
|
||||
if (moe_block_size == 8) {
|
||||
} else if (moe_block_size == 16)
|
||||
kernel = permute_cols_kernel<16>;
|
||||
else if (moe_block_size == 32)
|
||||
kernel = permute_cols_kernel<32>;
|
||||
else if (moe_block_size == 48)
|
||||
kernel = permute_cols_kernel<48>;
|
||||
else if (moe_block_size == 64)
|
||||
kernel = permute_cols_kernel<64>;
|
||||
else
|
||||
TORCH_CHECK(false, "unsupported moe_block_size ", moe_block_size);
|
||||
|
||||
// avoid ">>>" being formatted to "> > >"
|
||||
// clang-format off
|
||||
kernel<<<sms, default_threads, 0, stream>>>(
|
||||
A_ptr, perm_ptr, a_tmp_ptr, sorted_token_ids_ptr, expert_ids_ptr,
|
||||
num_tokens_past_padded_ptr, prob_m, prob_k, top_k);
|
||||
// clang-format on
|
||||
A_ptr = a_tmp_ptr;
|
||||
prob_m = prob_m * top_k;
|
||||
top_k = 1;
|
||||
|
||||
// If we have a full K, then we can run the non-act-order version of Marlin
|
||||
// (since the weight rows are reordered by increasing group ids, and by
|
||||
// having a full K, we have full original groups)
|
||||
if (is_k_full) has_act_order = false;
|
||||
}
|
||||
|
||||
int max_shared_mem = 0;
|
||||
cudaDeviceGetAttribute(&max_shared_mem,
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
||||
TORCH_CHECK(max_shared_mem > 0);
|
||||
|
||||
// Set thread config
|
||||
exec_config_t exec_cfg;
|
||||
thread_config_t thread_tfg;
|
||||
if (thread_k != -1 && thread_n != -1) {
|
||||
thread_tfg = thread_config_t{thread_k, thread_n, default_threads};
|
||||
exec_cfg = exec_config_t{1, thread_tfg};
|
||||
TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
|
||||
" is not divisible by thread_n = ", thread_n);
|
||||
TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
|
||||
" is not divisible by thread_k = ", thread_k);
|
||||
} else {
|
||||
// Auto config
|
||||
exec_cfg = determine_exec_config<scalar_t>(
|
||||
q_type, prob_m, prob_n, prob_k, thread_m_blocks, m_block_size_8,
|
||||
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float,
|
||||
max_shared_mem);
|
||||
thread_tfg = exec_cfg.tb_cfg;
|
||||
}
|
||||
|
||||
int num_threads = thread_tfg.num_threads;
|
||||
thread_k = thread_tfg.thread_k;
|
||||
thread_n = thread_tfg.thread_n;
|
||||
int blocks = sms * exec_cfg.blocks_per_sm;
|
||||
if (exec_cfg.blocks_per_sm > 1)
|
||||
max_shared_mem = max_shared_mem / exec_cfg.blocks_per_sm - 1024;
|
||||
|
||||
int thread_k_blocks = thread_k / 16;
|
||||
int thread_n_blocks = thread_n / 16;
|
||||
|
||||
TORCH_CHECK(is_valid_config(thread_tfg, thread_m_blocks, prob_m, prob_n,
|
||||
prob_k, num_bits, group_size, has_act_order,
|
||||
is_k_full, has_zp, is_zp_float, max_shared_mem),
|
||||
"Invalid thread config: thread_m_blocks = ", thread_m_blocks,
|
||||
", thread_k = ", thread_tfg.thread_k,
|
||||
", thread_n = ", thread_tfg.thread_n,
|
||||
", num_threads = ", thread_tfg.num_threads, " for MKN = [",
|
||||
prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
|
||||
", group_size = ", group_size,
|
||||
", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
|
||||
", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float,
|
||||
", max_shared_mem = ", max_shared_mem);
|
||||
|
||||
auto kernel = get_marlin_kernel<scalar_t>(
|
||||
q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, m_block_size_8,
|
||||
has_act_order, has_zp, group_blocks, num_threads, is_zp_float);
|
||||
|
||||
if (kernel == MarlinDefault) {
|
||||
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
|
||||
", ", prob_k, "]", ", has_act_order = ", has_act_order,
|
||||
", num_groups = ", num_groups, ", group_size = ", group_size,
|
||||
", thread_m_blocks = ", thread_m_blocks,
|
||||
", thread_n_blocks = ", thread_n_blocks,
|
||||
", thread_k_blocks = ", thread_k_blocks,
|
||||
", num_bits = ", num_bits);
|
||||
}
|
||||
|
||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
max_shared_mem);
|
||||
// avoid ">>>" being formatted to "> > >"
|
||||
// clang-format off
|
||||
kernel<<<blocks, num_threads, max_shared_mem, stream>>>(
|
||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr,
|
||||
sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,
|
||||
topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m,
|
||||
prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
|
||||
torch::Tensor moe_wna16_marlin_gemm(
|
||||
torch::Tensor& a, std::optional<torch::Tensor> const& c_or_none,
|
||||
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
|
||||
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
|
||||
torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids,
|
||||
torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights,
|
||||
int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep,
|
||||
vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n,
|
||||
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
|
||||
bool is_zp_float) {
|
||||
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
|
||||
int pack_factor = 32 / b_q_type.size_bits();
|
||||
|
||||
if (moe_block_size != 8) {
|
||||
TORCH_CHECK(moe_block_size % 16 == 0,
|
||||
"unsupported moe_block_size=", moe_block_size);
|
||||
TORCH_CHECK(moe_block_size >= 16 && moe_block_size <= 64,
|
||||
"unsupported moe_block_size=", moe_block_size);
|
||||
}
|
||||
|
||||
// Verify A
|
||||
TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0),
|
||||
", size_m = ", size_m);
|
||||
TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1),
|
||||
", size_k = ", size_k);
|
||||
|
||||
// Verify B
|
||||
TORCH_CHECK(
|
||||
size_k % MARLIN_NAMESPACE_NAME::tile_size == 0, "size_k = ", size_k,
|
||||
" is not divisible by tile_size = ", MARLIN_NAMESPACE_NAME::tile_size);
|
||||
TORCH_CHECK((size_k / MARLIN_NAMESPACE_NAME::tile_size) == b_q_weight.size(1),
|
||||
"Shape mismatch: b_q_weight.size(1) = ", b_q_weight.size(1),
|
||||
", size_k = ", size_k,
|
||||
", tile_size = ", MARLIN_NAMESPACE_NAME::tile_size);
|
||||
TORCH_CHECK(
|
||||
b_q_weight.size(2) % MARLIN_NAMESPACE_NAME::tile_size == 0,
|
||||
"b_q_weight.size(2) = ", b_q_weight.size(2),
|
||||
" is not divisible by tile_size = ", MARLIN_NAMESPACE_NAME::tile_size);
|
||||
int actual_size_n =
|
||||
(b_q_weight.size(2) / MARLIN_NAMESPACE_NAME::tile_size) * pack_factor;
|
||||
TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
|
||||
", actual_size_n = ", actual_size_n);
|
||||
|
||||
// Verify device and strides
|
||||
TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
|
||||
TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
|
||||
|
||||
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
|
||||
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
|
||||
|
||||
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
|
||||
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
|
||||
|
||||
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
|
||||
// auto -1)
|
||||
int thread_k = -1;
|
||||
// thread_n: `n` size of a thread_tile in `weights` (can usually be left as
|
||||
// auto -1)
|
||||
int thread_n = -1;
|
||||
// sms: number of SMs to use for the kernel
|
||||
int sms = -1;
|
||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device());
|
||||
|
||||
// Alloc buffers
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
||||
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
|
||||
torch::Tensor c;
|
||||
if (c_or_none.has_value()) {
|
||||
c = c_or_none.value();
|
||||
TORCH_CHECK(c.device().is_cuda(), "c is not on GPU");
|
||||
TORCH_CHECK(c.is_contiguous(), "c is not contiguous");
|
||||
TORCH_CHECK(c.size(0) == size_m * top_k,
|
||||
"Shape mismatch: c.size(0) = ", c.size(0),
|
||||
", size_m * topk = ", size_m * top_k);
|
||||
TORCH_CHECK(c.size(1) == size_n, "Shape mismatch: c.size(1) = ", c.size(1),
|
||||
", size_n = ", size_n);
|
||||
} else {
|
||||
c = torch::empty({size_m * top_k, size_n}, options);
|
||||
}
|
||||
|
||||
// Alloc C tmp buffer that is going to be used for the global reduce
|
||||
torch::Tensor c_tmp;
|
||||
auto options_fp32 =
|
||||
torch::TensorOptions().dtype(at::kFloat).device(a.device());
|
||||
if (use_fp32_reduce && !use_atomic_add) {
|
||||
// max num of threadblocks is sms * 4
|
||||
long max_c_tmp_size = min(
|
||||
(long)size_n * sorted_token_ids.size(0),
|
||||
(long)sms * 4 * moe_block_size * MARLIN_NAMESPACE_NAME::max_thread_n);
|
||||
if (moe_block_size == 8) max_c_tmp_size *= 2;
|
||||
c_tmp = torch::empty({max_c_tmp_size}, options_fp32);
|
||||
} else {
|
||||
c_tmp = torch::empty({0}, options_fp32);
|
||||
}
|
||||
|
||||
// Detect groupsize and act_order
|
||||
int num_groups = -1;
|
||||
int group_size = -1;
|
||||
|
||||
int rank = b_scales.sizes().size();
|
||||
TORCH_CHECK(rank == 3, "b_scales rank = ", rank, " is not 3");
|
||||
TORCH_CHECK(b_scales.size(2) == size_n, "b_scales dim 2 = ", b_scales.size(2),
|
||||
" is not size_n = ", size_n);
|
||||
num_groups = b_scales.size(1);
|
||||
|
||||
torch::Tensor g_idx, perm, a_tmp;
|
||||
;
|
||||
if (g_idx_or_none.has_value() && perm_or_none.has_value()) {
|
||||
g_idx = g_idx_or_none.value();
|
||||
perm = perm_or_none.value();
|
||||
|
||||
TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU");
|
||||
TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous");
|
||||
TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU");
|
||||
TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous");
|
||||
|
||||
// Verify g_idx and perm
|
||||
TORCH_CHECK((g_idx.size(-1) == 0 && perm.size(-1) == 0) ||
|
||||
(g_idx.size(-1) == size_k && perm.size(-1) == size_k),
|
||||
"Unexpected g_idx.size(-1) = ", g_idx.size(-1),
|
||||
" and perm.size(-1) = ", perm.size(-1),
|
||||
", where size_k = ", size_k);
|
||||
} else {
|
||||
g_idx = torch::empty({0}, options);
|
||||
perm = torch::empty({0}, options);
|
||||
a_tmp = torch::empty({0}, options);
|
||||
}
|
||||
bool has_act_order = g_idx.size(-1) > 0 && perm.size(-1) > 0;
|
||||
|
||||
if (has_act_order) {
|
||||
a_tmp = torch::empty({size_m * top_k, size_k}, options);
|
||||
if (is_k_full) {
|
||||
TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1");
|
||||
TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k,
|
||||
", is not divisible by num_groups = ", num_groups);
|
||||
group_size = size_k / num_groups;
|
||||
} else {
|
||||
group_size = 0;
|
||||
}
|
||||
|
||||
} else {
|
||||
a_tmp = torch::empty({0}, options);
|
||||
if (num_groups > 1) {
|
||||
TORCH_CHECK(
|
||||
size_k % num_groups == 0, "size_k = ", size_k,
|
||||
", is not divisible by b_scales.size(1) = ", b_scales.size(1));
|
||||
group_size = size_k / num_groups;
|
||||
} else {
|
||||
group_size = -1;
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor b_zeros;
|
||||
if (b_zeros_or_none.has_value()) {
|
||||
b_zeros = b_zeros_or_none.value();
|
||||
TORCH_CHECK(b_zeros.device().is_cuda(), "b_zeros is not on GPU");
|
||||
TORCH_CHECK(b_zeros.is_contiguous(), "b_zeros is not contiguous");
|
||||
} else {
|
||||
b_zeros = torch::empty({0}, options);
|
||||
}
|
||||
bool has_zp = b_zeros.size(-1) > 0;
|
||||
|
||||
if (has_zp) {
|
||||
TORCH_CHECK(
|
||||
b_q_type == vllm::kU4,
|
||||
"b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str());
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128,
|
||||
"b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
|
||||
b_q_type.str());
|
||||
}
|
||||
|
||||
if (has_zp && is_zp_float) {
|
||||
TORCH_CHECK(a.scalar_type() == at::ScalarType::Half,
|
||||
"Computation type must be float16 (half) when using float zero "
|
||||
"points.");
|
||||
}
|
||||
|
||||
// Verify b_zeros
|
||||
if (has_zp) {
|
||||
int rank = b_zeros.sizes().size();
|
||||
TORCH_CHECK(rank == 3, "b_zeros rank = ", rank, " is not 3");
|
||||
if (is_zp_float) {
|
||||
TORCH_CHECK(b_zeros.size(2) == size_n,
|
||||
"b_zeros dim 2 = ", b_zeros.size(2),
|
||||
" is not size_n = ", size_n);
|
||||
TORCH_CHECK(num_groups == b_zeros.size(1),
|
||||
"b_zeros dim 1 = ", b_zeros.size(1),
|
||||
" is not num_groups = ", num_groups);
|
||||
TORCH_CHECK(num_groups != -1, "num_groups must be != -1");
|
||||
} else {
|
||||
TORCH_CHECK(b_zeros.size(1) == num_groups,
|
||||
"b_zeros dim 1 = ", b_zeros.size(1),
|
||||
" is not num_groups = ", num_groups);
|
||||
TORCH_CHECK(b_zeros.size(2) == size_n / pack_factor,
|
||||
"b_zeros dim 2 = ", b_zeros.size(2),
|
||||
" is not size_n / pack_factor = ", size_n / pack_factor);
|
||||
}
|
||||
}
|
||||
|
||||
// Verify workspace size
|
||||
TORCH_CHECK(size_n % MARLIN_NAMESPACE_NAME::min_thread_n == 0,
|
||||
"size_n = ", size_n, ", is not divisible by min_thread_n = ",
|
||||
MARLIN_NAMESPACE_NAME::min_thread_n);
|
||||
|
||||
int max_n_tiles = size_n / MARLIN_NAMESPACE_NAME::min_thread_n;
|
||||
int min_workspace_size = min(
|
||||
max_n_tiles * (int)(sorted_token_ids.size(0) / moe_block_size), sms * 4);
|
||||
TORCH_CHECK(workspace.numel() >= min_workspace_size,
|
||||
"workspace.numel = ", workspace.numel(),
|
||||
" is below min_workspace_size = ", min_workspace_size);
|
||||
|
||||
int dev = a.get_device();
|
||||
if (a.scalar_type() == at::ScalarType::Half) {
|
||||
MARLIN_NAMESPACE_NAME::marlin_mm<half>(
|
||||
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
|
||||
c_tmp.data_ptr<float>(), b_scales.data_ptr<at::Half>(),
|
||||
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
|
||||
a_tmp.data_ptr<at::Half>(), sorted_token_ids.data_ptr(),
|
||||
expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(),
|
||||
topk_weights.data_ptr(), moe_block_size, top_k, mul_topk_weights, is_ep,
|
||||
size_m, size_n, size_k, workspace.data_ptr(), b_q_type, has_act_order,
|
||||
is_k_full, has_zp, num_groups, group_size, dev,
|
||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||
use_atomic_add, use_fp32_reduce, is_zp_float);
|
||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||
MARLIN_NAMESPACE_NAME::marlin_mm<nv_bfloat16>(
|
||||
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
||||
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
|
||||
b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(),
|
||||
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
|
||||
sorted_token_ids.data_ptr(), expert_ids.data_ptr(),
|
||||
num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(),
|
||||
moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k,
|
||||
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
|
||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float);
|
||||
} else {
|
||||
TORCH_CHECK(false,
|
||||
"moe_wna16_marlin_gemm only supports bfloat16 and float16");
|
||||
}
|
||||
|
||||
return c;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("moe_wna16_marlin_gemm", &moe_wna16_marlin_gemm);
|
||||
}
|
||||
@ -13,7 +13,6 @@
|
||||
template <typename scalar_t, int bit, int GROUPS>
|
||||
__global__ void moe_wna16_gemm_kernel(
|
||||
const scalar_t* __restrict__ input, scalar_t* __restrict__ output,
|
||||
|
||||
const uint32_t* __restrict__ qweight, const scalar_t* __restrict__ scales,
|
||||
const uint32_t* __restrict__ qzeros,
|
||||
|
||||
@ -54,8 +53,6 @@ __global__ void moe_wna16_gemm_kernel(
|
||||
if (token_index / top_k >= size_m) break;
|
||||
|
||||
num_valid_tokens = m + 1;
|
||||
if (blockIdx.z == 0 && offset_n < size_n)
|
||||
output[token_index * size_n + offset_n] = Dtype::int2num(0);
|
||||
|
||||
if (expert_id != -1) {
|
||||
int k_per_thread = DIVIDE(BLOCK_SIZE_K, BLOCK_SIZE_N);
|
||||
@ -284,8 +281,7 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
|
||||
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
|
||||
int64_t BLOCK_SIZE_K, int64_t bit) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
auto options =
|
||||
torch::TensorOptions().dtype(input.dtype()).device(input.device());
|
||||
output.zero_();
|
||||
|
||||
const int num_experts = b_qweight.size(0);
|
||||
const int size_m = input.size(0);
|
||||
@ -302,9 +298,9 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
|
||||
const uint32_t* b_qzeros_ptr;
|
||||
if (b_qzeros.has_value())
|
||||
b_qzeros_ptr = (const uint32_t*)b_qzeros.value().data_ptr<uint8_t>();
|
||||
const float* topk_weights_ptr;
|
||||
const float* topk_weights_ptr = nullptr;
|
||||
if (topk_weights.has_value())
|
||||
topk_weights_ptr = (const float*)topk_weights.value().data_ptr();
|
||||
topk_weights_ptr = (const float*)topk_weights.value().data_ptr<float>();
|
||||
|
||||
int groups_per_block_row = BLOCK_SIZE_K / group_size;
|
||||
TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8");
|
||||
|
||||
@ -43,14 +43,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
m.impl("moe_wna16_gemm", torch::kCUDA, &moe_wna16_gemm);
|
||||
|
||||
m.def(
|
||||
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
|
||||
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
|
||||
"b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, "
|
||||
"int b_q_type, SymInt size_m, "
|
||||
"SymInt size_n, SymInt size_k, bool is_k_full, int num_experts, int "
|
||||
"topk, "
|
||||
"int moe_block_size, bool replicate_input, bool apply_weights)"
|
||||
" -> Tensor");
|
||||
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
|
||||
"Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none,"
|
||||
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
|
||||
"Tensor sorted_token_ids,"
|
||||
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
|
||||
"Tensor! topk_weights, int moe_block_size, int top_k, "
|
||||
"bool mul_topk_weights, bool is_ep, int b_q_type_id,"
|
||||
"int size_m, int size_n, int size_k,"
|
||||
"bool is_full_k, bool use_atomic_add,"
|
||||
"bool use_fp32_reduce, bool is_zp_float) -> Tensor");
|
||||
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
#endif
|
||||
|
||||
27
csrc/ops.h
27
csrc/ops.h
@ -52,6 +52,15 @@ void paged_attention_v2(
|
||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||
const int64_t blocksparse_head_sliding_step);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
void merge_attn_states(torch::Tensor& output,
|
||||
std::optional<torch::Tensor> output_lse,
|
||||
const torch::Tensor& prefix_output,
|
||||
const torch::Tensor& prefix_lse,
|
||||
const torch::Tensor& suffix_output,
|
||||
const torch::Tensor& suffix_lse);
|
||||
#endif
|
||||
|
||||
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
||||
double epsilon);
|
||||
|
||||
@ -119,6 +128,12 @@ void advance_step_flashinfer(
|
||||
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
|
||||
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds);
|
||||
|
||||
void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
|
||||
torch::Tensor const& q_pe,
|
||||
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||
torch::Tensor const& seq_lens,
|
||||
torch::Tensor const& page_table, double scale);
|
||||
|
||||
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
@ -145,7 +160,8 @@ torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm);
|
||||
#endif
|
||||
|
||||
torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m,
|
||||
int64_t n);
|
||||
int64_t n,
|
||||
std::optional<at::ScalarType> const& dtype);
|
||||
|
||||
torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
|
||||
int64_t type, int64_t row);
|
||||
@ -267,10 +283,10 @@ void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
|
||||
const std::optional<at::Tensor>& has_initial_state,
|
||||
bool silu_activation, int64_t pad_slot_id);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
using fptr_t = int64_t;
|
||||
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
|
||||
torch::Tensor& rank_data, int64_t rank, bool full_nvlink);
|
||||
torch::Tensor& rank_data, int64_t rank,
|
||||
bool fully_connected);
|
||||
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
|
||||
fptr_t reg_buffer, int64_t reg_buffer_sz_bytes);
|
||||
void dispose(fptr_t _fa);
|
||||
@ -281,4 +297,7 @@ get_graph_buffer_ipc_meta(fptr_t _fa);
|
||||
void register_graph_buffers(fptr_t _fa,
|
||||
const std::vector<std::vector<int64_t>>& handles,
|
||||
const std::vector<std::vector<int64_t>>& offsets);
|
||||
#endif
|
||||
std::tuple<int64_t, torch::Tensor> allocate_shared_buffer_and_handle(
|
||||
int64_t size);
|
||||
int64_t open_mem_handle(torch::Tensor& mem_handle);
|
||||
void free_shared_buffer(int64_t buffer);
|
||||
|
||||
@ -46,14 +46,26 @@ __global__ void compute_expert_offsets(
|
||||
}
|
||||
|
||||
__global__ void compute_arg_sorts(const int* __restrict__ topk_ids,
|
||||
const int32_t* __restrict__ expert_offsets,
|
||||
int32_t* input_permutation,
|
||||
int32_t* output_permutation,
|
||||
int32_t* atomic_buffer, const int topk_length,
|
||||
const int topk) {
|
||||
int expert_id = blockIdx.x;
|
||||
int const blk_expert_id = blockIdx.x;
|
||||
int const num_experts = gridDim.x;
|
||||
int32_t const num_tokens = expert_offsets[num_experts];
|
||||
|
||||
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
|
||||
if (topk_ids[i] == expert_id) {
|
||||
int const expert_id = topk_ids[i];
|
||||
if (expert_id == -1 && blockIdx.x == 0) {
|
||||
// output_permutation is used to re-order the moe outputs. It is
|
||||
// used as c2 = c2[c_map], where c2 is a torch.tensor that is the
|
||||
// output of the cutlass kernels and c_map is the output_permutation.
|
||||
// c2 is initialized to zeros, therefore by setting the output_permutation
|
||||
// to num_tokens, we are guaranteed to fill the moe outputs to zero
|
||||
// for "invalid" topk_ids.
|
||||
output_permutation[i] = num_tokens;
|
||||
} else if (expert_id == blk_expert_id) {
|
||||
int start = atomicAdd(&atomic_buffer[expert_id], 1);
|
||||
input_permutation[start] = i / topk;
|
||||
output_permutation[i] = start;
|
||||
@ -83,6 +95,7 @@ void get_cutlass_moe_mm_data_caller(
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
|
||||
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
|
||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<const int32_t*>(expert_offsets.data_ptr()),
|
||||
static_cast<int32_t*>(input_permutation.data_ptr()),
|
||||
static_cast<int32_t*>(output_permutation.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(),
|
||||
|
||||
@ -336,7 +336,7 @@ inline void cutlass_gemm_sm89_fp8_dispatch(torch::Tensor& out,
|
||||
|
||||
uint32_t const m = a.size(0);
|
||||
uint32_t const mp2 =
|
||||
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
|
||||
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
|
||||
|
||||
if (mp2 <= 16) {
|
||||
// M in [1, 16]
|
||||
|
||||
@ -321,7 +321,7 @@ inline void cutlass_gemm_sm89_int8_dispatch(torch::Tensor& out,
|
||||
|
||||
uint32_t const m = a.size(0);
|
||||
uint32_t const mp2 =
|
||||
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
|
||||
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
|
||||
|
||||
if (mp2 <= 16) {
|
||||
// M in [1, 16]
|
||||
|
||||
@ -134,7 +134,7 @@ typename T::Gemm::Arguments args_from_options(
|
||||
using StrideB = typename T::StrideB;
|
||||
using StrideD = typename T::StrideD;
|
||||
using Sm100BlkScaledConfig =
|
||||
typename T::Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig;
|
||||
typename T::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
|
||||
int m = static_cast<int>(M);
|
||||
int n = static_cast<int>(N);
|
||||
|
||||
@ -94,8 +94,8 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
|
||||
dfloat2 v;
|
||||
dequantize_kernel(vx, ib, iqs, v);
|
||||
|
||||
y[iybs + iqs + 0] = v.x;
|
||||
y[iybs + iqs + y_offset] = v.y;
|
||||
y[iybs + iqs + 0] = convert_from_half<dst_t>(v.x);
|
||||
y[iybs + iqs + y_offset] = convert_from_half<dst_t>(v.y);
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
@ -114,10 +114,10 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t
|
||||
|
||||
half dall = __low2half(x[i].dm);
|
||||
half dmin = __high2half(x[i].dm);
|
||||
y[l+ 0] = __hsub(__hmul(dall, __int2half_rn((x[i].scales[is+0] & 0xF) * ((q >> 0) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+0] >> 4)));
|
||||
y[l+32] = __hsub(__hmul(dall, __int2half_rn((x[i].scales[is+2] & 0xF) * ((q >> 2) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+2] >> 4)));
|
||||
y[l+64] = __hsub(__hmul(dall, __int2half_rn((x[i].scales[is+4] & 0xF) * ((q >> 4) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+4] >> 4)));
|
||||
y[l+96] = __hsub(__hmul(dall, __int2half_rn((x[i].scales[is+6] & 0xF) * ((q >> 6) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+6] >> 4)));
|
||||
y[l+ 0] = convert_from_half<dst_t>(__hsub(__hmul(dall, __int2half_rn((x[i].scales[is+0] & 0xF) * ((q >> 0) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+0] >> 4))));
|
||||
y[l+32] = convert_from_half<dst_t>(__hsub(__hmul(dall, __int2half_rn((x[i].scales[is+2] & 0xF) * ((q >> 2) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+2] >> 4))));
|
||||
y[l+64] = convert_from_half<dst_t>(__hsub(__hmul(dall, __int2half_rn((x[i].scales[is+4] & 0xF) * ((q >> 4) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+4] >> 4))));
|
||||
y[l+96] = convert_from_half<dst_t>(__hsub(__hmul(dall, __int2half_rn((x[i].scales[is+6] & 0xF) * ((q >> 6) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+6] >> 4))));
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
@ -148,7 +148,9 @@ static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t
|
||||
const uint8_t * q = x[i].qs + 32*n;
|
||||
const uint8_t * hm = x[i].hmask;
|
||||
|
||||
for (int l = l0; l < l0+4; ++l) y[l] = __hmul(dl, __int2half_rn((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)));
|
||||
for (int l = l0; l < l0+4; ++l) {
|
||||
y[l] = convert_from_half<dst_t>(__hmul(dl, __int2half_rn((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4))));
|
||||
}
|
||||
}
|
||||
|
||||
static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
|
||||
@ -188,8 +190,8 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t
|
||||
const half d2 = __hmul(dall, __int2half_rn(sc));
|
||||
const half m2 = __hmul(dmin, __int2half_rn(m));
|
||||
for (int l = 0; l < n; ++l) {
|
||||
y[l + 0] = __hsub(__hmul(d1, __int2half_rn(q[l] & 0xF)), m1);
|
||||
y[l +32] = __hsub(__hmul(d2, __int2half_rn(q[l] >> 4)), m2);
|
||||
y[l + 0] = convert_from_half<dst_t>(__hsub(__hmul(d1, __int2half_rn(q[l] & 0xF)), m1));
|
||||
y[l +32] = convert_from_half<dst_t>(__hsub(__hmul(d2, __int2half_rn(q[l] >> 4)), m2));
|
||||
}
|
||||
}
|
||||
|
||||
@ -220,11 +222,11 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t
|
||||
const half d2 = __hmul(dall, __int2half_rn(sc)); const half m2 = __hmul(dmin, __int2half_rn(m));
|
||||
|
||||
uint8_t hm = 1 << (2*il);
|
||||
y[ 0] = __hsub(__hmul(d1, __int2half_rn((ql[0] & 0xF) + (qh[0] & hm ? 16 : 0))), m1);
|
||||
y[ 1] = __hsub(__hmul(d1, __int2half_rn((ql[1] & 0xF) + (qh[1] & hm ? 16 : 0))), m1);
|
||||
y[ 0] = convert_from_half<dst_t>(__hsub(__hmul(d1, __int2half_rn((ql[0] & 0xF) + (qh[0] & hm ? 16 : 0))), m1));
|
||||
y[ 1] = convert_from_half<dst_t>(__hsub(__hmul(d1, __int2half_rn((ql[1] & 0xF) + (qh[1] & hm ? 16 : 0))), m1));
|
||||
hm <<= 1;
|
||||
y[32] = __hsub(__hmul(d2, __int2half_rn((ql[0] >> 4) + (qh[0] & hm ? 16 : 0))), m2);
|
||||
y[33] = __hsub(__hmul(d2, __int2half_rn((ql[1] >> 4) + (qh[1] & hm ? 16 : 0))), m2);
|
||||
y[32] = convert_from_half<dst_t>(__hsub(__hmul(d2, __int2half_rn((ql[0] >> 4) + (qh[0] & hm ? 16 : 0))), m2));
|
||||
y[33] = convert_from_half<dst_t>(__hsub(__hmul(d2, __int2half_rn((ql[1] >> 4) + (qh[1] & hm ? 16 : 0))), m2));
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
@ -247,10 +249,10 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
|
||||
const uint8_t qh = x[i].qh[32*ip + il];
|
||||
const int8_t * sc = x[i].scales + is;
|
||||
|
||||
y[ 0] = __hmul(d, __int2half_rn(sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32)));
|
||||
y[32] = __hmul(d, __int2half_rn(sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32)));
|
||||
y[64] = __hmul(d, __int2half_rn(sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32)));
|
||||
y[96] = __hmul(d, __int2half_rn(sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32)));
|
||||
y[ 0] = convert_from_half<dst_t>(__hmul(d, __int2half_rn(sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32))));
|
||||
y[32] = convert_from_half<dst_t>(__hmul(d, __int2half_rn(sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32))));
|
||||
y[64] = convert_from_half<dst_t>(__hmul(d, __int2half_rn(sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32))));
|
||||
y[96] = convert_from_half<dst_t>(__hmul(d, __int2half_rn(sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32))));
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
@ -269,7 +271,7 @@ static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, ds
|
||||
const uint32_t aux32 = q2[2] | (q2[3] << 16);
|
||||
const float d = __half2float(x[i].d) * (0.5f + (aux32 >> 28)) * 0.25f;
|
||||
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
|
||||
for (int j = 0; j < 8; ++j) y[j] = __float2half(d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f));
|
||||
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
@ -286,7 +288,7 @@ static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst
|
||||
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
|
||||
const float d = __half2float(x[i].d) * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
|
||||
const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];
|
||||
for (int j = 0; j < 8; ++j) y[j] = __float2half(d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f));
|
||||
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
||||
|
||||
}
|
||||
|
||||
@ -303,7 +305,7 @@ static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_
|
||||
const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
|
||||
const float d = __half2float(x[i].d) * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
|
||||
const uint8_t signs = x[i].qs[QK_K/8+4*ib+il];
|
||||
for (int j = 0; j < 8; ++j) y[j] = __float2half(d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f));
|
||||
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
@ -324,8 +326,8 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds
|
||||
const float d = __half2float(x[i].d) * (0.5f + (aux32 >> 28)) * 0.5f;
|
||||
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
y[j+0] = __float2half(d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f));
|
||||
y[j+4] = __float2half(d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f));
|
||||
y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
|
||||
y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
|
||||
}
|
||||
}
|
||||
|
||||
@ -345,8 +347,8 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_
|
||||
const float d = __half2float(x[i].d) * (0.5f + ((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf)) * 0.5f;
|
||||
const uint8_t signs = x[i].signs[4*ib + il];
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
y[j+0] = __float2half(d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f));
|
||||
y[j+4] = __float2half(d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f));
|
||||
y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
|
||||
y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
|
||||
}
|
||||
}
|
||||
|
||||
@ -367,7 +369,7 @@ static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_
|
||||
grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
|
||||
grid32[0] &= 0x0f0f0f0f;
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
y[j] = __float2half(d * (q[j] + delta));
|
||||
y[j] = d * (q[j] + delta);
|
||||
}
|
||||
}
|
||||
|
||||
@ -392,7 +394,7 @@ static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_
|
||||
grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
|
||||
grid32[0] &= 0x0f0f0f0f;
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
y[j] = __float2half(d * (q[j] + delta));
|
||||
y[j] = d * (q[j] + delta);
|
||||
}
|
||||
}
|
||||
|
||||
@ -409,8 +411,8 @@ static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst
|
||||
const uint8_t * q4 = x[ib].qs + 4*il;
|
||||
const float d = __half2float(x[ib].d);
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
y[j+ 0] = __float2half(d * kvalues_iq4nl[q4[j] & 0xf]);
|
||||
y[j+16] = __float2half(d * kvalues_iq4nl[q4[j] >> 4]);
|
||||
y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
|
||||
y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
|
||||
}
|
||||
|
||||
}
|
||||
@ -427,8 +429,8 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst
|
||||
const uint8_t * q4 = x[i].qs + 16*ib + 4*il;
|
||||
const float d = __half2float(x[i].d) * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
y[j+ 0] = __float2half(d * kvalues_iq4nl[q4[j] & 0xf]);
|
||||
y[j+16] = __float2half(d * kvalues_iq4nl[q4[j] >> 4]);
|
||||
y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
|
||||
y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
|
||||
}
|
||||
}
|
||||
|
||||
@ -522,7 +524,8 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k,
|
||||
dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
static to_fp16_cuda_t ggml_get_to_fp16_cuda(int64_t type) {
|
||||
template<typename dst_t>
|
||||
static to_cuda_ggml_t<dst_t> ggml_get_to_cuda(int64_t type) {
|
||||
switch (type) {
|
||||
case 2:
|
||||
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
|
||||
|
||||
@ -1063,7 +1063,8 @@ static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -
|
||||
typedef half dfloat; // dequantize float
|
||||
typedef half2 dfloat2;
|
||||
typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
|
||||
typedef void (*to_fp16_cuda_t)(const void * __restrict__ x, dfloat * __restrict__ y, int k, cudaStream_t stream);
|
||||
template<typename dst_t>
|
||||
using to_cuda_ggml_t = void (*)(const void * __restrict__ x, dst_t * __restrict__ y, int k, cudaStream_t stream);
|
||||
typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs);
|
||||
typedef void (*allocate_tiles_cuda_t)(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc);
|
||||
typedef void (*load_tiles_cuda_t)(
|
||||
@ -1075,6 +1076,25 @@ typedef float (*vec_dot_q_mul_mat_cuda_t)(
|
||||
|
||||
// Utility function
|
||||
|
||||
template<typename dst_t>
|
||||
static __device__ __forceinline__ dst_t convert_from_half(half val) {
|
||||
return val;
|
||||
}
|
||||
|
||||
template<>
|
||||
__device__ __forceinline__ c10::BFloat16 convert_from_half<c10::BFloat16>(half val) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
return __float2bfloat16(__half2float(val));
|
||||
#else
|
||||
return __half2float(val);
|
||||
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
}
|
||||
|
||||
template<>
|
||||
__device__ __forceinline__ float convert_from_half<float>(half val) {
|
||||
return __half2float(val);
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
|
||||
#ifndef __has_builtin
|
||||
|
||||
@ -71,14 +71,19 @@ static void quantize_row_q8_1_cuda(const scalar_t* x, void* vy, const int kx,
|
||||
}
|
||||
|
||||
torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight
|
||||
int64_t type, int64_t m, int64_t n) {
|
||||
int64_t type, int64_t m, int64_t n,
|
||||
std::optional<at::ScalarType> const& dtype) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(W));
|
||||
auto options =
|
||||
torch::TensorOptions().dtype(torch::kFloat16).device(W.device());
|
||||
auto dtype_ = dtype.value_or(torch::kFloat16);
|
||||
auto options = torch::TensorOptions().dtype(dtype_).device(W.device());
|
||||
at::Tensor DW = torch::empty({m, n}, options);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(type);
|
||||
to_fp16_cuda((void*)W.data_ptr(), (half*)DW.data_ptr(), m * n, stream);
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(DW.scalar_type(), "ggml_dequantize", [&] {
|
||||
auto to_cuda = ggml_get_to_cuda<scalar_t>(type);
|
||||
to_cuda((void*)W.data_ptr(), (scalar_t*)DW.data_ptr(), m * n, stream);
|
||||
});
|
||||
|
||||
return DW;
|
||||
}
|
||||
|
||||
|
||||
@ -129,7 +129,7 @@ static __device__ __forceinline__ void moe_q(
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MOE_X_Q4_0 64
|
||||
#define MOE_X_Q4_0 8
|
||||
#define MOE_Y_Q4_0 128
|
||||
#define NWARPS_Q4_0 8
|
||||
#else
|
||||
@ -190,7 +190,7 @@ static void ggml_moe_q4_0_q8_1_cuda(
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MOE_X_Q4_1 64
|
||||
#define MOE_X_Q4_1 8
|
||||
#define MOE_Y_Q4_1 128
|
||||
#define NWARPS_Q4_1 8
|
||||
#else
|
||||
@ -251,7 +251,7 @@ static void ggml_moe_q4_1_q8_1_cuda(
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MOE_X_Q5_0 64
|
||||
#define MOE_X_Q5_0 8
|
||||
#define MOE_Y_Q5_0 128
|
||||
#define NWARPS_Q5_0 8
|
||||
#else
|
||||
@ -312,7 +312,7 @@ static void ggml_moe_q5_0_q8_1_cuda(
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MOE_X_Q5_1 64
|
||||
#define MOE_X_Q5_1 8
|
||||
#define MOE_Y_Q5_1 128
|
||||
#define NWARPS_Q5_1 8
|
||||
#else
|
||||
@ -373,7 +373,7 @@ static void ggml_moe_q5_1_q8_1_cuda(
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MOE_X_Q8_0 64
|
||||
#define MOE_X_Q8_0 8
|
||||
#define MOE_Y_Q8_0 128
|
||||
#define NWARPS_Q8_0 8
|
||||
#else
|
||||
@ -434,7 +434,7 @@ static void ggml_moe_q8_0_q8_1_cuda(
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MOE_X_Q2_K 64
|
||||
#define MOE_X_Q2_K 8
|
||||
#define MOE_Y_Q2_K 128
|
||||
#define NWARPS_Q2_K 8
|
||||
#else
|
||||
@ -495,7 +495,7 @@ static void ggml_moe_q2_K_q8_1_cuda(
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MOE_X_Q3_K 64
|
||||
#define MOE_X_Q3_K 8
|
||||
#define MOE_Y_Q3_K 128
|
||||
#define NWARPS_Q3_K 8
|
||||
#else
|
||||
@ -556,7 +556,7 @@ static void ggml_moe_q3_K_q8_1_cuda(
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MOE_X_Q4_K 64
|
||||
#define MOE_X_Q4_K 8
|
||||
#define MOE_Y_Q4_K 128
|
||||
#define NWARPS_Q4_K 8
|
||||
#else
|
||||
@ -617,7 +617,7 @@ static void ggml_moe_q4_K_q8_1_cuda(
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MOE_X_Q5_K 64
|
||||
#define MOE_X_Q5_K 8
|
||||
#define MOE_Y_Q5_K 128
|
||||
#define NWARPS_Q5_K 8
|
||||
#else
|
||||
@ -678,7 +678,7 @@ static void ggml_moe_q5_K_q8_1_cuda(
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MOE_X_Q6_K 64
|
||||
#define MOE_X_Q6_K 8
|
||||
#define MOE_Y_Q6_K 128
|
||||
#define NWARPS_Q6_K 8
|
||||
#else
|
||||
|
||||
@ -1785,7 +1785,7 @@ __global__ void Marlin(
|
||||
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
|
||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \
|
||||
num_groups, prob_m, prob_n, prob_k, lda, locks, \
|
||||
use_atomic_add, use_fp32_reduce); \
|
||||
part_use_atomic_add, use_fp32_reduce); \
|
||||
} \
|
||||
}
|
||||
|
||||
@ -2215,6 +2215,10 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
thread_m_blocks = exec_cfg.max_m_blocks;
|
||||
}
|
||||
|
||||
// atomic add reduce have better performance only when m * n is small
|
||||
bool part_use_atomic_add =
|
||||
use_atomic_add && div_ceil(prob_m, 64) * prob_n <= 2048;
|
||||
|
||||
if (false) {
|
||||
}
|
||||
GPTQ_CALL_IF(vllm::kU4B8, 16, 4, 256)
|
||||
|
||||
@ -9,7 +9,11 @@
|
||||
#include <cuda_runtime.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace marlin {
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin
|
||||
#endif
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
// Marlin params
|
||||
|
||||
@ -23,6 +27,7 @@ static constexpr int pipe_stages =
|
||||
|
||||
static constexpr int min_thread_n = 64;
|
||||
static constexpr int min_thread_k = 64;
|
||||
static constexpr int max_thread_n = 256;
|
||||
|
||||
static constexpr int tile_size = 16;
|
||||
static constexpr int max_par = 16;
|
||||
@ -84,4 +89,4 @@ __device__ inline void cp_async_wait() {
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace marlin
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
|
||||
@ -5,7 +5,11 @@
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
namespace marlin {
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin
|
||||
#endif
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
template <typename scalar_t>
|
||||
class ScalarType {};
|
||||
@ -54,7 +58,7 @@ class ScalarType<nv_bfloat16> {
|
||||
using FragS = Vec<nv_bfloat162, 1>;
|
||||
using FragZP = Vec<nv_bfloat162, 4>;
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
|
||||
static __device__ float inline num2float(const nv_bfloat16 x) {
|
||||
return __bfloat162float(x);
|
||||
}
|
||||
@ -74,6 +78,6 @@ class ScalarType<nv_bfloat16> {
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace marlin
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
|
||||
#endif
|
||||
|
||||
@ -272,6 +272,7 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride,
|
||||
@ -291,6 +292,13 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
const int rowid = laneid / 16;
|
||||
|
||||
const auto seq_idx = blockIdx.x;
|
||||
// NOTE queries with sequence len > 1 are prefills and taken care by another
|
||||
// kernel.
|
||||
if (query_start_loc_ptr != nullptr &&
|
||||
(query_start_loc_ptr[seq_idx + 1] - query_start_loc_ptr[seq_idx]) != 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
const auto partition_idx = blockIdx.y;
|
||||
|
||||
constexpr int T_PAR_SIZE = 256; // token partition size set to 256
|
||||
@ -377,9 +385,10 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
// fetch Q in shared across warps and then write to registers
|
||||
const int local_qhead_idx = 4 * warpid + rowid;
|
||||
const int global_qhead_idx = wg_start_head_idx + local_qhead_idx;
|
||||
const int64_t seq_idx64 = static_cast<int64_t>(seq_idx);
|
||||
const int64_t query_start_off = static_cast<int64_t>(
|
||||
query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx);
|
||||
const scalar_t* q_ptr =
|
||||
q + seq_idx64 * q_stride + global_qhead_idx * HEAD_SIZE;
|
||||
q + query_start_off * q_stride + global_qhead_idx * HEAD_SIZE;
|
||||
|
||||
const int qhead_element = lane16id * CONTIGUOUS_SCALAR_ELEMS_16B;
|
||||
if ((local_qhead_idx < GQA_RATIO) && (qhead_element < HEAD_SIZE)) {
|
||||
@ -777,6 +786,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride,
|
||||
@ -794,6 +804,12 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
|
||||
const int lane4id = laneid % 4;
|
||||
|
||||
const auto seq_idx = blockIdx.x;
|
||||
// NOTE queries with sequence len > 1 are prefills and taken care by another
|
||||
// kernel.
|
||||
if (query_start_loc_ptr != nullptr &&
|
||||
(query_start_loc_ptr[seq_idx + 1] - query_start_loc_ptr[seq_idx] != 1)) {
|
||||
return;
|
||||
}
|
||||
const auto partition_idx = blockIdx.y;
|
||||
const auto partition_size = blockDim.x;
|
||||
const auto max_num_partitions = gridDim.y;
|
||||
@ -882,9 +898,11 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
|
||||
}
|
||||
|
||||
// fetch q elements
|
||||
// every 4 lanes fetch 8 elems, so warp fetches 8*16 = 128 elems
|
||||
// every 4 lanes fetch 8 elems, so warp fetches 8*16 = 128 elemsc
|
||||
const int64_t query_start_off = static_cast<int64_t>(
|
||||
query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx);
|
||||
const scalar_t* q_ptr =
|
||||
q + seq_idx * q_stride + wg_start_head_idx * HEAD_SIZE;
|
||||
q + query_start_off * q_stride + wg_start_head_idx * HEAD_SIZE;
|
||||
const _B16x8* q_ptrh8 = reinterpret_cast<const _B16x8*>(q_ptr);
|
||||
const int qhead_elemh8 = laneid / 4;
|
||||
|
||||
@ -1267,10 +1285,19 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
|
||||
// max_num_partitions, head_size]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
|
||||
const int max_num_partitions) {
|
||||
const auto num_heads = gridDim.x;
|
||||
const auto head_idx = blockIdx.x;
|
||||
const auto seq_idx = blockIdx.y;
|
||||
|
||||
// NOTE queries with sequence len > 1 are prefills and taken care by another
|
||||
// kernel.
|
||||
if (query_start_loc_ptr != nullptr &&
|
||||
(query_start_loc_ptr[seq_idx + 1] - query_start_loc_ptr[seq_idx] != 1)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int context_len = context_lens[seq_idx];
|
||||
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
|
||||
[[maybe_unused]] constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
@ -1439,7 +1466,9 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
__fdividef(1.0f, shared_global_exp_sum + 1e-6f);
|
||||
acc *= inv_global_exp_sum;
|
||||
|
||||
OUTT* out_ptr = out + static_cast<int64_t>(seq_idx) * num_heads * HEAD_SIZE +
|
||||
const int64_t query_start_off = static_cast<int64_t>(
|
||||
query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx);
|
||||
OUTT* out_ptr = out + query_start_off * num_heads * HEAD_SIZE +
|
||||
static_cast<int64_t>(head_idx) * HEAD_SIZE;
|
||||
if constexpr (std::is_same<OUTT, bit8_t>::value) {
|
||||
out_ptr[threadIdx.x] =
|
||||
@ -1466,6 +1495,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride,
|
||||
@ -1492,6 +1522,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride,
|
||||
@ -1515,6 +1546,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
||||
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
|
||||
const int max_num_partitions) {
|
||||
UNREACHABLE_CODE
|
||||
}
|
||||
@ -1522,34 +1554,34 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
|
||||
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
|
||||
|
||||
#define LAUNCH_CUSTOM_ATTENTION_MFMA16(GQA_RATIO) \
|
||||
paged_attention_ll4mi_QKV_mfma16_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \
|
||||
HEAD_SIZE, NTHR, ALIBI_ENABLED, \
|
||||
GQA_RATIO> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
|
||||
block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \
|
||||
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
|
||||
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \
|
||||
k_scale_ptr, v_scale_ptr);
|
||||
#define LAUNCH_CUSTOM_ATTENTION_MFMA16(GQA_RATIO) \
|
||||
paged_attention_ll4mi_QKV_mfma16_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \
|
||||
HEAD_SIZE, NTHR, ALIBI_ENABLED, \
|
||||
GQA_RATIO> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
|
||||
block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \
|
||||
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
|
||||
kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \
|
||||
max_ctx_blocks, k_scale_ptr, v_scale_ptr);
|
||||
|
||||
#define LAUNCH_CUSTOM_ATTENTION_MFMA4(GQA_RATIO) \
|
||||
paged_attention_ll4mi_QKV_mfma4_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \
|
||||
HEAD_SIZE, NTHR, ALIBI_ENABLED, \
|
||||
GQA_RATIO> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
|
||||
block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \
|
||||
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
|
||||
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \
|
||||
k_scale_ptr, v_scale_ptr);
|
||||
#define LAUNCH_CUSTOM_ATTENTION_MFMA4(GQA_RATIO) \
|
||||
paged_attention_ll4mi_QKV_mfma4_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \
|
||||
HEAD_SIZE, NTHR, ALIBI_ENABLED, \
|
||||
GQA_RATIO> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
|
||||
block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \
|
||||
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
|
||||
kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \
|
||||
max_ctx_blocks, k_scale_ptr, v_scale_ptr);
|
||||
|
||||
#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \
|
||||
paged_attention_ll4mi_reduce_kernel<T, OUTT, HEAD_SIZE, HEAD_SIZE, \
|
||||
PARTITION_SIZE, NPAR_LOOPS> \
|
||||
<<<reduce_grid, reduce_block, 0, stream>>>( \
|
||||
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \
|
||||
context_lens_ptr, max_num_partitions);
|
||||
context_lens_ptr, query_start_loc_ptr, max_num_partitions);
|
||||
|
||||
template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE,
|
||||
int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE_OLD,
|
||||
@ -1559,9 +1591,10 @@ void paged_attention_custom_launcher(
|
||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, const int num_kv_heads, float scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& context_lens,
|
||||
int max_context_len, const std::optional<torch::Tensor>& alibi_slopes,
|
||||
torch::Tensor& k_scale, torch::Tensor& v_scale) {
|
||||
int num_seqs = query.size(0);
|
||||
const std::optional<torch::Tensor>& query_start_loc, int max_context_len,
|
||||
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale) {
|
||||
int num_seqs = block_tables.size(0);
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
int max_num_blocks_per_seq = block_tables.size(1);
|
||||
@ -1569,6 +1602,13 @@ void paged_attention_custom_launcher(
|
||||
int kv_block_stride = key_cache.stride(0);
|
||||
int kv_head_stride = key_cache.stride(1);
|
||||
|
||||
// NOTE: query start location is optional for V0 decode should not be used.
|
||||
// If batch contains mix of prefills and decode, prefills should be skipped.
|
||||
const int* query_start_loc_ptr =
|
||||
query_start_loc
|
||||
? reinterpret_cast<const int*>(query_start_loc.value().data_ptr())
|
||||
: nullptr;
|
||||
|
||||
// NOTE: alibi_slopes is optional.
|
||||
const float* alibi_slopes_ptr =
|
||||
alibi_slopes
|
||||
@ -1700,8 +1740,8 @@ void paged_attention_custom_launcher(
|
||||
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
|
||||
PSIZE, ALIBI_ENABLED>( \
|
||||
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
||||
num_kv_heads, scale, block_tables, context_lens, max_context_len, \
|
||||
alibi_slopes, k_scale, v_scale);
|
||||
num_kv_heads, scale, block_tables, context_lens, query_start_loc, \
|
||||
max_context_len, alibi_slopes, k_scale, v_scale);
|
||||
|
||||
#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
|
||||
PSIZE) \
|
||||
@ -1750,6 +1790,7 @@ void paged_attention(
|
||||
double scale,
|
||||
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
torch::Tensor& context_lens, // [num_seqs]
|
||||
const std::optional<torch::Tensor>& query_start_loc, // [num_seqs]
|
||||
int64_t block_size, int64_t max_context_len,
|
||||
const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
|
||||
@ -2,13 +2,23 @@
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
|
||||
const int64_t rows_per_block);
|
||||
|
||||
torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
|
||||
const int64_t CuCount);
|
||||
|
||||
void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
|
||||
at::Tensor& scale_a, at::Tensor& scale_b, const int64_t CuCount);
|
||||
|
||||
void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums,
|
||||
torch::Tensor& max_logits, torch::Tensor& tmp_out,
|
||||
torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int64_t num_kv_heads,
|
||||
double scale, torch::Tensor& block_tables,
|
||||
torch::Tensor& context_lens, int64_t block_size,
|
||||
int64_t max_context_len,
|
||||
torch::Tensor& context_lens,
|
||||
const std::optional<torch::Tensor>& query_start_loc,
|
||||
int64_t block_size, int64_t max_context_len,
|
||||
const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale);
|
||||
|
||||
1600
csrc/rocm/skinny_gemms.cu
Normal file
1600
csrc/rocm/skinny_gemms.cu
Normal file
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user