Compare commits
489 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| bc3b20f81f | |||
| 54be44ee74 | |||
| 2815bd6143 | |||
| 17bccecb1c | |||
| c335930d75 | |||
| a0304dc504 | |||
| c7941cca18 | |||
| b6dd32aa07 | |||
| f94886946e | |||
| 72dfe4c74f | |||
| 8b464d9660 | |||
| 889ebb2638 | |||
| 3ad986c28b | |||
| 344e193b7d | |||
| fb1c933ade | |||
| 72c5b97231 | |||
| fa93cd9f60 | |||
| aec9674dbe | |||
| 7fcc4223dc | |||
| 8262a3e23b | |||
| f211331c48 | |||
| 9053d0b134 | |||
| cb3f2d8d10 | |||
| c12df53b60 | |||
| d1aeea7553 | |||
| d8bccde686 | |||
| 20e489eaa1 | |||
| 4213475ec7 | |||
| d92879baf6 | |||
| 690fe019f0 | |||
| ed7a29d9f8 | |||
| 756848e79e | |||
| 18445edd0f | |||
| 30215ca61f | |||
| 838cedade7 | |||
| 4283a28c2f | |||
| 93a126fbc7 | |||
| 8e4b351a0c | |||
| 9869453c42 | |||
| 3642c59aa8 | |||
| 43eea2953b | |||
| de7eb10ce4 | |||
| fd11a325b8 | |||
| 4d17e20310 | |||
| 10fd1d7380 | |||
| 52b4f4a8d7 | |||
| e782e0a170 | |||
| dc2ceca5c5 | |||
| f8acd01ff7 | |||
| c48334d405 | |||
| 909fdaf152 | |||
| 8c1c926d00 | |||
| df6f3ce883 | |||
| 513f074766 | |||
| b07bf83c7d | |||
| 53e8cf53a4 | |||
| 54271bb766 | |||
| 9e96f56efb | |||
| b278911229 | |||
| 7bd0c7745c | |||
| 1cf0719ebd | |||
| 537d5ee025 | |||
| c8e5be35f7 | |||
| a6e72e1e4f | |||
| 5e83a7277f | |||
| 68af5f6c5c | |||
| 8de2901fea | |||
| c53e0730cb | |||
| a0e619e62a | |||
| 70116459c3 | |||
| 65e262b93b | |||
| 43faa0461a | |||
| 48cb2109b6 | |||
| a5450f11c9 | |||
| 9d98ab5ec6 | |||
| df5c879527 | |||
| 423e9f1cbe | |||
| 0bd7f8fca5 | |||
| d5615af9ae | |||
| 19dcc02a72 | |||
| 7feae92c1f | |||
| f851b84266 | |||
| fc966e9cc6 | |||
| ef19e67d2c | |||
| a41351f363 | |||
| 6aae216b4e | |||
| b22980a1dc | |||
| 881f735827 | |||
| 2f54045508 | |||
| 5aa6efb9a5 | |||
| 6ca0234478 | |||
| 649818995f | |||
| 7a0a9da72b | |||
| 69bff9bc89 | |||
| 41ca7eb491 | |||
| eef364723c | |||
| 0d6e187e88 | |||
| 9420a1fc30 | |||
| 583e900996 | |||
| 05e1fbfc52 | |||
| fe92176321 | |||
| 6d0df0ebeb | |||
| 0fa939e2d1 | |||
| 0422ce109f | |||
| 47bdee409c | |||
| 49f189439d | |||
| 5adf6f6b7f | |||
| 4115f19958 | |||
| 340d7b1b21 | |||
| 1bcbcbf574 | |||
| 82e43b2d7e | |||
| 67309a1cb5 | |||
| b724afe343 | |||
| 21f4f1c9a4 | |||
| b0c1f6202d | |||
| c0dfd97519 | |||
| a9138e85b1 | |||
| 0a05ed57e6 | |||
| 14288d1332 | |||
| b411418ff0 | |||
| 2bc0f72ae5 | |||
| 9c1244de57 | |||
| db2f8d915c | |||
| 6167c0e5d2 | |||
| ed2e464653 | |||
| 2c8ed8ee48 | |||
| ed50f46641 | |||
| 46e678bcff | |||
| 6b2427f995 | |||
| b07d741661 | |||
| 41fb013d29 | |||
| 32d4b669d0 | |||
| 3cde34a4a4 | |||
| bdb3660312 | |||
| f3a21e9c68 | |||
| 8e630d680e | |||
| af869f6dff | |||
| 53c0fa1e25 | |||
| f7912cba3d | |||
| 6317a5174a | |||
| aa72d9a4ea | |||
| ce17db8085 | |||
| 8c87a9ad46 | |||
| ec69124eb4 | |||
| d0da99fb70 | |||
| b2f195c429 | |||
| 047797ef90 | |||
| eb8ef4224d | |||
| 56a735261c | |||
| e1cf90e099 | |||
| 6bc1e30ef9 | |||
| 7e081ba7ca | |||
| 1e013fa388 | |||
| bc7c4d206b | |||
| f67e9e9f22 | |||
| 36fe78769f | |||
| 83d933718c | |||
| 5175b884f7 | |||
| 5536b30a4c | |||
| 7f58fb9718 | |||
| 30bc3e0f66 | |||
| f34410715f | |||
| 68d4c33202 | |||
| f961d7f6ef | |||
| d059110498 | |||
| 571e8dd65e | |||
| 4b91c927f6 | |||
| 0e237f0035 | |||
| 8f7bace7c3 | |||
| e4d6144232 | |||
| 8d32dc603d | |||
| c4ab9f3e71 | |||
| 2689d5c027 | |||
| acba33a0f1 | |||
| a114bf20a3 | |||
| 3097ce3a32 | |||
| d6da9322c8 | |||
| 71ce44047f | |||
| 188b7f9b8c | |||
| b9b4746950 | |||
| 7b8a2ab76f | |||
| c9acbf1141 | |||
| 5b794cae8d | |||
| 0e4254492f | |||
| 1311913f55 | |||
| 29f395c97c | |||
| fa3bba2a53 | |||
| 986537f1c3 | |||
| 210207525e | |||
| 71eda0bb76 | |||
| 471fe65630 | |||
| 3a0fba5cf4 | |||
| 299ebb62b2 | |||
| f728ab8e35 | |||
| 63e26fff78 | |||
| fe3462c774 | |||
| 3b34fd5273 | |||
| 55d6d3fdb8 | |||
| 7272bfae77 | |||
| d9ac9e3dc5 | |||
| d41faaf9df | |||
| b34f33438a | |||
| 26c0406555 | |||
| 4c41278b77 | |||
| bb3605db85 | |||
| fe742aef5a | |||
| 4b07d36891 | |||
| 87aaadef73 | |||
| 682e0b6d2f | |||
| d6195a748b | |||
| 205d84aaa9 | |||
| 5124f5bf51 | |||
| 83f3c3bd91 | |||
| d9737ca1c6 | |||
| 9d4ca19d50 | |||
| 2ef0dc53b8 | |||
| 1d4680fad2 | |||
| 2c1bd848a6 | |||
| 5c9121203c | |||
| 490b1698a5 | |||
| 5a5e29de88 | |||
| 3d3ab3689f | |||
| 686623c5e7 | |||
| aadb656562 | |||
| 87e067de41 | |||
| 26507f8973 | |||
| 9c1d5b456d | |||
| e31045f95c | |||
| aaec845f8e | |||
| 7bdfd29a35 | |||
| e78587a64c | |||
| 7eb4255628 | |||
| 6a0f547561 | |||
| 30ed81b7ca | |||
| 7a4a5de729 | |||
| c16fb5dae8 | |||
| e37073efd7 | |||
| 183dad7a85 | |||
| 3408e47159 | |||
| 0377b8310b | |||
| e4755f7fac | |||
| 92edf35826 | |||
| eb5819b2d9 | |||
| 5989f4684d | |||
| 5125d72f02 | |||
| a018e555fd | |||
| 6211b92273 | |||
| 05fcd1b430 | |||
| 7c02d6a137 | |||
| 11c3b98491 | |||
| dbe7f07001 | |||
| c69bf4ee06 | |||
| d27ea94034 | |||
| 99ed526101 | |||
| 207da28186 | |||
| 5b1aca2ae3 | |||
| d8e557b5e5 | |||
| 61a44a0b22 | |||
| a6481525b8 | |||
| 8cac35ba43 | |||
| 9dbf7a2dc1 | |||
| 607029e515 | |||
| cb072ce93b | |||
| 95aca283b4 | |||
| 2b05b8ce69 | |||
| 3c776dcefb | |||
| 2cbd4d2999 | |||
| 3092375e27 | |||
| 3cd91dc955 | |||
| 8a7368e069 | |||
| 93e561ec4d | |||
| e1b004839a | |||
| ee378f3d49 | |||
| e82ee40de3 | |||
| facbe2a114 | |||
| 7168920491 | |||
| 21378a2323 | |||
| 976711d9db | |||
| 44fa4d556c | |||
| 3ac98edcb1 | |||
| 966c742ed2 | |||
| 0d7d05f4b6 | |||
| 96bb8aa68b | |||
| 3badb0213b | |||
| fdcb850f14 | |||
| 54a66e5fee | |||
| 280d62b8a2 | |||
| 1666e66443 | |||
| 1575c1701a | |||
| 6ae996a873 | |||
| b590adfdc1 | |||
| b4fe16c75b | |||
| bc5dd4f669 | |||
| dbb036cf61 | |||
| 70e7ed841d | |||
| d06ba4ed3f | |||
| 6b40996ae8 | |||
| d2020acac7 | |||
| 1eb3c2ed48 | |||
| c64ee87267 | |||
| b1308b84a3 | |||
| 7b5ecf79bd | |||
| 9883a18859 | |||
| b3f2fddd17 | |||
| aa29841ede | |||
| 6bf27affb6 | |||
| 1dd23386ec | |||
| 7cbfc10943 | |||
| ce4ddd2d1a | |||
| e51929ebca | |||
| dc1b4a6f13 | |||
| 63d2705edb | |||
| d085a44082 | |||
| f49e5aff11 | |||
| 6c11ecf8d3 | |||
| 93e5f3c5fb | |||
| 70363bccfa | |||
| 3cdc57669f | |||
| 68bb122eb4 | |||
| d9fc8cd9da | |||
| f069f3ea74 | |||
| c5bc0e7fcc | |||
| 4a3a518722 | |||
| fbf722c6e6 | |||
| e92d7085bf | |||
| bd6028d6b0 | |||
| 802329dee9 | |||
| 41cc883c29 | |||
| 57504a4bcf | |||
| ed4792c990 | |||
| 87b836ba77 | |||
| 56c76c2e0e | |||
| c09632a66c | |||
| a3bf8d4a2b | |||
| 16eda8c43a | |||
| cd77382ac1 | |||
| 71b9cde010 | |||
| 5285589f37 | |||
| f41647ee6b | |||
| 4d022cbc75 | |||
| 70de35a881 | |||
| 34b2cf3b33 | |||
| 9e90c9f73f | |||
| e9528f6dc6 | |||
| 51baa9c333 | |||
| 35e076b3a8 | |||
| a26f59ccbc | |||
| aa3b3d76e0 | |||
| f7030df3be | |||
| 905e91e9ac | |||
| f8f9c0ba62 | |||
| dda811021a | |||
| 93195146ea | |||
| ed37599544 | |||
| 99ef59cf7f | |||
| d544d141ec | |||
| 3e397a9484 | |||
| 268c325078 | |||
| 3cc9af88ff | |||
| 7cd0bd7212 | |||
| 56d4aefa33 | |||
| dd143ef541 | |||
| daefed052c | |||
| 5fbab20e02 | |||
| e8224f3dca | |||
| 9665313c39 | |||
| 0c54fc7273 | |||
| c1b57855ec | |||
| 83b824c8b4 | |||
| 7678fcd5b6 | |||
| 8661c0241d | |||
| ce8d6b75fc | |||
| 61de3ef74b | |||
| ec1f9c8c91 | |||
| 65e09094c4 | |||
| c70cf0fe06 | |||
| a5d11a54dc | |||
| 3d4c87758e | |||
| a9bd832fc5 | |||
| 417bcefbae | |||
| baada0e737 | |||
| 82eb61dd4c | |||
| 0d4d06fe2f | |||
| 4aed0ca6a2 | |||
| 1621b25288 | |||
| a564797151 | |||
| 1da6a09274 | |||
| 1e44ffc3ff | |||
| a454748544 | |||
| 1bff42c4b7 | |||
| cb391d85dc | |||
| fee5b8d37f | |||
| b2ce859bd2 | |||
| 566f10a929 | |||
| c3b5189137 | |||
| a25866ac8d | |||
| 098900d7c2 | |||
| 98d01d3ce2 | |||
| d55244df31 | |||
| 04149cce27 | |||
| 24834f4894 | |||
| ec7da6fcf3 | |||
| 819d548e8a | |||
| 477d2a8aa2 | |||
| e484e02857 | |||
| 24f6b9a713 | |||
| 9cdde47289 | |||
| b1eb4ca152 | |||
| 87b4ac56c2 | |||
| cb84e45ac7 | |||
| 4716377fbc | |||
| 4e9cf8c1dd | |||
| 2976dc27e9 | |||
| 102bf967f0 | |||
| 1f4b09b525 | |||
| 86c3369eb8 | |||
| 2755c34a8f | |||
| db10422184 | |||
| e1a2c699dd | |||
| 0115ccd5c0 | |||
| 40b4284fe3 | |||
| 4ebc0b9640 | |||
| dc96fd54c6 | |||
| 1f5d13ab9f | |||
| 90cb44eb02 | |||
| e11880deea | |||
| 9351f91be9 | |||
| 5a1e1c8353 | |||
| 69ecaa7c79 | |||
| 7f00899ff7 | |||
| 995e3d1f41 | |||
| b4ac449a83 | |||
| 8e5314a468 | |||
| 87918e40c4 | |||
| f6b32efb7f | |||
| b99733d092 | |||
| 05a015d6a5 | |||
| ad971af8c7 | |||
| f2ebb6f541 | |||
| 1d01211264 | |||
| f94ab12f79 | |||
| a865bc1ca6 | |||
| 21802c4b6d | |||
| 652907b354 | |||
| 24f1c01e0f | |||
| fad6e2538e | |||
| 7f6d47c1a2 | |||
| 3147586ebd | |||
| ed636d99ca | |||
| 090c856d76 | |||
| ad434d4cfe | |||
| 66d433b94f | |||
| 027b204ff1 | |||
| 55dcce91df | |||
| 8017c8db7f | |||
| dc3529dbf6 | |||
| 7699258ef0 | |||
| e9ba99f296 | |||
| 7c80368710 | |||
| 95d63f38c0 | |||
| bb8dab821e | |||
| fc0f87768a | |||
| 0a57386721 | |||
| 3749e28774 | |||
| 86fc2321ff | |||
| 2549c0dfef | |||
| b10e519895 | |||
| 9bde5ba127 | |||
| 72c8f1ad04 | |||
| da224daaa9 | |||
| 3a100b9278 | |||
| 242a637aea | |||
| c2a9671510 | |||
| d5ae4f7f42 | |||
| b6c502a150 | |||
| 9ca710e525 | |||
| eb07c8cb5b | |||
| ba10801961 | |||
| 620fc2d09e | |||
| 29283eaa7e | |||
| 2fa66ef713 | |||
| 13affc432d | |||
| d8f094a92a | |||
| 97ae6d777f | |||
| 6baeee70d1 | |||
| d2517a4939 | |||
| 6342adc438 | |||
| 0adba91547 | |||
| 4285e423a6 |
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m deepseek-ai/DeepSeek-V2-Lite-Chat -b "auto" -l 1000 -f 5 -t 2
|
||||
model_name: "deepseek-ai/DeepSeek-V2-Lite-Chat"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For hf script, without -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform -b auto -l 1000 -f 5
|
||||
model_name: "nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For hf script, without -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m meta-llama/Meta-Llama-3-70B-Instruct -b 32 -l 250 -f 5
|
||||
model_name: "meta-llama/Meta-Llama-3-70B-Instruct"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8A8-FP8-Channelwise-compressed-tensors -b auto -l 1000 -f 5 -t 1
|
||||
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-FP8-Channelwise-compressed-tensors"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform -b auto -l 1000 -f 5 -t 1
|
||||
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test -b 32 -l 1000 -f 5 -t 1
|
||||
model_name: "nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Meta-Llama-3-8B-Instruct-FP8 -b 32 -l 250 -f 5 -t 1
|
||||
model_name: "neuralmagic/Meta-Llama-3-8B-Instruct-FP8"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Asym-Per-Token-Test -b "auto" -l 250 -f 5 -t 1
|
||||
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Asym-Per-Token-Test"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Per-Token-Test -b "auto" -l 250 -f 5 -t 1
|
||||
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Per-Token-Test"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test -b auto -l 1000 -f 5 -t 1
|
||||
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test"
|
||||
tasks:
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m meta-llama/Meta-Llama-3-8B-Instruct -b 32 -l 250 -f 5 -t 1
|
||||
# For hf script, without -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m meta-llama/Meta-Llama-3-8B-Instruct -b 32 -l 250 -f 5
|
||||
model_name: "meta-llama/Meta-Llama-3-8B-Instruct"
|
||||
tasks:
|
||||
- name: "gsm8k"
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m HandH1998/QQQ-Llama-3-8b-g128 -b 32 -l 1000 -f 5 -t 1
|
||||
model_name: "HandH1998/QQQ-Llama-3-8b-g128"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8 -b "auto" -l 1000 -f 5 -t 1
|
||||
model_name: "neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m mgoin/Minitron-4B-Base-FP8 -b auto -l 1000 -f 5 -t 1
|
||||
model_name: "mgoin/Minitron-4B-Base-FP8"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Mixtral-8x22B-Instruct-v0.1-FP8-dynamic -b "auto" -l 250 -f 5 -t 8
|
||||
model_name: "neuralmagic/Mixtral-8x22B-Instruct-v0.1-FP8-dynamic"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8 -b "auto" -l 250 -f 5 -t 4
|
||||
model_name: "neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8"
|
||||
tasks:
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m neuralmagic/Mixtral-8x7B-Instruct-v0.1 -b 32 -l 250 -f 5 -t 4
|
||||
# For hf script, without -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m neuralmagic/Mixtral-8x7B-Instruct-v0.1 -b 32 -l 250 -f 5
|
||||
model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
tasks:
|
||||
- name: "gsm8k"
|
||||
|
||||
@ -0,0 +1,12 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16 -b auto -l 1319 -f 5 -t 1
|
||||
model_name: "nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16"
|
||||
tasks:
|
||||
- name: "gsm8k"
|
||||
metrics:
|
||||
- name: "exact_match,strict-match"
|
||||
value: 0.30
|
||||
- name: "exact_match,flexible-extract"
|
||||
value: 0.465
|
||||
limit: 1319
|
||||
num_fewshot: 5
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Qwen2-1.5B-Instruct-FP8W8 -b auto -l 1000 -f 5 -t 1
|
||||
model_name: "nm-testing/Qwen2-1.5B-Instruct-FP8W8"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Qwen2-1.5B-Instruct-quantized.w8a8 -b "auto" -l 1000 -f 5 -t 1
|
||||
model_name: "neuralmagic/Qwen2-1.5B-Instruct-quantized.w8a8"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Qwen2-1.5B-Instruct-W8A16-Channelwise -b "auto" -l 1000 -f 5 -t 1
|
||||
model_name: "nm-testing/Qwen2-1.5B-Instruct-W8A16-Channelwise"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m Qwen/Qwen2-57B-A14B-Instruct -b "auto" -l 250 -f 5 -t 4
|
||||
model_name: "Qwen/Qwen2-57B-A14B-Instruct"
|
||||
tasks:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM -b "auto" -t 2
|
||||
model_name: "nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM"
|
||||
tasks:
|
||||
|
||||
@ -4,7 +4,7 @@ Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml
|
||||
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml
|
||||
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
|
||||
Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
|
||||
Minitron-4B-Base-FP8.yaml
|
||||
Qwen1.5-MoE-W4A16-compressed-tensors.yaml
|
||||
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
|
||||
Qwen2-1.5B-Instruct-FP8W8.yaml
|
||||
Meta-Llama-3-8B-QQQ.yaml
|
||||
|
||||
@ -16,7 +16,7 @@ import numpy
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
RTOL = 0.05
|
||||
RTOL = 0.08
|
||||
TEST_DATA_FILE = os.environ.get(
|
||||
"LM_EVAL_TEST_DATA_FILE",
|
||||
".buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct.yaml")
|
||||
|
||||
@ -86,3 +86,18 @@ steps:
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version)"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
- block: "Build Neuron release image"
|
||||
key: block-neuron-release-image-build
|
||||
depends_on: ~
|
||||
|
||||
- label: "Build and publish Neuron release image"
|
||||
depends_on: block-neuron-release-image-build
|
||||
agents:
|
||||
queue: neuron-postmerge
|
||||
commands:
|
||||
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:latest --progress plain -f docker/Dockerfile.neuron ."
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:$(buildkite-agent meta-data get release-version)"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
@ -98,6 +98,13 @@ if [[ $commands == *" kernels "* ]]; then
|
||||
--ignore=kernels/test_machete_mm.py \
|
||||
--ignore=kernels/test_mha_attn.py \
|
||||
--ignore=kernels/test_block_fp8.py \
|
||||
--ignore=kernels/test_cutlass_moe.py \
|
||||
--ignore=kernels/test_mamba_ssm_ssd.py \
|
||||
--ignore=kernels/test_attention.py \
|
||||
--ignore=kernels/test_block_int8.py \
|
||||
--ignore=kernels/test_fused_quant_layernorm.py \
|
||||
--ignore=kernels/test_int8_kernel.py \
|
||||
--ignore=kernels/test_triton_moe_ptpc_fp8.py \
|
||||
--ignore=kernels/test_permute_cols.py"
|
||||
fi
|
||||
|
||||
|
||||
@ -5,10 +5,41 @@
|
||||
set -ex
|
||||
|
||||
# Setup cleanup
|
||||
remove_docker_container() { docker rm -f cpu-test || true; docker system prune -f; }
|
||||
remove_docker_container() {
|
||||
if [[ -n "$container_id" ]]; then
|
||||
podman rm -f "$container_id" || true
|
||||
fi
|
||||
podman system prune -f
|
||||
}
|
||||
trap remove_docker_container EXIT
|
||||
remove_docker_container
|
||||
|
||||
# Try building the docker image
|
||||
docker build -t cpu-test -f docker/Dockerfile.ppc64le .
|
||||
podman build -t cpu-test-ubi9-ppc -f docker/Dockerfile.ppc64le .
|
||||
|
||||
# Run the image
|
||||
container_id=$(podman run -itd --entrypoint /bin/bash -v /tmp/:/root/.cache/huggingface --privileged=true --network host -e HF_TOKEN cpu-test-ubi9-ppc)
|
||||
|
||||
function cpu_tests() {
|
||||
|
||||
# offline inference
|
||||
podman exec -it "$container_id" bash -c "
|
||||
set -e
|
||||
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m"
|
||||
|
||||
# Run basic model test
|
||||
podman exec -it "$container_id" bash -c "
|
||||
set -e
|
||||
pip install pytest pytest-asyncio einops peft Pillow soundfile transformers_stream_generator matplotlib
|
||||
pip install sentence-transformers datamodel_code_generator
|
||||
pytest -v -s tests/models/embedding/language/test_cls_models.py::test_classification_models[float-jason9693/Qwen2.5-1.5B-apeach]
|
||||
pytest -v -s tests/models/embedding/language/test_embedding.py::test_models[half-BAAI/bge-base-en-v1.5]
|
||||
pytest -v -s tests/models/encoder_decoder/language -m cpu_model"
|
||||
}
|
||||
|
||||
# All of CPU tests are expected to be finished less than 40 mins.
|
||||
|
||||
export container_id
|
||||
export -f cpu_tests
|
||||
timeout 40m bash -c cpu_tests
|
||||
|
||||
|
||||
13
.buildkite/scripts/hardware_ci/run-cpu-test-s390x.sh
Executable file
13
.buildkite/scripts/hardware_ci/run-cpu-test-s390x.sh
Executable file
@ -0,0 +1,13 @@
|
||||
#!/bin/bash
|
||||
|
||||
# This script build the CPU docker image and run the offline inference inside the container.
|
||||
# It serves a sanity check for compilation and basic model usage.
|
||||
set -ex
|
||||
|
||||
# Setup cleanup
|
||||
remove_docker_container() { docker rm -f cpu-test || true; docker system prune -f; }
|
||||
trap remove_docker_container EXIT
|
||||
remove_docker_container
|
||||
|
||||
# Try building the docker image
|
||||
docker build -t cpu-test -f docker/Dockerfile.s390x .
|
||||
@ -17,10 +17,13 @@ source /etc/environment
|
||||
docker run --privileged --net host --shm-size=16G -it \
|
||||
-e "HF_TOKEN=$HF_TOKEN" --name tpu-test \
|
||||
vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git \
|
||||
&& python3 -m pip install pytest \
|
||||
&& python3 -m pip install pytest pytest-asyncio tpu-info \
|
||||
&& python3 -m pip install lm_eval[api]==0.4.4 \
|
||||
&& export VLLM_XLA_CACHE_PATH= \
|
||||
&& export VLLM_USE_V1=1 \
|
||||
&& export VLLM_XLA_CHECK_RECOMPILATION=1 \
|
||||
&& echo HARDWARE \
|
||||
&& tpu-info \
|
||||
&& echo TEST_0 \
|
||||
&& pytest -v -s /workspace/vllm/tests/v1/tpu/test_perf.py \
|
||||
&& echo TEST_1 \
|
||||
@ -40,7 +43,11 @@ docker run --privileged --net host --shm-size=16G -it \
|
||||
&& echo TEST_8 \
|
||||
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py \
|
||||
&& echo TEST_9 \
|
||||
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py" \
|
||||
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_multimodal.py \
|
||||
&& echo TEST_10 \
|
||||
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py \
|
||||
&& echo TEST_11 \
|
||||
&& pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py" \
|
||||
|
||||
|
||||
# TODO: This test fails because it uses RANDOM_SEED sampling
|
||||
|
||||
@ -5,8 +5,8 @@
|
||||
set -ex
|
||||
set -o pipefail
|
||||
|
||||
# cd into parent directory of this file
|
||||
cd "$(dirname "${BASH_SOURCE[0]}")/.."
|
||||
# cd 2 levels into the working directory
|
||||
cd "$(dirname "${BASH_SOURCE[0]}")/../.."
|
||||
|
||||
(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
# Documentation
|
||||
# label(str): the name of the test. emoji allowed.
|
||||
# fast_check(bool): whether to run this on each commit on fastcheck pipeline.
|
||||
# torch_nightly(bool): whether to run this on vllm against torch nightly pipeline.
|
||||
# fast_check_only(bool): run this test on fastcheck pipeline only
|
||||
# optional(bool): never run this test by default (i.e. need to unblock manually) unless it's scheduled nightly run.
|
||||
# command(str): the single command to run for tests. incompatible with commands.
|
||||
@ -70,6 +71,7 @@ steps:
|
||||
- label: Basic Correctness Test # 30min
|
||||
#mirror_hardwares: [amd]
|
||||
fast_check: true
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/basic_correctness/test_basic_correctness
|
||||
@ -104,6 +106,7 @@ steps:
|
||||
- label: Entrypoints Test # 40min
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
fast_check: true
|
||||
torch_nightly: true
|
||||
#mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -118,7 +121,7 @@ steps:
|
||||
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
|
||||
- VLLM_USE_V1=0 pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/correctness/
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_openai_schema.py
|
||||
- pytest -v -s entrypoints/test_chat_utils.py
|
||||
- VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
|
||||
|
||||
@ -163,11 +166,6 @@ steps:
|
||||
- tests/tracing
|
||||
commands:
|
||||
- pytest -v -s metrics
|
||||
- "pip install \
|
||||
'opentelemetry-sdk>=1.26.0,<1.27.0' \
|
||||
'opentelemetry-api>=1.26.0,<1.27.0' \
|
||||
'opentelemetry-exporter-otlp>=1.26.0,<1.27.0' \
|
||||
'opentelemetry-semantic-conventions-ai>=0.4.1,<0.5.0'"
|
||||
- pytest -v -s tracing
|
||||
|
||||
##### fast check tests #####
|
||||
@ -210,6 +208,8 @@ steps:
|
||||
- pytest -v -s v1/sample
|
||||
- pytest -v -s v1/worker
|
||||
- pytest -v -s v1/structured_output
|
||||
- pytest -v -s v1/spec_decode
|
||||
- pytest -v -s v1/test_serial_utils.py
|
||||
- pytest -v -s v1/test_stats.py
|
||||
- pytest -v -s v1/test_utils.py
|
||||
- pytest -v -s v1/test_oracle.py
|
||||
@ -292,6 +292,15 @@ steps:
|
||||
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py
|
||||
parallelism: 4
|
||||
|
||||
- label: PyTorch Compilation Unit Tests
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/compile
|
||||
commands:
|
||||
- pytest -v -s compile/test_pass_manager.py
|
||||
- pytest -v -s compile/test_fusion.py
|
||||
- pytest -v -s compile/test_sequence_parallelism.py
|
||||
|
||||
- label: PyTorch Fullgraph Smoke Test # 9min
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -301,7 +310,6 @@ steps:
|
||||
# these tests need to be separated, cannot combine
|
||||
- pytest -v -s compile/piecewise/test_simple.py
|
||||
- pytest -v -s compile/piecewise/test_toy_llama.py
|
||||
- pytest -v -s compile/test_pass_manager.py
|
||||
|
||||
- label: PyTorch Fullgraph Test # 18min
|
||||
source_file_dependencies:
|
||||
@ -310,15 +318,46 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s compile/test_full_graph.py
|
||||
|
||||
- label: Kernels Test %N # 1h each
|
||||
# mirror_hardwares: [amd]
|
||||
- label: Kernels Core Operation Test
|
||||
source_file_dependencies:
|
||||
- csrc/
|
||||
- vllm/attention
|
||||
- tests/kernels
|
||||
- tests/kernels/core
|
||||
commands:
|
||||
- pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
||||
parallelism: 4
|
||||
- pytest -v -s kernels/core
|
||||
|
||||
- label: Kernels Attention Test %N
|
||||
source_file_dependencies:
|
||||
- csrc/attention/
|
||||
- vllm/attention
|
||||
- vllm/v1/attention
|
||||
- tests/kernels/attention
|
||||
commands:
|
||||
- pytest -v -s kernels/attention --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
||||
parallelism: 2
|
||||
|
||||
- label: Kernels Quantization Test %N
|
||||
source_file_dependencies:
|
||||
- csrc/quantization/
|
||||
- vllm/model_executor/layers/quantization
|
||||
- tests/kernels/quantization
|
||||
commands:
|
||||
- pytest -v -s kernels/quantization --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
||||
parallelism: 2
|
||||
|
||||
- label: Kernels MoE Test
|
||||
source_file_dependencies:
|
||||
- csrc/moe/
|
||||
- tests/kernels/moe
|
||||
- vllm/model_executor/layers/fused_moe/
|
||||
commands:
|
||||
- pytest -v -s kernels/moe
|
||||
|
||||
- label: Kernels Mamba Test
|
||||
source_file_dependencies:
|
||||
- csrc/mamba/
|
||||
- tests/kernels/mamba
|
||||
commands:
|
||||
- pytest -v -s kernels/mamba
|
||||
|
||||
- label: Tensorizer Test # 11min
|
||||
# mirror_hardwares: [amd]
|
||||
@ -339,6 +378,13 @@ steps:
|
||||
commands:
|
||||
- bash scripts/run-benchmarks.sh
|
||||
|
||||
- label: Benchmarks CLI Test # 10min
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/benchmarks/
|
||||
commands:
|
||||
- pytest -v -s benchmarks/
|
||||
|
||||
- label: Quantization Test # 33min
|
||||
source_file_dependencies:
|
||||
- csrc/
|
||||
@ -376,8 +422,10 @@ steps:
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/tool_use
|
||||
- tests/mistral_tool_use
|
||||
commands:
|
||||
- pytest -v -s tool_use
|
||||
- pytest -v -s mistral_tool_use
|
||||
|
||||
##### models test #####
|
||||
|
||||
@ -389,7 +437,9 @@ steps:
|
||||
- pytest -v -s models/test_transformers.py
|
||||
- pytest -v -s models/test_registry.py
|
||||
# V1 Test: https://github.com/vllm-project/vllm/issues/14531
|
||||
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py
|
||||
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2'
|
||||
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'llama4'
|
||||
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'plamo2'
|
||||
|
||||
- label: Language Models Test (Standard) # 32min
|
||||
#mirror_hardwares: [amd]
|
||||
@ -399,6 +449,8 @@ steps:
|
||||
- tests/models/embedding/language
|
||||
- tests/models/encoder_decoder/language
|
||||
commands:
|
||||
# Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
|
||||
- pip install causal-conv1d
|
||||
- pytest -v -s models/decoder_only/language -m 'core_model or quant_model'
|
||||
- pytest -v -s models/embedding/language -m core_model
|
||||
|
||||
@ -410,6 +462,8 @@ steps:
|
||||
- tests/models/embedding/language
|
||||
- tests/models/encoder_decoder/language
|
||||
commands:
|
||||
# Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
|
||||
- pip install causal-conv1d
|
||||
- pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model'
|
||||
- pytest -v -s models/embedding/language -m 'not core_model'
|
||||
|
||||
@ -426,7 +480,7 @@ steps:
|
||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||
- pytest -v -s models/multimodal
|
||||
- pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model'
|
||||
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model'
|
||||
- pytest -v -s models/decoder_only/vision_language -m 'core_model or quant_model'
|
||||
- pytest -v -s models/embedding/vision_language -m core_model
|
||||
- pytest -v -s models/encoder_decoder/audio_language -m core_model
|
||||
- pytest -v -s models/encoder_decoder/language -m core_model
|
||||
@ -445,10 +499,7 @@ steps:
|
||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||
- pytest -v -s models/decoder_only/audio_language -m 'not core_model and not quant_model'
|
||||
- pytest -v -s models/decoder_only/vision_language/test_models.py -m 'split(group=0) and not core_model and not quant_model'
|
||||
# HACK - run phi3v tests separately to sidestep this transformers bug
|
||||
# https://github.com/huggingface/transformers/issues/34307
|
||||
- pytest -v -s models/decoder_only/vision_language/test_phi3v.py
|
||||
- pytest -v -s --ignore models/decoder_only/vision_language/test_models.py --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model and not quant_model'
|
||||
- pytest -v -s --ignore models/decoder_only/vision_language/test_models.py models/decoder_only/vision_language -m 'not core_model and not quant_model'
|
||||
- pytest -v -s models/embedding/vision_language -m 'not core_model'
|
||||
- pytest -v -s models/encoder_decoder/language -m 'not core_model'
|
||||
- pytest -v -s models/encoder_decoder/vision_language -m 'not core_model'
|
||||
@ -533,11 +584,14 @@ steps:
|
||||
- pytest models/encoder_decoder/language/test_bart.py -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest models/decoder_only/vision_language/test_models.py -v -s -m 'distributed(num_gpus=2)'
|
||||
# test sequence parallel
|
||||
- pytest -v -s distributed/test_sequence_parallel.py
|
||||
# this test fails consistently.
|
||||
# TODO: investigate and fix
|
||||
# - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
|
||||
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
||||
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/test_disagg.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown
|
||||
|
||||
- label: Plugin Tests (2 GPUs) # 40min
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
|
||||
1
.github/CODEOWNERS
vendored
1
.github/CODEOWNERS
vendored
@ -12,6 +12,7 @@
|
||||
/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth
|
||||
/vllm/model_executor/guided_decoding @mgoin @russellb
|
||||
/vllm/multimodal @DarkLight1337 @ywang96
|
||||
/vllm/vllm_flash_attn @LucasWilkinson
|
||||
CMakeLists.txt @tlrmchlsmth
|
||||
|
||||
# vLLM V1
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/200-installation.yml
vendored
2
.github/ISSUE_TEMPLATE/200-installation.yml
vendored
@ -14,7 +14,7 @@ body:
|
||||
description: |
|
||||
Please run the following and paste the output below.
|
||||
```sh
|
||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/collect_env.py
|
||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/vllm/collect_env.py
|
||||
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
||||
python collect_env.py
|
||||
```
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/300-usage.yml
vendored
2
.github/ISSUE_TEMPLATE/300-usage.yml
vendored
@ -14,7 +14,7 @@ body:
|
||||
description: |
|
||||
Please run the following and paste the output below.
|
||||
```sh
|
||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/collect_env.py
|
||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/vllm/collect_env.py
|
||||
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
||||
python collect_env.py
|
||||
```
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/400-bug-report.yml
vendored
2
.github/ISSUE_TEMPLATE/400-bug-report.yml
vendored
@ -14,7 +14,7 @@ body:
|
||||
description: |
|
||||
Please run the following and paste the output below.
|
||||
```sh
|
||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/collect_env.py
|
||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/vllm/collect_env.py
|
||||
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
||||
python collect_env.py
|
||||
```
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/600-new-model.yml
vendored
2
.github/ISSUE_TEMPLATE/600-new-model.yml
vendored
@ -9,7 +9,7 @@ body:
|
||||
value: >
|
||||
#### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue+sort%3Acreated-desc+).
|
||||
|
||||
#### We also highly recommend you read https://docs.vllm.ai/en/latest/contributing/model/adding_model.html first to understand how to add a new model.
|
||||
#### We also highly recommend you read https://docs.vllm.ai/en/latest/contributing/model/index.html first to understand how to add a new model.
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: The model to consider.
|
||||
|
||||
@ -35,7 +35,7 @@ body:
|
||||
description: |
|
||||
Please run the following and paste the output below.
|
||||
```sh
|
||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/collect_env.py
|
||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/vllm/collect_env.py
|
||||
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
||||
python collect_env.py
|
||||
```
|
||||
|
||||
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
@ -3,4 +3,4 @@ FILL IN THE PR DESCRIPTION HERE
|
||||
FIX #xxxx (*link existing issues this PR will resolve*)
|
||||
|
||||
<!--- pyml disable-next-line no-emphasis-as-heading -->
|
||||
**BEFORE SUBMITTING, PLEASE READ <https://docs.vllm.ai/en/latest/contributing/overview.html>**
|
||||
**BEFORE SUBMITTING, PLEASE READ <https://docs.vllm.ai/en/latest/contributing/overview.html>** (anything written below this line will be removed by GitHub Actions)
|
||||
|
||||
34
.github/mergify.yml
vendored
34
.github/mergify.yml
vendored
@ -55,11 +55,19 @@ pull_request_rules:
|
||||
description: Automatically apply structured-output label
|
||||
conditions:
|
||||
- or:
|
||||
- files~=^benchmarks/structured_schemas/
|
||||
- files=benchmarks/benchmark_serving_structured_output.py
|
||||
- files=benchmarks/run_structured_output_benchmark.sh
|
||||
- files=docs/source/features/structured_outputs.md
|
||||
- files=examples/offline_inference/structured_outputs.py
|
||||
- files=examples/online_serving/openai_chat_completion_structured_outputs.py
|
||||
- files=examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py
|
||||
- files~=^vllm/model_executor/guided_decoding/
|
||||
- files=tests/model_executor/test_guided_processors.py
|
||||
- files=tests/entrypoints/llm/test_guided_generate.py
|
||||
- files=benchmarks/benchmark_serving_guided.py
|
||||
- files=benchmarks/benchmark_guided.py
|
||||
- files~=^tests/v1/structured_output/
|
||||
- files=tests/v1/entrypoints/llm/test_guided_generate.py
|
||||
- files~=^vllm/v1/structured_output/
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
@ -118,6 +126,28 @@ pull_request_rules:
|
||||
remove:
|
||||
- tpu
|
||||
|
||||
- name: label-tool-calling
|
||||
description: Automatically add tool-calling label
|
||||
conditions:
|
||||
- or:
|
||||
- files~=^tests/tool_use/
|
||||
- files~=^tests/mistral_tool_use/
|
||||
- files~=^tests/entrypoints/openai/tool_parsers/
|
||||
- files=tests/entrypoints/openai/test_chat_with_tool_reasoning.py
|
||||
- files~=^vllm/entrypoints/openai/tool_parsers/
|
||||
- files=docs/source/features/tool_calling.md
|
||||
- files=docs/source/getting_started/examples/openai_chat_completion_client_with_tools.md
|
||||
- files=docs/source/getting_started/examples/chat_with_tools.md
|
||||
- files~=^examples/tool_chat_*
|
||||
- files=examples/offline_inference/chat_with_tools.py
|
||||
- files=examples/online_serving/openai_chat_completion_client_with_tools_required.py
|
||||
- files=examples/online_serving/openai_chat_completion_tool_calls_with_reasoning.py
|
||||
- files=examples/online_serving/openai_chat_completion_client_with_tools.py
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
- tool-calling
|
||||
|
||||
- name: ping author on conflicts and add 'needs-rebase' label
|
||||
conditions:
|
||||
- conflict
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@ -3,7 +3,6 @@
|
||||
|
||||
# vllm-flash-attn built from source
|
||||
vllm/vllm_flash_attn/*
|
||||
!vllm/vllm_flash_attn/fa_utils.py
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
@ -203,3 +202,6 @@ benchmarks/**/*.json
|
||||
# Linting
|
||||
actionlint
|
||||
shellcheck*/
|
||||
|
||||
# Ingore moe/marlin_moe gen code
|
||||
csrc/moe/marlin_moe_wna16/kernel_*
|
||||
|
||||
@ -11,7 +11,6 @@ repos:
|
||||
hooks:
|
||||
- id: yapf
|
||||
args: [--in-place, --verbose]
|
||||
additional_dependencies: [toml] # TODO: Remove when yapf is upgraded
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.9.3
|
||||
hooks:
|
||||
@ -122,6 +121,12 @@ repos:
|
||||
language: system
|
||||
always_run: true
|
||||
pass_filenames: false
|
||||
- id: update-dockerfile-graph
|
||||
name: Update Dockerfile dependency graph
|
||||
entry: tools/update-dockerfile-graph.sh
|
||||
language: script
|
||||
files: ^docker/Dockerfile$
|
||||
pass_filenames: false
|
||||
# Keep `suggestion` last
|
||||
- id: suggestion
|
||||
name: Suggestion
|
||||
|
||||
@ -230,6 +230,7 @@ set(VLLM_EXT_SRC
|
||||
"csrc/cache_kernels.cu"
|
||||
"csrc/attention/paged_attention_v1.cu"
|
||||
"csrc/attention/paged_attention_v2.cu"
|
||||
"csrc/attention/merge_attn_states.cu"
|
||||
"csrc/pos_encoding_kernels.cu"
|
||||
"csrc/activation_kernels.cu"
|
||||
"csrc/layernorm_kernels.cu"
|
||||
@ -250,7 +251,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
# Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case.
|
||||
# Please keep this in sync with FetchContent_Declare line below.
|
||||
set(CUTLASS_REVISION "v3.8.0" CACHE STRING "CUTLASS revision to use")
|
||||
set(CUTLASS_REVISION "v3.9.0" CACHE STRING "CUTLASS revision to use")
|
||||
|
||||
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
|
||||
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
|
||||
@ -268,7 +269,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
cutlass
|
||||
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
|
||||
# Please keep this in sync with CUTLASS_REVISION line above.
|
||||
GIT_TAG v3.8.0
|
||||
GIT_TAG v3.9.0
|
||||
GIT_PROGRESS TRUE
|
||||
|
||||
# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
|
||||
@ -289,7 +290,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
|
||||
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
|
||||
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
|
||||
"csrc/cutlass_extensions/common.cpp")
|
||||
"csrc/cutlass_extensions/common.cpp"
|
||||
"csrc/attention/mla/cutlass_mla_entry.cu")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${VLLM_EXT_SRC}"
|
||||
@ -462,7 +464,26 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
set(FP4_ARCHS)
|
||||
endif()
|
||||
|
||||
#
|
||||
# CUTLASS MLA Archs and flags
|
||||
cuda_archs_loose_intersection(MLA_ARCHS "10.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND MLA_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/attention/mla/cutlass_mla_kernels.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${MLA_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MLA=1")
|
||||
# Add MLA-specific include directories only to MLA source files
|
||||
set_source_files_properties(${SRCS}
|
||||
PROPERTIES INCLUDE_DIRECTORIES "${CUTLASS_DIR}/examples/77_blackwell_fmha;${CUTLASS_DIR}/examples/common")
|
||||
message(STATUS "Building CUTLASS MLA for archs: ${MLA_ARCHS}")
|
||||
else()
|
||||
message(STATUS "Not building CUTLASS MLA as no compatible archs were found.")
|
||||
# clear MLA_ARCHS
|
||||
set(MLA_ARCHS)
|
||||
endif()
|
||||
|
||||
# CUTLASS MoE kernels
|
||||
|
||||
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works
|
||||
@ -608,21 +629,51 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}")
|
||||
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
|
||||
if (MARLIN_MOE_ARCHS)
|
||||
set(MARLIN_MOE_SRC
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel.h"
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h"
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu"
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h"
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu"
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h"
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu"
|
||||
"csrc/moe/marlin_moe_ops.cu")
|
||||
|
||||
#
|
||||
# For the Marlin MOE kernels we automatically generate sources for various
|
||||
# preselected input type pairs and schedules.
|
||||
# Generate sources:
|
||||
set(MOE_MARLIN_GEN_SCRIPT
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/csrc/moe/marlin_moe_wna16/generate_kernels.py)
|
||||
file(MD5 ${MOE_MARLIN_GEN_SCRIPT} MOE_MARLIN_GEN_SCRIPT_HASH)
|
||||
|
||||
message(STATUS "Marlin MOE generation script hash: ${MOE_MARLIN_GEN_SCRIPT_HASH}")
|
||||
message(STATUS "Last run Marlin MOE generate script hash: $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH}")
|
||||
|
||||
if (NOT DEFINED CACHE{MOE_MARLIN_GEN_SCRIPT_HASH}
|
||||
OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH})
|
||||
execute_process(
|
||||
COMMAND ${CMAKE_COMMAND} -E env
|
||||
PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH
|
||||
${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT}
|
||||
RESULT_VARIABLE moe_marlin_generation_result
|
||||
OUTPUT_VARIABLE moe_marlin_generation_output
|
||||
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log
|
||||
ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log
|
||||
)
|
||||
|
||||
if (NOT moe_marlin_generation_result EQUAL 0)
|
||||
message(FATAL_ERROR "Marlin MOE generation failed."
|
||||
" Result: \"${moe_marlin_generation_result}\""
|
||||
"\nCheck the log for details: "
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log")
|
||||
else()
|
||||
set(MOE_MARLIN_GEN_SCRIPT_HASH ${MOE_MARLIN_GEN_SCRIPT_HASH}
|
||||
CACHE STRING "Last run Marlin MOE generate script hash" FORCE)
|
||||
message(STATUS "Marlin MOE generation completed successfully.")
|
||||
endif()
|
||||
else()
|
||||
message(STATUS "Marlin MOE generation script has not changed, skipping generation.")
|
||||
endif()
|
||||
|
||||
file(GLOB MOE_WNAA16_MARLIN_SRC "csrc/moe/marlin_moe_wna16/*.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_MOE_SRC}"
|
||||
SRCS "${MOE_WNAA16_MARLIN_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_MOE_ARCHS}")
|
||||
|
||||
list(APPEND VLLM_MOE_EXT_SRC "${MARLIN_MOE_SRC}")
|
||||
list(APPEND VLLM_MOE_EXT_SRC ${MOE_WNAA16_MARLIN_SRC})
|
||||
|
||||
message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}")
|
||||
else()
|
||||
message(STATUS "Not building Marlin MOE kernels as no compatible archs found"
|
||||
@ -647,6 +698,7 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
|
||||
#
|
||||
set(VLLM_ROCM_EXT_SRC
|
||||
"csrc/rocm/torch_bindings.cpp"
|
||||
"csrc/rocm/skinny_gemms.cu"
|
||||
"csrc/rocm/attention.cu")
|
||||
|
||||
define_gpu_extension_target(
|
||||
|
||||
@ -10,16 +10,13 @@ Easy, fast, and cheap LLM serving for everyone
|
||||
</h3>
|
||||
|
||||
<p align="center">
|
||||
| <a href="https://docs.vllm.ai"><b>Documentation</b></a> | <a href="https://vllm.ai"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://x.com/vllm_project"><b>Twitter/X</b></a> | <a href="https://discuss.vllm.ai"><b>User Forum</b></a> | <a href="https://slack.vllm.ai"><b>Developer Slack</b></a> |
|
||||
| <a href="https://docs.vllm.ai"><b>Documentation</b></a> | <a href="https://blog.vllm.ai/"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://x.com/vllm_project"><b>Twitter/X</b></a> | <a href="https://discuss.vllm.ai"><b>User Forum</b></a> | <a href="https://slack.vllm.ai"><b>Developer Slack</b></a> |
|
||||
</p>
|
||||
|
||||
---
|
||||
|
||||
[2025/04] We're hosting our first-ever *vLLM Asia Developer Day* in Singapore on *April 3rd*! This is a full-day event (9 AM - 9 PM SGT) in partnership with SGInnovate, AMD, and Embedded LLM. Meet the vLLM team and learn about LLM inference for RL, MI300X, and more! [Register Now](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day)
|
||||
|
||||
---
|
||||
|
||||
*Latest News* 🔥
|
||||
- [2025/04] We hosted [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing).
|
||||
- [2025/03] We hosted [vLLM x Ollama Inference Night](https://lu.ma/vllm-ollama)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/16T2PDD1YwRnZ4Tu8Q5r6n53c5Lr5c73UV9Vd2_eBo4U/edit?usp=sharing).
|
||||
- [2025/03] We hosted [the first vLLM China Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg)! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1REHvfQMKGnvz6p3Fd23HhSO4c8j5WPGZV0bKYLwnHyQ/edit?usp=sharing).
|
||||
- [2025/03] We hosted [the East Coast vLLM Meetup](https://lu.ma/7mu4k4xx)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1NHiv8EUFF1NLd3fEYODm56nDmL26lEeXCaDgyDlTsRs/edit#slide=id.g31441846c39_0_0).
|
||||
|
||||
@ -204,6 +204,24 @@ python3 vllm/benchmarks/benchmark_serving.py \
|
||||
--seed 42
|
||||
```
|
||||
|
||||
### Running With Sampling Parameters
|
||||
|
||||
When using OpenAI-compatible backends such as `vllm`, optional sampling
|
||||
parameters can be specified. Example client command:
|
||||
|
||||
```bash
|
||||
python3 vllm/benchmarks/benchmark_serving.py \
|
||||
--backend vllm \
|
||||
--model NousResearch/Hermes-3-Llama-3.1-8B \
|
||||
--endpoint /v1/completions \
|
||||
--dataset-name sharegpt \
|
||||
--dataset-path <your data path>/ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||
--top-k 10 \
|
||||
--top-p 0.9 \
|
||||
--temperature 0.5 \
|
||||
--num-prompts 10
|
||||
```
|
||||
|
||||
---
|
||||
## Example - Offline Throughput Benchmark
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
@ -32,6 +33,7 @@ class RequestFuncInput:
|
||||
extra_body: Optional[dict] = None
|
||||
multi_modal_content: Optional[dict] = None
|
||||
ignore_eos: bool = False
|
||||
language: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -436,6 +438,110 @@ async def async_request_openai_chat_completions(
|
||||
return output
|
||||
|
||||
|
||||
async def async_request_openai_audio(
|
||||
request_func_input: RequestFuncInput,
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
# Lazy import without PlaceholderModule to avoid vllm dep.
|
||||
import soundfile
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith(
|
||||
("transcriptions", "translations"
|
||||
)), "OpenAI Chat Completions API URL must end with 'transcriptions' "
|
||||
"or `translations`."
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
content = [{"type": "text", "text": request_func_input.prompt}]
|
||||
payload = {
|
||||
"model": request_func_input.model_name \
|
||||
if request_func_input.model_name else request_func_input.model,
|
||||
"temperature": 0.0,
|
||||
"max_completion_tokens": request_func_input.output_len,
|
||||
"stream": True,
|
||||
"language": "en",
|
||||
# Flattened due to multipart/form-data
|
||||
"stream_include_usage": True,
|
||||
"stream_continuous_usage_stats": True
|
||||
}
|
||||
if request_func_input.extra_body:
|
||||
payload.update(request_func_input.extra_body)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
}
|
||||
|
||||
# Send audio file
|
||||
def to_bytes(y, sr):
|
||||
buffer = io.BytesIO()
|
||||
soundfile.write(buffer, y, sr, format="WAV")
|
||||
buffer.seek(0)
|
||||
return buffer
|
||||
|
||||
with to_bytes(*request_func_input.multi_modal_content['audio']) as f:
|
||||
form = aiohttp.FormData()
|
||||
form.add_field('file', f, content_type='audio/wav')
|
||||
for key, value in payload.items():
|
||||
form.add_field(key, str(value))
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
generated_text = ""
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url,
|
||||
data=form,
|
||||
headers=headers) as response:
|
||||
if response.status == 200:
|
||||
async for chunk_bytes in response.content:
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
||||
"data: ")
|
||||
if chunk != "[DONE]":
|
||||
timestamp = time.perf_counter()
|
||||
data = json.loads(chunk)
|
||||
|
||||
if choices := data.get("choices"):
|
||||
content = choices[0]["delta"].get(
|
||||
"content")
|
||||
# First token
|
||||
if ttft == 0.0:
|
||||
ttft = timestamp - st
|
||||
output.ttft = ttft
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(
|
||||
timestamp - most_recent_timestamp)
|
||||
|
||||
generated_text += content or ""
|
||||
elif usage := data.get("usage"):
|
||||
output.output_tokens = usage.get(
|
||||
"completion_tokens")
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
|
||||
output.generated_text = generated_text
|
||||
output.success = True
|
||||
output.latency = most_recent_timestamp - st
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
except Exception:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
def get_model(pretrained_model_name_or_path: str) -> str:
|
||||
if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true':
|
||||
from modelscope import snapshot_download
|
||||
@ -493,7 +599,14 @@ ASYNC_REQUEST_FUNCS = {
|
||||
"deepspeed-mii": async_request_deepspeed_mii,
|
||||
"openai": async_request_openai_completions,
|
||||
"openai-chat": async_request_openai_chat_completions,
|
||||
"openai-audio": async_request_openai_audio,
|
||||
"tensorrt-llm": async_request_trt_llm,
|
||||
"scalellm": async_request_openai_completions,
|
||||
"sglang": async_request_openai_completions,
|
||||
}
|
||||
|
||||
OPENAI_COMPATIBLE_BACKENDS = [
|
||||
k for k, v in ASYNC_REQUEST_FUNCS.items()
|
||||
if v in (async_request_openai_completions,
|
||||
async_request_openai_chat_completions)
|
||||
]
|
||||
|
||||
@ -64,6 +64,7 @@ class SampleRequest:
|
||||
|
||||
class BenchmarkDataset(ABC):
|
||||
DEFAULT_SEED = 0
|
||||
IS_MULTIMODAL = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -188,6 +189,9 @@ class BenchmarkDataset(ABC):
|
||||
"""
|
||||
if len(requests) < num_requests:
|
||||
random.seed(self.random_seed)
|
||||
logger.info("Current number of requests: %d", len(requests))
|
||||
logger.info("Oversampled requests to reach %d total samples.",
|
||||
num_requests)
|
||||
additional = random.choices(requests,
|
||||
k=num_requests - len(requests))
|
||||
requests.extend(additional)
|
||||
@ -288,7 +292,7 @@ def process_image(image: Any) -> Mapping[str, Any]:
|
||||
class RandomDataset(BenchmarkDataset):
|
||||
# Default values copied from benchmark_serving.py for the random dataset.
|
||||
DEFAULT_PREFIX_LEN = 0
|
||||
DEFAULT_RANGE_RATIO = 1.0
|
||||
DEFAULT_RANGE_RATIO = 0.0
|
||||
DEFAULT_INPUT_LEN = 1024
|
||||
DEFAULT_OUTPUT_LEN = 128
|
||||
|
||||
@ -308,19 +312,32 @@ class RandomDataset(BenchmarkDataset):
|
||||
output_len: int = DEFAULT_OUTPUT_LEN,
|
||||
**kwargs,
|
||||
) -> list[SampleRequest]:
|
||||
# Enforce range_ratio < 1
|
||||
assert range_ratio < 1.0, (
|
||||
"random_range_ratio must be < 1.0 to ensure a valid sampling range"
|
||||
)
|
||||
|
||||
vocab_size = tokenizer.vocab_size
|
||||
|
||||
prefix_token_ids = (np.random.randint(
|
||||
0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else [])
|
||||
|
||||
input_low = int(input_len * range_ratio)
|
||||
output_low = int(output_len * range_ratio)
|
||||
# New sampling logic: [X * (1 - b), X * (1 + b)]
|
||||
input_low = int(input_len * (1 - range_ratio))
|
||||
input_high = int(input_len * (1 + range_ratio))
|
||||
output_low = int(output_len * (1 - range_ratio))
|
||||
output_high = int(output_len * (1 + range_ratio))
|
||||
|
||||
# Add logging for debugging
|
||||
logger.info("Sampling input_len from [%s, %s]", input_low, input_high)
|
||||
logger.info("Sampling output_len from [%s, %s]", output_low,
|
||||
output_high)
|
||||
|
||||
input_lens = np.random.randint(input_low,
|
||||
input_len + 1,
|
||||
input_high + 1,
|
||||
size=num_requests)
|
||||
output_lens = np.random.randint(output_low,
|
||||
output_len + 1,
|
||||
output_high + 1,
|
||||
size=num_requests)
|
||||
offsets = np.random.randint(0, vocab_size, size=num_requests)
|
||||
|
||||
@ -388,6 +405,13 @@ class ShareGPTDataset(BenchmarkDataset):
|
||||
entry["conversations"][1]["value"],
|
||||
)
|
||||
|
||||
prompt = tokenizer.apply_chat_template([{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}],
|
||||
add_generation_prompt=True,
|
||||
tokenize=False)
|
||||
|
||||
lora_request, tokenizer = self.get_random_lora_request(
|
||||
tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path)
|
||||
prompt_ids = tokenizer(prompt).input_ids
|
||||
@ -472,11 +496,11 @@ class SonnetDataset(BenchmarkDataset):
|
||||
|
||||
# Determine how many poem lines to use.
|
||||
num_input_lines = round((input_len - base_offset) / avg_len)
|
||||
num_prefix_lines = round((prefix_len - base_offset) / avg_len)
|
||||
num_prefix_lines = max(round((prefix_len - base_offset) / avg_len), 0)
|
||||
prefix_lines = self.data[:num_prefix_lines]
|
||||
|
||||
samples = []
|
||||
for _ in range(num_requests):
|
||||
while len(samples) < num_requests:
|
||||
extra_lines = random.choices(self.data,
|
||||
k=num_input_lines - num_prefix_lines)
|
||||
prompt = f"{base_prompt}{''.join(prefix_lines + extra_lines)}"
|
||||
@ -484,13 +508,14 @@ class SonnetDataset(BenchmarkDataset):
|
||||
prompt_formatted = tokenizer.apply_chat_template(
|
||||
msg, add_generation_prompt=True, tokenize=False)
|
||||
prompt_len = len(tokenizer(prompt_formatted).input_ids)
|
||||
samples.append(
|
||||
SampleRequest(
|
||||
prompt=prompt_formatted
|
||||
if return_prompt_formatted else prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
))
|
||||
if prompt_len <= input_len:
|
||||
samples.append(
|
||||
SampleRequest(
|
||||
prompt=prompt_formatted
|
||||
if return_prompt_formatted else prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
))
|
||||
return samples
|
||||
|
||||
|
||||
@ -607,6 +632,7 @@ class ConversationDataset(HuggingFaceDataset):
|
||||
SUPPORTED_DATASET_PATHS = {
|
||||
'lmms-lab/LLaVA-OneVision-Data', 'Aeala/ShareGPT_Vicuna_unfiltered'
|
||||
}
|
||||
IS_MULTIMODAL = True
|
||||
|
||||
def sample(self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
@ -671,6 +697,7 @@ class VisionArenaDataset(HuggingFaceDataset):
|
||||
"lmarena-ai/vision-arena-bench-v0.1":
|
||||
lambda x: x["turns"][0][0]["content"]
|
||||
}
|
||||
IS_MULTIMODAL = True
|
||||
|
||||
def sample(
|
||||
self,
|
||||
@ -743,6 +770,14 @@ class InstructCoderDataset(HuggingFaceDataset):
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
prompt = f"{item['instruction']}:\n{item['input']}"
|
||||
|
||||
prompt = tokenizer.apply_chat_template([{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}],
|
||||
add_generation_prompt=True,
|
||||
tokenize=False)
|
||||
|
||||
prompt_len = len(tokenizer(prompt).input_ids)
|
||||
sampled_requests.append(
|
||||
SampleRequest(
|
||||
@ -776,11 +811,18 @@ class AIMODataset(HuggingFaceDataset):
|
||||
sampled_requests = []
|
||||
dynamic_output = output_len is None
|
||||
|
||||
for item in self.data:
|
||||
for i, item in enumerate(self.data):
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
prompt, completion = item['problem'], item["solution"]
|
||||
|
||||
prompt = tokenizer.apply_chat_template([{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}],
|
||||
add_generation_prompt=True,
|
||||
tokenize=False)
|
||||
|
||||
prompt_ids = tokenizer(prompt).input_ids
|
||||
completion_ids = tokenizer(completion).input_ids
|
||||
prompt_len = len(prompt_ids)
|
||||
@ -801,3 +843,180 @@ class AIMODataset(HuggingFaceDataset):
|
||||
))
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
return sampled_requests
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# ASR Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ASRDataset(HuggingFaceDataset):
|
||||
"""
|
||||
Dataset class for processing a ASR dataset for transcription.
|
||||
Tested on the following set:
|
||||
|
||||
+----------------+----------------------------------------+--------------------------+-----------------------------+
|
||||
| Dataset | Domain | Speaking Style | hf-subset |
|
||||
+----------------+----------------------------------------+--------------------------+-----------------------------+
|
||||
| TED-LIUM | TED talks | Oratory | release1, release2, release3|
|
||||
| | | | release3-speaker-adaptation |
|
||||
| VoxPopuli | European Parliament | Oratory | en, de, it, fr, ... |
|
||||
| LibriSpeech | Audiobook | Narrated | "LIUM/tedlium" |
|
||||
| GigaSpeech | Audiobook, podcast, YouTube | Narrated, spontaneous | xs, s, m, l, xl, dev, test |
|
||||
| SPGISpeech | Financial meetings | Oratory, spontaneous | S, M, L, dev, test |
|
||||
| AMI | Meetings | Spontaneous | ihm, sdm |
|
||||
+----------------+----------------------------------------+--------------------------+-----------------------------+
|
||||
|
||||
""" # noqa: E501
|
||||
SUPPORTED_DATASET_PATHS = {
|
||||
"openslr/librispeech_asr", "facebook/voxpopuli", "LIUM/tedlium",
|
||||
"edinburghcstr/ami", "speechcolab/gigaspeech", "kensho/spgispeech"
|
||||
}
|
||||
|
||||
DEFAULT_OUTPUT_LEN = 128
|
||||
IS_MULTIMODAL = True
|
||||
|
||||
# TODO Whisper-specific. Abstract interface when more models are supported.
|
||||
TRANSCRIPTION_PREAMBLE = "<|startoftranscript|><|en|><|transcribe|>"\
|
||||
"<|notimestamps|>"
|
||||
skip_long_audios: bool = True
|
||||
|
||||
def sample(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> list:
|
||||
import librosa
|
||||
output_len = (output_len
|
||||
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
|
||||
prompt = ASRDataset.TRANSCRIPTION_PREAMBLE
|
||||
prompt_len = len(tokenizer(prompt).input_ids)
|
||||
sampled_requests = []
|
||||
skipped = 0
|
||||
for item in self.data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
audio = item["audio"]
|
||||
y, sr = audio["array"], audio["sampling_rate"]
|
||||
duration_s = librosa.get_duration(y=y, sr=sr)
|
||||
# Whisper max supported duration
|
||||
if self.skip_long_audios and duration_s > 30:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
mm_content = {"audio": (y, sr)}
|
||||
sampled_requests.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
multi_modal_data=mm_content,
|
||||
))
|
||||
if skipped:
|
||||
logger.warning("%d samples discarded from dataset due to" \
|
||||
" their length being greater than" \
|
||||
" what Whisper supports.", skipped)
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
return sampled_requests
|
||||
|
||||
|
||||
class MTBenchDataset(HuggingFaceDataset):
|
||||
"""
|
||||
MT-Bench Dataset.
|
||||
https://huggingface.co/datasets/philschmid/mt-bench
|
||||
|
||||
We create a single turn dataset for MT-Bench.
|
||||
This is similar to Spec decoding benchmark setup in vLLM
|
||||
https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18
|
||||
""" # noqa: E501
|
||||
|
||||
DEFAULT_OUTPUT_LEN = 256 # avg len used in SD bench in vLLM
|
||||
SUPPORTED_DATASET_PATHS = {
|
||||
"philschmid/mt-bench",
|
||||
}
|
||||
|
||||
def sample(self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
**kwargs) -> list:
|
||||
output_len = (output_len
|
||||
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
|
||||
sampled_requests = []
|
||||
|
||||
for item in self.data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
prompt = item['turns'][0]
|
||||
|
||||
# apply template
|
||||
prompt = tokenizer.apply_chat_template([{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}],
|
||||
add_generation_prompt=True,
|
||||
tokenize=False)
|
||||
|
||||
prompt_len = len(tokenizer(prompt).input_ids)
|
||||
sampled_requests.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
))
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
return sampled_requests
|
||||
|
||||
|
||||
class CNNDailyMailDataset(HuggingFaceDataset):
|
||||
"""
|
||||
MT-Bench Dataset.
|
||||
https://huggingface.co/datasets/philschmid/mt-bench
|
||||
|
||||
We create a single turn dataset for MT-Bench.
|
||||
This is similar to Spec decoding benchmark setup in vLLM
|
||||
https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18
|
||||
""" # noqa: E501
|
||||
|
||||
DEFAULT_OUTPUT_LEN = 256 # avg len used in SD bench in vLLM
|
||||
SUPPORTED_DATASET_PATHS = {
|
||||
"abisee/cnn_dailymail",
|
||||
}
|
||||
|
||||
def sample(self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
**kwargs) -> list:
|
||||
output_len = (output_len
|
||||
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
|
||||
sampled_requests = []
|
||||
|
||||
for item in self.data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
instruction = "Could you summarize the following article, " \
|
||||
"please reuse text from the article if possible: "
|
||||
prompt = instruction + item['article']
|
||||
|
||||
# apply template
|
||||
prompt = tokenizer.apply_chat_template([{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}],
|
||||
add_generation_prompt=True,
|
||||
tokenize=False)
|
||||
|
||||
prompt_len = len(tokenizer(prompt).input_ids)
|
||||
sampled_requests.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
))
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
return sampled_requests
|
||||
|
||||
@ -63,14 +63,16 @@ class Request:
|
||||
output_len: int
|
||||
|
||||
|
||||
def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> str:
|
||||
def sample_tokens(tokenizer: PreTrainedTokenizerBase,
|
||||
length: int) -> list[int]:
|
||||
vocab = tokenizer.get_vocab()
|
||||
all_special_ids = set(tokenizer.all_special_ids)
|
||||
|
||||
# Remove the special tokens.
|
||||
vocab = {
|
||||
k: v
|
||||
for k, v in vocab.items() if k not in tokenizer.all_special_ids
|
||||
}
|
||||
return random.choices(list(vocab.values()), k=length)
|
||||
return random.choices(
|
||||
[v for k, v in vocab.items() if k not in all_special_ids],
|
||||
k=length,
|
||||
)
|
||||
|
||||
|
||||
def sample_requests_from_dataset(
|
||||
|
||||
@ -34,7 +34,8 @@ from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
|
||||
from backend_request_func import (ASYNC_REQUEST_FUNCS,
|
||||
OPENAI_COMPATIBLE_BACKENDS, RequestFuncInput,
|
||||
RequestFuncOutput)
|
||||
from tqdm.asyncio import tqdm
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
@ -49,7 +50,7 @@ try:
|
||||
except ImportError:
|
||||
from argparse import ArgumentParser as FlexibleArgumentParser
|
||||
|
||||
from benchmark_dataset import (AIMODataset, BurstGPTDataset,
|
||||
from benchmark_dataset import (AIMODataset, ASRDataset, BurstGPTDataset,
|
||||
ConversationDataset, HuggingFaceDataset,
|
||||
InstructCoderDataset, RandomDataset,
|
||||
SampleRequest, ShareGPTDataset, SonnetDataset,
|
||||
@ -155,7 +156,7 @@ def calculate_metrics(
|
||||
if outputs[i].success:
|
||||
output_len = outputs[i].output_tokens
|
||||
|
||||
if output_len is None:
|
||||
if not output_len:
|
||||
# We use the tokenizer to count the number of output tokens
|
||||
# for some serving backends instead of looking at
|
||||
# len(outputs[i].itl) since multiple output tokens may be
|
||||
@ -260,6 +261,7 @@ async def benchmark(
|
||||
goodput_config_dict: dict[str, float],
|
||||
max_concurrency: Optional[int],
|
||||
lora_modules: Optional[Iterable[str]],
|
||||
extra_body: Optional[dict],
|
||||
):
|
||||
if backend in ASYNC_REQUEST_FUNCS:
|
||||
request_func = ASYNC_REQUEST_FUNCS[backend]
|
||||
@ -272,10 +274,6 @@ async def benchmark(
|
||||
input_requests[0].expected_output_len, \
|
||||
input_requests[0].multi_modal_data
|
||||
|
||||
if backend != "openai-chat" and test_mm_content is not None:
|
||||
# multi-modal benchmark is only available on OpenAI Chat backend.
|
||||
raise ValueError(
|
||||
"Multi-modal content is only supported on 'openai-chat' backend.")
|
||||
assert test_mm_content is None or isinstance(test_mm_content, dict)
|
||||
test_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
@ -287,6 +285,7 @@ async def benchmark(
|
||||
logprobs=logprobs,
|
||||
multi_modal_content=test_mm_content,
|
||||
ignore_eos=ignore_eos,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
|
||||
test_output = await request_func(request_func_input=test_input)
|
||||
@ -313,7 +312,8 @@ async def benchmark(
|
||||
output_len=test_output_len,
|
||||
logprobs=logprobs,
|
||||
multi_modal_content=test_mm_content,
|
||||
ignore_eos=ignore_eos)
|
||||
ignore_eos=ignore_eos,
|
||||
extra_body=extra_body)
|
||||
profile_output = await request_func(request_func_input=profile_input)
|
||||
if profile_output.success:
|
||||
print("Profiler started")
|
||||
@ -363,7 +363,8 @@ async def benchmark(
|
||||
output_len=output_len,
|
||||
logprobs=logprobs,
|
||||
multi_modal_content=mm_content,
|
||||
ignore_eos=ignore_eos)
|
||||
ignore_eos=ignore_eos,
|
||||
extra_body=extra_body)
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
limited_request_func(request_func_input=request_func_input,
|
||||
@ -599,6 +600,9 @@ def main(args: argparse.Namespace):
|
||||
elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_class = AIMODataset
|
||||
args.hf_split = "train"
|
||||
elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_class = ASRDataset
|
||||
args.hf_split = "train"
|
||||
else:
|
||||
supported_datasets = set([
|
||||
dataset_name for cls in HuggingFaceDataset.__subclasses__()
|
||||
@ -610,6 +614,13 @@ def main(args: argparse.Namespace):
|
||||
f" from one of following: {supported_datasets}. "
|
||||
"Please consider contributing if you would "
|
||||
"like to add support for additional dataset formats.")
|
||||
|
||||
if (dataset_class.IS_MULTIMODAL and backend not in \
|
||||
["openai-chat", "openai-audio"]):
|
||||
# multi-modal benchmark is only available on OpenAI Chat backend.
|
||||
raise ValueError(
|
||||
"Multi-modal content is only supported on 'openai-chat' and " \
|
||||
"'openai-audio' backend.")
|
||||
input_requests = dataset_class(
|
||||
dataset_path=args.dataset_path,
|
||||
dataset_subset=args.hf_subset,
|
||||
@ -652,6 +663,26 @@ def main(args: argparse.Namespace):
|
||||
raise ValueError(f"Unknown dataset: {args.dataset_name}") from err
|
||||
goodput_config_dict = check_goodput_args(args)
|
||||
|
||||
# Collect the sampling parameters.
|
||||
sampling_params = {
|
||||
k: v
|
||||
for k, v in {
|
||||
"top_p": args.top_p,
|
||||
"top_k": args.top_k,
|
||||
"min_p": args.min_p,
|
||||
"temperature": args.temperature
|
||||
}.items() if v is not None
|
||||
}
|
||||
|
||||
# Sampling parameters are only supported by openai-compatible backend.
|
||||
if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS:
|
||||
raise ValueError(
|
||||
"Sampling parameters are only supported by openai-compatible "
|
||||
"backends.")
|
||||
|
||||
if "temperature" not in sampling_params:
|
||||
sampling_params["temperature"] = 0.0 # Default to greedy decoding.
|
||||
|
||||
# Avoid GC processing "static" data - reduce pause times.
|
||||
gc.collect()
|
||||
gc.freeze()
|
||||
@ -678,10 +709,11 @@ def main(args: argparse.Namespace):
|
||||
goodput_config_dict=goodput_config_dict,
|
||||
max_concurrency=args.max_concurrency,
|
||||
lora_modules=args.lora_modules,
|
||||
extra_body=sampling_params,
|
||||
))
|
||||
|
||||
# Save config and results to json
|
||||
if args.save_result:
|
||||
if args.save_result or args.append_result:
|
||||
result_json: dict[str, Any] = {}
|
||||
|
||||
# Setup
|
||||
@ -702,6 +734,14 @@ def main(args: argparse.Namespace):
|
||||
raise ValueError(
|
||||
"Invalid metadata format. Please use KEY=VALUE format."
|
||||
)
|
||||
# Traffic
|
||||
result_json["request_rate"] = (args.request_rate if args.request_rate
|
||||
< float("inf") else "inf")
|
||||
result_json["burstiness"] = args.burstiness
|
||||
result_json["max_concurrency"] = args.max_concurrency
|
||||
|
||||
# Merge with benchmark result
|
||||
result_json = {**result_json, **benchmark_result}
|
||||
|
||||
if not args.save_detailed:
|
||||
# Remove fields with too many data points
|
||||
@ -712,15 +752,6 @@ def main(args: argparse.Namespace):
|
||||
if field in result_json:
|
||||
del result_json[field]
|
||||
|
||||
# Traffic
|
||||
result_json["request_rate"] = (args.request_rate if args.request_rate
|
||||
< float("inf") else "inf")
|
||||
result_json["burstiness"] = args.burstiness
|
||||
result_json["max_concurrency"] = args.max_concurrency
|
||||
|
||||
# Merge with benchmark result
|
||||
result_json = {**result_json, **benchmark_result}
|
||||
|
||||
# Save to file
|
||||
base_model_id = model_id.split("/")[-1]
|
||||
max_concurrency_str = (f"-concurrency{args.max_concurrency}"
|
||||
@ -730,7 +761,12 @@ def main(args: argparse.Namespace):
|
||||
file_name = args.result_filename
|
||||
if args.result_dir:
|
||||
file_name = os.path.join(args.result_dir, file_name)
|
||||
with open(file_name, "w", encoding='utf-8') as outfile:
|
||||
with open(file_name,
|
||||
mode="a+" if args.append_result else "w",
|
||||
encoding='utf-8') as outfile:
|
||||
# Append a newline.
|
||||
if args.append_result and outfile.tell() != 0:
|
||||
outfile.write("\n")
|
||||
json.dump(result_json, outfile)
|
||||
save_to_pytorch_benchmark_format(args, result_json, file_name)
|
||||
|
||||
@ -862,6 +898,11 @@ if __name__ == "__main__":
|
||||
help="When saving the results, whether to include per request "
|
||||
"information such as response, error, ttfs, tpots, etc.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--append-result",
|
||||
action="store_true",
|
||||
help="Append the benchmark result to the existing json file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata",
|
||||
metavar="KEY=VALUE",
|
||||
@ -895,7 +936,7 @@ if __name__ == "__main__":
|
||||
"--percentile-metrics",
|
||||
type=str,
|
||||
default="ttft,tpot,itl",
|
||||
help="Comma-seperated list of selected metrics to report percentils. "
|
||||
help="Comma-separated list of selected metrics to report percentils. "
|
||||
"This argument specifies the metrics to report percentiles. "
|
||||
"Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". "
|
||||
"Default value is \"ttft,tpot,itl\".")
|
||||
@ -903,7 +944,7 @@ if __name__ == "__main__":
|
||||
"--metric-percentiles",
|
||||
type=str,
|
||||
default="99",
|
||||
help="Comma-seperated list of percentiles for selected metrics. "
|
||||
help="Comma-separated list of percentiles for selected metrics. "
|
||||
"To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". "
|
||||
"Default value is \"99\". "
|
||||
"Use \"--percentile-metrics\" to select metrics.",
|
||||
@ -970,18 +1011,23 @@ if __name__ == "__main__":
|
||||
random_group.add_argument(
|
||||
"--random-range-ratio",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Range of sampled ratio of input/output length, "
|
||||
"used only for random sampling.",
|
||||
default=0.0,
|
||||
help="Range ratio for sampling input/output length, "
|
||||
"used only for random sampling. Must be in the range [0, 1) to define "
|
||||
"a symmetric sampling range"
|
||||
"[length * (1 - range_ratio), length * (1 + range_ratio)].",
|
||||
)
|
||||
random_group.add_argument(
|
||||
"--random-prefix-len",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of fixed prefix tokens before random "
|
||||
" context. The length range of context in a random "
|
||||
" request is [random-prefix-len, "
|
||||
" random-prefix-len + random-prefix-len * random-range-ratio).")
|
||||
help=("Number of fixed prefix tokens before the random context "
|
||||
"in a request. "
|
||||
"The total input length is the sum of `random-prefix-len` and "
|
||||
"a random "
|
||||
"context length sampled from [input_len * (1 - range_ratio), "
|
||||
"input_len * (1 + range_ratio)]."),
|
||||
)
|
||||
|
||||
hf_group = parser.add_argument_group("hf dataset options")
|
||||
hf_group.add_argument("--hf-subset",
|
||||
@ -1000,6 +1046,33 @@ if __name__ == "__main__":
|
||||
"from the sampled HF dataset.",
|
||||
)
|
||||
|
||||
sampling_group = parser.add_argument_group("sampling parameters")
|
||||
sampling_group.add_argument(
|
||||
"--top-p",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Top-p sampling parameter. Only has effect on openai-compatible "
|
||||
"backends.")
|
||||
sampling_group.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Top-k sampling parameter. Only has effect on openai-compatible "
|
||||
"backends.")
|
||||
sampling_group.add_argument(
|
||||
"--min-p",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Min-p sampling parameter. Only has effect on openai-compatible "
|
||||
"backends.")
|
||||
sampling_group.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Temperature sampling parameter. Only has effect on "
|
||||
"openai-compatible backends. If not specified, default to greedy "
|
||||
"decoding (i.e. temperature==0.0).")
|
||||
|
||||
parser.add_argument(
|
||||
'--tokenizer-mode',
|
||||
type=str,
|
||||
|
||||
@ -11,7 +11,7 @@ On the client side, run:
|
||||
--model <your_model> \
|
||||
--dataset json \
|
||||
--structured-output-ratio 1.0 \
|
||||
--structured-output-backend xgrammar \
|
||||
--structured-output-backend auto \
|
||||
--request-rate 10 \
|
||||
--num-prompts 1000
|
||||
|
||||
@ -51,7 +51,7 @@ try:
|
||||
except ImportError:
|
||||
from argparse import ArgumentParser as FlexibleArgumentParser
|
||||
|
||||
from vllm.v1.structured_output.utils import (
|
||||
from vllm.v1.structured_output.backend_xgrammar import (
|
||||
has_xgrammar_unsupported_json_features)
|
||||
|
||||
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
|
||||
@ -130,10 +130,11 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
||||
"description":
|
||||
"An unique optional field to avoid cached schemas"
|
||||
}
|
||||
else:
|
||||
json_schemas = [schema] * args.num_prompts
|
||||
|
||||
def gen_prompt(index: int):
|
||||
schema = json_schemas[index % len(json_schemas)]
|
||||
return f"Generate an example of a user profile given the following schema: {json.dumps(schema)}" # noqa: E501
|
||||
return f"Generate an example of a user profile given the following schema: {json.dumps(get_schema(index))}" # noqa: E501
|
||||
|
||||
def get_schema(index: int):
|
||||
return json_schemas[index % len(json_schemas)]
|
||||
@ -149,17 +150,17 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
||||
|
||||
elif args.dataset == "grammar":
|
||||
schema = """
|
||||
?start: select_statement
|
||||
root ::= select_statement
|
||||
|
||||
?select_statement: "SELECT " column_list " FROM " table_name
|
||||
select_statement ::= "SELECT " column " from " table " where " condition
|
||||
|
||||
?column_list: column_name ("," column_name)*
|
||||
column ::= "col_1 " | "col_2 "
|
||||
|
||||
?table_name: identifier
|
||||
table ::= "table_1 " | "table_2 "
|
||||
|
||||
?column_name: identifier
|
||||
condition ::= column "= " number
|
||||
|
||||
?identifier: /[a-zA-Z_][a-zA-Z0-9_]*/
|
||||
number ::= "1 " | "2 "
|
||||
"""
|
||||
prompt = "Generate an SQL query to show the 'username' \
|
||||
and 'email' from the 'users' table."
|
||||
@ -963,7 +964,7 @@ if __name__ == "__main__":
|
||||
"--percentile-metrics",
|
||||
type=str,
|
||||
default="ttft,tpot,itl",
|
||||
help="Comma-seperated list of selected metrics to report percentils. "
|
||||
help="Comma-separated list of selected metrics to report percentils. "
|
||||
"This argument specifies the metrics to report percentiles. "
|
||||
"Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". "
|
||||
"Default value is \"ttft,tpot,itl\".")
|
||||
@ -971,7 +972,7 @@ if __name__ == "__main__":
|
||||
"--metric-percentiles",
|
||||
type=str,
|
||||
default="99",
|
||||
help="Comma-seperated list of percentiles for selected metrics. "
|
||||
help="Comma-separated list of percentiles for selected metrics. "
|
||||
"To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". "
|
||||
"Default value is \"99\". "
|
||||
"Use \"--percentile-metrics\" to select metrics.",
|
||||
@ -996,12 +997,14 @@ if __name__ == "__main__":
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Ratio of Structured Outputs requests")
|
||||
parser.add_argument(
|
||||
"--structured-output-backend",
|
||||
type=str,
|
||||
choices=["outlines", "lm-format-enforcer", "xgrammar", "guidance"],
|
||||
default="xgrammar",
|
||||
help="Backend to use for structured outputs")
|
||||
parser.add_argument("--structured-output-backend",
|
||||
type=str,
|
||||
choices=[
|
||||
"outlines", "lm-format-enforcer", "xgrammar",
|
||||
"guidance", "auto"
|
||||
],
|
||||
default="auto",
|
||||
help="Backend to use for structured outputs")
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
@ -12,7 +12,8 @@ from typing import Any, Optional, Union
|
||||
import torch
|
||||
import uvloop
|
||||
from benchmark_dataset import (AIMODataset, BurstGPTDataset,
|
||||
ConversationDataset, InstructCoderDataset,
|
||||
CNNDailyMailDataset, ConversationDataset,
|
||||
InstructCoderDataset, MTBenchDataset,
|
||||
RandomDataset, SampleRequest, ShareGPTDataset,
|
||||
SonnetDataset, VisionArenaDataset)
|
||||
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||
@ -57,9 +58,9 @@ def run_vllm(
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
temperature=0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
ignore_eos=False,
|
||||
max_tokens=request.expected_output_len,
|
||||
detokenize=not disable_detokenize,
|
||||
))
|
||||
@ -123,9 +124,9 @@ def run_vllm_chat(
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
temperature=0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
ignore_eos=False,
|
||||
max_tokens=request.expected_output_len,
|
||||
detokenize=not disable_detokenize,
|
||||
))
|
||||
@ -167,9 +168,9 @@ async def run_vllm_async(
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
temperature=0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
ignore_eos=False,
|
||||
max_tokens=request.expected_output_len,
|
||||
detokenize=not disable_detokenize,
|
||||
))
|
||||
@ -213,14 +214,17 @@ def run_hf(
|
||||
max_prompt_len = 0
|
||||
max_output_len = 0
|
||||
for i in range(len(requests)):
|
||||
prompt, prompt_len, output_len = requests[i]
|
||||
prompt = requests[i].prompt
|
||||
prompt_len = requests[i].prompt_len
|
||||
output_len = requests[i].expected_output_len
|
||||
# Add the prompt to the batch.
|
||||
batch.append(prompt)
|
||||
max_prompt_len = max(max_prompt_len, prompt_len)
|
||||
max_output_len = max(max_output_len, output_len)
|
||||
if len(batch) < max_batch_size and i != len(requests) - 1:
|
||||
# Check if we can add more requests to the batch.
|
||||
_, next_prompt_len, next_output_len = requests[i + 1]
|
||||
next_prompt_len = requests[i + 1].prompt_len
|
||||
next_output_len = requests[i + 1].expected_output_len
|
||||
if (max(max_prompt_len, next_prompt_len) +
|
||||
max(max_output_len, next_output_len)) <= 2048:
|
||||
# We can add more requests to the batch.
|
||||
@ -336,6 +340,14 @@ def get_requests(args, tokenizer):
|
||||
dataset_cls = AIMODataset
|
||||
common_kwargs['dataset_subset'] = None
|
||||
common_kwargs['dataset_split'] = "train"
|
||||
elif args.dataset_path in MTBenchDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = MTBenchDataset
|
||||
common_kwargs['dataset_subset'] = None
|
||||
common_kwargs['dataset_split'] = "train"
|
||||
elif args.dataset_path in CNNDailyMailDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = CNNDailyMailDataset
|
||||
common_kwargs['dataset_subset'] = '3.0.0'
|
||||
common_kwargs['dataset_split'] = "train"
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset name: {args.dataset_name}")
|
||||
# Remove None values
|
||||
@ -474,8 +486,11 @@ def validate_args(args):
|
||||
VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys()
|
||||
| ConversationDataset.SUPPORTED_DATASET_PATHS):
|
||||
assert args.backend == "vllm-chat", f"{args.dataset_path} needs to use vllm-chat as the backend." #noqa: E501
|
||||
elif args.dataset_path in (InstructCoderDataset.SUPPORTED_DATASET_PATHS
|
||||
| AIMODataset.SUPPORTED_DATASET_PATHS):
|
||||
elif args.dataset_path in (
|
||||
InstructCoderDataset.SUPPORTED_DATASET_PATHS
|
||||
| AIMODataset.SUPPORTED_DATASET_PATHS
|
||||
| MTBenchDataset.SUPPORTED_DATASET_PATHS
|
||||
| CNNDailyMailDataset.SUPPORTED_DATASET_PATHS):
|
||||
assert args.backend == "vllm", f"{args.dataset_path} needs to use vllm as the backend." #noqa: E501
|
||||
else:
|
||||
raise ValueError(
|
||||
@ -520,6 +535,13 @@ def validate_args(args):
|
||||
raise ValueError(
|
||||
"Tokenizer must be the same as the model for MII backend.")
|
||||
|
||||
# --data-parallel is not supported currently.
|
||||
# https://github.com/vllm-project/vllm/issues/16222
|
||||
if args.data_parallel_size > 1:
|
||||
raise ValueError(
|
||||
"Data parallel is not supported in offline benchmark, \
|
||||
please use benchmark serving instead")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(description="Benchmark the throughput.")
|
||||
@ -591,18 +613,30 @@ if __name__ == "__main__":
|
||||
default=None,
|
||||
help="Path to the lora adapters to use. This can be an absolute path, "
|
||||
"a relative path, or a Hugging Face model identifier.")
|
||||
parser.add_argument("--prefix-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of prefix tokens per request."
|
||||
"This is for the RandomDataset and SonnetDataset")
|
||||
parser.add_argument(
|
||||
"--prefix-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help=f"Number of prefix tokens to be used in RandomDataset "
|
||||
"and SonnetDataset. For RandomDataset, the total input "
|
||||
"length is the sum of prefix-len (default: "
|
||||
f"{RandomDataset.DEFAULT_PREFIX_LEN}) and a random context length "
|
||||
"sampled from [input_len * (1 - range_ratio), "
|
||||
"input_len * (1 + range_ratio)]. For SonnetDataset, "
|
||||
f"prefix_len (default: {SonnetDataset.DEFAULT_PREFIX_LEN}) "
|
||||
"controls how much of the input is fixed lines versus "
|
||||
"random lines, but the total input length remains approximately "
|
||||
"input_len tokens.")
|
||||
# random dataset
|
||||
parser.add_argument(
|
||||
"--random-range-ratio",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Range of sampled ratio of input/output length, "
|
||||
"used only for RandomDataSet.",
|
||||
help=f"Range ratio (default : {RandomDataset.DEFAULT_RANGE_RATIO}) "
|
||||
"for sampling input/output length, "
|
||||
"used only for RandomDataset. Must be in the range [0, 1) to "
|
||||
"define a symmetric sampling range "
|
||||
"[length * (1 - range_ratio), length * (1 + range_ratio)].",
|
||||
)
|
||||
|
||||
# hf dtaset
|
||||
|
||||
236
benchmarks/kernels/benchmark_bitblas.py
Normal file
236
benchmarks/kernels/benchmark_bitblas.py
Normal file
@ -0,0 +1,236 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
|
||||
MINIMUM_BITBLAS_VERSION)
|
||||
|
||||
try:
|
||||
import bitblas
|
||||
if bitblas.__version__ < MINIMUM_BITBLAS_VERSION:
|
||||
raise ImportError("bitblas version is wrong. Please "
|
||||
f"install bitblas>={MINIMUM_BITBLAS_VERSION}")
|
||||
except ImportError as e:
|
||||
bitblas_import_exception = e
|
||||
raise ValueError("Trying to use the bitblas backend, but could not import"
|
||||
f"with the following error: {bitblas_import_exception}. "
|
||||
"Please install bitblas through the following command: "
|
||||
f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`"
|
||||
) from bitblas_import_exception
|
||||
|
||||
from bitblas import Matmul, MatmulConfig, auto_detect_nvidia_target
|
||||
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark BitBLAS int4 on a specific target.")
|
||||
|
||||
# Add arguments to the parser
|
||||
parser.add_argument(
|
||||
"--target",
|
||||
type=str,
|
||||
default=auto_detect_nvidia_target(),
|
||||
help="Specify the target device for benchmarking.",
|
||||
)
|
||||
parser.add_argument("--group_size",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Group size for grouped quantization.")
|
||||
parser.add_argument(
|
||||
"--A_dtype",
|
||||
type=str,
|
||||
default="float16",
|
||||
choices=["float16", "float32", "float64", "int32", "int8"],
|
||||
help="Data type of activation A.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--W_dtype",
|
||||
type=str,
|
||||
default="int4",
|
||||
choices=[
|
||||
"float16",
|
||||
"float32",
|
||||
"float64",
|
||||
"int32",
|
||||
"int8",
|
||||
"int4",
|
||||
"int2",
|
||||
"int1",
|
||||
"nf4",
|
||||
"fp4_e2m1",
|
||||
],
|
||||
help="Data type of weight W.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--accum_dtype",
|
||||
type=str,
|
||||
default="float16",
|
||||
choices=["float16", "int32"],
|
||||
help="Data type for accumulation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--out_dtype",
|
||||
type=str,
|
||||
default="float16",
|
||||
choices=["float16", "float32", "int32", "int8"],
|
||||
help="Data type for output.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--layout",
|
||||
type=str,
|
||||
default="nt",
|
||||
choices=["nt", "nn"],
|
||||
help="Matrix layout, 'nt' for non-transpose A and transpose W.",
|
||||
)
|
||||
parser.add_argument("--with_bias",
|
||||
action="store_true",
|
||||
help="Include bias in the benchmark.")
|
||||
parser.add_argument(
|
||||
"--with_scaling",
|
||||
action="store_true",
|
||||
help="Include scaling factor in the quantization.",
|
||||
)
|
||||
parser.add_argument("--with_zeros",
|
||||
action="store_true",
|
||||
help="Include zeros in the quantization.")
|
||||
parser.add_argument(
|
||||
"--zeros_mode",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["original", "rescale", "quantized"],
|
||||
help="Specify the mode for calculating zeros.",
|
||||
)
|
||||
|
||||
# Parse the arguments
|
||||
args = parser.parse_args()
|
||||
|
||||
# Assign arguments to variables
|
||||
target = args.target
|
||||
A_dtype = args.A_dtype
|
||||
W_dtype = args.W_dtype
|
||||
accum_dtype = args.accum_dtype
|
||||
out_dtype = args.out_dtype
|
||||
layout = args.layout
|
||||
with_bias = args.with_bias
|
||||
group_size = args.group_size
|
||||
with_scaling = args.with_scaling
|
||||
with_zeros = args.with_zeros
|
||||
zeros_mode = args.zeros_mode
|
||||
|
||||
# Define a list of shared arguments that repeat in every config
|
||||
shared_args = [
|
||||
A_dtype,
|
||||
W_dtype,
|
||||
out_dtype,
|
||||
accum_dtype,
|
||||
layout,
|
||||
with_bias,
|
||||
group_size,
|
||||
with_scaling,
|
||||
with_zeros,
|
||||
zeros_mode,
|
||||
]
|
||||
|
||||
# Define just the (M, K, N) shapes in a more compact list
|
||||
shapes = [
|
||||
# square test
|
||||
(1, 16384, 16384),
|
||||
# BLOOM-176B
|
||||
(1, 43008, 14336),
|
||||
(1, 14336, 14336),
|
||||
(1, 57344, 14336),
|
||||
(1, 14336, 57344),
|
||||
# OPT-65B
|
||||
(1, 9216, 9216),
|
||||
(1, 36864, 9216),
|
||||
(1, 9216, 36864),
|
||||
(1, 22016, 8192),
|
||||
# LLAMA-70B/65B
|
||||
(1, 8192, 22016),
|
||||
(1, 8192, 8192),
|
||||
(1, 28672, 8192),
|
||||
(1, 8192, 28672),
|
||||
# square test
|
||||
(16384, 16384, 16384),
|
||||
# BLOOM-176B
|
||||
(8192, 43008, 14336),
|
||||
(8192, 14336, 14336),
|
||||
(8192, 57344, 14336),
|
||||
(8192, 14336, 57344),
|
||||
# OPT-65B
|
||||
(8192, 9216, 9216),
|
||||
(8192, 36864, 9216),
|
||||
(8192, 9216, 36864),
|
||||
(8192, 22016, 8192),
|
||||
# LLAMA-70B/65B
|
||||
(8192, 8192, 22016),
|
||||
(8192, 8192, 8192),
|
||||
(8192, 28672, 8192),
|
||||
(8192, 8192, 28672),
|
||||
]
|
||||
|
||||
# Build test shapes with all the shared arguments
|
||||
test_shapes = [(MatmulConfig, Matmul, (*shape, *shared_args))
|
||||
for shape in shapes]
|
||||
|
||||
benchmark_sets = []
|
||||
benchmark_sets.extend(test_shapes)
|
||||
|
||||
benchmark_results = {}
|
||||
for config_class, operator, input_args in benchmark_sets:
|
||||
config = config_class(*input_args)
|
||||
matmul = operator(config, target=target, enable_tuning=True)
|
||||
kernel_latency = matmul.profile_latency()
|
||||
|
||||
print("Time cost is: {:.3f} ms".format(kernel_latency))
|
||||
|
||||
profile_config = {
|
||||
f"{operator.__name__}-{'-'.join([str(i) for i in input_args])}": {
|
||||
"BitBLAS_top20_latency": kernel_latency,
|
||||
}
|
||||
}
|
||||
|
||||
benchmark_results.update(profile_config)
|
||||
|
||||
# Define headers for the table
|
||||
headers = [
|
||||
"PrimFunc",
|
||||
"Input Arguments",
|
||||
"BitBLAS Top20 Latency",
|
||||
]
|
||||
|
||||
# Calculate column widths for pretty printing
|
||||
col_widths = [0, 0, 0]
|
||||
for config_key, values in benchmark_results.items():
|
||||
args_split = config_key.split("-")
|
||||
func_name = args_split[0]
|
||||
input_args_str = "-".join(args_split[1:])
|
||||
col_widths[0] = max(col_widths[0], len(func_name) + 2, len(headers[0]) + 2)
|
||||
col_widths[1] = max(col_widths[1],
|
||||
len(input_args_str) + 2,
|
||||
len(headers[1]) + 2)
|
||||
col_widths[2] = max(col_widths[2],
|
||||
len(f"{values['BitBLAS_top20_latency']:.3f} ms") + 2,
|
||||
len(headers[2]) + 2)
|
||||
# break only if you want to measure widths from a single example;
|
||||
# otherwise, let it loop over all items.
|
||||
|
||||
# Print header
|
||||
for i, header in enumerate(headers):
|
||||
headers[i] = header.ljust(col_widths[i])
|
||||
print("".join(headers))
|
||||
print("-" * sum(col_widths))
|
||||
|
||||
# Print rows
|
||||
for config_key, values in benchmark_results.items():
|
||||
args_split = config_key.split("-")
|
||||
func_name = args_split[0]
|
||||
input_args_str = "-".join(args_split[1:])
|
||||
row = [
|
||||
func_name,
|
||||
input_args_str,
|
||||
f"{values['BitBLAS_top20_latency']:.3f} ms",
|
||||
]
|
||||
row_str = "".join(
|
||||
[str(cell).ljust(col_widths[idx]) for idx, cell in enumerate(row)])
|
||||
print(row_str)
|
||||
@ -17,8 +17,14 @@ from torch.utils.benchmark import Measurement as TMeasurement
|
||||
from utils import ArgPool, Bench, CudaGraphBenchParams
|
||||
from weight_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm.lora.ops.triton_ops import LoRAKernelMeta, lora_expand, lora_shrink
|
||||
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
if HAS_TRITON:
|
||||
from vllm.lora.ops.triton_ops import (LoRAKernelMeta, lora_expand,
|
||||
lora_shrink)
|
||||
from vllm.lora.ops.triton_ops.utils import (_LORA_A_PTR_DICT,
|
||||
_LORA_B_PTR_DICT)
|
||||
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
|
||||
|
||||
@ -553,6 +553,8 @@ def main(args: argparse.Namespace):
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
else:
|
||||
# Support for llama4
|
||||
config = config.get_text_config()
|
||||
# Default: Mixtral.
|
||||
E = config.num_local_experts
|
||||
topk = config.num_experts_per_tok
|
||||
|
||||
213
benchmarks/run.sh
Normal file
213
benchmarks/run.sh
Normal file
@ -0,0 +1,213 @@
|
||||
# python benchmarks/benchmark_throughput.py \
|
||||
# --model meta-llama/Meta-Llama-3-8B-Instruct \
|
||||
# --dataset-name sonnet \
|
||||
# --dataset-path /data/lily/batch-sd/benchmarks/sonnet.txt \
|
||||
# --prefix-len 0 \
|
||||
# --output-len 512 \
|
||||
# --num-prompts 200 \
|
||||
# --speculative_config '{"method": "ngram", "num_speculative_tokens": 20, "prompt_lookup_min": 2, "prompt_lookup_max": 5}'
|
||||
|
||||
|
||||
# python benchmarks/benchmark_throughput.py \
|
||||
# --model meta-llama/Meta-Llama-3-8B-Instruct \
|
||||
# --dataset-name sharegpt \
|
||||
# --dataset-path /data/lily/ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||
# --prefix-len 0 \
|
||||
# --output-len 512 \
|
||||
# --num-prompts 200 \
|
||||
# --speculative_config '{"method": "ngram", "num_speculative_tokens": 20, "prompt_lookup_min": 2, "prompt_lookup_max": 5}'
|
||||
|
||||
# python benchmarks/benchmark_throughput.py \
|
||||
# --model meta-llama/Meta-Llama-3-8B-Instruct \
|
||||
# --dataset-name hf \
|
||||
# --dataset-path likaixin/InstructCoder \
|
||||
# --prefix-len 0 \
|
||||
# --output-len 512 \
|
||||
# --num-prompts 200 \
|
||||
# --speculative_config '{"method": "ngram", "num_speculative_tokens": 20, "prompt_lookup_min": 2, "prompt_lookup_max": 5}'
|
||||
|
||||
|
||||
# python benchmarks/benchmark_throughput.py \
|
||||
# --model meta-llama/Meta-Llama-3-8B-Instruct \
|
||||
# --dataset-name sonnet \
|
||||
# --dataset-path /data/lily/batch-sd/benchmarks/sonnet.txt \
|
||||
# --prefix-len 0 \
|
||||
# --output-len 512 \
|
||||
# --num-prompts 200 \
|
||||
# --speculative_config '{"method": "eagle", "model": "yuhuili/EAGLE-LLaMA3-Instruct-8B", "num_speculative_tokens": 20}'
|
||||
|
||||
# python benchmarks/benchmark_throughput.py \
|
||||
# --model meta-llama/Meta-Llama-3-8B-Instruct \
|
||||
# --dataset-name sharegpt \
|
||||
# --dataset-path /data/lily/ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||
# --prefix-len 0 \
|
||||
# --output-len 512 \
|
||||
# --num-prompts 200 \
|
||||
# --speculative_config '{"method": "eagle", "model": "yuhuili/EAGLE-LLaMA3-Instruct-8B", "num_speculative_tokens": 20}'
|
||||
|
||||
|
||||
# python benchmarks/benchmark_throughput.py \
|
||||
# --model meta-llama/Meta-Llama-3-8B-Instruct \
|
||||
# --dataset-name hf \
|
||||
# --dataset-path likaixin/InstructCoder \
|
||||
# --prefix-len 0 \
|
||||
# --output-len 512 \
|
||||
# --num-prompts 200 \
|
||||
# --speculative_config '{"method": "eagle", "model": "yuhuili/EAGLE-LLaMA3-Instruct-8B", "num_speculative_tokens": 20}'
|
||||
|
||||
|
||||
# python benchmarks/benchmark_throughput.py \
|
||||
# --model meta-llama/Meta-Llama-3.1-8B-Instruct \
|
||||
# --dataset-name hf \
|
||||
# --dataset-path likaixin/InstructCoder \
|
||||
# --prefix-len 0 \
|
||||
# --output-len 512 \
|
||||
# --num-prompts 200 \
|
||||
# --speculative_config '{"method": "eagle3", "model": "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 20}'
|
||||
|
||||
|
||||
python benchmarks/benchmark_throughput.py \
|
||||
--model meta-llama/Meta-Llama-3.1-8B-Instruct\
|
||||
--dataset-name hf \
|
||||
--dataset-path philschmid/mt-bench \
|
||||
--prefix-len 0 \
|
||||
--output-len 512 \
|
||||
--num-prompts 200 \
|
||||
--speculative_config '{"method": "eagle3", "model": "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 20}'
|
||||
|
||||
|
||||
# python benchmarks/benchmark_throughput.py \
|
||||
# --model meta-llama/Meta-Llama-3.1-8B-Instruct \
|
||||
# --dataset-name sharegpt \
|
||||
# --dataset-path /data/lily/ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||
# --prefix-len 0 \
|
||||
# --output-len 512 \
|
||||
# --num-prompts 200 \
|
||||
# --speculative_config '{"method": "eagle3", "model": "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 20}'
|
||||
|
||||
# python benchmarks/benchmark_throughput.py \
|
||||
# --model meta-llama/Meta-Llama-3.1-8B-Instruct \
|
||||
# --dataset-name sonnet \
|
||||
# --dataset-path /data/lily/batch-sd/benchmarks/sonnet.txt \
|
||||
# --prefix-len 0 \
|
||||
# --output-len 512 \
|
||||
# --num-prompts 200 \
|
||||
# --speculative_config '{"method": "eagle3", "model": "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 20}'
|
||||
|
||||
|
||||
# python benchmarks/benchmark_throughput.py \
|
||||
# --model meta-llama/Meta-Llama-3.1-8B-Instruct \
|
||||
# --dataset-name hf \
|
||||
# --dataset-path likaixin/InstructCoder \
|
||||
# --prefix-len 0 \
|
||||
# --output-len 512 \
|
||||
# --num-prompts 200 \
|
||||
# --speculative_config '{"method": "eagle", "model": "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 20}'
|
||||
|
||||
# python benchmarks/benchmark_throughput.py \
|
||||
# --model meta-llama/Meta-Llama-3.1-8B-Instruct \
|
||||
# --dataset-name sharegpt \
|
||||
# --dataset-path /data/lily/ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||
# --prefix-len 0 \
|
||||
# --output-len 512 \
|
||||
# --num-prompts 200 \
|
||||
# --speculative_config '{"method": "eagle", "model": "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 20}'
|
||||
|
||||
# python benchmarks/benchmark_throughput.py \
|
||||
# --model meta-llama/Meta-Llama-3.1-8B-Instruct \
|
||||
# --dataset-name hf \
|
||||
# --dataset-path likaixin/InstructCoder \
|
||||
# --prefix-len 0 \
|
||||
# --output-len 512 \
|
||||
# --num-prompts 200 \
|
||||
# --speculative_config '{"method": "ngram", "num_speculative_tokens": 20, "prompt_lookup_min": 2, "prompt_lookup_max": 5}'
|
||||
|
||||
|
||||
# python benchmarks/benchmark_throughput.py \
|
||||
# --model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
|
||||
# --dataset-name hf \
|
||||
# --dataset-path AI-MO/aimo-validation-aime \
|
||||
# --prefix-len 0 \
|
||||
# --output-len 1024 \
|
||||
# --num-prompts 90 \
|
||||
# --speculative_config '{"method": "eagle3", "num_speculative_tokens": 20, "model": "yuhuili/EAGLE3-DeepSeek-R1-Distill-LLaMA-8B"}'
|
||||
|
||||
|
||||
|
||||
# python benchmarks/benchmark_throughput.py \
|
||||
# --model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
|
||||
# --dataset-name hf \
|
||||
# --dataset-path AI-MO/aimo-validation-aime \
|
||||
# --prefix-len 0 \
|
||||
# --output-len 1024 \
|
||||
# --num-prompts 90 \
|
||||
# --speculative_config '{"method": "ngram", "num_speculative_tokens": 20, "prompt_lookup_min": 2, "prompt_lookup_max": 5}'
|
||||
|
||||
|
||||
|
||||
# python benchmarks/benchmark_throughput.py \
|
||||
# --model meta-llama/Meta-Llama-3.1-8B-Instruct \
|
||||
# --dataset-name sharegpt \
|
||||
# --dataset-path /data/lily/ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||
# --prefix-len 0 \
|
||||
# --output-len 512 \
|
||||
# --num-prompts 200 \
|
||||
# --speculative_config '{"method": "ngram", "num_speculative_tokens": 20, "prompt_lookup_min": 2, "prompt_lookup_max": 5}'
|
||||
|
||||
|
||||
# python benchmarks/benchmark_throughput.py \
|
||||
# --model meta-llama/Meta-Llama-3.1-8B-Instruct \
|
||||
# --dataset-name hf \
|
||||
# --dataset-path philschmid/mt-bench \
|
||||
# --prefix-len 0 \
|
||||
# --output-len 512 \
|
||||
# --num-prompts 200 \
|
||||
# --speculative_config '{"method": "ngram", "num_speculative_tokens": 20, "prompt_lookup_min": 2, "prompt_lookup_max": 5}'
|
||||
|
||||
# python benchmarks/benchmark_throughput.py \
|
||||
# --model meta-llama/Meta-Llama-3.1-8B-Instruct \
|
||||
# --dataset-name hf \
|
||||
# --dataset-path philschmid/mt-bench \
|
||||
# --prefix-len 0 \
|
||||
# --output-len 512 \
|
||||
# --num-prompts 200 \
|
||||
# --speculative_config '{"method": "eagle", "model": "yuhuili/EAGLE-LLaMA3-Instruct-8B", "num_speculative_tokens": 20}'
|
||||
|
||||
|
||||
# python benchmarks/benchmark_throughput.py \
|
||||
# --model meta-llama/Meta-Llama-3.1-8B-Instruct \
|
||||
# --dataset-name hf \
|
||||
# --dataset-path abisee/cnn_dailymail \
|
||||
# --prefix-len 0 \
|
||||
# --output-len 512 \
|
||||
# --num-prompts 200 \
|
||||
# --speculative_config '{"method": "eagle", "model": "yuhuili/EAGLE-LLaMA3-Instruct-8B", "num_speculative_tokens": 20}'
|
||||
|
||||
# python benchmarks/benchmark_throughput.py \
|
||||
# --model meta-llama/Meta-Llama-3.1-8B-Instruct \
|
||||
# --dataset-name hf \
|
||||
# --dataset-path abisee/cnn_dailymail \
|
||||
# --prefix-len 0 \
|
||||
# --output-len 512 \
|
||||
# --num-prompts 200 \
|
||||
# --speculative_config '{"method": "ngram", "num_speculative_tokens": 20, "prompt_lookup_min": 2, "prompt_lookup_max": 5}'
|
||||
|
||||
|
||||
# python benchmarks/benchmark_throughput.py \
|
||||
# --model meta-llama/Meta-Llama-3.1-8B-Instruct \
|
||||
# --dataset-name hf \
|
||||
# --dataset-path philschmid/mt-bench \
|
||||
# --prefix-len 0 \
|
||||
# --output-len 512 \
|
||||
# --num-prompts 10 \
|
||||
# --speculative_config '{"method": "eagle3", "model": "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 20}'
|
||||
|
||||
|
||||
# python benchmarks/benchmark_throughput.py \
|
||||
# --model meta-llama/Meta-Llama-3.1-8B-Instruct \
|
||||
# --dataset-name hf \
|
||||
# --dataset-path abisee/cnn_dailymail \
|
||||
# --prefix-len 0 \
|
||||
# --output-len 512 \
|
||||
# --num-prompts 200 \
|
||||
# --speculative_config '{"method": "ngram", "num_speculative_tokens": 20, "prompt_lookup_min": 2, "prompt_lookup_max": 5}'
|
||||
63
benchmarks/visualize/common.py
Normal file
63
benchmarks/visualize/common.py
Normal file
@ -0,0 +1,63 @@
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
|
||||
MODEL_TO_NAMES = {
|
||||
"r1-distill-llama-8B" : "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
|
||||
"llama3-8B" : "meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
"llama3.1-8B" : "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"llama3.1-70B" : "meta-llama/Llama-3.1-70B-Instruct",
|
||||
}
|
||||
|
||||
@dataclass
|
||||
class AccStats:
|
||||
lens: list[int]
|
||||
probs: list[float] = None
|
||||
entropies: list[float] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.probs is not None:
|
||||
assert len(self.lens) == len(self.probs), "Length of lens and probs must match"
|
||||
if self.entropies is not None:
|
||||
assert len(self.lens) == len(self.entropies), "Length of lens and entropies must match"
|
||||
|
||||
# remove the prefill accepted lens
|
||||
self.lens = self.lens[1:]
|
||||
|
||||
# remove the last proposed tokens
|
||||
if self.probs:
|
||||
self.probs = self.probs[:-1]
|
||||
if self.entropies:
|
||||
self.entropies = self.entropies[:-1]
|
||||
|
||||
@property
|
||||
def length(self):
|
||||
return len(self.lens)
|
||||
|
||||
# def cleanup(acc_stats: AccStats) ->
|
||||
# # Remove the prefill phase
|
||||
# data = data[1:]
|
||||
# # Cap the maximum value to 10
|
||||
# data = [min(x, 10) for x in data]
|
||||
# return data
|
||||
|
||||
def load_data(datapath, tokenizer, verbose=False):
|
||||
acceptance_stats = []
|
||||
with open(datapath, "r") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
data = json.loads(line)
|
||||
stat = AccStats(
|
||||
lens=data['acc']['acc_len'],
|
||||
probs=data['acc'].get('acc_prob', None),
|
||||
entropies=data['acc'].get('acc_entropy', None)
|
||||
)
|
||||
acceptance_stats.append(stat)
|
||||
if verbose:
|
||||
print("Input:", tokenizer.decode(data['prompt_token_ids']))
|
||||
print("Output:", tokenizer.decode(data['generated_token_ids']))
|
||||
print("=============================================")
|
||||
|
||||
max_length = max(stats.length for stats in acceptance_stats)
|
||||
|
||||
print(f"Load {len(acceptance_stats)} with max length {max_length}")
|
||||
return acceptance_stats
|
||||
108
benchmarks/visualize/vis_acc.py
Normal file
108
benchmarks/visualize/vis_acc.py
Normal file
@ -0,0 +1,108 @@
|
||||
import json
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
from transformers import AutoTokenizer
|
||||
from .common import MODEL_TO_NAMES, load_data
|
||||
import requests
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
class AcceptanceStatsClient:
|
||||
"""Client for fetching and processing acceptance statistics data."""
|
||||
|
||||
def __init__(self, model_name, method, dataset, data_path=None):
|
||||
"""Initialize the client with model and dataset info."""
|
||||
self.model_name = model_name
|
||||
self.method = method
|
||||
self.dataset = dataset
|
||||
|
||||
if data_path is None:
|
||||
self.data_path = f"/data/lily/batch-sd/data/{model_name}/{method}_{dataset}_acceptance_stats.jsonl"
|
||||
else:
|
||||
self.data_path = data_path
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_TO_NAMES[model_name], use_fast=False)
|
||||
self.acceptance_stats = None
|
||||
|
||||
def load_data(self):
|
||||
"""Load the acceptance statistics from file."""
|
||||
self.acceptance_stats = load_data(self.data_path, self.tokenizer)
|
||||
return self.acceptance_stats
|
||||
|
||||
def plot_heatmap(self, output_dir="figures"):
|
||||
"""Plot the acceptance statistics as a heatmap."""
|
||||
if self.acceptance_stats is None:
|
||||
self.load_data()
|
||||
|
||||
fig, ax = plt.subplots(figsize=(12, 8))
|
||||
sns.heatmap(self.acceptance_stats, cmap="YlGnBu")
|
||||
plt.xlabel("Position")
|
||||
plt.ylabel("Request ID")
|
||||
|
||||
# Add Y-axis labels on the right
|
||||
ax2 = ax.twinx()
|
||||
ax2.set_ylim(ax.get_ylim())
|
||||
ax2.set_yticks([])
|
||||
ax2.set_ylabel("# of Accepted Tokens", labelpad=10)
|
||||
|
||||
plt.title(f"Acceptance Statistics: {self.model_name} - {self.method} - {self.dataset}")
|
||||
plt.tight_layout()
|
||||
|
||||
# Create output directory if it doesn't exist
|
||||
output_path = Path(output_dir) / self.model_name
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
|
||||
output_file = output_path / f"{self.method}_{self.dataset}_acceptance_stats.pdf"
|
||||
plt.savefig(output_file)
|
||||
print(f"Saved heatmap to {output_file}")
|
||||
return fig
|
||||
|
||||
def get_summary_stats(self):
|
||||
"""Get summary statistics about the acceptance data."""
|
||||
if self.acceptance_stats is None:
|
||||
self.load_data()
|
||||
|
||||
# Calculate average acceptance rate for each position
|
||||
avg_by_position = [sum(col)/len(col) for col in zip(*self.acceptance_stats) if sum(1 for v in col if v >= 0) > 0]
|
||||
|
||||
# Calculate average acceptance rate for each request
|
||||
avg_by_request = [sum(row)/len(row) for row in self.acceptance_stats]
|
||||
|
||||
return {
|
||||
"total_requests": len(self.acceptance_stats),
|
||||
"max_position": len(avg_by_position),
|
||||
"avg_acceptance_rate": sum(avg_by_request)/len(avg_by_request),
|
||||
"avg_by_position": avg_by_position,
|
||||
"avg_by_request": avg_by_request
|
||||
}
|
||||
|
||||
# Example model configuration
|
||||
model = "llama3.1-8B"
|
||||
# model = "r1-distill-llama-8B"
|
||||
method = "eagle3"
|
||||
dataset = "mtbench"
|
||||
# dataset = "aime"
|
||||
# method = "ngram"
|
||||
# dataset = "cnndailymail"
|
||||
# datapath = f"/data/lily/batch-sd/data/{model}/{method}_{dataset}_acceptance_stats.jsonl"
|
||||
datapath = "acceptance_stats.jsonl"
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_TO_NAMES[model], use_fast=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Use the client instead of directly loading data
|
||||
client = AcceptanceStatsClient(model, method, dataset, datapath)
|
||||
acceptance_stats = client.load_data()
|
||||
|
||||
# Get summary statistics
|
||||
summary = client.get_summary_stats()
|
||||
print("Summary Statistics:")
|
||||
print(f"Total Requests: {summary['total_requests']}")
|
||||
print(f"Max Position: {summary['max_position']}")
|
||||
print(f"Average Acceptance Rate: {summary['avg_acceptance_rate']:.2f}")
|
||||
|
||||
# Create heatmap visualization
|
||||
plot_heatmap = False
|
||||
if plot_heatmap:
|
||||
client.plot_heatmap()
|
||||
|
||||
69
benchmarks/visualize/vis_acc_diff.py
Normal file
69
benchmarks/visualize/vis_acc_diff.py
Normal file
@ -0,0 +1,69 @@
|
||||
import json
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.colors import LinearSegmentedColormap
|
||||
|
||||
model = "llama3.1-8B"
|
||||
dataset = "instructcode"
|
||||
method1 = "ngram"
|
||||
method2 = "eagle3"
|
||||
|
||||
def get_datapath(method):
|
||||
datapath = f"/data/lily/batch-sd/data/{model}/{method}_{dataset}_acceptance_stats.jsonl"
|
||||
return datapath
|
||||
|
||||
def cleanup(data):
|
||||
# Remove the prefill phase
|
||||
data = data[1:]
|
||||
# Cap the maximum value to 10
|
||||
data = [min(x, 10) for x in data]
|
||||
return data
|
||||
|
||||
def load_data(datapath):
|
||||
acceptance_stats = {}
|
||||
with open(datapath, "r") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
data = json.loads(line)
|
||||
key = hash(tuple(data['prompt_token_ids']))
|
||||
acceptance_stats[key] = cleanup(data['acc'])
|
||||
# Pad the acceptance stats to the same length
|
||||
max_length = max(len(stats) for k, stats in acceptance_stats.items())
|
||||
|
||||
for key in acceptance_stats:
|
||||
acceptance_stats[key] += [-2] * (max_length - len(acceptance_stats[key]))
|
||||
|
||||
print(f"Load {len(acceptance_stats)} with max length {max_length} from {datapath}")
|
||||
return acceptance_stats
|
||||
|
||||
def diff(acceptance_stats1, acceptance_stats2):
|
||||
diff = {}
|
||||
for key in acceptance_stats1:
|
||||
if key in acceptance_stats2:
|
||||
diff[key] = [a - b for a, b in zip(acceptance_stats1[key], acceptance_stats2[key])]
|
||||
return diff
|
||||
|
||||
datapath_1 = get_datapath(method1)
|
||||
datapath_2 = get_datapath(method2)
|
||||
acceptance_stats_1 = load_data(datapath_1)
|
||||
acceptance_stats_2 = load_data(datapath_2)
|
||||
acceptance_stats_diff = diff(acceptance_stats_1, acceptance_stats_2)
|
||||
|
||||
acceptance_stats = list(acceptance_stats_diff.values())
|
||||
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
colors = ["red", "white", "blue"]
|
||||
custom_cmap = LinearSegmentedColormap.from_list("custom", colors, N=256)
|
||||
sns.heatmap(acceptance_stats, cmap=custom_cmap, center=0)
|
||||
plt.xlabel("Position")
|
||||
plt.ylabel("Request ID")
|
||||
# Add Y-axis labels on the right
|
||||
ax2 = ax.twinx()
|
||||
ax2.set_ylim(ax.get_ylim()) # Match y-axis range
|
||||
ax2.set_yticks([]) # Remove right tick marks if undesired
|
||||
ax2.set_ylabel("# of Accepted Tokens", labelpad=10) # Set right y-axis label
|
||||
plt.title(f"Diff between {method2} - {method1} acceptance stats for {dataset}")
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(f"figures/{model}/diff_{method2}_{method1}_{dataset}_acceptance_stats.pdf")
|
||||
38
benchmarks/visualize/vis_prob_entropy.py
Normal file
38
benchmarks/visualize/vis_prob_entropy.py
Normal file
@ -0,0 +1,38 @@
|
||||
from transformers import AutoTokenizer
|
||||
from common import MODEL_TO_NAMES, load_data
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def plot_prob_entropy(acceptance_stats,
|
||||
output_path):
|
||||
|
||||
acc_probs = []
|
||||
rej_probs = []
|
||||
for stat in acceptance_stats:
|
||||
for i, acc_len in enumerate(stat.lens):
|
||||
acc_probs.extend(stat.probs[i][:acc_len-1])
|
||||
rej_probs.extend(stat.probs[i][acc_len-1:])
|
||||
|
||||
fig, ax = plt.subplots(figsize=(12, 8))
|
||||
plt.hist(acc_probs, bins=100, alpha=0.5,
|
||||
label='Accepted Probabilities', color='green')
|
||||
plt.hist(rej_probs, bins=100, alpha=0.5,
|
||||
label='Rejected Probabilities', color='red')
|
||||
plt.xlabel('Probability')
|
||||
plt.ylabel('Frequency')
|
||||
plt.title('Distribution of Accepted and Rejected Probabilities')
|
||||
plt.legend()
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
datapath = "/data/lily/sd-benchmark-paper/batch-sd/acceptance_stats.jsonl"
|
||||
model = "llama3.1-8B"
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_TO_NAMES[model],
|
||||
use_fast=False)
|
||||
acceptance_stats = load_data(datapath, tokenizer)
|
||||
plot_prob_entropy(acceptance_stats, output_path="prob_entropy_figures")
|
||||
|
||||
|
||||
|
||||
@ -33,8 +33,6 @@ endif()
|
||||
|
||||
if(MACOSX_FOUND)
|
||||
list(APPEND CXX_COMPILE_FLAGS
|
||||
"-Xpreprocessor"
|
||||
"-fopenmp"
|
||||
"-DVLLM_CPU_EXTENSION")
|
||||
else()
|
||||
list(APPEND CXX_COMPILE_FLAGS
|
||||
|
||||
@ -38,7 +38,7 @@ else()
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn
|
||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||
GIT_TAG dc9d410b3e2d6534a4c70724c2515f4def670a22
|
||||
GIT_TAG 8798f27777fb57f447070301bf33a9f9c607f491
|
||||
GIT_PROGRESS TRUE
|
||||
# Don't share the vllm-flash-attn build between build types
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
|
||||
178
csrc/attention/merge_attn_states.cu
Normal file
178
csrc/attention/merge_attn_states.cu
Normal file
@ -0,0 +1,178 @@
|
||||
#include <optional>
|
||||
#include <torch/all.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <algorithm>
|
||||
|
||||
#include "attention_dtypes.h"
|
||||
#include "attention_utils.cuh"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
|
||||
// can be used to combine partial attention results (in the split-KV case)
|
||||
template <typename scalar_t, const uint NUM_THREADS>
|
||||
__global__ void merge_attn_states_kernel(
|
||||
scalar_t* output, float* output_lse, const scalar_t* prefix_output,
|
||||
const float* prefix_lse, const scalar_t* suffix_output,
|
||||
const float* suffix_lse, const uint num_tokens, const uint num_heads,
|
||||
const uint head_size) {
|
||||
using pack_128b_t = uint4;
|
||||
const uint pack_size = 16 / sizeof(scalar_t);
|
||||
const uint threads_per_head = head_size / pack_size;
|
||||
|
||||
const uint global_idx = blockIdx.x * NUM_THREADS + threadIdx.x;
|
||||
const uint token_head_threads = num_tokens * num_heads * threads_per_head;
|
||||
|
||||
if (global_idx >= token_head_threads) return;
|
||||
|
||||
// global_idx -> token_idx + head_idx + pack_idx
|
||||
const uint token_head_idx = global_idx / threads_per_head;
|
||||
const uint pack_idx = global_idx % threads_per_head;
|
||||
|
||||
const uint token_idx = token_head_idx / num_heads;
|
||||
const uint head_idx = token_head_idx % num_heads;
|
||||
|
||||
const uint pack_offset = pack_idx * pack_size; // (0~15)*8, etc.
|
||||
const uint head_offset =
|
||||
token_idx * num_heads * head_size + head_idx * head_size;
|
||||
const scalar_t* prefix_head_ptr = prefix_output + head_offset;
|
||||
const scalar_t* suffix_head_ptr = suffix_output + head_offset;
|
||||
scalar_t* output_head_ptr = output + head_offset;
|
||||
|
||||
float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
|
||||
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
|
||||
p_lse = std::isinf(p_lse) ? -std::numeric_limits<float>::infinity() : p_lse;
|
||||
s_lse = std::isinf(s_lse) ? -std::numeric_limits<float>::infinity() : s_lse;
|
||||
|
||||
const float max_lse = fmaxf(p_lse, s_lse);
|
||||
p_lse = p_lse - max_lse;
|
||||
s_lse = s_lse - max_lse;
|
||||
const float p_se = expf(p_lse);
|
||||
const float s_se = expf(s_lse);
|
||||
const float out_se = p_se + s_se;
|
||||
const float p_scale = p_se / out_se;
|
||||
const float s_scale = s_se / out_se;
|
||||
|
||||
if (pack_offset < head_size) {
|
||||
// Pack 128b load
|
||||
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(
|
||||
prefix_head_ptr)[pack_offset / pack_size];
|
||||
pack_128b_t s_out_pack = reinterpret_cast<const pack_128b_t*>(
|
||||
suffix_head_ptr)[pack_offset / pack_size];
|
||||
pack_128b_t o_out_pack;
|
||||
|
||||
#pragma unroll
|
||||
for (uint i = 0; i < pack_size; ++i) {
|
||||
// Always use float for FMA to keep high precision.
|
||||
// half(uint16_t), bfloat16, float -> float.
|
||||
const float p_out_f =
|
||||
vllm::to_float(reinterpret_cast<const scalar_t*>(&p_out_pack)[i]);
|
||||
const float s_out_f =
|
||||
vllm::to_float(reinterpret_cast<const scalar_t*>(&s_out_pack)[i]);
|
||||
// fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale)
|
||||
const float o_out_f = p_out_f * p_scale + (s_out_f * s_scale);
|
||||
// float -> half(uint16_t), bfloat16, float.
|
||||
vllm::from_float(reinterpret_cast<scalar_t*>(&o_out_pack)[i], o_out_f);
|
||||
}
|
||||
|
||||
// Pack 128b storage
|
||||
reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] =
|
||||
o_out_pack;
|
||||
}
|
||||
// We only need to write to output_lse once per head.
|
||||
if (output_lse != nullptr && pack_idx == 0) {
|
||||
float out_lse = logf(out_se) + max_lse;
|
||||
output_lse[head_idx * num_tokens + token_idx] = out_lse;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
// The following macro is used to dispatch the conversion function based on
|
||||
// the output data type. The FN is a macro that calls a function with
|
||||
// template<typename scalar_t>.
|
||||
#define DISPATCH_BY_SCALAR_DTYPE(scalar_dtype, fn) \
|
||||
{ \
|
||||
if (scalar_dtype == at::ScalarType::Float) { \
|
||||
fn(float); \
|
||||
} else if (scalar_dtype == at::ScalarType::Half) { \
|
||||
fn(uint16_t); \
|
||||
} else if (scalar_dtype == at::ScalarType::BFloat16) { \
|
||||
fn(__nv_bfloat16); \
|
||||
} else { \
|
||||
TORCH_CHECK(false, "Unsupported data type of O: ", scalar_dtype); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \
|
||||
{ \
|
||||
vllm::merge_attn_states_kernel<scalar_t, NUM_THREADS> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<scalar_t*>(output.data_ptr()), output_lse_ptr, \
|
||||
reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \
|
||||
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
|
||||
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
|
||||
reinterpret_cast<float*>(suffix_lse.data_ptr()), num_tokens, \
|
||||
num_heads, head_size); \
|
||||
}
|
||||
|
||||
/*@brief Merges the attention states from prefix and suffix
|
||||
* into the output tensor. NUM_TOKENS: n, NUM_HEADS: h, HEAD_SIZE: d
|
||||
*
|
||||
* @param output [n,h,d] The output tensor to store the merged attention states.
|
||||
* @param output_lse [h,d] Optional tensor to store the log-sum-exp values.
|
||||
* @param prefix_output [n,h,d] The prefix attention states.
|
||||
* @param prefix_lse [h,n] The log-sum-exp values for the prefix attention
|
||||
* states.
|
||||
* @param suffix_output [n,h,d] The suffix attention states.
|
||||
* @param suffix_lse [h,n] The log-sum-exp values for the suffix attention
|
||||
* states.
|
||||
*/
|
||||
template <typename scalar_t>
|
||||
void merge_attn_states_launcher(torch::Tensor& output,
|
||||
std::optional<torch::Tensor> output_lse,
|
||||
const torch::Tensor& prefix_output,
|
||||
const torch::Tensor& prefix_lse,
|
||||
const torch::Tensor& suffix_output,
|
||||
const torch::Tensor& suffix_lse) {
|
||||
constexpr uint NUM_THREADS = 128;
|
||||
const uint num_tokens = output.size(0);
|
||||
const uint num_heads = output.size(1);
|
||||
const uint head_size = output.size(2);
|
||||
const uint pack_size = 16 / sizeof(scalar_t);
|
||||
TORCH_CHECK(head_size % pack_size == 0,
|
||||
"headsize must be multiple of pack_size:", pack_size);
|
||||
float* output_lse_ptr = nullptr;
|
||||
if (output_lse.has_value()) {
|
||||
output_lse_ptr = output_lse.value().data_ptr<float>();
|
||||
}
|
||||
// Process one pack elements per thread. for float, the
|
||||
// pack_size is 4 for half/bf16, the pack_size is 8.
|
||||
const uint threads_per_head = head_size / pack_size;
|
||||
const uint total_threads = num_tokens * num_heads * threads_per_head;
|
||||
|
||||
dim3 block(NUM_THREADS);
|
||||
dim3 grid((total_threads + NUM_THREADS - 1) / NUM_THREADS);
|
||||
|
||||
const c10::cuda::OptionalCUDAGuard device_guard(prefix_output.device());
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS);
|
||||
}
|
||||
|
||||
#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \
|
||||
{ \
|
||||
merge_attn_states_launcher<scalar_t>(output, output_lse, prefix_output, \
|
||||
prefix_lse, suffix_output, \
|
||||
suffix_lse); \
|
||||
}
|
||||
|
||||
void merge_attn_states(torch::Tensor& output,
|
||||
std::optional<torch::Tensor> output_lse,
|
||||
const torch::Tensor& prefix_output,
|
||||
const torch::Tensor& prefix_lse,
|
||||
const torch::Tensor& suffix_output,
|
||||
const torch::Tensor& suffix_lse) {
|
||||
DISPATCH_BY_SCALAR_DTYPE(output.dtype(), CALL_MERGE_ATTN_STATES_LAUNCHER);
|
||||
}
|
||||
38
csrc/attention/mla/cutlass_mla_entry.cu
Normal file
38
csrc/attention/mla/cutlass_mla_entry.cu
Normal file
@ -0,0 +1,38 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#if defined ENABLE_CUTLASS_MLA && ENABLE_CUTLASS_MLA
|
||||
void cutlass_mla_decode_sm100a(torch::Tensor const& out,
|
||||
torch::Tensor const& q_nope,
|
||||
torch::Tensor const& q_pe,
|
||||
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||
torch::Tensor const& seq_lens,
|
||||
torch::Tensor const& page_table, double scale);
|
||||
#endif
|
||||
|
||||
void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
|
||||
torch::Tensor const& q_pe,
|
||||
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||
torch::Tensor const& seq_lens,
|
||||
torch::Tensor const& page_table, double scale) {
|
||||
#if defined ENABLE_CUTLASS_MLA && ENABLE_CUTLASS_MLA
|
||||
return cutlass_mla_decode_sm100a(out, q_nope, q_pe, kv_c_and_k_pe_cache,
|
||||
seq_lens, page_table, scale);
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled cutlass MLA");
|
||||
}
|
||||
225
csrc/attention/mla/cutlass_mla_kernels.cu
Normal file
225
csrc/attention/mla/cutlass_mla_kernels.cu
Normal file
@ -0,0 +1,225 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/kernel_hardware_info.h"
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
|
||||
#include "device/sm100_mla.hpp"
|
||||
#include "kernel/sm100_mla_tile_scheduler.hpp"
|
||||
|
||||
using namespace cute;
|
||||
using namespace cutlass::fmha::kernel;
|
||||
|
||||
template <typename T, bool PersistenceOption = true>
|
||||
struct MlaSm100 {
|
||||
using Element = T;
|
||||
using ElementAcc = float;
|
||||
using ElementOut = T;
|
||||
|
||||
using TileShape = Shape<_128, _128, Shape<_512, _64>>;
|
||||
using TileShapeH = cute::tuple_element_t<0, TileShape>;
|
||||
using TileShapeD = cute::tuple_element_t<2, TileShape>;
|
||||
|
||||
// H K (D_latent D_rope) B
|
||||
using ProblemShape = cute::tuple<TileShapeH, int, TileShapeD, int>;
|
||||
|
||||
using StrideQ = cute::tuple<int64_t, _1, int64_t>; // H D B
|
||||
using StrideK = cute::tuple<int64_t, _1, int64_t>; // K D B
|
||||
using StrideO = StrideK; // H D B
|
||||
using StrideLSE = cute::tuple<_1, int>; // H B
|
||||
|
||||
using TileScheduler =
|
||||
std::conditional_t<PersistenceOption, Sm100MlaPersistentTileScheduler,
|
||||
Sm100MlaIndividualTileScheduler>;
|
||||
|
||||
using FmhaKernel =
|
||||
cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized<
|
||||
TileShape, Element, ElementAcc, ElementOut, ElementAcc, TileScheduler,
|
||||
/*kIsCpAsync=*/true>;
|
||||
using Fmha = cutlass::fmha::device::MLA<FmhaKernel>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
typename T::Fmha::Arguments args_from_options(
|
||||
at::Tensor const& out, at::Tensor const& q_nope, at::Tensor const& q_pe,
|
||||
at::Tensor const& kv_c_and_k_pe_cache, at::Tensor const& seq_lens,
|
||||
at::Tensor const& page_table, double scale) {
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = q_nope.device().index();
|
||||
hw_info.sm_count =
|
||||
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
|
||||
hw_info.device_id);
|
||||
|
||||
int batches = q_nope.sizes()[0];
|
||||
int page_count_per_seq = page_table.sizes()[1];
|
||||
int page_count_total = kv_c_and_k_pe_cache.sizes()[0];
|
||||
int page_size = kv_c_and_k_pe_cache.sizes()[1];
|
||||
int max_seq_len = page_size * page_count_per_seq;
|
||||
using TileShapeH = typename T::TileShapeH;
|
||||
using TileShapeD = typename T::TileShapeD;
|
||||
auto problem_shape =
|
||||
cute::make_tuple(TileShapeH{}, max_seq_len, TileShapeD{}, batches);
|
||||
|
||||
auto [H, K, D, B] = problem_shape;
|
||||
auto [D_latent, D_rope] = D;
|
||||
|
||||
using StrideQ = typename T::StrideQ;
|
||||
using StrideK = typename T::StrideK;
|
||||
using StrideO = typename T::StrideO;
|
||||
using StrideLSE = typename T::StrideLSE;
|
||||
|
||||
StrideQ stride_Q_latent = cute::make_tuple(
|
||||
static_cast<int64_t>(D_latent), _1{}, static_cast<int64_t>(H * D_latent));
|
||||
StrideQ stride_Q_rope = cute::make_tuple(static_cast<int64_t>(D_rope), _1{},
|
||||
static_cast<int64_t>(H * D_rope));
|
||||
StrideK stride_C =
|
||||
cute::make_tuple(static_cast<int64_t>(D_latent + D_rope), _1{},
|
||||
static_cast<int64_t>(page_size * (D_latent + D_rope)));
|
||||
StrideLSE stride_PT = cute::make_stride(_1{}, page_count_per_seq);
|
||||
StrideLSE stride_LSE = cute::make_tuple(_1{}, static_cast<int>(H));
|
||||
StrideO stride_O = cute::make_tuple(static_cast<int64_t>(D_latent), _1{},
|
||||
static_cast<int64_t>(H * D_latent));
|
||||
|
||||
using Element = typename T::Element;
|
||||
using ElementOut = typename T::ElementOut;
|
||||
using ElementAcc = typename T::ElementAcc;
|
||||
auto Q_latent_ptr = static_cast<Element*>(q_nope.data_ptr());
|
||||
auto Q_rope_ptr = static_cast<Element*>(q_pe.data_ptr());
|
||||
auto C_ptr = static_cast<Element*>(kv_c_and_k_pe_cache.data_ptr());
|
||||
auto scale_f = static_cast<float>(scale);
|
||||
typename T::Fmha::Arguments arguments{
|
||||
problem_shape,
|
||||
{scale_f, Q_latent_ptr, stride_Q_latent, Q_rope_ptr, stride_Q_rope, C_ptr,
|
||||
stride_C, C_ptr + D_latent, stride_C,
|
||||
static_cast<int*>(seq_lens.data_ptr()),
|
||||
static_cast<int*>(page_table.data_ptr()), stride_PT, page_count_total,
|
||||
page_size},
|
||||
{static_cast<ElementOut*>(out.data_ptr()), stride_O,
|
||||
static_cast<ElementAcc*>(nullptr), stride_LSE},
|
||||
hw_info,
|
||||
-1, // split_kv
|
||||
nullptr, // is_var_split_kv
|
||||
};
|
||||
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
|
||||
// split_kv automatically based on batch size and sequence length to balance
|
||||
// workload across available SMs. Consider using var_split_kv for manual
|
||||
// control if needed.
|
||||
T::Fmha::set_split_kv(arguments);
|
||||
return arguments;
|
||||
}
|
||||
|
||||
template <typename Element>
|
||||
void runMla(at::Tensor const& out, at::Tensor const& q_nope,
|
||||
at::Tensor const& q_pe, at::Tensor const& kv_c_and_k_pe_cache,
|
||||
at::Tensor const& seq_lens, at::Tensor const& page_table,
|
||||
float scale, cudaStream_t stream) {
|
||||
using MlaSm100Type = MlaSm100<Element>;
|
||||
typename MlaSm100Type::Fmha fmha;
|
||||
auto arguments = args_from_options<MlaSm100Type>(
|
||||
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, scale);
|
||||
size_t workspace_size = MlaSm100Type::Fmha::get_workspace_size(arguments);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(q_nope.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
CUTLASS_CHECK(fmha.can_implement(arguments));
|
||||
|
||||
CUTLASS_CHECK(fmha.initialize(arguments, workspace.data_ptr(), stream));
|
||||
|
||||
CUTLASS_CHECK(fmha.run(arguments, workspace.data_ptr(), stream));
|
||||
}
|
||||
|
||||
void cutlass_mla_decode_sm100a(torch::Tensor const& out,
|
||||
torch::Tensor const& q_nope,
|
||||
torch::Tensor const& q_pe,
|
||||
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||
torch::Tensor const& seq_lens,
|
||||
torch::Tensor const& page_table, double scale) {
|
||||
TORCH_CHECK(q_nope.device().is_cuda(), "q_nope must be on CUDA");
|
||||
TORCH_CHECK(q_nope.dim() == 3, "q_nope must be a 3D tensor");
|
||||
TORCH_CHECK(q_pe.dim() == 3, "q_pe must be a 3D tensor");
|
||||
TORCH_CHECK(kv_c_and_k_pe_cache.dim() == 3,
|
||||
"kv_c_and_k_pe_cache must be a 3D tensor");
|
||||
TORCH_CHECK(seq_lens.dim() == 1, "seq_lens must be a 1D tensor");
|
||||
TORCH_CHECK(page_table.dim() == 2, "page_table must be a 2D tensor");
|
||||
TORCH_CHECK(out.dim() == 3, "out must be a 3D tensor");
|
||||
|
||||
auto B_q_nope = q_nope.size(0);
|
||||
auto H_q_nope = q_nope.size(1);
|
||||
auto D_q_nope = q_nope.size(2);
|
||||
auto B_q_pe = q_pe.size(0);
|
||||
auto H_q_pe = q_pe.size(1);
|
||||
auto D_q_pe = q_pe.size(2);
|
||||
auto B_pt = page_table.size(0);
|
||||
auto PAGE_NUM = page_table.size(1);
|
||||
auto PAGE_SIZE = kv_c_and_k_pe_cache.size(1);
|
||||
auto D_ckv = kv_c_and_k_pe_cache.size(2);
|
||||
auto B_o = out.size(0);
|
||||
auto H_o = out.size(1);
|
||||
auto D_o = out.size(2);
|
||||
|
||||
TORCH_CHECK(D_q_nope == 512, "D_q_nope must be equal to 512");
|
||||
TORCH_CHECK(D_q_pe == 64, "D_q_pe must be equal to 64");
|
||||
TORCH_CHECK(D_ckv == 576, "D_ckv must be equal to 576");
|
||||
TORCH_CHECK(H_q_nope == H_q_pe && H_q_nope == H_o && H_o == 128,
|
||||
"H_q_nope, H_q_pe, and H_o must be equal to 128");
|
||||
TORCH_CHECK(PAGE_SIZE > 0 && (PAGE_SIZE & (PAGE_SIZE - 1)) == 0,
|
||||
"PAGE_SIZE must be a power of 2");
|
||||
TORCH_CHECK(
|
||||
B_q_nope == B_q_pe && B_q_nope == B_pt && B_q_nope == B_o,
|
||||
"Batch dims must be same for page_table, q_nope and q_pe, and out");
|
||||
TORCH_CHECK(PAGE_NUM % (128 / PAGE_SIZE) == 0,
|
||||
"PAGE_NUM must be divisible by 128 / PAGE_SIZE");
|
||||
TORCH_CHECK(D_o == 512, "D_o must be equal to 512");
|
||||
|
||||
TORCH_CHECK(q_nope.dtype() == at::ScalarType::Half ||
|
||||
q_nope.dtype() == at::ScalarType::BFloat16 ||
|
||||
q_nope.dtype() == at::ScalarType::Float8_e4m3fn,
|
||||
"q_nope must be a half, bfloat16, or float8_e4m3fn tensor");
|
||||
TORCH_CHECK(kv_c_and_k_pe_cache.dtype() == q_nope.dtype() &&
|
||||
q_nope.dtype() == q_pe.dtype(),
|
||||
"kv_c_and_k_pe_cache, q_nope, and q_pe must be the same type");
|
||||
TORCH_CHECK(seq_lens.dtype() == torch::kInt32,
|
||||
"seq_lens must be a 32-bit integer tensor");
|
||||
TORCH_CHECK(page_table.dtype() == torch::kInt32,
|
||||
"page_table must be a 32-bit integer tensor");
|
||||
|
||||
auto in_dtype = q_nope.dtype();
|
||||
at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()};
|
||||
const cudaStream_t stream =
|
||||
at::cuda::getCurrentCUDAStream(q_nope.get_device());
|
||||
if (in_dtype == at::ScalarType::Half) {
|
||||
runMla<cutlass::half_t>(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens,
|
||||
page_table, scale, stream);
|
||||
} else if (in_dtype == at::ScalarType::BFloat16) {
|
||||
runMla<cutlass::bfloat16_t>(out, q_nope, q_pe, kv_c_and_k_pe_cache,
|
||||
seq_lens, page_table, scale, stream);
|
||||
} else if (in_dtype == at::ScalarType::Float8_e4m3fn) {
|
||||
runMla<cutlass::float_e4m3_t>(out, q_nope, q_pe, kv_c_and_k_pe_cache,
|
||||
seq_lens, page_table, scale, stream);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported input data type of MLA");
|
||||
}
|
||||
}
|
||||
@ -270,9 +270,10 @@ __global__ void reshape_and_cache_flash_kernel(
|
||||
cache_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads,
|
||||
// head_size]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int block_stride, const int key_stride, const int value_stride,
|
||||
const int num_heads, const int head_size, const int block_size,
|
||||
const float* k_scale, const float* v_scale) {
|
||||
const int64_t block_stride, const int64_t page_stride,
|
||||
const int64_t head_stride, const int64_t key_stride,
|
||||
const int64_t value_stride, const int num_heads, const int head_size,
|
||||
const int block_size, const float* k_scale, const float* v_scale) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
// NOTE: slot_idx can be -1 if the token is padded
|
||||
@ -288,8 +289,8 @@ __global__ void reshape_and_cache_flash_kernel(
|
||||
const int head_idx = i / head_size;
|
||||
const int head_offset = i % head_size;
|
||||
const int64_t tgt_key_value_idx = block_idx * block_stride +
|
||||
block_offset * num_heads * head_size +
|
||||
head_idx * head_size + head_offset;
|
||||
block_offset * page_stride +
|
||||
head_idx * head_stride + head_offset;
|
||||
scalar_t tgt_key = key[src_key_idx];
|
||||
scalar_t tgt_value = value[src_value_idx];
|
||||
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
|
||||
@ -396,16 +397,16 @@ void reshape_and_cache(
|
||||
// KV_T is the data type of key and value tensors.
|
||||
// CACHE_T is the stored data type of kv-cache.
|
||||
// KV_DTYPE is the real data type of kv-cache.
|
||||
#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
||||
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
||||
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride, \
|
||||
value_stride, num_heads, head_size, block_size, \
|
||||
reinterpret_cast<const float*>(k_scale.data_ptr()), \
|
||||
#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
||||
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
||||
slot_mapping.data_ptr<int64_t>(), block_stride, page_stride, \
|
||||
head_stride, key_stride, value_stride, num_heads, head_size, \
|
||||
block_size, reinterpret_cast<const float*>(k_scale.data_ptr()), \
|
||||
reinterpret_cast<const float*>(v_scale.data_ptr()));
|
||||
|
||||
void reshape_and_cache_flash(
|
||||
@ -432,9 +433,11 @@ void reshape_and_cache_flash(
|
||||
int head_size = key.size(2);
|
||||
int block_size = key_cache.size(1);
|
||||
|
||||
int key_stride = key.stride(0);
|
||||
int value_stride = value.stride(0);
|
||||
int block_stride = key_cache.stride(0);
|
||||
int64_t key_stride = key.stride(0);
|
||||
int64_t value_stride = value.stride(0);
|
||||
int64_t block_stride = key_cache.stride(0);
|
||||
int64_t page_stride = key_cache.stride(1);
|
||||
int64_t head_stride = key_cache.stride(2);
|
||||
TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0));
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
|
||||
@ -4,6 +4,11 @@
|
||||
#include <string>
|
||||
#include <sched.h>
|
||||
#endif
|
||||
#if __GLIBC__ == 2 && __GLIBC_MINOR__ < 30
|
||||
#include <unistd.h>
|
||||
#include <sys/syscall.h>
|
||||
#define gettid() syscall(SYS_gettid)
|
||||
#endif
|
||||
|
||||
#include "cpu_types.hpp"
|
||||
|
||||
|
||||
@ -375,7 +375,7 @@ class CustomAllreduce {
|
||||
bool fully_connected_;
|
||||
|
||||
RankSignals sg_;
|
||||
// Stores an map from a pointer to its peer pointers from all ranks.
|
||||
// Stores a map from a pointer to its peer pointers from all ranks.
|
||||
std::unordered_map<void*, RankData*> buffers_;
|
||||
Signal* self_sg_;
|
||||
|
||||
|
||||
@ -422,7 +422,7 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
|
||||
int final_state_position = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize);
|
||||
// in case the final state is separated between the last "smem_exchange" and
|
||||
// and the one before it (chunk = n_chunks - 1 and chunk = n_chunks - 2),
|
||||
// (which occurs when `final_state_position` is a non-positivie index)
|
||||
// (which occurs when `final_state_position` is a non-positive index)
|
||||
// we load the correct data from smem_exchange from both chunks, the last chunk iteration and the one before it
|
||||
if (conv_states != nullptr && final_state_position < 0 && seqlen > kWidth){
|
||||
input_t vals_load[kNElts] = {0};
|
||||
|
||||
103
csrc/moe/marlin_moe_wna16/generate_kernels.py
Normal file
103
csrc/moe/marlin_moe_wna16/generate_kernels.py
Normal file
@ -0,0 +1,103 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import glob
|
||||
import itertools
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
import jinja2
|
||||
|
||||
FILE_HEAD = """
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
""".strip()
|
||||
|
||||
TEMPLATE = ("template __global__ void Marlin<"
|
||||
"{{scalar_t}}, "
|
||||
"{{w_type_id}}, "
|
||||
"{{threads}}, "
|
||||
"{{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, "
|
||||
"{{thread_k_blocks}}, "
|
||||
"{{'true' if m_block_size_8 else 'false'}}, "
|
||||
"{{stages}}, "
|
||||
"{{'true' if has_act_order else 'false'}}, "
|
||||
"{{'true' if has_zp else 'false'}}, "
|
||||
"{{group_blocks}}, "
|
||||
"{{'true' if is_zp_float else 'false'}}>"
|
||||
"( MARLIN_KERNEL_PARAMS );")
|
||||
|
||||
# int8 with zero point case (vllm::kU8) is also supported,
|
||||
# we don't add it to reduce wheel size.
|
||||
SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128"]
|
||||
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
|
||||
|
||||
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
|
||||
# group_blocks:
|
||||
# = 0 : act order case
|
||||
# = -1 : channelwise quantization
|
||||
# > 0 : group_size=16*group_blocks
|
||||
GROUP_BLOCKS = [0, -1, 2, 4, 8]
|
||||
DTYPES = ["fp16", "bf16"]
|
||||
|
||||
|
||||
def remove_old_kernels():
|
||||
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"):
|
||||
subprocess.call(["rm", "-f", filename])
|
||||
|
||||
|
||||
def generate_new_kernels():
|
||||
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
|
||||
has_zp = "B" not in scalar_type
|
||||
all_template_str_list = []
|
||||
|
||||
for group_blocks, m_blocks, thread_configs in itertools.product(
|
||||
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS):
|
||||
|
||||
has_act_order = group_blocks == 0
|
||||
if has_zp and has_act_order:
|
||||
continue
|
||||
if thread_configs[2] == 256:
|
||||
if m_blocks <= 1 and thread_configs[0] != 128:
|
||||
continue
|
||||
if m_blocks > 1 and thread_configs[0] != 64:
|
||||
continue
|
||||
|
||||
k_blocks = thread_configs[0] // 16
|
||||
n_blocks = thread_configs[1] // 16
|
||||
threads = thread_configs[2]
|
||||
|
||||
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
|
||||
|
||||
template_str = jinja2.Template(TEMPLATE).render(
|
||||
scalar_t=c_dtype,
|
||||
w_type_id=scalar_type + ".id()",
|
||||
threads=threads,
|
||||
thread_m_blocks=max(m_blocks, 1),
|
||||
thread_n_blocks=n_blocks,
|
||||
thread_k_blocks=k_blocks,
|
||||
m_block_size_8=m_blocks == 0.5,
|
||||
stages="pipe_stages",
|
||||
has_act_order=has_act_order,
|
||||
has_zp=has_zp,
|
||||
group_blocks=group_blocks,
|
||||
is_zp_float=False,
|
||||
)
|
||||
|
||||
all_template_str_list.append(template_str)
|
||||
|
||||
file_content = FILE_HEAD + "\n\n"
|
||||
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
|
||||
filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu"
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
|
||||
f.write(file_content)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
remove_old_kernels()
|
||||
generate_new_kernels()
|
||||
44
csrc/moe/marlin_moe_wna16/kernel.h
Normal file
44
csrc/moe/marlin_moe_wna16/kernel.h
Normal file
@ -0,0 +1,44 @@
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#endif
|
||||
|
||||
#include "quantization/gptq_marlin/marlin.cuh"
|
||||
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
#define MARLIN_KERNEL_PARAMS \
|
||||
const int4 *__restrict__ A, const int4 *__restrict__ B, \
|
||||
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
|
||||
const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \
|
||||
const int *__restrict__ g_idx, \
|
||||
const int32_t *__restrict__ sorted_token_ids_ptr, \
|
||||
const int32_t *__restrict__ expert_ids_ptr, \
|
||||
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
|
||||
const float *__restrict__ topk_weights_ptr, int top_k, \
|
||||
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
|
||||
int prob_n, int prob_k, int *locks, bool use_atomic_add, \
|
||||
bool use_fp32_reduce
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
||||
const int threads, // number of threads in a threadblock
|
||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||
// dimension (batchsize) of the
|
||||
// threadblock
|
||||
const int thread_n_blocks, // same for n dimension (output)
|
||||
const int thread_k_blocks, // same for k dimension (reduction)
|
||||
const bool m_block_size_8, // whether m_block_size == 8
|
||||
// only works when thread_m_blocks == 1
|
||||
const int stages, // number of stages for the async global->shared
|
||||
// fetch pipeline
|
||||
const bool has_act_order, // whether act_order is enabled
|
||||
const bool has_zp, // whether zero-points are enabled
|
||||
const int group_blocks, // number of consecutive 16x16 blocks
|
||||
// with a separate quantization scale
|
||||
const bool is_zp_float // is zero point of float16 type?
|
||||
>
|
||||
__global__ void Marlin(MARLIN_KERNEL_PARAMS);
|
||||
|
||||
}
|
||||
1917
csrc/moe/marlin_moe_wna16/marlin_template.h
Normal file
1917
csrc/moe/marlin_moe_wna16/marlin_template.h
Normal file
File diff suppressed because it is too large
Load Diff
927
csrc/moe/marlin_moe_wna16/ops.cu
Normal file
927
csrc/moe/marlin_moe_wna16/ops.cu
Normal file
@ -0,0 +1,927 @@
|
||||
/*
|
||||
* Modified by Neural Magic
|
||||
* Copyright (C) Marlin.2024 Elias Frantar
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/*
|
||||
* Adapted from https://github.com/IST-DASLab/marlin
|
||||
*/
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#endif
|
||||
|
||||
#include "kernel.h"
|
||||
#include "core/registration.h"
|
||||
|
||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||
static_assert(std::is_same<scalar_t, half>::value || \
|
||||
std::is_same<scalar_t, nv_bfloat16>::value, \
|
||||
"only float16 and bfloat16 is supported");
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
__global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){};
|
||||
|
||||
using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS);
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
|
||||
template <int moe_block_size>
|
||||
__global__ void permute_cols_kernel(
|
||||
int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr,
|
||||
int4* __restrict__ out_int4_ptr,
|
||||
const int32_t* __restrict__ sorted_token_ids_ptr,
|
||||
const int32_t* __restrict__ expert_ids_ptr,
|
||||
const int32_t* __restrict__ num_tokens_past_padded_ptr, int size_m,
|
||||
int size_k, int top_k) {};
|
||||
|
||||
} // namespace marlin
|
||||
|
||||
torch::Tensor moe_wna16_marlin_gemm(
|
||||
torch::Tensor& a, std::optional<torch::Tensor> const& c_or_none,
|
||||
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
|
||||
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
|
||||
torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids,
|
||||
torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights,
|
||||
int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep,
|
||||
vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n,
|
||||
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
|
||||
bool is_zp_float) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||
"marlin_gemm(..) requires CUDA_ARCH >= 8.0");
|
||||
return torch::empty({1, 1});
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
// For a given "a" of size [M,K] performs a permutation of the K columns based
|
||||
// on the given "perm" indices.
|
||||
template <int moe_block_size>
|
||||
__global__ void permute_cols_kernel(
|
||||
int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr,
|
||||
int4* __restrict__ out_int4_ptr,
|
||||
const int32_t* __restrict__ sorted_token_ids_ptr,
|
||||
const int32_t* __restrict__ expert_ids_ptr,
|
||||
const int32_t* __restrict__ num_tokens_past_padded_ptr, int size_m,
|
||||
int size_k, int top_k) {
|
||||
int num_tokens_past_padded = num_tokens_past_padded_ptr[0];
|
||||
int num_moe_blocks = div_ceil(num_tokens_past_padded, moe_block_size);
|
||||
int32_t block_sorted_ids[moe_block_size];
|
||||
int block_num_valid_tokens = 0;
|
||||
int64_t old_expert_id = 0;
|
||||
int64_t expert_id = 0;
|
||||
int row_stride = size_k * sizeof(half) / 16;
|
||||
|
||||
auto read_moe_block_data = [&](int block_id) {
|
||||
block_num_valid_tokens = moe_block_size;
|
||||
int4* tmp_block_sorted_ids = reinterpret_cast<int4*>(block_sorted_ids);
|
||||
for (int i = 0; i < moe_block_size / 4; i++) {
|
||||
tmp_block_sorted_ids[i] =
|
||||
((int4*)sorted_token_ids_ptr)[block_id * moe_block_size / 4 + i];
|
||||
}
|
||||
for (int i = 0; i < moe_block_size; i++) {
|
||||
if (block_sorted_ids[i] >= size_m * top_k) {
|
||||
block_num_valid_tokens = i;
|
||||
break;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
auto permute_row = [&](int row) {
|
||||
int iters = size_k / default_threads;
|
||||
int rest = size_k % default_threads;
|
||||
|
||||
int in_offset = (row / top_k) * row_stride;
|
||||
int out_offset = row * row_stride;
|
||||
|
||||
half const* a_row_half =
|
||||
reinterpret_cast<half const*>(a_int4_ptr + in_offset);
|
||||
half* out_half = reinterpret_cast<half*>(out_int4_ptr + out_offset);
|
||||
|
||||
int base_k = 0;
|
||||
|
||||
for (int i = 0; i < iters; i++) {
|
||||
int cur_k = base_k + threadIdx.x;
|
||||
int src_pos = perm_int_ptr[cur_k];
|
||||
|
||||
out_half[cur_k] = a_row_half[src_pos];
|
||||
|
||||
base_k += default_threads;
|
||||
}
|
||||
|
||||
if (rest) {
|
||||
if (threadIdx.x < rest) {
|
||||
int cur_k = base_k + threadIdx.x;
|
||||
int src_pos = perm_int_ptr[cur_k];
|
||||
|
||||
out_half[cur_k] = a_row_half[src_pos];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
for (int index = blockIdx.x; index < num_moe_blocks; index += gridDim.x) {
|
||||
old_expert_id = expert_id;
|
||||
int tmp_expert_id = expert_ids_ptr[index];
|
||||
if (tmp_expert_id == -1) continue;
|
||||
expert_id = tmp_expert_id;
|
||||
perm_int_ptr += (expert_id - old_expert_id) * size_k;
|
||||
read_moe_block_data(index);
|
||||
|
||||
for (int i = 0; i < block_num_valid_tokens; i++)
|
||||
permute_row(block_sorted_ids[i]);
|
||||
}
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
int thread_k;
|
||||
int thread_n;
|
||||
int num_threads;
|
||||
} thread_config_t;
|
||||
|
||||
thread_config_t small_batch_thread_configs[] = {
|
||||
// Ordered by priority
|
||||
|
||||
// thread_k, thread_n, num_threads
|
||||
{128, 128, 256},
|
||||
{64, 128, 128}};
|
||||
|
||||
thread_config_t large_batch_thread_configs[] = {
|
||||
// Ordered by priority
|
||||
|
||||
// thread_k, thread_n, num_threads
|
||||
{64, 256, 256},
|
||||
{64, 128, 128}};
|
||||
|
||||
typedef struct {
|
||||
int blocks_per_sm;
|
||||
thread_config_t tb_cfg;
|
||||
} exec_config_t;
|
||||
|
||||
int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
|
||||
int prob_n, int prob_k, int num_bits, int group_size,
|
||||
bool has_act_order, bool is_k_full) {
|
||||
bool cache_scales_chunk = has_act_order && !is_k_full;
|
||||
|
||||
int tb_n = th_config.thread_n;
|
||||
int tb_k = th_config.thread_k;
|
||||
|
||||
// Get max scale groups per thread-block
|
||||
int tb_groups;
|
||||
if (group_size == -1) {
|
||||
tb_groups = 1;
|
||||
} else if (group_size == 0) {
|
||||
tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size
|
||||
} else {
|
||||
tb_groups = div_ceil(tb_k, group_size);
|
||||
}
|
||||
|
||||
if (cache_scales_chunk) {
|
||||
int load_groups =
|
||||
tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K
|
||||
load_groups = max(load_groups, 32); // We load at least 32 scale groups
|
||||
return load_groups * tb_n * 2;
|
||||
|
||||
} else {
|
||||
int tb_scales = tb_groups * tb_n * 2;
|
||||
|
||||
return tb_scales * pipe_stages;
|
||||
}
|
||||
}
|
||||
|
||||
int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
|
||||
int prob_m, int prob_n, int prob_k, int num_bits,
|
||||
int group_size, bool has_act_order, bool is_k_full,
|
||||
int has_zp, int is_zp_float) {
|
||||
int pack_factor = 32 / num_bits;
|
||||
|
||||
// Get B size
|
||||
int tb_k = th_config.thread_k;
|
||||
int tb_n = th_config.thread_n;
|
||||
int tb_m = thread_m_blocks * 16;
|
||||
|
||||
// shm size for block_sorted_ids/block_topk_weights
|
||||
// both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)
|
||||
int sh_block_meta_size = tb_m * 4 * 2;
|
||||
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2;
|
||||
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
|
||||
int sh_s_size =
|
||||
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
|
||||
group_size, has_act_order, is_k_full);
|
||||
int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0;
|
||||
int sh_zp_size = 0;
|
||||
if (has_zp) {
|
||||
if (is_zp_float)
|
||||
sh_zp_size = sh_s_size;
|
||||
else if (num_bits == 4)
|
||||
sh_zp_size = sh_s_size / 4;
|
||||
else if (num_bits == 8)
|
||||
sh_zp_size = sh_s_size / 2;
|
||||
}
|
||||
|
||||
int total_size = sh_a_size + sh_b_size + sh_s_size + sh_zp_size +
|
||||
sh_g_idx_size + sh_block_meta_size;
|
||||
|
||||
return total_size;
|
||||
}
|
||||
|
||||
bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
|
||||
int prob_m, int prob_n, int prob_k, int num_bits,
|
||||
int group_size, bool has_act_order, bool is_k_full,
|
||||
int has_zp, int is_zp_float, int max_shared_mem) {
|
||||
// Sanity
|
||||
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
|
||||
th_config.num_threads == -1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Verify K/N are divisible by thread K/N
|
||||
if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Verify min for thread K/N
|
||||
if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// num_threads must be at least 128 (= 4 warps)
|
||||
if (th_config.num_threads < 128) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check that pipeline fits into cache
|
||||
int cache_size = get_kernel_cache_size(
|
||||
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size,
|
||||
has_act_order, is_k_full, has_zp, is_zp_float);
|
||||
return cache_size <= max_shared_mem;
|
||||
}
|
||||
|
||||
#define __GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
||||
M_BLOCK_SIZE_8, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
|
||||
NUM_THREADS, IS_ZP_FLOAT) \
|
||||
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
|
||||
thread_n_blocks == THREAD_N_BLOCKS && \
|
||||
thread_k_blocks == THREAD_K_BLOCKS && \
|
||||
m_block_size_8 == M_BLOCK_SIZE_8 && \
|
||||
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
|
||||
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
|
||||
is_zp_float == IS_ZP_FLOAT) { \
|
||||
kernel = Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
|
||||
THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, \
|
||||
pipe_stages, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
|
||||
IS_ZP_FLOAT>; \
|
||||
}
|
||||
|
||||
#define GPTQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, true, false, 0, NUM_THREADS, \
|
||||
false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
|
||||
NUM_THREADS, false) \
|
||||
\
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 8, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
|
||||
NUM_THREADS, false)
|
||||
|
||||
#define GPTQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
|
||||
NUM_THREADS, false) \
|
||||
\
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
|
||||
NUM_THREADS, false) \
|
||||
\
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
|
||||
NUM_THREADS, false) \
|
||||
\
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
|
||||
NUM_THREADS, false)
|
||||
|
||||
#define AWQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 2, NUM_THREADS, \
|
||||
false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \
|
||||
false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 8, NUM_THREADS, \
|
||||
false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
|
||||
NUM_THREADS, false)
|
||||
|
||||
#define AWQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
|
||||
NUM_THREADS, false) \
|
||||
\
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
|
||||
NUM_THREADS, false) \
|
||||
\
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
|
||||
NUM_THREADS, false)
|
||||
|
||||
// We currently have 4-bit models only with group_blocks == 4
|
||||
#define HQQ_GET_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \
|
||||
true) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||
NUM_THREADS, true) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||
NUM_THREADS, true) \
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||
NUM_THREADS, true) \
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||
NUM_THREADS, true)
|
||||
|
||||
template <typename scalar_t>
|
||||
MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
|
||||
int thread_m_blocks, int thread_n_blocks,
|
||||
int thread_k_blocks, bool m_block_size_8,
|
||||
bool has_act_order, bool has_zp,
|
||||
int group_blocks, int num_threads,
|
||||
bool is_zp_float) {
|
||||
int num_bits = q_type.size_bits();
|
||||
auto kernel = MarlinDefault;
|
||||
if (false) {
|
||||
}
|
||||
GPTQ_GET_IF_M1(vllm::kU4B8, 8, 8, 256)
|
||||
GPTQ_GET_IF_M1(vllm::kU4B8, 8, 4, 128)
|
||||
|
||||
GPTQ_GET_IF_M234(vllm::kU4B8, 16, 4, 256)
|
||||
GPTQ_GET_IF_M234(vllm::kU4B8, 8, 4, 128)
|
||||
|
||||
GPTQ_GET_IF_M1(vllm::kU8B128, 8, 8, 256)
|
||||
GPTQ_GET_IF_M1(vllm::kU8B128, 8, 4, 128)
|
||||
|
||||
GPTQ_GET_IF_M234(vllm::kU8B128, 16, 4, 256)
|
||||
GPTQ_GET_IF_M234(vllm::kU8B128, 8, 4, 128)
|
||||
|
||||
AWQ_GET_IF_M1(vllm::kU4, 8, 8, 256)
|
||||
AWQ_GET_IF_M1(vllm::kU4, 8, 4, 128)
|
||||
|
||||
AWQ_GET_IF_M234(vllm::kU4, 16, 4, 256)
|
||||
AWQ_GET_IF_M234(vllm::kU4, 8, 4, 128)
|
||||
|
||||
return kernel;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
|
||||
int prob_n, int prob_k, int thread_m_blocks,
|
||||
bool m_block_size_8, int num_bits,
|
||||
int group_size, bool has_act_order,
|
||||
bool is_k_full, bool has_zp,
|
||||
bool is_zp_float, int max_shared_mem) {
|
||||
exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}};
|
||||
thread_config_t* thread_configs = thread_m_blocks > 1
|
||||
? large_batch_thread_configs
|
||||
: small_batch_thread_configs;
|
||||
int thread_configs_size =
|
||||
thread_m_blocks > 1
|
||||
? sizeof(large_batch_thread_configs) / sizeof(thread_config_t)
|
||||
: sizeof(small_batch_thread_configs) / sizeof(thread_config_t);
|
||||
|
||||
int count = 0;
|
||||
constexpr int device_max_reg_size = 255 * 1024;
|
||||
for (int i = 0; i < thread_configs_size; i++) {
|
||||
thread_config_t th_config = thread_configs[i];
|
||||
|
||||
if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k,
|
||||
num_bits, group_size, has_act_order, is_k_full, has_zp,
|
||||
is_zp_float, max_shared_mem)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int cache_size = get_kernel_cache_size(
|
||||
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits,
|
||||
group_size, has_act_order, is_k_full, has_zp, is_zp_float);
|
||||
|
||||
int group_blocks = 0;
|
||||
if (!has_act_order) {
|
||||
group_blocks = group_size == -1 ? -1 : group_size / 16;
|
||||
}
|
||||
|
||||
auto kernel = get_marlin_kernel<scalar_t>(
|
||||
q_type, thread_m_blocks, th_config.thread_n / 16,
|
||||
th_config.thread_k / 16, m_block_size_8, has_act_order, has_zp,
|
||||
group_blocks, th_config.num_threads, is_zp_float);
|
||||
|
||||
if (kernel == MarlinDefault) continue;
|
||||
|
||||
if (thread_m_blocks > 1) {
|
||||
exec_cfg = {1, th_config};
|
||||
break;
|
||||
} else {
|
||||
cudaFuncAttributes attr;
|
||||
cudaFuncGetAttributes(&attr, kernel);
|
||||
int reg_size = max(attr.numRegs, 1) * th_config.num_threads * 4;
|
||||
int allow_count = min(device_max_reg_size / reg_size,
|
||||
max_shared_mem / (cache_size + 1024));
|
||||
allow_count = max(min(allow_count, 4), 1);
|
||||
if (allow_count > count) {
|
||||
count = allow_count;
|
||||
exec_cfg = {count, th_config};
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
return exec_cfg;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
void* zp, void* g_idx, void* perm, void* a_tmp,
|
||||
void* sorted_token_ids, void* expert_ids,
|
||||
void* num_tokens_past_padded, void* topk_weights,
|
||||
int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep,
|
||||
int prob_m, int prob_n, int prob_k, void* workspace,
|
||||
vllm::ScalarType const& q_type, bool has_act_order,
|
||||
bool is_k_full, bool has_zp, int num_groups, int group_size,
|
||||
int dev, cudaStream_t stream, int thread_k, int thread_n,
|
||||
int sms, bool use_atomic_add, bool use_fp32_reduce,
|
||||
bool is_zp_float) {
|
||||
int thread_m_blocks = div_ceil(moe_block_size, 16);
|
||||
bool m_block_size_8 = moe_block_size == 8;
|
||||
|
||||
if (has_zp) {
|
||||
TORCH_CHECK(
|
||||
q_type == vllm::kU4 || q_type == vllm::kU8,
|
||||
"q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
q_type == vllm::kU4B8 || q_type == vllm::kU8B128,
|
||||
"q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
|
||||
q_type.str());
|
||||
}
|
||||
|
||||
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
||||
", ", prob_n, ", ", prob_k, "]");
|
||||
|
||||
int group_blocks = 0;
|
||||
if (has_act_order) {
|
||||
if (is_k_full) {
|
||||
TORCH_CHECK(group_size != -1);
|
||||
group_blocks = group_size / 16;
|
||||
TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
|
||||
" is not divisible by group_blocks = ", group_blocks);
|
||||
} else {
|
||||
TORCH_CHECK(group_size == 0);
|
||||
group_blocks = 0;
|
||||
}
|
||||
} else {
|
||||
if (group_size == -1) {
|
||||
group_blocks = -1;
|
||||
} else {
|
||||
group_blocks = group_size / 16;
|
||||
TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
|
||||
" is not divisible by group_blocks = ", group_blocks);
|
||||
}
|
||||
}
|
||||
|
||||
int num_bits = q_type.size_bits();
|
||||
const int4* A_ptr = (const int4*)A;
|
||||
const int4* B_ptr = (const int4*)B;
|
||||
int4* C_ptr = (int4*)C;
|
||||
int4* C_tmp_ptr = (int4*)C_tmp;
|
||||
const int4* s_ptr = (const int4*)s;
|
||||
const int4* zp_ptr = (const int4*)zp;
|
||||
const int* g_idx_ptr = (const int*)g_idx;
|
||||
const int* perm_ptr = (const int*)perm;
|
||||
int4* a_tmp_ptr = (int4*)a_tmp;
|
||||
const int32_t* sorted_token_ids_ptr = (const int32_t*)sorted_token_ids;
|
||||
const int32_t* expert_ids_ptr = (const int32_t*)expert_ids;
|
||||
const int32_t* num_tokens_past_padded_ptr =
|
||||
(const int32_t*)num_tokens_past_padded;
|
||||
const float* topk_weights_ptr = (const float*)topk_weights;
|
||||
int* locks = (int*)workspace;
|
||||
|
||||
if (has_act_order) {
|
||||
// Permute A columns
|
||||
auto kernel = permute_cols_kernel<8>;
|
||||
if (moe_block_size == 8) {
|
||||
} else if (moe_block_size == 16)
|
||||
kernel = permute_cols_kernel<16>;
|
||||
else if (moe_block_size == 32)
|
||||
kernel = permute_cols_kernel<32>;
|
||||
else if (moe_block_size == 48)
|
||||
kernel = permute_cols_kernel<48>;
|
||||
else if (moe_block_size == 64)
|
||||
kernel = permute_cols_kernel<64>;
|
||||
else
|
||||
TORCH_CHECK(false, "unsupported moe_block_size ", moe_block_size);
|
||||
|
||||
// avoid ">>>" being formatted to "> > >"
|
||||
// clang-format off
|
||||
kernel<<<sms, default_threads, 0, stream>>>(
|
||||
A_ptr, perm_ptr, a_tmp_ptr, sorted_token_ids_ptr, expert_ids_ptr,
|
||||
num_tokens_past_padded_ptr, prob_m, prob_k, top_k);
|
||||
// clang-format on
|
||||
A_ptr = a_tmp_ptr;
|
||||
prob_m = prob_m * top_k;
|
||||
top_k = 1;
|
||||
|
||||
// If we have a full K, then we can run the non-act-order version of Marlin
|
||||
// (since the weight rows are reordered by increasing group ids, and by
|
||||
// having a full K, we have full original groups)
|
||||
if (is_k_full) has_act_order = false;
|
||||
}
|
||||
|
||||
int max_shared_mem = 0;
|
||||
cudaDeviceGetAttribute(&max_shared_mem,
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
||||
TORCH_CHECK(max_shared_mem > 0);
|
||||
|
||||
// Set thread config
|
||||
exec_config_t exec_cfg;
|
||||
thread_config_t thread_tfg;
|
||||
if (thread_k != -1 && thread_n != -1) {
|
||||
thread_tfg = thread_config_t{thread_k, thread_n, default_threads};
|
||||
exec_cfg = exec_config_t{1, thread_tfg};
|
||||
TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
|
||||
" is not divisible by thread_n = ", thread_n);
|
||||
TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
|
||||
" is not divisible by thread_k = ", thread_k);
|
||||
} else {
|
||||
// Auto config
|
||||
exec_cfg = determine_exec_config<scalar_t>(
|
||||
q_type, prob_m, prob_n, prob_k, thread_m_blocks, m_block_size_8,
|
||||
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float,
|
||||
max_shared_mem);
|
||||
thread_tfg = exec_cfg.tb_cfg;
|
||||
}
|
||||
|
||||
int num_threads = thread_tfg.num_threads;
|
||||
thread_k = thread_tfg.thread_k;
|
||||
thread_n = thread_tfg.thread_n;
|
||||
int blocks = sms * exec_cfg.blocks_per_sm;
|
||||
if (exec_cfg.blocks_per_sm > 1)
|
||||
max_shared_mem = max_shared_mem / exec_cfg.blocks_per_sm - 1024;
|
||||
|
||||
int thread_k_blocks = thread_k / 16;
|
||||
int thread_n_blocks = thread_n / 16;
|
||||
|
||||
TORCH_CHECK(is_valid_config(thread_tfg, thread_m_blocks, prob_m, prob_n,
|
||||
prob_k, num_bits, group_size, has_act_order,
|
||||
is_k_full, has_zp, is_zp_float, max_shared_mem),
|
||||
"Invalid thread config: thread_m_blocks = ", thread_m_blocks,
|
||||
", thread_k = ", thread_tfg.thread_k,
|
||||
", thread_n = ", thread_tfg.thread_n,
|
||||
", num_threads = ", thread_tfg.num_threads, " for MKN = [",
|
||||
prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
|
||||
", group_size = ", group_size,
|
||||
", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
|
||||
", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float,
|
||||
", max_shared_mem = ", max_shared_mem);
|
||||
|
||||
auto kernel = get_marlin_kernel<scalar_t>(
|
||||
q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, m_block_size_8,
|
||||
has_act_order, has_zp, group_blocks, num_threads, is_zp_float);
|
||||
|
||||
if (kernel == MarlinDefault) {
|
||||
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
|
||||
", ", prob_k, "]", ", has_act_order = ", has_act_order,
|
||||
", num_groups = ", num_groups, ", group_size = ", group_size,
|
||||
", thread_m_blocks = ", thread_m_blocks,
|
||||
", thread_n_blocks = ", thread_n_blocks,
|
||||
", thread_k_blocks = ", thread_k_blocks,
|
||||
", num_bits = ", num_bits);
|
||||
}
|
||||
|
||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
max_shared_mem);
|
||||
// avoid ">>>" being formatted to "> > >"
|
||||
// clang-format off
|
||||
kernel<<<blocks, num_threads, max_shared_mem, stream>>>(
|
||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr,
|
||||
sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,
|
||||
topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m,
|
||||
prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
|
||||
torch::Tensor moe_wna16_marlin_gemm(
|
||||
torch::Tensor& a, std::optional<torch::Tensor> const& c_or_none,
|
||||
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
|
||||
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
|
||||
torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids,
|
||||
torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights,
|
||||
int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep,
|
||||
vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n,
|
||||
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
|
||||
bool is_zp_float) {
|
||||
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
|
||||
int pack_factor = 32 / b_q_type.size_bits();
|
||||
|
||||
if (moe_block_size != 8) {
|
||||
TORCH_CHECK(moe_block_size % 16 == 0,
|
||||
"unsupported moe_block_size=", moe_block_size);
|
||||
TORCH_CHECK(moe_block_size >= 16 && moe_block_size <= 64,
|
||||
"unsupported moe_block_size=", moe_block_size);
|
||||
}
|
||||
|
||||
// Verify A
|
||||
TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0),
|
||||
", size_m = ", size_m);
|
||||
TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1),
|
||||
", size_k = ", size_k);
|
||||
|
||||
// Verify B
|
||||
TORCH_CHECK(
|
||||
size_k % MARLIN_NAMESPACE_NAME::tile_size == 0, "size_k = ", size_k,
|
||||
" is not divisible by tile_size = ", MARLIN_NAMESPACE_NAME::tile_size);
|
||||
TORCH_CHECK((size_k / MARLIN_NAMESPACE_NAME::tile_size) == b_q_weight.size(1),
|
||||
"Shape mismatch: b_q_weight.size(1) = ", b_q_weight.size(1),
|
||||
", size_k = ", size_k,
|
||||
", tile_size = ", MARLIN_NAMESPACE_NAME::tile_size);
|
||||
TORCH_CHECK(
|
||||
b_q_weight.size(2) % MARLIN_NAMESPACE_NAME::tile_size == 0,
|
||||
"b_q_weight.size(2) = ", b_q_weight.size(2),
|
||||
" is not divisible by tile_size = ", MARLIN_NAMESPACE_NAME::tile_size);
|
||||
int actual_size_n =
|
||||
(b_q_weight.size(2) / MARLIN_NAMESPACE_NAME::tile_size) * pack_factor;
|
||||
TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
|
||||
", actual_size_n = ", actual_size_n);
|
||||
|
||||
// Verify device and strides
|
||||
TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
|
||||
TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
|
||||
|
||||
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
|
||||
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
|
||||
|
||||
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
|
||||
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
|
||||
|
||||
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
|
||||
// auto -1)
|
||||
int thread_k = -1;
|
||||
// thread_n: `n` size of a thread_tile in `weights` (can usually be left as
|
||||
// auto -1)
|
||||
int thread_n = -1;
|
||||
// sms: number of SMs to use for the kernel
|
||||
int sms = -1;
|
||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device());
|
||||
|
||||
// Alloc buffers
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
||||
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
|
||||
torch::Tensor c;
|
||||
if (c_or_none.has_value()) {
|
||||
c = c_or_none.value();
|
||||
TORCH_CHECK(c.device().is_cuda(), "c is not on GPU");
|
||||
TORCH_CHECK(c.is_contiguous(), "c is not contiguous");
|
||||
TORCH_CHECK(c.size(0) == size_m * top_k,
|
||||
"Shape mismatch: c.size(0) = ", c.size(0),
|
||||
", size_m * topk = ", size_m * top_k);
|
||||
TORCH_CHECK(c.size(1) == size_n, "Shape mismatch: c.size(1) = ", c.size(1),
|
||||
", size_n = ", size_n);
|
||||
} else {
|
||||
c = torch::empty({size_m * top_k, size_n}, options);
|
||||
}
|
||||
|
||||
// Alloc C tmp buffer that is going to be used for the global reduce
|
||||
torch::Tensor c_tmp;
|
||||
auto options_fp32 =
|
||||
torch::TensorOptions().dtype(at::kFloat).device(a.device());
|
||||
if (use_fp32_reduce && !use_atomic_add) {
|
||||
// max num of threadblocks is sms * 4
|
||||
long max_c_tmp_size = min(
|
||||
(long)size_n * sorted_token_ids.size(0),
|
||||
(long)sms * 4 * moe_block_size * MARLIN_NAMESPACE_NAME::max_thread_n);
|
||||
if (moe_block_size == 8) max_c_tmp_size *= 2;
|
||||
c_tmp = torch::empty({max_c_tmp_size}, options_fp32);
|
||||
} else {
|
||||
c_tmp = torch::empty({0}, options_fp32);
|
||||
}
|
||||
|
||||
// Detect groupsize and act_order
|
||||
int num_groups = -1;
|
||||
int group_size = -1;
|
||||
|
||||
int rank = b_scales.sizes().size();
|
||||
TORCH_CHECK(rank == 3, "b_scales rank = ", rank, " is not 3");
|
||||
TORCH_CHECK(b_scales.size(2) == size_n, "b_scales dim 2 = ", b_scales.size(2),
|
||||
" is not size_n = ", size_n);
|
||||
num_groups = b_scales.size(1);
|
||||
|
||||
torch::Tensor g_idx, perm, a_tmp;
|
||||
;
|
||||
if (g_idx_or_none.has_value() && perm_or_none.has_value()) {
|
||||
g_idx = g_idx_or_none.value();
|
||||
perm = perm_or_none.value();
|
||||
|
||||
TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU");
|
||||
TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous");
|
||||
TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU");
|
||||
TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous");
|
||||
|
||||
// Verify g_idx and perm
|
||||
TORCH_CHECK((g_idx.size(-1) == 0 && perm.size(-1) == 0) ||
|
||||
(g_idx.size(-1) == size_k && perm.size(-1) == size_k),
|
||||
"Unexpected g_idx.size(-1) = ", g_idx.size(-1),
|
||||
" and perm.size(-1) = ", perm.size(-1),
|
||||
", where size_k = ", size_k);
|
||||
} else {
|
||||
g_idx = torch::empty({0}, options);
|
||||
perm = torch::empty({0}, options);
|
||||
a_tmp = torch::empty({0}, options);
|
||||
}
|
||||
bool has_act_order = g_idx.size(-1) > 0 && perm.size(-1) > 0;
|
||||
|
||||
if (has_act_order) {
|
||||
a_tmp = torch::empty({size_m * top_k, size_k}, options);
|
||||
if (is_k_full) {
|
||||
TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1");
|
||||
TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k,
|
||||
", is not divisible by num_groups = ", num_groups);
|
||||
group_size = size_k / num_groups;
|
||||
} else {
|
||||
group_size = 0;
|
||||
}
|
||||
|
||||
} else {
|
||||
a_tmp = torch::empty({0}, options);
|
||||
if (num_groups > 1) {
|
||||
TORCH_CHECK(
|
||||
size_k % num_groups == 0, "size_k = ", size_k,
|
||||
", is not divisible by b_scales.size(1) = ", b_scales.size(1));
|
||||
group_size = size_k / num_groups;
|
||||
} else {
|
||||
group_size = -1;
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor b_zeros;
|
||||
if (b_zeros_or_none.has_value()) {
|
||||
b_zeros = b_zeros_or_none.value();
|
||||
TORCH_CHECK(b_zeros.device().is_cuda(), "b_zeros is not on GPU");
|
||||
TORCH_CHECK(b_zeros.is_contiguous(), "b_zeros is not contiguous");
|
||||
} else {
|
||||
b_zeros = torch::empty({0}, options);
|
||||
}
|
||||
bool has_zp = b_zeros.size(-1) > 0;
|
||||
|
||||
if (has_zp) {
|
||||
TORCH_CHECK(
|
||||
b_q_type == vllm::kU4,
|
||||
"b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str());
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128,
|
||||
"b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
|
||||
b_q_type.str());
|
||||
}
|
||||
|
||||
if (has_zp && is_zp_float) {
|
||||
TORCH_CHECK(a.scalar_type() == at::ScalarType::Half,
|
||||
"Computation type must be float16 (half) when using float zero "
|
||||
"points.");
|
||||
}
|
||||
|
||||
// Verify b_zeros
|
||||
if (has_zp) {
|
||||
int rank = b_zeros.sizes().size();
|
||||
TORCH_CHECK(rank == 3, "b_zeros rank = ", rank, " is not 3");
|
||||
if (is_zp_float) {
|
||||
TORCH_CHECK(b_zeros.size(2) == size_n,
|
||||
"b_zeros dim 2 = ", b_zeros.size(2),
|
||||
" is not size_n = ", size_n);
|
||||
TORCH_CHECK(num_groups == b_zeros.size(1),
|
||||
"b_zeros dim 1 = ", b_zeros.size(1),
|
||||
" is not num_groups = ", num_groups);
|
||||
TORCH_CHECK(num_groups != -1, "num_groups must be != -1");
|
||||
} else {
|
||||
TORCH_CHECK(b_zeros.size(1) == num_groups,
|
||||
"b_zeros dim 1 = ", b_zeros.size(1),
|
||||
" is not num_groups = ", num_groups);
|
||||
TORCH_CHECK(b_zeros.size(2) == size_n / pack_factor,
|
||||
"b_zeros dim 2 = ", b_zeros.size(2),
|
||||
" is not size_n / pack_factor = ", size_n / pack_factor);
|
||||
}
|
||||
}
|
||||
|
||||
// Verify workspace size
|
||||
TORCH_CHECK(size_n % MARLIN_NAMESPACE_NAME::min_thread_n == 0,
|
||||
"size_n = ", size_n, ", is not divisible by min_thread_n = ",
|
||||
MARLIN_NAMESPACE_NAME::min_thread_n);
|
||||
|
||||
int max_n_tiles = size_n / MARLIN_NAMESPACE_NAME::min_thread_n;
|
||||
int min_workspace_size = min(
|
||||
max_n_tiles * (int)(sorted_token_ids.size(0) / moe_block_size), sms * 4);
|
||||
TORCH_CHECK(workspace.numel() >= min_workspace_size,
|
||||
"workspace.numel = ", workspace.numel(),
|
||||
" is below min_workspace_size = ", min_workspace_size);
|
||||
|
||||
int dev = a.get_device();
|
||||
if (a.scalar_type() == at::ScalarType::Half) {
|
||||
MARLIN_NAMESPACE_NAME::marlin_mm<half>(
|
||||
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
|
||||
c_tmp.data_ptr<float>(), b_scales.data_ptr<at::Half>(),
|
||||
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
|
||||
a_tmp.data_ptr<at::Half>(), sorted_token_ids.data_ptr(),
|
||||
expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(),
|
||||
topk_weights.data_ptr(), moe_block_size, top_k, mul_topk_weights, is_ep,
|
||||
size_m, size_n, size_k, workspace.data_ptr(), b_q_type, has_act_order,
|
||||
is_k_full, has_zp, num_groups, group_size, dev,
|
||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||
use_atomic_add, use_fp32_reduce, is_zp_float);
|
||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||
MARLIN_NAMESPACE_NAME::marlin_mm<nv_bfloat16>(
|
||||
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
||||
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
|
||||
b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(),
|
||||
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
|
||||
sorted_token_ids.data_ptr(), expert_ids.data_ptr(),
|
||||
num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(),
|
||||
moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k,
|
||||
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
|
||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float);
|
||||
} else {
|
||||
TORCH_CHECK(false,
|
||||
"moe_wna16_marlin_gemm only supports bfloat16 and float16");
|
||||
}
|
||||
|
||||
return c;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("moe_wna16_marlin_gemm", &moe_wna16_marlin_gemm);
|
||||
}
|
||||
@ -13,7 +13,6 @@
|
||||
template <typename scalar_t, int bit, int GROUPS>
|
||||
__global__ void moe_wna16_gemm_kernel(
|
||||
const scalar_t* __restrict__ input, scalar_t* __restrict__ output,
|
||||
|
||||
const uint32_t* __restrict__ qweight, const scalar_t* __restrict__ scales,
|
||||
const uint32_t* __restrict__ qzeros,
|
||||
|
||||
@ -54,8 +53,6 @@ __global__ void moe_wna16_gemm_kernel(
|
||||
if (token_index / top_k >= size_m) break;
|
||||
|
||||
num_valid_tokens = m + 1;
|
||||
if (blockIdx.z == 0 && offset_n < size_n)
|
||||
output[token_index * size_n + offset_n] = Dtype::int2num(0);
|
||||
|
||||
if (expert_id != -1) {
|
||||
int k_per_thread = DIVIDE(BLOCK_SIZE_K, BLOCK_SIZE_N);
|
||||
@ -284,8 +281,7 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
|
||||
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
|
||||
int64_t BLOCK_SIZE_K, int64_t bit) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
auto options =
|
||||
torch::TensorOptions().dtype(input.dtype()).device(input.device());
|
||||
output.zero_();
|
||||
|
||||
const int num_experts = b_qweight.size(0);
|
||||
const int size_m = input.size(0);
|
||||
@ -302,9 +298,9 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
|
||||
const uint32_t* b_qzeros_ptr;
|
||||
if (b_qzeros.has_value())
|
||||
b_qzeros_ptr = (const uint32_t*)b_qzeros.value().data_ptr<uint8_t>();
|
||||
const float* topk_weights_ptr;
|
||||
const float* topk_weights_ptr = nullptr;
|
||||
if (topk_weights.has_value())
|
||||
topk_weights_ptr = (const float*)topk_weights.value().data_ptr();
|
||||
topk_weights_ptr = (const float*)topk_weights.value().data_ptr<float>();
|
||||
|
||||
int groups_per_block_row = BLOCK_SIZE_K / group_size;
|
||||
TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8");
|
||||
|
||||
@ -43,14 +43,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
m.impl("moe_wna16_gemm", torch::kCUDA, &moe_wna16_gemm);
|
||||
|
||||
m.def(
|
||||
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
|
||||
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
|
||||
"b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, "
|
||||
"int b_q_type, SymInt size_m, "
|
||||
"SymInt size_n, SymInt size_k, bool is_k_full, int num_experts, int "
|
||||
"topk, "
|
||||
"int moe_block_size, bool replicate_input, bool apply_weights)"
|
||||
" -> Tensor");
|
||||
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
|
||||
"Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none,"
|
||||
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
|
||||
"Tensor sorted_token_ids,"
|
||||
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
|
||||
"Tensor! topk_weights, int moe_block_size, int top_k, "
|
||||
"bool mul_topk_weights, bool is_ep, int b_q_type_id,"
|
||||
"int size_m, int size_n, int size_k,"
|
||||
"bool is_full_k, bool use_atomic_add,"
|
||||
"bool use_fp32_reduce, bool is_zp_float) -> Tensor");
|
||||
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
#endif
|
||||
|
||||
15
csrc/ops.h
15
csrc/ops.h
@ -52,6 +52,15 @@ void paged_attention_v2(
|
||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||
const int64_t blocksparse_head_sliding_step);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
void merge_attn_states(torch::Tensor& output,
|
||||
std::optional<torch::Tensor> output_lse,
|
||||
const torch::Tensor& prefix_output,
|
||||
const torch::Tensor& prefix_lse,
|
||||
const torch::Tensor& suffix_output,
|
||||
const torch::Tensor& suffix_lse);
|
||||
#endif
|
||||
|
||||
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
||||
double epsilon);
|
||||
|
||||
@ -119,6 +128,12 @@ void advance_step_flashinfer(
|
||||
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
|
||||
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds);
|
||||
|
||||
void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
|
||||
torch::Tensor const& q_pe,
|
||||
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||
torch::Tensor const& seq_lens,
|
||||
torch::Tensor const& page_table, double scale);
|
||||
|
||||
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
|
||||
@ -46,14 +46,26 @@ __global__ void compute_expert_offsets(
|
||||
}
|
||||
|
||||
__global__ void compute_arg_sorts(const int* __restrict__ topk_ids,
|
||||
const int32_t* __restrict__ expert_offsets,
|
||||
int32_t* input_permutation,
|
||||
int32_t* output_permutation,
|
||||
int32_t* atomic_buffer, const int topk_length,
|
||||
const int topk) {
|
||||
int expert_id = blockIdx.x;
|
||||
int const blk_expert_id = blockIdx.x;
|
||||
int const num_experts = gridDim.x;
|
||||
int32_t const num_tokens = expert_offsets[num_experts];
|
||||
|
||||
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
|
||||
if (topk_ids[i] == expert_id) {
|
||||
int const expert_id = topk_ids[i];
|
||||
if (expert_id == -1 && blockIdx.x == 0) {
|
||||
// output_permutation is used to re-order the moe outputs. It is
|
||||
// used as c2 = c2[c_map], where c2 is a torch.tensor that is the
|
||||
// output of the cutlass kernels and c_map is the output_permutation.
|
||||
// c2 is initialized to zeros, therefore by setting the output_permutation
|
||||
// to num_tokens, we are guaranteed to fill the moe outputs to zero
|
||||
// for "invalid" topk_ids.
|
||||
output_permutation[i] = num_tokens;
|
||||
} else if (expert_id == blk_expert_id) {
|
||||
int start = atomicAdd(&atomic_buffer[expert_id], 1);
|
||||
input_permutation[start] = i / topk;
|
||||
output_permutation[i] = start;
|
||||
@ -83,6 +95,7 @@ void get_cutlass_moe_mm_data_caller(
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
|
||||
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
|
||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<const int32_t*>(expert_offsets.data_ptr()),
|
||||
static_cast<int32_t*>(input_permutation.data_ptr()),
|
||||
static_cast<int32_t*>(output_permutation.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(),
|
||||
|
||||
@ -336,7 +336,7 @@ inline void cutlass_gemm_sm89_fp8_dispatch(torch::Tensor& out,
|
||||
|
||||
uint32_t const m = a.size(0);
|
||||
uint32_t const mp2 =
|
||||
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
|
||||
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
|
||||
|
||||
if (mp2 <= 16) {
|
||||
// M in [1, 16]
|
||||
|
||||
@ -321,7 +321,7 @@ inline void cutlass_gemm_sm89_int8_dispatch(torch::Tensor& out,
|
||||
|
||||
uint32_t const m = a.size(0);
|
||||
uint32_t const mp2 =
|
||||
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
|
||||
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
|
||||
|
||||
if (mp2 <= 16) {
|
||||
// M in [1, 16]
|
||||
|
||||
@ -134,7 +134,7 @@ typename T::Gemm::Arguments args_from_options(
|
||||
using StrideB = typename T::StrideB;
|
||||
using StrideD = typename T::StrideD;
|
||||
using Sm100BlkScaledConfig =
|
||||
typename T::Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig;
|
||||
typename T::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
|
||||
int m = static_cast<int>(M);
|
||||
int n = static_cast<int>(N);
|
||||
|
||||
@ -129,7 +129,7 @@ static __device__ __forceinline__ void moe_q(
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MOE_X_Q4_0 64
|
||||
#define MOE_X_Q4_0 8
|
||||
#define MOE_Y_Q4_0 128
|
||||
#define NWARPS_Q4_0 8
|
||||
#else
|
||||
@ -190,7 +190,7 @@ static void ggml_moe_q4_0_q8_1_cuda(
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MOE_X_Q4_1 64
|
||||
#define MOE_X_Q4_1 8
|
||||
#define MOE_Y_Q4_1 128
|
||||
#define NWARPS_Q4_1 8
|
||||
#else
|
||||
@ -251,7 +251,7 @@ static void ggml_moe_q4_1_q8_1_cuda(
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MOE_X_Q5_0 64
|
||||
#define MOE_X_Q5_0 8
|
||||
#define MOE_Y_Q5_0 128
|
||||
#define NWARPS_Q5_0 8
|
||||
#else
|
||||
@ -312,7 +312,7 @@ static void ggml_moe_q5_0_q8_1_cuda(
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MOE_X_Q5_1 64
|
||||
#define MOE_X_Q5_1 8
|
||||
#define MOE_Y_Q5_1 128
|
||||
#define NWARPS_Q5_1 8
|
||||
#else
|
||||
@ -373,7 +373,7 @@ static void ggml_moe_q5_1_q8_1_cuda(
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MOE_X_Q8_0 64
|
||||
#define MOE_X_Q8_0 8
|
||||
#define MOE_Y_Q8_0 128
|
||||
#define NWARPS_Q8_0 8
|
||||
#else
|
||||
@ -434,7 +434,7 @@ static void ggml_moe_q8_0_q8_1_cuda(
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MOE_X_Q2_K 64
|
||||
#define MOE_X_Q2_K 8
|
||||
#define MOE_Y_Q2_K 128
|
||||
#define NWARPS_Q2_K 8
|
||||
#else
|
||||
@ -495,7 +495,7 @@ static void ggml_moe_q2_K_q8_1_cuda(
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MOE_X_Q3_K 64
|
||||
#define MOE_X_Q3_K 8
|
||||
#define MOE_Y_Q3_K 128
|
||||
#define NWARPS_Q3_K 8
|
||||
#else
|
||||
@ -556,7 +556,7 @@ static void ggml_moe_q3_K_q8_1_cuda(
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MOE_X_Q4_K 64
|
||||
#define MOE_X_Q4_K 8
|
||||
#define MOE_Y_Q4_K 128
|
||||
#define NWARPS_Q4_K 8
|
||||
#else
|
||||
@ -617,7 +617,7 @@ static void ggml_moe_q4_K_q8_1_cuda(
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MOE_X_Q5_K 64
|
||||
#define MOE_X_Q5_K 8
|
||||
#define MOE_Y_Q5_K 128
|
||||
#define NWARPS_Q5_K 8
|
||||
#else
|
||||
@ -678,7 +678,7 @@ static void ggml_moe_q5_K_q8_1_cuda(
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MOE_X_Q6_K 64
|
||||
#define MOE_X_Q6_K 8
|
||||
#define MOE_Y_Q6_K 128
|
||||
#define NWARPS_Q6_K 8
|
||||
#else
|
||||
|
||||
@ -1785,7 +1785,7 @@ __global__ void Marlin(
|
||||
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
|
||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \
|
||||
num_groups, prob_m, prob_n, prob_k, lda, locks, \
|
||||
use_atomic_add, use_fp32_reduce); \
|
||||
part_use_atomic_add, use_fp32_reduce); \
|
||||
} \
|
||||
}
|
||||
|
||||
@ -2215,6 +2215,10 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
thread_m_blocks = exec_cfg.max_m_blocks;
|
||||
}
|
||||
|
||||
// atomic add reduce have better performance only when m * n is small
|
||||
bool part_use_atomic_add =
|
||||
use_atomic_add && div_ceil(prob_m, 64) * prob_n <= 2048;
|
||||
|
||||
if (false) {
|
||||
}
|
||||
GPTQ_CALL_IF(vllm::kU4B8, 16, 4, 256)
|
||||
|
||||
@ -9,7 +9,11 @@
|
||||
#include <cuda_runtime.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace marlin {
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin
|
||||
#endif
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
// Marlin params
|
||||
|
||||
@ -23,6 +27,7 @@ static constexpr int pipe_stages =
|
||||
|
||||
static constexpr int min_thread_n = 64;
|
||||
static constexpr int min_thread_k = 64;
|
||||
static constexpr int max_thread_n = 256;
|
||||
|
||||
static constexpr int tile_size = 16;
|
||||
static constexpr int max_par = 16;
|
||||
@ -84,4 +89,4 @@ __device__ inline void cp_async_wait() {
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace marlin
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
|
||||
@ -5,7 +5,11 @@
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
namespace marlin {
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin
|
||||
#endif
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
template <typename scalar_t>
|
||||
class ScalarType {};
|
||||
@ -54,7 +58,7 @@ class ScalarType<nv_bfloat16> {
|
||||
using FragS = Vec<nv_bfloat162, 1>;
|
||||
using FragZP = Vec<nv_bfloat162, 4>;
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
|
||||
static __device__ float inline num2float(const nv_bfloat16 x) {
|
||||
return __bfloat162float(x);
|
||||
}
|
||||
@ -74,6 +78,6 @@ class ScalarType<nv_bfloat16> {
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace marlin
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
|
||||
#endif
|
||||
|
||||
@ -2,6 +2,15 @@
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
|
||||
const int64_t rows_per_block);
|
||||
|
||||
torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
|
||||
const int64_t CuCount);
|
||||
|
||||
void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
|
||||
at::Tensor& scale_a, at::Tensor& scale_b, const int64_t CuCount);
|
||||
|
||||
void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums,
|
||||
torch::Tensor& max_logits, torch::Tensor& tmp_out,
|
||||
torch::Tensor& query, torch::Tensor& key_cache,
|
||||
|
||||
1600
csrc/rocm/skinny_gemms.cu
Normal file
1600
csrc/rocm/skinny_gemms.cu
Normal file
File diff suppressed because it is too large
Load Diff
@ -14,6 +14,24 @@
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
|
||||
// vLLM custom ops for rocm
|
||||
|
||||
// Custom gemm op for matrix-vector multiplication
|
||||
rocm_ops.def(
|
||||
"LLMM1(Tensor in_a, Tensor in_b, int rows_per_block) -> "
|
||||
"Tensor");
|
||||
rocm_ops.impl("LLMM1", torch::kCUDA, &LLMM1);
|
||||
|
||||
// Custom gemm op for skinny matrix-matrix multiplication
|
||||
rocm_ops.def(
|
||||
"wvSplitK(Tensor in_a, Tensor in_b, int CuCount) -> "
|
||||
"Tensor");
|
||||
rocm_ops.impl("wvSplitK", torch::kCUDA, &wvSplitK);
|
||||
|
||||
// wvSplitK for fp8
|
||||
rocm_ops.def(
|
||||
"wvSplitKQ(Tensor in_a, Tensor in_b, Tensor! out_c, Tensor scale_a, "
|
||||
" Tensor scale_b, int CuCount) -> ()");
|
||||
rocm_ops.impl("wvSplitKQ", torch::kCUDA, &wvSplitKQ);
|
||||
|
||||
// Custom attention op
|
||||
// Compute the attention between an input query and the cached
|
||||
// keys/values using PagedAttention.
|
||||
|
||||
@ -64,6 +64,21 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" int blocksparse_head_sliding_step) -> ()");
|
||||
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// Merge attn states
|
||||
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
|
||||
// can be used to combine partial attention results (in the split-KV case)
|
||||
ops.def(
|
||||
"merge_attn_states("
|
||||
" Tensor! output,"
|
||||
" Tensor!? output_lse,"
|
||||
" Tensor prefix_output,"
|
||||
" Tensor prefix_lse,"
|
||||
" Tensor suffix_output,"
|
||||
" Tensor suffix_lse) -> ()");
|
||||
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
|
||||
#endif
|
||||
|
||||
// Activation ops
|
||||
// Activation function used in SwiGLU.
|
||||
ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
|
||||
@ -115,6 +130,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
") -> ()");
|
||||
ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer);
|
||||
|
||||
// Compute MLA decode using cutlass.
|
||||
ops.def(
|
||||
"cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe,"
|
||||
" Tensor kv_c_and_k_pe_cache, Tensor seq_lens,"
|
||||
" Tensor page_table, float scale) -> ()");
|
||||
ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
|
||||
|
||||
// Layernorm
|
||||
// Apply Root Mean Square (RMS) Normalization to the input tensor.
|
||||
ops.def(
|
||||
|
||||
@ -162,6 +162,9 @@ ENV UV_HTTP_TIMEOUT=500
|
||||
COPY requirements/lint.txt requirements/lint.txt
|
||||
COPY requirements/test.txt requirements/test.txt
|
||||
COPY requirements/dev.txt requirements/dev.txt
|
||||
# Workaround for #17068
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system mamba-ssm==2.2.4 --no-build-isolation
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements/dev.txt
|
||||
#################### DEV IMAGE ####################
|
||||
@ -240,6 +243,8 @@ if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
|
||||
uv pip install --system https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.1.post2/flashinfer_python-0.2.1.post2+cu124torch2.6-cp38-abi3-linux_x86_64.whl ; \
|
||||
fi
|
||||
COPY examples examples
|
||||
COPY benchmarks benchmarks
|
||||
COPY ./vllm/collect_env.py .
|
||||
|
||||
# Although we build Flashinfer with AOT mode, there's still
|
||||
# some issues w.r.t. JIT compilation. Therefore we need to
|
||||
@ -263,6 +268,9 @@ ADD . /vllm-workspace/
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
|
||||
# install development dependencies (for testing)
|
||||
# Workaround for #17068
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system mamba-ssm==2.2.4 --no-build-isolation
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements/dev.txt
|
||||
|
||||
@ -289,6 +297,7 @@ RUN mv vllm test_docs/
|
||||
#################### OPENAI API SERVER ####################
|
||||
# base openai image with additional requirements, for any subsequent openai-style images
|
||||
FROM vllm-base AS vllm-openai-base
|
||||
ARG TARGETPLATFORM
|
||||
|
||||
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
|
||||
@ -18,6 +18,8 @@ WORKDIR /workspace/
|
||||
ARG PYTHON_VERSION=3.12
|
||||
ARG PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu"
|
||||
|
||||
ENV LD_PRELOAD=""
|
||||
|
||||
# Install minimal dependencies and uv
|
||||
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
||||
@ -32,6 +34,7 @@ ENV CMAKE_CXX_COMPILER_LAUNCHER=ccache
|
||||
|
||||
ENV PATH="/root/.local/bin:$PATH"
|
||||
ENV VIRTUAL_ENV="/opt/venv"
|
||||
ENV UV_PYTHON_INSTALL_DIR=/opt/uv/python
|
||||
RUN uv venv --python ${PYTHON_VERSION} --seed ${VIRTUAL_ENV}
|
||||
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
|
||||
|
||||
@ -118,6 +121,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
ADD ./tests/ ./tests/
|
||||
ADD ./examples/ ./examples/
|
||||
ADD ./benchmarks/ ./benchmarks/
|
||||
ADD ./vllm/collect_env.py .
|
||||
|
||||
# install development dependencies (for testing)
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
FROM vault.habana.ai/gaudi-docker/1.19.1/ubuntu22.04/habanalabs/pytorch-installer-2.5.1:latest
|
||||
FROM vault.habana.ai/gaudi-docker/1.20.1/ubuntu22.04/habanalabs/pytorch-installer-2.6.0:latest
|
||||
|
||||
COPY ./ /workspace/vllm
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# default base image
|
||||
# https://gallery.ecr.aws/neuron/pytorch-inference-neuronx
|
||||
ARG BASE_IMAGE="public.ecr.aws/neuron/pytorch-inference-neuronx:2.5.1-neuronx-py310-sdk2.21.0-ubuntu22.04"
|
||||
ARG BASE_IMAGE="public.ecr.aws/neuron/pytorch-inference-neuronx:2.5.1-neuronx-py310-sdk2.22.0-ubuntu22.04"
|
||||
|
||||
FROM $BASE_IMAGE
|
||||
|
||||
@ -21,9 +21,9 @@ VOLUME [ ${APP_MOUNT} ]
|
||||
WORKDIR ${APP_MOUNT}/vllm
|
||||
|
||||
RUN python3 -m pip install --upgrade pip
|
||||
RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas
|
||||
RUN python3 -m pip install sentencepiece transformers==4.45.2 -U
|
||||
RUN python3 -m pip install neuronx-cc==2.16.345.0 --extra-index-url=https://pip.repos.neuron.amazonaws.com -U
|
||||
RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas tenacity
|
||||
RUN python3 -m pip install sentencepiece transformers==4.48.0 -U
|
||||
RUN python3 -m pip install neuronx-cc==2.17.194.0 --extra-index-url=https://pip.repos.neuron.amazonaws.com -U
|
||||
RUN python3 -m pip install pytest
|
||||
|
||||
# uninstall transformers-neuronx package explicitly to avoid version conflict
|
||||
|
||||
307
docker/Dockerfile.nightly_torch
Normal file
307
docker/Dockerfile.nightly_torch
Normal file
@ -0,0 +1,307 @@
|
||||
# The vLLM Dockerfile is used to construct vLLM image against torch nightly that can be directly used for testing
|
||||
|
||||
# for torch nightly, cuda >=12.6 is required,
|
||||
# use 12.8 due to FlashAttention issue with cuda 12.6 (https://github.com/vllm-project/vllm/issues/15435#issuecomment-2775924628)
|
||||
ARG CUDA_VERSION=12.8.0
|
||||
#
|
||||
#################### BASE BUILD IMAGE ####################
|
||||
# prepare basic build environment
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 AS base
|
||||
ARG CUDA_VERSION=12.8.0
|
||||
ARG PYTHON_VERSION=3.12
|
||||
ARG TARGETPLATFORM
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
# Install Python and other dependencies
|
||||
RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
|
||||
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
|
||||
&& apt-get update -y \
|
||||
&& apt-get install -y ccache software-properties-common git curl sudo \
|
||||
&& add-apt-repository ppa:deadsnakes/ppa \
|
||||
&& apt-get update -y \
|
||||
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
|
||||
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
|
||||
&& update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \
|
||||
&& ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \
|
||||
&& curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \
|
||||
&& python3 --version \
|
||||
&& python3 -m pip --version
|
||||
# Install uv for faster pip installs
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
python3 -m pip install uv
|
||||
|
||||
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
|
||||
# Upgrade to GCC 10 to avoid https://gcc.gnu.org/bugzilla/show_bug.cgi?id=92519
|
||||
# as it was causing spam when compiling the CUTLASS kernels
|
||||
RUN apt-get install -y gcc-10 g++-10
|
||||
RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-10 110 --slave /usr/bin/g++ g++ /usr/bin/g++-10
|
||||
RUN <<EOF
|
||||
gcc --version
|
||||
EOF
|
||||
|
||||
# Workaround for https://github.com/openai/triton/issues/2507 and
|
||||
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
|
||||
# this won't be needed for future versions of this docker image
|
||||
# or future versions of triton.
|
||||
RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
# install build and runtime dependencies
|
||||
COPY requirements/common.txt requirements/common.txt
|
||||
COPY use_existing_torch.py use_existing_torch.py
|
||||
COPY pyproject.toml pyproject.toml
|
||||
|
||||
# install build and runtime dependencies without stable torch version
|
||||
RUN python3 use_existing_torch.py
|
||||
|
||||
# install torch nightly
|
||||
ARG PINNED_TORCH_VERSION
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
if [ -n "$PINNED_TORCH_VERSION" ]; then \
|
||||
pkgs="$PINNED_TORCH_VERSION"; \
|
||||
else \
|
||||
pkgs="torch torchaudio torchvision"; \
|
||||
fi && \
|
||||
uv pip install --system $pkgs --index-url https://download.pytorch.org/whl/nightly/cu128
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system numba==0.61.2
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements/common.txt
|
||||
|
||||
# must put before installing xformers, so it can install the correct version of xfomrers.
|
||||
ARG torch_cuda_arch_list='8.0;8.6;8.9;9.0'
|
||||
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
|
||||
|
||||
# Build xformers with cuda and torch nightly
|
||||
# following official xformers guidance: https://github.com/facebookresearch/xformers#build
|
||||
# todo(elainewy): cache xformers build result for faster build
|
||||
ARG max_jobs=16
|
||||
ENV MAX_JOBS=${max_jobs}
|
||||
ARG XFORMERS_COMMIT=f2de641ef670510cadab099ce6954031f52f191c
|
||||
|
||||
ENV CCACHE_DIR=/root/.cache/ccache
|
||||
RUN --mount=type=cache,target=/root/.cache/ccache \
|
||||
--mount=type=cache,target=/root/.cache/uv \
|
||||
echo 'git clone xformers...' \
|
||||
&& git clone https://github.com/facebookresearch/xformers.git --recursive \
|
||||
&& cd xformers \
|
||||
&& git checkout ${XFORMERS_COMMIT} \
|
||||
&& git submodule update --init --recursive \
|
||||
&& echo 'finish git clone xformers...' \
|
||||
&& rm -rf build \
|
||||
&& python3 setup.py bdist_wheel --dist-dir=../xformers-dist --verbose \
|
||||
&& cd .. \
|
||||
&& rm -rf xformers
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system xformers-dist/*.whl --verbose
|
||||
|
||||
# build can take a long time, and the torch nightly version fetched from url can be different in next docker stage.
|
||||
# track the nightly torch version used in the build, when we set up runtime environment we can make sure the version is the same
|
||||
RUN uv pip freeze | grep -i '^torch\|^torchvision\|^torchaudio' > torch_build_versions.txt
|
||||
RUN cat torch_build_versions.txt
|
||||
|
||||
# cuda arch list used by torch
|
||||
# can be useful for `test`
|
||||
# explicitly set the list to avoid issues with torch 2.2
|
||||
# see https://github.com/pytorch/pytorch/pull/123243
|
||||
|
||||
# Override the arch list for flash-attn to reduce the binary size
|
||||
ARG vllm_fa_cmake_gpu_arches='80-real;90-real'
|
||||
ENV VLLM_FA_CMAKE_GPU_ARCHES=${vllm_fa_cmake_gpu_arches}
|
||||
#################### BASE BUILD IMAGE ####################
|
||||
|
||||
#################### WHEEL BUILD IMAGE ####################
|
||||
FROM base AS build
|
||||
ARG TARGETPLATFORM
|
||||
|
||||
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN python3 use_existing_torch.py
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements/build.txt
|
||||
|
||||
ARG GIT_REPO_CHECK=0
|
||||
RUN --mount=type=bind,source=.git,target=.git \
|
||||
if [ "$GIT_REPO_CHECK" != "0" ]; then bash tools/check_repo.sh ; fi
|
||||
|
||||
# Max jobs used by Ninja to build extensions
|
||||
ARG max_jobs=16
|
||||
ENV MAX_JOBS=${max_jobs}
|
||||
ARG nvcc_threads=2
|
||||
ENV NVCC_THREADS=$nvcc_threads
|
||||
|
||||
ARG USE_SCCACHE
|
||||
ARG SCCACHE_BUCKET_NAME=vllm-build-sccache
|
||||
ARG SCCACHE_REGION_NAME=us-west-2
|
||||
ARG SCCACHE_S3_NO_CREDENTIALS=0
|
||||
|
||||
# if USE_SCCACHE is set, use sccache to speed up compilation
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=.git,target=.git \
|
||||
if [ "$USE_SCCACHE" = "1" ]; then \
|
||||
echo "Installing sccache..." \
|
||||
&& curl -L -o sccache.tar.gz https://github.com/mozilla/sccache/releases/download/v0.8.1/sccache-v0.8.1-x86_64-unknown-linux-musl.tar.gz \
|
||||
&& tar -xzf sccache.tar.gz \
|
||||
&& sudo mv sccache-v0.8.1-x86_64-unknown-linux-musl/sccache /usr/bin/sccache \
|
||||
&& rm -rf sccache.tar.gz sccache-v0.8.1-x86_64-unknown-linux-musl \
|
||||
&& export SCCACHE_BUCKET=${SCCACHE_BUCKET_NAME} \
|
||||
&& export SCCACHE_REGION=${SCCACHE_REGION_NAME} \
|
||||
&& export SCCACHE_S3_NO_CREDENTIALS=${SCCACHE_S3_NO_CREDENTIALS} \
|
||||
&& export SCCACHE_IDLE_TIMEOUT=0 \
|
||||
&& export CMAKE_BUILD_TYPE=Release \
|
||||
&& sccache --show-stats \
|
||||
&& python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38 \
|
||||
&& sccache --show-stats; \
|
||||
fi
|
||||
|
||||
ENV CCACHE_DIR=/root/.cache/ccache
|
||||
RUN --mount=type=cache,target=/root/.cache/ccache \
|
||||
--mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=.git,target=.git \
|
||||
if [ "$USE_SCCACHE" != "1" ]; then \
|
||||
# Clean any existing CMake artifacts
|
||||
rm -rf .deps && \
|
||||
mkdir -p .deps && \
|
||||
python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38; \
|
||||
fi
|
||||
|
||||
#################### WHEEL BUILD IMAGE ####################
|
||||
|
||||
################### VLLM INSTALLED IMAGE ####################
|
||||
# Setup clean environment for vLLM and its dependencies for test and api server using ubuntu22.04 with AOT flashinfer
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 AS vllm-base
|
||||
# prepare for environment starts
|
||||
ARG CUDA_VERSION=12.8.0
|
||||
ARG PYTHON_VERSION=3.12
|
||||
WORKDIR /vllm-workspace
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ARG TARGETPLATFORM
|
||||
|
||||
RUN PYTHON_VERSION_STR=$(echo ${PYTHON_VERSION} | sed 's/\.//g') && \
|
||||
echo "export PYTHON_VERSION_STR=${PYTHON_VERSION_STR}" >> /etc/environment
|
||||
|
||||
# Install Python and other dependencies
|
||||
RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
|
||||
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
|
||||
&& apt-get update -y \
|
||||
&& apt-get install -y ccache software-properties-common git curl wget sudo vim python3-pip \
|
||||
&& apt-get install -y ffmpeg libsm6 libxext6 libgl1 \
|
||||
&& add-apt-repository ppa:deadsnakes/ppa \
|
||||
&& apt-get update -y \
|
||||
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv libibverbs-dev \
|
||||
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
|
||||
&& update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \
|
||||
&& ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \
|
||||
&& curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \
|
||||
&& python3 --version && python3 -m pip --version
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
python3 -m pip install uv
|
||||
|
||||
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
|
||||
# Workaround for https://github.com/openai/triton/issues/2507 and
|
||||
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
|
||||
# this won't be needed for future versions of this docker image
|
||||
# or future versions of triton.
|
||||
RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/
|
||||
|
||||
# get the nightly torch version used in the build to make sure the version is the same
|
||||
COPY --from=base /workspace/torch_build_versions.txt ./torch_build_versions.txt
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system $(cat torch_build_versions.txt | xargs) --index-url https://download.pytorch.org/whl/nightly/cu128
|
||||
|
||||
# install the vllm wheel
|
||||
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/vllm-dist \
|
||||
--mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system vllm-dist/*.whl --verbose
|
||||
|
||||
# install xformers again for the new environment
|
||||
RUN --mount=type=bind,from=base,src=/workspace/xformers-dist,target=/vllm-workspace/xformers-dist \
|
||||
--mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system /vllm-workspace/xformers-dist/*.whl --verbose
|
||||
|
||||
ARG torch_cuda_arch_list='8.0;8.6;8.9;9.0'
|
||||
|
||||
# install package for build flashinfer
|
||||
# see issue: https://github.com/flashinfer-ai/flashinfer/issues/738
|
||||
RUN pip install setuptools==75.6.0 packaging==23.2 ninja==1.11.1.3 build==1.2.2.post1
|
||||
|
||||
|
||||
# build flashinfer for torch nightly from source around 10 mins
|
||||
# release version: v0.2.2.post1
|
||||
# todo(elainewy): cache flashinfer build result for faster build
|
||||
ENV CCACHE_DIR=/root/.cache/ccache
|
||||
RUN --mount=type=cache,target=/root/.cache/ccache \
|
||||
--mount=type=cache,target=/root/.cache/uv \
|
||||
echo "git clone flashinfer..." \
|
||||
&& git clone --recursive https://github.com/flashinfer-ai/flashinfer.git \
|
||||
&& cd flashinfer \
|
||||
&& git checkout v0.2.2.post1 \
|
||||
&& git submodule update --init --recursive \
|
||||
&& echo "finish git clone flashinfer..." \
|
||||
&& rm -rf build \
|
||||
&& export TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list} \
|
||||
&& FLASHINFER_ENABLE_AOT=1 python3 setup.py bdist_wheel --dist-dir=../flashinfer-dist --verbose \
|
||||
&& cd .. \
|
||||
&& rm -rf flashinfer
|
||||
|
||||
# install flashinfer
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system flashinfer-dist/*.whl --verbose
|
||||
|
||||
# install common packages
|
||||
COPY requirements/common.txt requirements/common.txt
|
||||
COPY use_existing_torch.py use_existing_torch.py
|
||||
COPY pyproject.toml pyproject.toml
|
||||
|
||||
COPY examples examples
|
||||
COPY benchmarks benchmarks
|
||||
COPY ./vllm/collect_env.py .
|
||||
|
||||
RUN python3 use_existing_torch.py
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements/common.txt
|
||||
|
||||
################### VLLM INSTALLED IMAGE ####################
|
||||
|
||||
|
||||
#################### UNITTEST IMAGE #############################
|
||||
FROM vllm-base as test
|
||||
COPY tests/ tests/
|
||||
|
||||
# install build and runtime dependencies without stable torch version
|
||||
COPY requirements/nightly_torch_test.txt requirements/nightly_torch_test.txt
|
||||
|
||||
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
|
||||
# install development dependencies (for testing)
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -e tests/vllm_test_utils
|
||||
|
||||
# enable fast downloads from hf (for testing)
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system hf_transfer
|
||||
ENV HF_HUB_ENABLE_HF_TRANSFER 1
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements/nightly_torch_test.txt
|
||||
|
||||
#################### UNITTEST IMAGE #############################
|
||||
|
||||
@ -126,13 +126,16 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
FROM base-builder AS cv-builder
|
||||
|
||||
ARG MAX_JOBS
|
||||
ARG OPENCV_VERSION=84
|
||||
ARG OPENCV_VERSION=86
|
||||
# patch for version 4.11.0.86
|
||||
ARG OPENCV_PATCH=97f3f39
|
||||
ARG ENABLE_HEADLESS=1
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
source /opt/rh/gcc-toolset-13/enable && \
|
||||
git clone --recursive https://github.com/opencv/opencv-python.git -b ${OPENCV_VERSION} && \
|
||||
cd opencv-python && \
|
||||
sed -i 's/"setuptools==59.2.0",/"setuptools<70.0",/g' pyproject.toml && \
|
||||
sed -i -E -e 's/"setuptools.+",/"setuptools",/g' pyproject.toml && \
|
||||
cd opencv && git cherry-pick --no-commit $OPENCV_PATCH && cd .. && \
|
||||
python -m build --wheel --installer=uv --outdir /opencvwheels/
|
||||
|
||||
###############################################################
|
||||
@ -148,9 +151,15 @@ COPY --from=arrow-builder /tmp/control /dev/null
|
||||
COPY --from=cv-builder /tmp/control /dev/null
|
||||
|
||||
ARG VLLM_TARGET_DEVICE=cpu
|
||||
ARG GRPC_PYTHON_BUILD_SYSTEM_OPENSSL=1
|
||||
|
||||
# this step installs vllm and populates uv cache
|
||||
# with all the transitive dependencies
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
source /opt/rh/gcc-toolset-13/enable && \
|
||||
git clone https://github.com/huggingface/xet-core.git && cd xet-core/hf_xet/ && \
|
||||
uv pip install maturin && \
|
||||
uv build --wheel --out-dir /hf_wheels/
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,from=torch-builder,source=/torchwheels/,target=/torchwheels/,ro \
|
||||
--mount=type=bind,from=arrow-builder,source=/arrowwheels/,target=/arrowwheels/,ro \
|
||||
@ -159,7 +168,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
source /opt/rh/gcc-toolset-13/enable && \
|
||||
uv pip install /opencvwheels/*.whl /arrowwheels/*.whl /torchwheels/*.whl && \
|
||||
sed -i -e 's/.*torch.*//g' /src/pyproject.toml /src/requirements/*.txt && \
|
||||
uv pip install pandas pythran pybind11 && \
|
||||
uv pip install pandas pythran pybind11 /hf_wheels/*.whl && \
|
||||
# sentencepiece.pc is in some pkgconfig inside uv cache
|
||||
export PKG_CONFIG_PATH=$(find / -type d -name "pkgconfig" 2>/dev/null | tr '\n' ':') && \
|
||||
uv pip install -r /src/requirements/common.txt -r /src/requirements/cpu.txt -r /src/requirements/build.txt --no-build-isolation && \
|
||||
@ -247,8 +256,9 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,from=torch-builder,source=/torchwheels/,target=/torchwheels/,ro \
|
||||
--mount=type=bind,from=arrow-builder,source=/arrowwheels/,target=/arrowwheels/,ro \
|
||||
--mount=type=bind,from=cv-builder,source=/opencvwheels/,target=/opencvwheels/,ro \
|
||||
--mount=type=bind,from=vllmcache-builder,source=/hf_wheels/,target=/hf_wheels/,ro \
|
||||
--mount=type=bind,from=vllmcache-builder,source=/vllmwheel/,target=/vllmwheel/,ro \
|
||||
HOME=/root uv pip install /opencvwheels/*.whl /arrowwheels/*.whl /torchwheels/*.whl /vllmwheel/*.whl
|
||||
HOME=/root uv pip install /opencvwheels/*.whl /arrowwheels/*.whl /torchwheels/*.whl /hf_wheels/*.whl /vllmwheel/*.whl
|
||||
|
||||
COPY ./ /workspace/vllm
|
||||
WORKDIR /workspace/vllm
|
||||
|
||||
@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
|
||||
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
|
||||
ARG FA_BRANCH="1a7f4dfa"
|
||||
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
|
||||
ARG AITER_BRANCH="8970b25b"
|
||||
ARG AITER_BRANCH="7e1ed08"
|
||||
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
|
||||
|
||||
FROM ${BASE_IMAGE} AS base
|
||||
|
||||
@ -58,7 +58,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
cd ../../python && \
|
||||
export PYARROW_PARALLEL=4 && \
|
||||
export ARROW_BUILD_TYPE=release && \
|
||||
uv pip install -r requirements/build.txt && \
|
||||
uv pip install -r requirements-build.txt && \
|
||||
python setup.py build_ext --build-type=$ARROW_BUILD_TYPE --bundle-arrow-cpp bdist_wheel
|
||||
|
||||
FROM python-install AS numa-build
|
||||
@ -96,6 +96,22 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install -v torch==${TORCH_VERSION} --extra-index-url https://download.pytorch.org/whl/nightly/cpu && \
|
||||
python setup.py bdist_wheel
|
||||
|
||||
FROM python-install AS hf-xet-builder
|
||||
# Install hf-xet
|
||||
WORKDIR /tmp
|
||||
ENV CARGO_HOME=/root/.cargo
|
||||
ENV RUSTUP_HOME=/root/.rustup
|
||||
ENV PATH="$CARGO_HOME/bin:$RUSTUP_HOME/bin:$PATH"
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,from=rust,source=/root/.cargo,target=/root/.cargo,rw \
|
||||
--mount=type=bind,from=rust,source=/root/.rustup,target=/root/.rustup,rw \
|
||||
git clone https://github.com/huggingface/xet-core.git && \
|
||||
cd xet-core/hf_xet/ && \
|
||||
uv pip install maturin patchelf && \
|
||||
python -m maturin build --release --out dist && \
|
||||
mkdir -p /tmp/hf-xet/dist && \
|
||||
cp dist/*.whl /tmp/hf-xet/dist/
|
||||
|
||||
# Final build stage
|
||||
FROM python-install AS vllm-cpu
|
||||
ARG PYTHON_VERSION
|
||||
@ -120,12 +136,15 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,from=rust,source=/root/.rustup,target=/root/.rustup,rw \
|
||||
--mount=type=bind,from=pyarrow,source=/tmp/arrow/python/dist,target=/tmp/arrow-wheels \
|
||||
--mount=type=bind,from=torch-vision,source=/tmp/vision/dist,target=/tmp/vision-wheels/ \
|
||||
--mount=type=bind,from=hf-xet-builder,source=/tmp/hf-xet/dist,target=/tmp/hf-xet-wheels/ \
|
||||
sed -i '/^torch/d' requirements/build.txt && \
|
||||
ARROW_WHL_FILE=$(ls /tmp/arrow-wheels/pyarrow-*.whl | head -n 1) && \
|
||||
VISION_WHL_FILE=$(ls /tmp/vision-wheels/*.whl | head -n 1) && \
|
||||
HF_XET_WHL_FILE=$(ls /tmp/hf-xet-wheels/*.whl | head -n 1) && \
|
||||
uv pip install -v \
|
||||
$ARROW_WHL_FILE \
|
||||
$VISION_WHL_FILE \
|
||||
$HF_XET_WHL_FILE \
|
||||
--extra-index-url https://download.pytorch.org/whl/nightly/cpu \
|
||||
--index-strategy unsafe-best-match \
|
||||
-r requirements/build.txt \
|
||||
@ -149,4 +168,5 @@ USER 2000
|
||||
WORKDIR /home/vllm
|
||||
|
||||
# Set the default entrypoint
|
||||
ENTRYPOINT ["python", "-m", "vllm.entrypoints.openai.api_server"]
|
||||
ENTRYPOINT ["python", "-m", "vllm.entrypoints.openai.api_server"]
|
||||
|
||||
|
||||
BIN
docs/source/assets/deployment/anything-llm-chat-with-doc.png
Normal file
BIN
docs/source/assets/deployment/anything-llm-chat-with-doc.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 118 KiB |
BIN
docs/source/assets/deployment/anything-llm-chat-without-doc.png
Normal file
BIN
docs/source/assets/deployment/anything-llm-chat-without-doc.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 136 KiB |
BIN
docs/source/assets/deployment/anything-llm-provider.png
Normal file
BIN
docs/source/assets/deployment/anything-llm-provider.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 110 KiB |
BIN
docs/source/assets/deployment/anything-llm-upload-doc.png
Normal file
BIN
docs/source/assets/deployment/anything-llm-upload-doc.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 111 KiB |
BIN
docs/source/assets/deployment/open_webui.png
Normal file
BIN
docs/source/assets/deployment/open_webui.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 68 KiB |
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user