Compare commits
683 Commits
gemma3n-mm
...
wide_ep_wo
| Author | SHA1 | Date | |
|---|---|---|---|
| f1c9ef3afd | |||
| d80a82f961 | |||
| a9b2a1d704 | |||
| 57c22e57f9 | |||
| bda9d0535f | |||
| 3d847a3125 | |||
| 5f8c9a425e | |||
| 1cbf951ba2 | |||
| a8936e5193 | |||
| 01a395e9e7 | |||
| 971948b846 | |||
| eed2f463b2 | |||
| 20950b29fb | |||
| 3339cba3ff | |||
| 0b8caf9095 | |||
| ccf27cc4d4 | |||
| c657369841 | |||
| 6c66f28fa5 | |||
| de509ae8eb | |||
| e7c4f9ee86 | |||
| 9094d11c5d | |||
| 56e544f24b | |||
| 97d6c30cc9 | |||
| a40a8506df | |||
| c215f5c877 | |||
| 1cd6eaba54 | |||
| f27fdfc3ed | |||
| de10ff0b7c | |||
| 9d197280fa | |||
| e98def439c | |||
| 05c1126f29 | |||
| 875af38e01 | |||
| 7728dd77bb | |||
| 2f6e6b33fb | |||
| a55c95096b | |||
| 97349fe2bc | |||
| 62965de5fe | |||
| 7ae75fa6d0 | |||
| f1b286b2fb | |||
| c7742d6113 | |||
| cea96a0156 | |||
| 2eddd437ba | |||
| 75d29cf4e1 | |||
| 41d3082c41 | |||
| 7cfea0df39 | |||
| 5ac3168ee3 | |||
| ec1250421a | |||
| 8177e2f02f | |||
| 396ee94180 | |||
| e189b50f53 | |||
| 136d750f5f | |||
| b3caeb82e7 | |||
| eab2f3980c | |||
| 9fe98d4250 | |||
| 29c6fbe58c | |||
| c72f049cb4 | |||
| f3a683b7c9 | |||
| 46d81d6951 | |||
| 5c3f2628d5 | |||
| 7311f74468 | |||
| 8ed01e32f7 | |||
| e38e96a3c0 | |||
| 40d86ee412 | |||
| 85d051f026 | |||
| 5140f54b89 | |||
| 947edd099e | |||
| fde60ee775 | |||
| b38bc652ac | |||
| adaf2c6d4f | |||
| 42343f1f89 | |||
| 965bc71b04 | |||
| 807a328bb6 | |||
| e0be2c4d09 | |||
| 9c8b2c2a8a | |||
| 2212cd6cfb | |||
| ce3a9b1378 | |||
| 2ce90e5b01 | |||
| 633f6e804b | |||
| b57296bb9a | |||
| 34ddcf9ff4 | |||
| fe56180c7f | |||
| 07d80d7b0e | |||
| 2dd72d23d9 | |||
| a6c7fb8cff | |||
| a7272c23d0 | |||
| 6066284914 | |||
| 1e9ea8e69d | |||
| d9f9a3fd96 | |||
| 1b25f1fe75 | |||
| e8cb0d0495 | |||
| 684174115d | |||
| cdb79ee63d | |||
| 5a19a6c670 | |||
| 2ded067fd2 | |||
| 13abd0eaf9 | |||
| 61b8cea3b4 | |||
| 526078a96c | |||
| 6da0078523 | |||
| 73e3949d07 | |||
| 6eca337ce0 | |||
| 85bda9e7d0 | |||
| 610852a423 | |||
| f0f4de8f26 | |||
| fc5f756db4 | |||
| e74bfc70e4 | |||
| 90eeea8f85 | |||
| dde295a934 | |||
| 6d8d0a24c0 | |||
| 11ef7a611e | |||
| dc2f159f8a | |||
| d5b981f8b1 | |||
| eec6942014 | |||
| fd48d99ffd | |||
| f8c15c4efb | |||
| aa08a954f9 | |||
| 13e4ee1dc3 | |||
| 772ce5af97 | |||
| 63d92abb7c | |||
| 11599b0e1f | |||
| f3137cdd81 | |||
| 82ec66f514 | |||
| 78c13e30e1 | |||
| 5c9b807b34 | |||
| 14bf19e39f | |||
| 4ac7713e32 | |||
| 8560a5b258 | |||
| 316b1bf706 | |||
| 7c734ee09b | |||
| f59ec35b7f | |||
| 2671334d45 | |||
| 2cc5016a19 | |||
| 6929f8b437 | |||
| 32ec9e2f2a | |||
| accac82928 | |||
| 23637dcdef | |||
| 6364af92f8 | |||
| 7aaa2bd5a8 | |||
| 2f5c14de6a | |||
| f002e9a870 | |||
| a1f3610fc6 | |||
| 4ecedd1806 | |||
| 107111a859 | |||
| 2dec7c1a5d | |||
| 08d2bd78da | |||
| 4f76a05f4f | |||
| f154bb9ff0 | |||
| 3ec7170ff1 | |||
| c401c64b4c | |||
| b77c7d327f | |||
| 35bc8bd5fb | |||
| 4594fc3b28 | |||
| ae268b6326 | |||
| 35366ae57c | |||
| 2226d5bd85 | |||
| 44554a0068 | |||
| 226b452a20 | |||
| f38ee34a0a | |||
| b194557a6c | |||
| 774d0c014b | |||
| 2c8db17cfd | |||
| 4fb56914c5 | |||
| 0df4d9b06b | |||
| ed25054577 | |||
| 10904e6d75 | |||
| a32237665d | |||
| bc8a8ce5ec | |||
| 32142b3c62 | |||
| 82b8027be6 | |||
| 3779eb8c81 | |||
| 9e23ad9655 | |||
| e69a92a1ce | |||
| 8425f785ad | |||
| c17231e827 | |||
| 6e5b5ca580 | |||
| 488d8a986a | |||
| af376ca19d | |||
| e7b2042681 | |||
| 90f1e55421 | |||
| 5e70dcd6e6 | |||
| 25d585ab7b | |||
| 8d0a01a5f2 | |||
| 0ec82edda5 | |||
| 005ae9be6c | |||
| 29d1ffc5b4 | |||
| 304dce7ec0 | |||
| 6ece16c4fe | |||
| a0e827e07c | |||
| a15a50fc17 | |||
| 6dda13c86b | |||
| 6b46c4b653 | |||
| d97841078b | |||
| e6b90a2805 | |||
| be54a951a3 | |||
| 042af0c8d3 | |||
| 378d33c392 | |||
| 940af1f03a | |||
| 92615d7fe8 | |||
| 8188196a1c | |||
| 7ba34b1241 | |||
| 9499e26e2a | |||
| 51ba839555 | |||
| d1fb65bde3 | |||
| 3a1d8940ae | |||
| 2b504eb770 | |||
| 10eb24cc91 | |||
| 2e8cbb58f3 | |||
| 752c6ade2e | |||
| 881e3cbe3b | |||
| 9f414a12ad | |||
| 6a971ed692 | |||
| da6579bf41 | |||
| c81259d33a | |||
| e3a0e43d7f | |||
| b3d82108e7 | |||
| 6d0734c562 | |||
| 7d94577138 | |||
| 59f935300c | |||
| 18e519ec86 | |||
| 1eaff27815 | |||
| cf8cc32674 | |||
| 3a2cb2649d | |||
| 3e04107d97 | |||
| 37bd8d6e4c | |||
| 468e2400fe | |||
| dcc6cfb991 | |||
| dd572c0ab3 | |||
| 9ffe905a41 | |||
| 9a9fda1423 | |||
| 466e878f2a | |||
| 217937221b | |||
| 5782581acf | |||
| 0f199f197b | |||
| b2eb2b5ad7 | |||
| 21274ab476 | |||
| ed8cbfedf8 | |||
| 45badd05d0 | |||
| 4adc66f64d | |||
| 55ad648715 | |||
| 5895afd780 | |||
| ca4eb82bcb | |||
| ba2dfbb0c2 | |||
| 1bf65138f6 | |||
| 54cf1cae62 | |||
| 5780121c95 | |||
| c7d8724e78 | |||
| b38baabcf9 | |||
| 89cab4d01f | |||
| b9a21e9173 | |||
| c4e3b12524 | |||
| 8dfb45ca33 | |||
| 8a8fc94639 | |||
| 4de7146351 | |||
| ac9fb732a5 | |||
| a3a6c695f4 | |||
| 90bd2ab6e3 | |||
| 9fb2d22032 | |||
| 2d6a38209b | |||
| 89e3c4e9b4 | |||
| fe8a2c544a | |||
| 4ef00b5cac | |||
| 5a7fb3ab9e | |||
| 11dfdf21bf | |||
| fdc5b43d20 | |||
| c5b8b5953a | |||
| 4fcef49ec4 | |||
| 8a4e5c5f3c | |||
| 76b494444f | |||
| 28a6d5423d | |||
| 58760e12b1 | |||
| a50d918225 | |||
| c9ba8104ed | |||
| 4e7dfbe7b4 | |||
| 72ad273582 | |||
| 01513a334a | |||
| ac2bf41e53 | |||
| a931b4cdcf | |||
| a0f8a79646 | |||
| 18bdcf4113 | |||
| 1c3198b6c4 | |||
| 260127ea54 | |||
| d0dc4cfca4 | |||
| d31a647124 | |||
| 85431bd9ad | |||
| c11013db8b | |||
| 1eb2b9c102 | |||
| 6ebf313790 | |||
| cfbcb9ed87 | |||
| 76ddeff293 | |||
| f46098335b | |||
| e9534c7202 | |||
| 7976446015 | |||
| fcb9f879c1 | |||
| 3ed94f9d0a | |||
| fa839565f2 | |||
| 75a99b98bf | |||
| b5c3b68359 | |||
| 6cbc4d4bea | |||
| 153c6f1e61 | |||
| 34cda778a0 | |||
| 30800b01c2 | |||
| 10be209493 | |||
| 19c863068b | |||
| f29fd8a7f8 | |||
| ed10f3cea1 | |||
| b637e9dcb8 | |||
| 1e36c8687e | |||
| 5bac61362b | |||
| 313ae8c16a | |||
| c847e34b39 | |||
| e7e3e6d263 | |||
| 4ffd963fa0 | |||
| 56fe4bedd6 | |||
| d91278181d | |||
| 20149d84d9 | |||
| 3534c39a20 | |||
| c586b55667 | |||
| 33d560001e | |||
| f148c44c6a | |||
| 235bfd5dfe | |||
| 68d28e37b0 | |||
| 37a7d5d74a | |||
| d4d309409f | |||
| 85bd6599e4 | |||
| 91b3d190ae | |||
| fc017915f5 | |||
| 9ad0a4588b | |||
| 016b8d1b7f | |||
| 80305c1b24 | |||
| 37e2ecace2 | |||
| 054c8657e3 | |||
| d4170fad39 | |||
| 946aadb4a0 | |||
| bcdfb2a330 | |||
| ba8c300018 | |||
| 8cdc371217 | |||
| 61e20828da | |||
| 55e1c66da5 | |||
| 86f3ac21ce | |||
| 149f2435a5 | |||
| c0569dbc82 | |||
| 8bb43b9c9e | |||
| 559756214b | |||
| 6d0cf239c6 | |||
| 3fc964433a | |||
| 0caf61c08a | |||
| 667624659b | |||
| 38efa28278 | |||
| e8cc53af5e | |||
| a4851cfe68 | |||
| 9887e8ec50 | |||
| f326ab9c88 | |||
| dcf2a5e208 | |||
| 1e9438e0b0 | |||
| 697ef765ee | |||
| a99b9f7dee | |||
| c488b928a7 | |||
| 2c7fa47161 | |||
| 88fc8a97e3 | |||
| 66f6fbd393 | |||
| 8632e831ba | |||
| 4bbfc36b16 | |||
| 80d38b8ac8 | |||
| 211b6a6113 | |||
| 247102f07f | |||
| bd4c1e6fdb | |||
| 99b4f080d8 | |||
| 020f58abcd | |||
| c1acd6d7d4 | |||
| 3b3b778d4a | |||
| 42d440c22b | |||
| f45a332886 | |||
| 6e2c176e1f | |||
| a86754a12b | |||
| c2a2f19aba | |||
| 2c11a738b3 | |||
| b639327ad9 | |||
| 4afe687a82 | |||
| 5de8d9f111 | |||
| c1c8ca57ff | |||
| a3a5a47e48 | |||
| fb25e95688 | |||
| 0d4891cd03 | |||
| f56d2996ca | |||
| 147afb448b | |||
| 3c7d942da8 | |||
| 890323dc1b | |||
| 01cae37713 | |||
| 11c0198615 | |||
| b1235c3e10 | |||
| 44d02f54db | |||
| a8593237c0 | |||
| fc0f41d10a | |||
| 7b828e30d5 | |||
| 5f0af36af5 | |||
| 0d21b2664c | |||
| 9907fc4494 | |||
| d47661f0cd | |||
| 53fa457391 | |||
| 6fb162447b | |||
| 66177189c5 | |||
| b4f0b5f9aa | |||
| cbd14ed561 | |||
| 7bd4c37ae7 | |||
| 8020e98c9f | |||
| 762be26a8e | |||
| 6a9e6b2abf | |||
| 5d09152ff1 | |||
| 31d5c1797f | |||
| 35514b682a | |||
| e2de455c34 | |||
| 5b032352cc | |||
| 922f316441 | |||
| 5923ab9524 | |||
| 0cf893cae1 | |||
| cf75cd2098 | |||
| b854321ffe | |||
| 5b6fe23d05 | |||
| f0c98cae27 | |||
| 574ad60db9 | |||
| fdadb6f43a | |||
| 41060c6e08 | |||
| 3de2ed767f | |||
| 299252ea82 | |||
| d6902ce79f | |||
| 5e53c89a74 | |||
| c66e38ea4c | |||
| 251595368f | |||
| 4bed167768 | |||
| b140416abf | |||
| 5b8366b61a | |||
| c7753a9809 | |||
| 4b9a9435bb | |||
| 3482fd7e4e | |||
| 77f77a951e | |||
| 1a4f35e2ea | |||
| be1e128dfb | |||
| 65393ee064 | |||
| dc221ad72d | |||
| 7571a4a7e5 | |||
| f67d986dd1 | |||
| cc876d0f29 | |||
| fdfd409f8f | |||
| ffbcc9e757 | |||
| 59389c927b | |||
| 8f2720def9 | |||
| ad6c2e1a0b | |||
| 49e8c7ea25 | |||
| 805d62ca88 | |||
| b7d9e9416f | |||
| 7c12a765aa | |||
| cd587c93ef | |||
| 332d4cb17b | |||
| bf03ff3575 | |||
| 47043eb678 | |||
| 31b96d1c64 | |||
| e59ba9e142 | |||
| 403b481573 | |||
| 138709f8d1 | |||
| 0bbac1c1b4 | |||
| a3e4e85ece | |||
| eb58f5953d | |||
| 4ac9c33f78 | |||
| efe73d0575 | |||
| 853487bc1b | |||
| 9ff2af6d2b | |||
| 70ca5484f5 | |||
| 5358cce5ff | |||
| 2155e95ef1 | |||
| f95570a52d | |||
| b6e7e3d58f | |||
| e760fcef22 | |||
| 6bbf1795b7 | |||
| 9e0ef888f0 | |||
| 97abeb1daa | |||
| 34dad19e7b | |||
| 6db31e7a27 | |||
| 977180c912 | |||
| c40784c794 | |||
| baed180aa0 | |||
| 0b407479ef | |||
| 5eaf570050 | |||
| d8ee5a2ca4 | |||
| b9fca83256 | |||
| 32dffc2772 | |||
| c438183e99 | |||
| baba0389f7 | |||
| c6c22f16d3 | |||
| dd382e0fe3 | |||
| 849590a2a7 | |||
| a4c23314c0 | |||
| b942c094e3 | |||
| b4bab81660 | |||
| b91cb3fa5c | |||
| 71d1d75b7a | |||
| 72d14d0eed | |||
| e34d130c16 | |||
| 7721ef1786 | |||
| 8369b7c2a9 | |||
| 3eb4ad53f3 | |||
| 90a2769f20 | |||
| e60d422f19 | |||
| 0d914c81a2 | |||
| 6e428cdd7a | |||
| 93b9d9f499 | |||
| af107d5a0e | |||
| 31c5d0a1b7 | |||
| afb7cff1b9 | |||
| d2e841a10a | |||
| 14601f5fba | |||
| 042d131f39 | |||
| 8e807cdfa4 | |||
| e601efcb10 | |||
| 22dd9c2730 | |||
| a6d795d593 | |||
| a37d75bbec | |||
| edd270bc78 | |||
| 110df74332 | |||
| 1ad69e8375 | |||
| b8a498c9b2 | |||
| 923147b5e8 | |||
| 45877ef740 | |||
| 6e4bef1bea | |||
| 4ff79a136e | |||
| 448acad31e | |||
| eb0b2d2f08 | |||
| 3112271f6e | |||
| 1fd471e957 | |||
| 2c5ebec064 | |||
| 2e610deb72 | |||
| 6e2c19ce22 | |||
| 47db8c2c15 | |||
| 462b269280 | |||
| c18b3b8e8b | |||
| 9528e3a05e | |||
| 9fb52e523a | |||
| e202dd2736 | |||
| 43813e6361 | |||
| cede942b87 | |||
| fe1e924811 | |||
| 4548c03c50 | |||
| 40b86aa05e | |||
| 432870829d | |||
| f73d02aadc | |||
| c5ebe040ac | |||
| 8d763cb891 | |||
| cf4cd53982 | |||
| 32c9be2200 | |||
| 8aeaa910a2 | |||
| 906e05d840 | |||
| ef9a2990ae | |||
| 7e90870491 | |||
| d3f05c9248 | |||
| c108781c85 | |||
| 3d184b95b8 | |||
| 2f35a022e6 | |||
| ffe00ef77a | |||
| 5561681d04 | |||
| fbd62d8750 | |||
| 2e26f9156a | |||
| 9e5452ee34 | |||
| 0e3fe896e2 | |||
| 1caca5a589 | |||
| 783921d889 | |||
| 4a98edff1f | |||
| a7bab0c9e5 | |||
| 25950dca9b | |||
| a4113b035c | |||
| 7e1665b089 | |||
| 8d1096e7db | |||
| 8d775dd30a | |||
| 78fe77534b | |||
| 2f2fcb31b8 | |||
| 1dba2c4ebe | |||
| 71d6de3a26 | |||
| 536fd33003 | |||
| 619b9f5c7e | |||
| d1b689c445 | |||
| 9854dc9040 | |||
| ff5c60fad8 | |||
| 6f1229f91d | |||
| 1819fbda63 | |||
| 7f0367109e | |||
| fb14d53cf6 | |||
| b024a42e93 | |||
| cb97f2bfc5 | |||
| 359200f6ac | |||
| 220aee902a | |||
| 67d25eca05 | |||
| 363528de27 | |||
| 4ff61ababa | |||
| 0ec3779df7 | |||
| b616f6a53d | |||
| 2e25bb12a8 | |||
| 9965c47d0d | |||
| 059d4cdb49 | |||
| bdb84e26b0 | |||
| 3dd359147d | |||
| 657f2f301a | |||
| a1aafc827a | |||
| 139508a418 | |||
| d265414dbc | |||
| 48fb076cbc | |||
| c1909e7e8c | |||
| b95877509b | |||
| 706ff13224 | |||
| ccbfb1d1c9 | |||
| 9e5552aa13 | |||
| 0c600b9ab6 | |||
| e303dcf523 | |||
| ae9c4d416f | |||
| d853520b3e | |||
| ba51aea65e | |||
| 8452946c06 | |||
| 2e7cbf2d7d | |||
| 7da296be04 | |||
| b205e8467d | |||
| be0cfb2b68 | |||
| 1a03dd496b | |||
| 27b8017636 | |||
| 9ec1e3065a | |||
| 9dae7d46bf | |||
| 7058d7dd5d | |||
| a0389e0554 | |||
| 3be8d312a2 | |||
| 3abfe22154 | |||
| e81fbefe8a | |||
| 9290de5667 | |||
| 7f280d69c9 | |||
| 02cabff207 | |||
| 3d19d47d91 | |||
| 8acb4badee | |||
| 314af8617c | |||
| 0e96cc9b7e | |||
| ecad851cbd | |||
| ed70f3c64f | |||
| 650d5dbd04 | |||
| 9025a9a705 | |||
| c05596f1a3 | |||
| 787b13389e | |||
| 96453cfa83 | |||
| b1c1fe35a5 | |||
| 08d81f1014 | |||
| 6cc1e7d96d | |||
| 9909726d2a | |||
| 22e9d42040 | |||
| 86debab54c | |||
| be250bbc67 | |||
| 27949354fa | |||
| bd5038af07 | |||
| a2f14dc8f9 | |||
| 92ee7baaf9 | |||
| 7151f92241 | |||
| e28533a16f | |||
| 6d42ce8315 | |||
| ded1fb635b | |||
| 97d9524fe9 | |||
| d8cf819a9a | |||
| 551ef1631a | |||
| 2863befce3 | |||
| 2965c99c86 | |||
| 2062c0723d | |||
| 1c50e100a9 | |||
| 3ee56e26be | |||
| 8fe7fc8634 | |||
| e936e401de | |||
| f5dfa07531 | |||
| 022c58b80f | |||
| 19108ef311 | |||
| 5a52f389dd | |||
| 65b1cbb138 | |||
| 6c9837a761 | |||
| 6f2f53a82d | |||
| 7b1895e6ce | |||
| 4d36693687 | |||
| daec9dea6e | |||
| daceac57c7 | |||
| 8615d9776f | |||
| 7b460c25f9 | |||
| f719772281 | |||
| d45417b804 | |||
| a29e62ea34 | |||
| e53be6f00a | |||
| c329ceca6d |
@ -46,6 +46,6 @@ while getopts "m:b:l:f:t:" OPT; do
|
||||
done
|
||||
|
||||
lm_eval --model vllm \
|
||||
--model_args "pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,distributed_executor_backend=ray,trust_remote_code=true,max_model_len=4096" \
|
||||
--model_args "pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,add_bos_token=true,trust_remote_code=true,max_model_len=4096" \
|
||||
--tasks gsm8k --num_fewshot "$FEWSHOT" --limit "$LIMIT" \
|
||||
--batch_size "$BATCH_SIZE"
|
||||
|
||||
@ -18,12 +18,14 @@ RTOL = 0.08
|
||||
|
||||
def launch_lm_eval(eval_config, tp_size):
|
||||
trust_remote_code = eval_config.get("trust_remote_code", False)
|
||||
max_model_len = eval_config.get("max_model_len", 4096)
|
||||
model_args = (
|
||||
f"pretrained={eval_config['model_name']},"
|
||||
f"tensor_parallel_size={tp_size},"
|
||||
f"enforce_eager=true,"
|
||||
f"add_bos_token=true,"
|
||||
f"trust_remote_code={trust_remote_code}"
|
||||
f"trust_remote_code={trust_remote_code},"
|
||||
f"max_model_len={max_model_len}"
|
||||
)
|
||||
results = lm_eval.simple_evaluate(
|
||||
model="vllm",
|
||||
|
||||
@ -11,7 +11,7 @@ See [vLLM performance dashboard](https://perf.vllm.ai) for the latest performanc
|
||||
|
||||
## Performance benchmark quick overview
|
||||
|
||||
**Benchmarking Coverage**: latency, throughput and fix-qps serving on A100 (the support for FP8 benchmark on H100 is coming!), with different models.
|
||||
**Benchmarking Coverage**: latency, throughput and fix-qps serving on A100 (the support for FP8 benchmark on H100 is coming!) and Intel® Xeon® Processors, with different models.
|
||||
|
||||
**Benchmarking Duration**: about 1hr.
|
||||
|
||||
@ -31,13 +31,27 @@ Performance benchmark will be triggered when:
|
||||
- A PR being merged into vllm.
|
||||
- Every commit for those PRs with `perf-benchmarks` label AND `ready` label.
|
||||
|
||||
Manually Trigger the benchmark
|
||||
|
||||
```bash
|
||||
bash .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh
|
||||
```
|
||||
|
||||
Runtime environment variables:
|
||||
- `ON_CPU`: set the value to '1' on Intel® Xeon® Processors. Default value is 0.
|
||||
- `SERVING_JSON`: JSON file to use for the serving tests. Default value is empty string (use default file).
|
||||
- `LATENCY_JSON`: JSON file to use for the latency tests. Default value is empty string (use default file).
|
||||
- `THROUGHPUT_JSON`: JSON file to use for the throughout tests. Default value is empty string (use default file).
|
||||
- `REMOTE_HOST`: IP for the remote vLLM service to benchmark. Default value is empty string.
|
||||
- `REMOTE_PORT`: Port for the remote vLLM service to benchmark. Default value is empty string.
|
||||
|
||||
Nightly benchmark will be triggered when:
|
||||
- Every commit for those PRs with `perf-benchmarks` label and `nightly-benchmarks` label.
|
||||
|
||||
## Performance benchmark details
|
||||
|
||||
See [performance-benchmarks-descriptions.md](performance-benchmarks-descriptions.md) for detailed descriptions, and use `tests/latency-tests.json`, `tests/throughput-tests.json`, `tests/serving-tests.json` to configure the test cases.
|
||||
|
||||
> NOTE: For Intel® Xeon® Processors, use `tests/latency-tests-cpu.json`, `tests/throughput-tests-cpu.json`, `tests/serving-tests-cpu.json` instead.
|
||||
### Latency test
|
||||
|
||||
Here is an example of one test inside `latency-tests.json`:
|
||||
@ -60,7 +74,7 @@ Here is an example of one test inside `latency-tests.json`:
|
||||
In this example:
|
||||
|
||||
- The `test_name` attributes is a unique identifier for the test. In `latency-tests.json`, it must start with `latency_`.
|
||||
- The `parameters` attribute control the command line arguments to be used for `benchmark_latency.py`. Note that please use underline `_` instead of the dash `-` when specifying the command line arguments, and `run-performance-benchmarks.sh` will convert the underline to dash when feeding the arguments to `benchmark_latency.py`. For example, the corresponding command line arguments for `benchmark_latency.py` will be `--model meta-llama/Meta-Llama-3-8B --tensor-parallel-size 1 --load-format dummy --num-iters-warmup 5 --num-iters 15`
|
||||
- The `parameters` attribute control the command line arguments to be used for `vllm bench latency`. Note that please use underline `_` instead of the dash `-` when specifying the command line arguments, and `run-performance-benchmarks.sh` will convert the underline to dash when feeding the arguments to `vllm bench latency`. For example, the corresponding command line arguments for `vllm bench latency` will be `--model meta-llama/Meta-Llama-3-8B --tensor-parallel-size 1 --load-format dummy --num-iters-warmup 5 --num-iters 15`
|
||||
|
||||
Note that the performance numbers are highly sensitive to the value of the parameters. Please make sure the parameters are set correctly.
|
||||
|
||||
@ -68,13 +82,13 @@ WARNING: The benchmarking script will save json results by itself, so please do
|
||||
|
||||
### Throughput test
|
||||
|
||||
The tests are specified in `throughput-tests.json`. The syntax is similar to `latency-tests.json`, except for that the parameters will be fed forward to `benchmark_throughput.py`.
|
||||
The tests are specified in `throughput-tests.json`. The syntax is similar to `latency-tests.json`, except for that the parameters will be fed forward to `vllm bench throughput`.
|
||||
|
||||
The number of this test is also stable -- a slight change on the value of this number might vary the performance numbers by a lot.
|
||||
|
||||
### Serving test
|
||||
|
||||
We test the throughput by using `benchmark_serving.py` with request rate = inf to cover the online serving overhead. The corresponding parameters are in `serving-tests.json`, and here is an example:
|
||||
We test the throughput by using `vllm bench serve` with request rate = inf to cover the online serving overhead. The corresponding parameters are in `serving-tests.json`, and here is an example:
|
||||
|
||||
```json
|
||||
[
|
||||
@ -104,8 +118,8 @@ Inside this example:
|
||||
|
||||
- The `test_name` attribute is also a unique identifier for the test. It must start with `serving_`.
|
||||
- The `server-parameters` includes the command line arguments for vLLM server.
|
||||
- The `client-parameters` includes the command line arguments for `benchmark_serving.py`.
|
||||
- The `qps_list` controls the list of qps for test. It will be used to configure the `--request-rate` parameter in `benchmark_serving.py`
|
||||
- The `client-parameters` includes the command line arguments for `vllm bench serve`.
|
||||
- The `qps_list` controls the list of qps for test. It will be used to configure the `--request-rate` parameter in `vllm bench serve`
|
||||
|
||||
The number of this test is less stable compared to the delay and latency benchmarks (due to randomized sharegpt dataset sampling inside `benchmark_serving.py`), but a large change on this number (e.g. 5% change) still vary the output greatly.
|
||||
|
||||
@ -119,6 +133,30 @@ If you do not see the table, please wait till the benchmark finish running.
|
||||
The json version of the table (together with the json version of the benchmark) will be also attached to the markdown file.
|
||||
The raw benchmarking results (in the format of json files) are in the `Artifacts` tab of the benchmarking.
|
||||
|
||||
The `compare-json-results.py` helps to compare benchmark results JSON files converted using `convert-results-json-to-markdown.py`.
|
||||
When run, benchmark script generates results under `benchmark/results` folder, along with the `benchmark_results.md` and `benchmark_results.json`.
|
||||
`compare-json-results.py` compares two `benchmark_results.json` files and provides performance ratio e.g. for Output Tput, Median TTFT and Median TPOT.
|
||||
|
||||
Here is an example using the script to compare result_a and result_b without detail test name.
|
||||
`python3 compare-json-results.py -f results_a/benchmark_results.json -f results_b/benchmark_results.json --ignore_test_name`
|
||||
|
||||
| | results_a/benchmark_results.json | results_b/benchmark_results.json | perf_ratio |
|
||||
|----|----------------------------------------|----------------------------------------|----------|
|
||||
| 0 | 142.633982 | 156.526018 | 1.097396 |
|
||||
| 1 | 241.620334 | 294.018783 | 1.216863 |
|
||||
| 2 | 218.298905 | 262.664916 | 1.203235 |
|
||||
| 3 | 242.743860 | 299.816190 | 1.235113 |
|
||||
|
||||
Here is an example using the script to compare result_a and result_b with detail test name.
|
||||
`python3 compare-json-results.py -f results_a/benchmark_results.json -f results_b/benchmark_results.json`
|
||||
| | results_a/benchmark_results.json_name | results_a/benchmark_results.json | results_b/benchmark_results.json_name | results_b/benchmark_results.json | perf_ratio |
|
||||
|---|---------------------------------------------|----------------------------------------|---------------------------------------------|----------------------------------------|----------|
|
||||
| 0 | serving_llama8B_tp1_sharegpt_qps_1 | 142.633982 | serving_llama8B_tp1_sharegpt_qps_1 | 156.526018 | 1.097396 |
|
||||
| 1 | serving_llama8B_tp1_sharegpt_qps_16 | 241.620334 | serving_llama8B_tp1_sharegpt_qps_16 | 294.018783 | 1.216863 |
|
||||
| 2 | serving_llama8B_tp1_sharegpt_qps_4 | 218.298905 | serving_llama8B_tp1_sharegpt_qps_4 | 262.664916 | 1.203235 |
|
||||
| 3 | serving_llama8B_tp1_sharegpt_qps_inf | 242.743860 | serving_llama8B_tp1_sharegpt_qps_inf | 299.816190 | 1.235113 |
|
||||
| 4 | serving_llama8B_tp2_random_1024_128_qps_1 | 96.613390 | serving_llama8B_tp4_random_1024_128_qps_1 | 108.404853 | 1.122048 |
|
||||
|
||||
## Nightly test details
|
||||
|
||||
See [nightly-descriptions.md](nightly-descriptions.md) for the detailed description on test workload, models and docker containers of benchmarking other llm engines.
|
||||
|
||||
@ -4,7 +4,8 @@
|
||||
- Input length: 32 tokens.
|
||||
- Output length: 128 tokens.
|
||||
- Batch size: fixed (8).
|
||||
- Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B.
|
||||
- GPU Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B.
|
||||
- CPU Models: llama-3.1 8B.
|
||||
- Evaluation metrics: end-to-end latency (mean, median, p99).
|
||||
|
||||
{latency_tests_markdown_table}
|
||||
@ -14,7 +15,8 @@
|
||||
- Input length: randomly sample 200 prompts from ShareGPT dataset (with fixed random seed).
|
||||
- Output length: the corresponding output length of these 200 prompts.
|
||||
- Batch size: dynamically determined by vllm to achieve maximum throughput.
|
||||
- Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B.
|
||||
- GPU Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B.
|
||||
- CPU Models: llama-3.1 8B.
|
||||
- Evaluation metrics: throughput.
|
||||
|
||||
{throughput_tests_markdown_table}
|
||||
@ -25,12 +27,18 @@
|
||||
- Output length: the corresponding output length of these 200 prompts.
|
||||
- Batch size: dynamically determined by vllm and the arrival pattern of the requests.
|
||||
- **Average QPS (query per second)**: 1, 4, 16 and inf. QPS = inf means all requests come at once. For other QPS values, the arrival time of each query is determined using a random Poisson process (with fixed random seed).
|
||||
- Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B.
|
||||
- We also added a speculative decoding test for llama-3 70B, under QPS 2
|
||||
- GPU Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B.
|
||||
- We also added a speculative decoding test for llama-3 70B on GPU, under QPS 2
|
||||
- CPU Models: llama-3.1 8B.
|
||||
- Evaluation metrics: throughput, TTFT (time to the first token, with mean, median and p99), ITL (inter-token latency, with mean, median and p99).
|
||||
- For CPU, we added random dataset tests to benchmark fixed input/output length with 100 prompts.
|
||||
|
||||
{serving_tests_markdown_table}
|
||||
|
||||
## Platform Information
|
||||
|
||||
{platform_markdown_table}
|
||||
|
||||
## json version of the benchmarking tables
|
||||
|
||||
This section contains the data of the markdown tables above in JSON format.
|
||||
|
||||
@ -0,0 +1,66 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def compare_data_columns(
|
||||
files, name_column, data_column, drop_column, ignore_test_name=False
|
||||
):
|
||||
print("\ncompare_data_column: " + data_column)
|
||||
frames = []
|
||||
compare_frames = []
|
||||
for file in files:
|
||||
data_df = pd.read_json(file)
|
||||
serving_df = data_df.dropna(subset=[drop_column], ignore_index=True)
|
||||
if ignore_test_name is False:
|
||||
serving_df = serving_df.rename(columns={name_column: file + "_name"})
|
||||
frames.append(serving_df[file + "_name"])
|
||||
serving_df = serving_df.rename(columns={data_column: file})
|
||||
frames.append(serving_df[file])
|
||||
compare_frames.append(serving_df[file])
|
||||
if len(compare_frames) >= 2:
|
||||
# Compare numbers among two files
|
||||
ratio_df = compare_frames[1] / compare_frames[0]
|
||||
frames.append(ratio_df)
|
||||
compare_frames.pop(1)
|
||||
|
||||
concat_df = pd.concat(frames, axis=1)
|
||||
return concat_df
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-f", "--file", action="append", type=str, help="input file name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ignore_test_name", action="store_true", help="ignore_test_name or not"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
files = args.file
|
||||
print("comparing : " + ", ".join(files))
|
||||
|
||||
drop_column = "P99"
|
||||
name_column = "Test name"
|
||||
data_cols_to_compare = ["Output Tput (tok/s)", "Median TTFT (ms)", "Median"]
|
||||
html_msgs_for_data_cols = [
|
||||
"Compare Output Tokens /n",
|
||||
"Median TTFT /n",
|
||||
"Median TPOT /n",
|
||||
]
|
||||
ignore_test_name = args.ignore_test_name
|
||||
with open("perf_comparison.html", "w") as text_file:
|
||||
for i in range(len(data_cols_to_compare)):
|
||||
output_df = compare_data_columns(
|
||||
files,
|
||||
name_column,
|
||||
data_cols_to_compare[i],
|
||||
drop_column,
|
||||
ignore_test_name=ignore_test_name,
|
||||
)
|
||||
print(output_df)
|
||||
html = output_df.to_html()
|
||||
text_file.write(html_msgs_for_data_cols[i])
|
||||
text_file.write(html)
|
||||
@ -3,9 +3,11 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from importlib import util
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import psutil
|
||||
from tabulate import tabulate
|
||||
|
||||
results_folder = Path("results/")
|
||||
@ -29,11 +31,11 @@ throughput_results = []
|
||||
throughput_results_column_mapping = {
|
||||
"test_name": "Test name",
|
||||
"gpu_type": "GPU",
|
||||
# "num_requests": "# of req.",
|
||||
# "total_num_tokens": "Total # of tokens",
|
||||
# "elapsed_time": "Elapsed time (s)",
|
||||
"num_requests": "# of req.",
|
||||
"total_num_tokens": "Total # of tokens",
|
||||
"elapsed_time": "Elapsed time (s)",
|
||||
"requests_per_second": "Tput (req/s)",
|
||||
# "tokens_per_second": "Tput (tok/s)",
|
||||
"tokens_per_second": "Tput (tok/s)",
|
||||
}
|
||||
|
||||
# serving results and the keys that will be printed into markdown
|
||||
@ -41,16 +43,18 @@ serving_results = []
|
||||
serving_column_mapping = {
|
||||
"test_name": "Test name",
|
||||
"gpu_type": "GPU",
|
||||
# "completed": "# of req.",
|
||||
"completed": "# of req.",
|
||||
"request_throughput": "Tput (req/s)",
|
||||
# "input_throughput": "Input Tput (tok/s)",
|
||||
# "output_throughput": "Output Tput (tok/s)",
|
||||
"total_token_throughput": "Total Token Tput (tok/s)",
|
||||
"output_throughput": "Output Tput (tok/s)",
|
||||
"total_input_tokens": "Total input tokens",
|
||||
"total_output_tokens": "Total output tokens",
|
||||
"mean_ttft_ms": "Mean TTFT (ms)",
|
||||
"median_ttft_ms": "Median TTFT (ms)",
|
||||
"p99_ttft_ms": "P99 TTFT (ms)",
|
||||
# "mean_tpot_ms": "Mean TPOT (ms)",
|
||||
# "median_tpot_ms": "Median",
|
||||
# "p99_tpot_ms": "P99",
|
||||
"mean_tpot_ms": "Mean TPOT (ms)",
|
||||
"median_tpot_ms": "Median",
|
||||
"p99_tpot_ms": "P99",
|
||||
"mean_itl_ms": "Mean ITL (ms)",
|
||||
"median_itl_ms": "Median ITL (ms)",
|
||||
"p99_itl_ms": "P99 ITL (ms)",
|
||||
@ -75,6 +79,20 @@ def results_to_json(latency, throughput, serving):
|
||||
)
|
||||
|
||||
|
||||
def get_size_with_unit(bytes, suffix="B"):
|
||||
"""
|
||||
Scale bytes to its proper format
|
||||
e.g:
|
||||
1253656 => '1.20MB'
|
||||
1253656678 => '1.17GB'
|
||||
"""
|
||||
factor = 1024
|
||||
for unit in ["", "K", "M", "G", "T", "P"]:
|
||||
if bytes < factor:
|
||||
return f"{bytes:.2f}{unit}{suffix}"
|
||||
bytes /= factor
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# collect results
|
||||
for test_file in results_folder.glob("*.json"):
|
||||
@ -82,7 +100,7 @@ if __name__ == "__main__":
|
||||
raw_result = json.loads(f.read())
|
||||
|
||||
if "serving" in str(test_file):
|
||||
# this result is generated via `benchmark_serving.py`
|
||||
# this result is generated via `vllm bench serve` command
|
||||
|
||||
# attach the benchmarking command to raw_result
|
||||
try:
|
||||
@ -102,7 +120,7 @@ if __name__ == "__main__":
|
||||
continue
|
||||
|
||||
elif "latency" in f.name:
|
||||
# this result is generated via `benchmark_latency.py`
|
||||
# this result is generated via `vllm bench latency` command
|
||||
|
||||
# attach the benchmarking command to raw_result
|
||||
try:
|
||||
@ -130,7 +148,7 @@ if __name__ == "__main__":
|
||||
continue
|
||||
|
||||
elif "throughput" in f.name:
|
||||
# this result is generated via `benchmark_throughput.py`
|
||||
# this result is generated via `vllm bench throughput` command
|
||||
|
||||
# attach the benchmarking command to raw_result
|
||||
try:
|
||||
@ -155,6 +173,27 @@ if __name__ == "__main__":
|
||||
serving_results = pd.DataFrame.from_dict(serving_results)
|
||||
throughput_results = pd.DataFrame.from_dict(throughput_results)
|
||||
|
||||
svmem = psutil.virtual_memory()
|
||||
platform_data = {
|
||||
"Physical cores": [psutil.cpu_count(logical=False)],
|
||||
"Total cores": [psutil.cpu_count(logical=True)],
|
||||
"Total Memory": [get_size_with_unit(svmem.total)],
|
||||
}
|
||||
|
||||
if util.find_spec("numa") is not None:
|
||||
from numa import info
|
||||
|
||||
platform_data["Total NUMA nodes"] = [info.get_num_configured_nodes()]
|
||||
|
||||
if util.find_spec("cpuinfo") is not None:
|
||||
from cpuinfo import get_cpu_info
|
||||
|
||||
platform_data["CPU Brand"] = [get_cpu_info()["brand_raw"]]
|
||||
|
||||
platform_results = pd.DataFrame.from_dict(
|
||||
platform_data, orient="index", columns=["Platform Info"]
|
||||
)
|
||||
|
||||
raw_results_json = results_to_json(
|
||||
latency_results, throughput_results, serving_results
|
||||
)
|
||||
@ -200,6 +239,9 @@ if __name__ == "__main__":
|
||||
throughput_md_table = tabulate(
|
||||
throughput_results, headers="keys", tablefmt="pipe", showindex=False
|
||||
)
|
||||
platform_md_table = tabulate(
|
||||
platform_results, headers="keys", tablefmt="pipe", showindex=True
|
||||
)
|
||||
|
||||
# document the result
|
||||
with open(results_folder / "benchmark_results.md", "w") as f:
|
||||
@ -211,6 +253,7 @@ if __name__ == "__main__":
|
||||
latency_tests_markdown_table=latency_md_table,
|
||||
throughput_tests_markdown_table=throughput_md_table,
|
||||
serving_tests_markdown_table=serving_md_table,
|
||||
platform_markdown_table=platform_md_table,
|
||||
benchmarking_results_in_json_string=processed_results_json,
|
||||
)
|
||||
f.write(results)
|
||||
|
||||
@ -73,7 +73,7 @@ get_current_llm_serving_engine() {
|
||||
echo "Container: vllm"
|
||||
# move to a completely irrelevant directory, to avoid import vllm from current folder
|
||||
export CURRENT_LLM_SERVING_ENGINE=vllm
|
||||
|
||||
|
||||
return
|
||||
fi
|
||||
}
|
||||
@ -95,12 +95,14 @@ json2args() {
|
||||
}
|
||||
|
||||
kill_gpu_processes() {
|
||||
pkill -f python
|
||||
pkill -f python3
|
||||
pkill -f tritonserver
|
||||
pkill -f pt_main_thread
|
||||
pkill -f text-generation
|
||||
pkill -f lmdeploy
|
||||
pkill -f '[p]ython'
|
||||
pkill -f '[p]ython3'
|
||||
pkill -f '[t]ritonserver'
|
||||
pkill -f '[p]t_main_thread'
|
||||
pkill -f '[t]ext-generation'
|
||||
pkill -f '[l]mdeploy'
|
||||
# vLLM now names the process with VLLM prefix after https://github.com/vllm-project/vllm/pull/21445
|
||||
pkill -f '[V]LLM'
|
||||
|
||||
while [ "$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits | head -n 1)" -ge 1000 ]; do
|
||||
sleep 1
|
||||
@ -125,7 +127,7 @@ ensure_installed() {
|
||||
}
|
||||
|
||||
run_serving_tests() {
|
||||
# run serving tests using `benchmark_serving.py`
|
||||
# run serving tests using `vllm bench serve` command
|
||||
# $1: a json file specifying serving test cases
|
||||
|
||||
local serving_test_file
|
||||
@ -225,7 +227,7 @@ run_serving_tests() {
|
||||
|
||||
if [[ "$dataset_name" = "sharegpt" ]]; then
|
||||
|
||||
client_command="python3 benchmark_serving.py \
|
||||
client_command="vllm bench serve \
|
||||
--backend $backend \
|
||||
--tokenizer /tokenizer_cache \
|
||||
--model $model \
|
||||
@ -246,7 +248,7 @@ run_serving_tests() {
|
||||
sonnet_output_len=$(echo "$common_params" | jq -r '.sonnet_output_len')
|
||||
sonnet_prefix_len=$(echo "$common_params" | jq -r '.sonnet_prefix_len')
|
||||
|
||||
client_command="python3 benchmark_serving.py \
|
||||
client_command="vllm bench serve \
|
||||
--backend $backend \
|
||||
--tokenizer /tokenizer_cache \
|
||||
--model $model \
|
||||
@ -265,13 +267,13 @@ run_serving_tests() {
|
||||
$client_args"
|
||||
|
||||
else
|
||||
|
||||
|
||||
echo "The dataset name must be either 'sharegpt' or 'sonnet'. Got $dataset_name."
|
||||
exit 1
|
||||
|
||||
fi
|
||||
|
||||
|
||||
|
||||
|
||||
echo "Running test case $test_name with qps $qps"
|
||||
echo "Client command: $client_command"
|
||||
@ -302,7 +304,7 @@ run_serving_tests() {
|
||||
}
|
||||
|
||||
run_genai_perf_tests() {
|
||||
# run genai-perf tests
|
||||
# run genai-perf tests
|
||||
|
||||
# $1: a json file specifying genai-perf test cases
|
||||
local genai_perf_test_file
|
||||
@ -311,14 +313,14 @@ run_genai_perf_tests() {
|
||||
# Iterate over genai-perf tests
|
||||
jq -c '.[]' "$genai_perf_test_file" | while read -r params; do
|
||||
# get the test name, and append the GPU type back to it.
|
||||
test_name=$(echo "$params" | jq -r '.test_name')
|
||||
|
||||
test_name=$(echo "$params" | jq -r '.test_name')
|
||||
|
||||
# if TEST_SELECTOR is set, only run the test cases that match the selector
|
||||
if [[ -n "$TEST_SELECTOR" ]] && [[ ! "$test_name" =~ $TEST_SELECTOR ]]; then
|
||||
echo "Skip test case $test_name."
|
||||
continue
|
||||
fi
|
||||
|
||||
|
||||
# prepend the current serving engine to the test name
|
||||
test_name=${CURRENT_LLM_SERVING_ENGINE}_${test_name}
|
||||
|
||||
@ -369,10 +371,10 @@ run_genai_perf_tests() {
|
||||
qps=$num_prompts
|
||||
echo "now qps is $qps"
|
||||
fi
|
||||
|
||||
|
||||
new_test_name=$test_name"_qps_"$qps
|
||||
backend=$CURRENT_LLM_SERVING_ENGINE
|
||||
|
||||
|
||||
if [[ "$backend" == *"vllm"* ]]; then
|
||||
backend="vllm"
|
||||
fi
|
||||
@ -413,7 +415,7 @@ prepare_dataset() {
|
||||
do
|
||||
cat sonnet.txt >> sonnet_4x.txt
|
||||
done
|
||||
|
||||
|
||||
}
|
||||
|
||||
main() {
|
||||
|
||||
@ -31,6 +31,20 @@ check_gpus() {
|
||||
echo "GPU type is $gpu_type"
|
||||
}
|
||||
|
||||
check_cpus() {
|
||||
# check the number of CPUs and NUMA Node and GPU type.
|
||||
declare -g numa_count=$(python3 -c "from numa import info;numa_size = info.get_num_configured_nodes(); print(numa_size)")
|
||||
if [[ $numa_count -gt 0 ]]; then
|
||||
echo "NUMA found."
|
||||
echo $numa_count
|
||||
else
|
||||
echo "Need at least 1 NUMA to run benchmarking."
|
||||
exit 1
|
||||
fi
|
||||
declare -g gpu_type="cpu"
|
||||
echo "GPU type is $gpu_type"
|
||||
}
|
||||
|
||||
check_hf_token() {
|
||||
# check if HF_TOKEN is available and valid
|
||||
if [[ -z "$HF_TOKEN" ]]; then
|
||||
@ -69,6 +83,22 @@ json2args() {
|
||||
echo "$args"
|
||||
}
|
||||
|
||||
json2envs() {
|
||||
# transforms the JSON string to environment variables.
|
||||
# example:
|
||||
# input: { "VLLM_CPU_KVCACHE_SPACE": 5 }
|
||||
# output: VLLM_CPU_KVCACHE_SPACE=5
|
||||
local json_string=$1
|
||||
local args=$(
|
||||
echo "$json_string" | jq -r '
|
||||
to_entries |
|
||||
map((.key ) + "=" + (.value | tostring)) |
|
||||
join(" ")
|
||||
'
|
||||
)
|
||||
echo "$args"
|
||||
}
|
||||
|
||||
wait_for_server() {
|
||||
# wait for vllm server to start
|
||||
# return 1 if vllm server crashes
|
||||
@ -96,7 +126,8 @@ kill_gpu_processes() {
|
||||
ps -aux
|
||||
lsof -t -i:8000 | xargs -r kill -9
|
||||
pgrep python3 | xargs -r kill -9
|
||||
|
||||
# vLLM now names the process with VLLM prefix after https://github.com/vllm-project/vllm/pull/21445
|
||||
pgrep VLLM | xargs -r kill -9
|
||||
|
||||
# wait until GPU memory usage smaller than 1GB
|
||||
if command -v nvidia-smi; then
|
||||
@ -134,7 +165,7 @@ upload_to_buildkite() {
|
||||
}
|
||||
|
||||
run_latency_tests() {
|
||||
# run latency tests using `benchmark_latency.py`
|
||||
# run latency tests using `vllm bench latency` command
|
||||
# $1: a json file specifying latency test cases
|
||||
|
||||
local latency_test_file
|
||||
@ -158,15 +189,24 @@ run_latency_tests() {
|
||||
# get arguments
|
||||
latency_params=$(echo "$params" | jq -r '.parameters')
|
||||
latency_args=$(json2args "$latency_params")
|
||||
latency_environment_variables=$(echo "$params" | jq -r '.environment_variables')
|
||||
latency_envs=$(json2envs "$latency_environment_variables")
|
||||
|
||||
# check if there is enough GPU to run the test
|
||||
tp=$(echo "$latency_params" | jq -r '.tensor_parallel_size')
|
||||
if [[ $gpu_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name."
|
||||
continue
|
||||
if [ "$ON_CPU" == "1" ];then
|
||||
if [[ $numa_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $numa_count NUMA nodes found. Skip testcase $test_name."
|
||||
continue
|
||||
fi
|
||||
else
|
||||
if [[ $gpu_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name."
|
||||
continue
|
||||
fi
|
||||
fi
|
||||
|
||||
latency_command="python3 benchmark_latency.py \
|
||||
latency_command=" $latency_envs vllm bench latency \
|
||||
--output-json $RESULTS_FOLDER/${test_name}.json \
|
||||
$latency_args"
|
||||
|
||||
@ -192,7 +232,7 @@ run_latency_tests() {
|
||||
}
|
||||
|
||||
run_throughput_tests() {
|
||||
# run throughput tests using `benchmark_throughput.py`
|
||||
# run throughput tests using `vllm bench throughput`
|
||||
# $1: a json file specifying throughput test cases
|
||||
|
||||
local throughput_test_file
|
||||
@ -216,15 +256,24 @@ run_throughput_tests() {
|
||||
# get arguments
|
||||
throughput_params=$(echo "$params" | jq -r '.parameters')
|
||||
throughput_args=$(json2args "$throughput_params")
|
||||
throughput_environment_variables=$(echo "$params" | jq -r '.environment_variables')
|
||||
throughput_envs=$(json2envs "$throughput_environment_variables")
|
||||
|
||||
# check if there is enough GPU to run the test
|
||||
tp=$(echo "$throughput_params" | jq -r '.tensor_parallel_size')
|
||||
if [[ $gpu_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name."
|
||||
continue
|
||||
if [ "$ON_CPU" == "1" ];then
|
||||
if [[ $numa_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $numa_count NUMA nodes found. Skip testcase $test_name."
|
||||
continue
|
||||
fi
|
||||
else
|
||||
if [[ $gpu_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name."
|
||||
continue
|
||||
fi
|
||||
fi
|
||||
|
||||
throughput_command="python3 benchmark_throughput.py \
|
||||
throughput_command=" $throughput_envs vllm bench throughput \
|
||||
--output-json $RESULTS_FOLDER/${test_name}.json \
|
||||
$throughput_args"
|
||||
|
||||
@ -249,7 +298,7 @@ run_throughput_tests() {
|
||||
}
|
||||
|
||||
run_serving_tests() {
|
||||
# run serving tests using `benchmark_serving.py`
|
||||
# run serving tests using `vllm bench serve` command
|
||||
# $1: a json file specifying serving test cases
|
||||
|
||||
local serving_test_file
|
||||
@ -272,18 +321,27 @@ run_serving_tests() {
|
||||
|
||||
# get client and server arguments
|
||||
server_params=$(echo "$params" | jq -r '.server_parameters')
|
||||
server_envs=$(echo "$params" | jq -r '.server_environment_variables')
|
||||
client_params=$(echo "$params" | jq -r '.client_parameters')
|
||||
server_args=$(json2args "$server_params")
|
||||
server_envs=$(json2envs "$server_envs")
|
||||
client_args=$(json2args "$client_params")
|
||||
qps_list=$(echo "$params" | jq -r '.qps_list')
|
||||
qps_list=$(echo "$qps_list" | jq -r '.[] | @sh')
|
||||
echo "Running over qps list $qps_list"
|
||||
|
||||
# check if there is enough GPU to run the test
|
||||
# check if there is enough resources to run the test
|
||||
tp=$(echo "$server_params" | jq -r '.tensor_parallel_size')
|
||||
if [[ $gpu_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name."
|
||||
continue
|
||||
if [ "$ON_CPU" == "1" ];then
|
||||
if [[ $numa_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $numa_count NUMA nodes found. Skip testcase $test_name."
|
||||
continue
|
||||
fi
|
||||
else
|
||||
if [[ $gpu_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name."
|
||||
continue
|
||||
fi
|
||||
fi
|
||||
|
||||
# check if server model and client model is aligned
|
||||
@ -294,23 +352,33 @@ run_serving_tests() {
|
||||
continue
|
||||
fi
|
||||
|
||||
server_command="python3 \
|
||||
server_command="$server_envs python3 \
|
||||
-m vllm.entrypoints.openai.api_server \
|
||||
$server_args"
|
||||
|
||||
# run the server
|
||||
echo "Running test case $test_name"
|
||||
echo "Server command: $server_command"
|
||||
bash -c "$server_command" &
|
||||
server_pid=$!
|
||||
|
||||
# wait until the server is alive
|
||||
if wait_for_server; then
|
||||
echo ""
|
||||
echo "vllm server is up and running."
|
||||
# support remote vllm server
|
||||
client_remote_args=""
|
||||
if [[ -z "${REMOTE_HOST}" ]]; then
|
||||
bash -c "$server_command" &
|
||||
server_pid=$!
|
||||
# wait until the server is alive
|
||||
if wait_for_server; then
|
||||
echo ""
|
||||
echo "vLLM server is up and running."
|
||||
else
|
||||
echo ""
|
||||
echo "vLLM failed to start within the timeout period."
|
||||
fi
|
||||
else
|
||||
echo ""
|
||||
echo "vllm failed to start within the timeout period."
|
||||
server_command="Using Remote Server $REMOTE_HOST $REMOTE_PORT"
|
||||
if [[ ${REMOTE_PORT} ]]; then
|
||||
client_remote_args=" --host=$REMOTE_HOST --port=$REMOTE_PORT "
|
||||
else
|
||||
client_remote_args=" --host=$REMOTE_HOST "
|
||||
fi
|
||||
fi
|
||||
|
||||
# iterate over different QPS
|
||||
@ -326,13 +394,13 @@ run_serving_tests() {
|
||||
|
||||
# pass the tensor parallel size to the client so that it can be displayed
|
||||
# on the benchmark dashboard
|
||||
client_command="python3 benchmark_serving.py \
|
||||
client_command="vllm bench serve \
|
||||
--save-result \
|
||||
--result-dir $RESULTS_FOLDER \
|
||||
--result-filename ${new_test_name}.json \
|
||||
--request-rate $qps \
|
||||
--metadata "tensor_parallel_size=$tp" \
|
||||
$client_args"
|
||||
$client_args $client_remote_args "
|
||||
|
||||
echo "Running test case $test_name with qps $qps"
|
||||
echo "Client command: $client_command"
|
||||
@ -360,7 +428,14 @@ run_serving_tests() {
|
||||
}
|
||||
|
||||
main() {
|
||||
check_gpus
|
||||
local ARCH
|
||||
ARCH=''
|
||||
if [ "$ON_CPU" == "1" ];then
|
||||
check_cpus
|
||||
ARCH='-cpu'
|
||||
else
|
||||
check_gpus
|
||||
fi
|
||||
check_hf_token
|
||||
|
||||
# Set to v1 to run v1 benchmark
|
||||
@ -373,7 +448,7 @@ main() {
|
||||
(which jq) || (apt-get update && apt-get -y install jq)
|
||||
(which lsof) || (apt-get update && apt-get install -y lsof)
|
||||
|
||||
# get the current IP address, required by benchmark_serving.py
|
||||
# get the current IP address, required by `vllm bench serve` command
|
||||
export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')
|
||||
# turn of the reporting of the status of each request, to clean up the terminal output
|
||||
export VLLM_LOGGING_LEVEL="WARNING"
|
||||
@ -386,9 +461,9 @@ main() {
|
||||
QUICK_BENCHMARK_ROOT=../.buildkite/nightly-benchmarks/
|
||||
|
||||
# benchmarking
|
||||
run_serving_tests $QUICK_BENCHMARK_ROOT/tests/serving-tests.json
|
||||
run_latency_tests $QUICK_BENCHMARK_ROOT/tests/latency-tests.json
|
||||
run_throughput_tests $QUICK_BENCHMARK_ROOT/tests/throughput-tests.json
|
||||
run_serving_tests $QUICK_BENCHMARK_ROOT/tests/"${SERVING_JSON:-serving-tests$ARCH.json}"
|
||||
run_latency_tests $QUICK_BENCHMARK_ROOT/tests/"${LATENCY_JSON:-latency-tests$ARCH.json}"
|
||||
run_throughput_tests $QUICK_BENCHMARK_ROOT/tests/"${THROUGHPUT_JSON:-throughput-tests$ARCH.json}"
|
||||
|
||||
# postprocess benchmarking results
|
||||
pip install tabulate pandas
|
||||
|
||||
30
.buildkite/nightly-benchmarks/tests/latency-tests-cpu.json
Normal file
30
.buildkite/nightly-benchmarks/tests/latency-tests-cpu.json
Normal file
@ -0,0 +1,30 @@
|
||||
[
|
||||
{
|
||||
"test_name": "latency_llama8B_tp1",
|
||||
"environment_variables": {
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 1,
|
||||
"load_format": "dummy",
|
||||
"num_iters_warmup": 5,
|
||||
"num_iters": 15
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "latency_llama8B_tp4",
|
||||
"environment_variables": {
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 4,
|
||||
"load_format": "dummy",
|
||||
"num_iters_warmup": 5,
|
||||
"num_iters": 15
|
||||
}
|
||||
}
|
||||
]
|
||||
158
.buildkite/nightly-benchmarks/tests/serving-tests-cpu.json
Normal file
158
.buildkite/nightly-benchmarks/tests/serving-tests-cpu.json
Normal file
@ -0,0 +1,158 @@
|
||||
[
|
||||
{
|
||||
"test_name": "serving_llama8B_tp1_sharegpt",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_ENGINE_ITERATION_TIMEOUT_S": 120,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 1,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
"block_size": 128,
|
||||
"trust_remote_code": "",
|
||||
"disable_log_stats": "",
|
||||
"disable_log_requests": "",
|
||||
"enforce_eager": "",
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"max_concurrency": 60,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_tp2_sharegpt",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_ENGINE_ITERATION_TIMEOUT_S": 120,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 2,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
"block_size": 128,
|
||||
"trust_remote_code": "",
|
||||
"disable_log_stats": "",
|
||||
"disable_log_requests": "",
|
||||
"enforce_eager": "",
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"max_concurrency": 60,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_tp4_sharegpt",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_ENGINE_ITERATION_TIMEOUT_S": 120,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 4,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
"block_size": 128,
|
||||
"trust_remote_code": "",
|
||||
"disable_log_stats": "",
|
||||
"disable_log_requests": "",
|
||||
"enforce_eager": "",
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"max_concurrency": 60,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_tp4_random_1024_128",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_ENGINE_ITERATION_TIMEOUT_S": 120,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 4,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
"block_size": 128,
|
||||
"trust_remote_code": "",
|
||||
"enable_chunked_prefill": "",
|
||||
"disable_log_stats": "",
|
||||
"disable_log_requests": "",
|
||||
"enforce_eager": "",
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "random",
|
||||
"random-input-len": 1024,
|
||||
"random-output-len": 128,
|
||||
"ignore-eos": "",
|
||||
"max_concurrency": 100,
|
||||
"num_prompts": 100
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_pp6_random_1024_128",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_ENGINE_ITERATION_TIMEOUT_S": 120,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"pipeline_parallel_size": 6,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
"block_size": 128,
|
||||
"trust_remote_code": "",
|
||||
"enable_chunked_prefill": "",
|
||||
"disable_log_stats": "",
|
||||
"disable_log_requests": "",
|
||||
"enforce_eager": "",
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "random",
|
||||
"random-input-len": 1024,
|
||||
"random-output-len": 128,
|
||||
"ignore-eos": "",
|
||||
"max_concurrency": 100,
|
||||
"num_prompts": 100
|
||||
}
|
||||
}
|
||||
]
|
||||
@ -0,0 +1,32 @@
|
||||
[
|
||||
{
|
||||
"test_name": "throughput_llama8B_tp1",
|
||||
"environment_variables": {
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 1,
|
||||
"load_format": "dummy",
|
||||
"dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"num_prompts": 200,
|
||||
"backend": "vllm"
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "throughput_llama8B_tp4",
|
||||
"environment_variables": {
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 4,
|
||||
"load_format": "dummy",
|
||||
"dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"num_prompts": 200,
|
||||
"backend": "vllm"
|
||||
}
|
||||
}
|
||||
]
|
||||
@ -52,7 +52,7 @@ steps:
|
||||
queue: cpu_queue_postmerge
|
||||
commands:
|
||||
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ."
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ."
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT"
|
||||
|
||||
- label: "Annotate release workflow"
|
||||
@ -101,7 +101,7 @@ steps:
|
||||
queue: cpu_queue_postmerge
|
||||
commands:
|
||||
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest --progress plain --target vllm-openai -f docker/Dockerfile.cpu ."
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --build-arg VLLM_CPU_AVX512BF16=true --build-arg VLLM_CPU_AVX512VNNI=true --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest --progress plain --target vllm-openai -f docker/Dockerfile.cpu ."
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest"
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version)"
|
||||
env:
|
||||
|
||||
@ -107,10 +107,8 @@ fi
|
||||
|
||||
if [[ $commands == *" kernels/attention"* ]]; then
|
||||
commands="${commands} \
|
||||
--ignore=kernels/attention/stest_attention_selector.py \
|
||||
--ignore=kernels/attention/test_blocksparse_attention.py \
|
||||
--ignore=kernels/attention/test_encoder_decoder_attn.py \
|
||||
--ignore=kernels/attention/test_attention_selector.py \
|
||||
--ignore=kernels/attention/test_encoder_decoder_attn.py \
|
||||
--ignore=kernels/attention/test_flash_attn.py \
|
||||
--ignore=kernels/attention/test_flashinfer.py \
|
||||
--ignore=kernels/attention/test_prefix_prefill.py \
|
||||
|
||||
@ -6,15 +6,16 @@ set -ex
|
||||
|
||||
# allow to bind to different cores
|
||||
CORE_RANGE=${CORE_RANGE:-48-95}
|
||||
# used for TP/PP E2E test
|
||||
OMP_CORE_RANGE=${OMP_CORE_RANGE:-48-95}
|
||||
NUMA_NODE=${NUMA_NODE:-1}
|
||||
|
||||
export CMAKE_BUILD_PARALLEL_LEVEL=32
|
||||
|
||||
# Setup cleanup
|
||||
remove_docker_container() {
|
||||
set -e;
|
||||
docker rm -f cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE"-avx2 || true;
|
||||
remove_docker_container() {
|
||||
set -e;
|
||||
docker rm -f cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE"-avx2 || true;
|
||||
}
|
||||
trap remove_docker_container EXIT
|
||||
remove_docker_container
|
||||
@ -24,8 +25,8 @@ numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --tag cpu-test-"$NUMA_NODE
|
||||
numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" --tag cpu-test-"$NUMA_NODE"-avx2 --target vllm-test -f docker/Dockerfile.cpu .
|
||||
|
||||
# Run the image, setting --shm-size=4g for tensor parallel.
|
||||
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --env VLLM_CPU_CI_ENV=1 --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE"
|
||||
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --env VLLM_CPU_CI_ENV=1 --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2
|
||||
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE"
|
||||
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2
|
||||
|
||||
function cpu_tests() {
|
||||
set -e
|
||||
@ -48,9 +49,16 @@ function cpu_tests() {
|
||||
# Run basic model test
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
set -e
|
||||
pytest -v -s tests/kernels/attention/test_cache.py -m cpu_model
|
||||
pytest -v -s tests/kernels/attention/test_mla_decode_cpu.py -m cpu_model
|
||||
pytest -v -s tests/models/language/generation -m cpu_model
|
||||
# Note: disable until supports V1
|
||||
# pytest -v -s tests/kernels/attention/test_cache.py -m cpu_model
|
||||
# pytest -v -s tests/kernels/attention/test_mla_decode_cpu.py -m cpu_model
|
||||
|
||||
# Note: disable Bart until supports V1
|
||||
pytest -v -s tests/models/language/generation -m cpu_model \
|
||||
--ignore=tests/models/language/generation/test_bart.py
|
||||
VLLM_CPU_SGL_KERNEL=1 pytest -v -s tests/models/language/generation -m cpu_model \
|
||||
--ignore=tests/models/language/generation/test_bart.py
|
||||
|
||||
pytest -v -s tests/models/language/pooling -m cpu_model
|
||||
pytest -v -s tests/models/multimodal/generation \
|
||||
--ignore=tests/models/multimodal/generation/test_mllama.py \
|
||||
@ -61,33 +69,26 @@ function cpu_tests() {
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
set -e
|
||||
pytest -s -v \
|
||||
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_static_setup \
|
||||
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynamic_per_token"
|
||||
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_logprobs[False-10-32-neuralmagic/Llama-3.2-1B-quantized.w8a8]"
|
||||
|
||||
# Note: disable it until supports V1
|
||||
# Run AWQ test
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
set -e
|
||||
VLLM_USE_V1=0 pytest -s -v \
|
||||
tests/quantization/test_ipex_quant.py"
|
||||
|
||||
# Run chunked-prefill and prefix-cache test
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
set -e
|
||||
pytest -s -v -k cpu_model \
|
||||
tests/basic_correctness/test_chunked_prefill.py"
|
||||
# docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
# set -e
|
||||
# VLLM_USE_V1=0 pytest -s -v \
|
||||
# tests/quantization/test_ipex_quant.py"
|
||||
|
||||
# online serving
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c '
|
||||
set -e
|
||||
python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m --dtype half &
|
||||
timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
|
||||
VLLM_CPU_CI_ENV=0 python3 benchmarks/benchmark_serving.py \
|
||||
VLLM_CPU_OMP_THREADS_BIND=$E2E_OMP_THREADS VLLM_CPU_SGL_KERNEL=1 vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -pp=2 &
|
||||
timeout 600 bash -c "until curl localhost:8000/v1/models; do sleep 1; done" || exit 1
|
||||
vllm bench serve \
|
||||
--backend vllm \
|
||||
--dataset-name random \
|
||||
--model facebook/opt-125m \
|
||||
--model meta-llama/Llama-3.2-3B-Instruct \
|
||||
--num-prompts 20 \
|
||||
--endpoint /v1/completions \
|
||||
--tokenizer facebook/opt-125m"
|
||||
--endpoint /v1/completions'
|
||||
|
||||
# Run multi-lora tests
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
@ -98,4 +99,4 @@ function cpu_tests() {
|
||||
|
||||
# All of CPU tests are expected to be finished less than 40 mins.
|
||||
export -f cpu_tests
|
||||
timeout 1h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"
|
||||
timeout 1.5h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"
|
||||
|
||||
@ -2,10 +2,32 @@
|
||||
|
||||
# This script build the CPU docker image and run the offline inference inside the container.
|
||||
# It serves a sanity check for compilation and basic model usage.
|
||||
set -ex
|
||||
set -exuo pipefail
|
||||
|
||||
# Try building the docker image
|
||||
docker build -t hpu-test-env -f docker/Dockerfile.hpu .
|
||||
cat <<EOF | docker build -t hpu-plugin-v1-test-env -f - .
|
||||
FROM gaudi-base-image:latest
|
||||
|
||||
COPY ./ /workspace/vllm
|
||||
|
||||
WORKDIR /workspace/vllm
|
||||
|
||||
ENV no_proxy=localhost,127.0.0.1
|
||||
ENV PT_HPU_ENABLE_LAZY_COLLECTIVES=true
|
||||
|
||||
RUN VLLM_TARGET_DEVICE=empty pip install .
|
||||
RUN pip install git+https://github.com/vllm-project/vllm-gaudi.git
|
||||
|
||||
# install development dependencies (for testing)
|
||||
RUN python3 -m pip install -e tests/vllm_test_utils
|
||||
|
||||
WORKDIR /workspace/
|
||||
|
||||
RUN git clone https://github.com/vllm-project/vllm-gaudi.git
|
||||
|
||||
RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks
|
||||
|
||||
EOF
|
||||
|
||||
# Setup cleanup
|
||||
# certain versions of HPU software stack have a bug that can
|
||||
@ -14,13 +36,21 @@ docker build -t hpu-test-env -f docker/Dockerfile.hpu .
|
||||
# functions, while other platforms only need one remove_docker_container
|
||||
# function.
|
||||
EXITCODE=1
|
||||
remove_docker_containers() { docker rm -f hpu-test || true; docker rm -f hpu-test-tp2 || true; }
|
||||
remove_docker_containers_and_exit() { remove_docker_containers; exit $EXITCODE; }
|
||||
trap remove_docker_containers_and_exit EXIT
|
||||
remove_docker_containers() { docker rm -f hpu-plugin-v1-test || true; }
|
||||
trap 'remove_docker_containers; exit $EXITCODE;' EXIT
|
||||
remove_docker_containers
|
||||
|
||||
# Run the image and launch offline inference
|
||||
docker run --runtime=habana --name=hpu-test --network=host -e HABANA_VISIBLE_DEVICES=all -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m
|
||||
docker run --runtime=habana --name=hpu-test-tp2 --network=host -e HABANA_VISIBLE_DEVICES=all -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --tensor-parallel-size 2
|
||||
echo "Running HPU plugin v1 test"
|
||||
docker run --rm --runtime=habana --name=hpu-plugin-v1-test --network=host \
|
||||
-e HABANA_VISIBLE_DEVICES=all \
|
||||
hpu-plugin-v1-test-env \
|
||||
/bin/bash "/workspace/vllm-gaudi/tests/upstream_tests/ci_tests.sh"
|
||||
|
||||
EXITCODE=$?
|
||||
if [ $EXITCODE -eq 0 ]; then
|
||||
echo "Test with basic model passed"
|
||||
else
|
||||
echo "Test with basic model FAILED with exit code: $EXITCODE" >&2
|
||||
fi
|
||||
|
||||
# The trap will handle the container removal and final exit.
|
||||
166
.buildkite/scripts/hardware_ci/run-tpu-v1-test-part2.sh
Executable file
166
.buildkite/scripts/hardware_ci/run-tpu-v1-test-part2.sh
Executable file
@ -0,0 +1,166 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -xu
|
||||
|
||||
|
||||
remove_docker_container() {
|
||||
docker rm -f tpu-test || true;
|
||||
docker rm -f vllm-tpu || true;
|
||||
}
|
||||
|
||||
trap remove_docker_container EXIT
|
||||
|
||||
# Remove the container that might not be cleaned up in the previous run.
|
||||
remove_docker_container
|
||||
|
||||
# Build the docker image.
|
||||
docker build -f docker/Dockerfile.tpu -t vllm-tpu .
|
||||
|
||||
# Set up cleanup.
|
||||
cleanup_docker() {
|
||||
# Get Docker's root directory
|
||||
docker_root=$(docker info -f '{{.DockerRootDir}}')
|
||||
if [ -z "$docker_root" ]; then
|
||||
echo "Failed to determine Docker root directory."
|
||||
exit 1
|
||||
fi
|
||||
echo "Docker root directory: $docker_root"
|
||||
# Check disk usage of the filesystem where Docker's root directory is located
|
||||
disk_usage=$(df "$docker_root" | tail -1 | awk '{print $5}' | sed 's/%//')
|
||||
# Define the threshold
|
||||
threshold=70
|
||||
if [ "$disk_usage" -gt "$threshold" ]; then
|
||||
echo "Disk usage is above $threshold%. Cleaning up Docker images and volumes..."
|
||||
# Remove dangling images (those that are not tagged and not used by any container)
|
||||
docker image prune -f
|
||||
# Remove unused volumes / force the system prune for old images as well.
|
||||
docker volume prune -f && docker system prune --force --filter "until=72h" --all
|
||||
echo "Docker images and volumes cleanup completed."
|
||||
else
|
||||
echo "Disk usage is below $threshold%. No cleanup needed."
|
||||
fi
|
||||
}
|
||||
cleanup_docker
|
||||
|
||||
# For HF_TOKEN.
|
||||
source /etc/environment
|
||||
|
||||
docker run --privileged --net host --shm-size=16G -it \
|
||||
-e "HF_TOKEN=$HF_TOKEN" --name tpu-test \
|
||||
vllm-tpu /bin/bash -c '
|
||||
set -e # Exit immediately if a command exits with a non-zero status.
|
||||
set -u # Treat unset variables as an error.
|
||||
|
||||
echo "--- Starting script inside Docker container ---"
|
||||
|
||||
# Create results directory
|
||||
RESULTS_DIR=$(mktemp -d)
|
||||
# If mktemp fails, set -e will cause the script to exit.
|
||||
echo "Results will be stored in: $RESULTS_DIR"
|
||||
|
||||
# Install dependencies
|
||||
echo "--- Installing Python dependencies ---"
|
||||
python3 -m pip install --progress-bar off git+https://github.com/thuml/depyf.git \
|
||||
&& python3 -m pip install --progress-bar off pytest pytest-asyncio tpu-info \
|
||||
&& python3 -m pip install --progress-bar off lm_eval[api]==0.4.4 \
|
||||
&& python3 -m pip install --progress-bar off hf-transfer
|
||||
echo "--- Python dependencies installed ---"
|
||||
export VLLM_USE_V1=1
|
||||
export VLLM_XLA_CHECK_RECOMPILATION=1
|
||||
export VLLM_XLA_CACHE_PATH=
|
||||
echo "Using VLLM V1"
|
||||
|
||||
echo "--- Hardware Information ---"
|
||||
# tpu-info
|
||||
echo "--- Starting Tests ---"
|
||||
set +e
|
||||
overall_script_exit_code=0
|
||||
|
||||
# --- Test Definitions ---
|
||||
# If a test fails, this function will print logs and will not cause the main script to exit.
|
||||
run_test() {
|
||||
local test_num=$1
|
||||
local test_name=$2
|
||||
local test_command=$3
|
||||
local log_file="$RESULTS_DIR/test_${test_num}.log"
|
||||
local actual_exit_code
|
||||
|
||||
echo "--- TEST_$test_num: Running $test_name ---"
|
||||
|
||||
# Execute the test command.
|
||||
eval "$test_command" > >(tee -a "$log_file") 2> >(tee -a "$log_file" >&2)
|
||||
actual_exit_code=$?
|
||||
|
||||
echo "TEST_${test_num}_COMMAND_EXIT_CODE: $actual_exit_code" # This goes to main log
|
||||
echo "TEST_${test_num}_COMMAND_EXIT_CODE: $actual_exit_code" >> "$log_file" # Also to per-test log
|
||||
|
||||
if [ "$actual_exit_code" -ne 0 ]; then
|
||||
echo "TEST_$test_num ($test_name) FAILED with exit code $actual_exit_code." >&2
|
||||
echo "--- Log for failed TEST_$test_num ($test_name) ---" >&2
|
||||
if [ -f "$log_file" ]; then
|
||||
cat "$log_file" >&2
|
||||
else
|
||||
echo "Log file $log_file not found for TEST_$test_num ($test_name)." >&2
|
||||
fi
|
||||
echo "--- End of log for TEST_$test_num ($test_name) ---" >&2
|
||||
return "$actual_exit_code" # Return the failure code
|
||||
else
|
||||
echo "TEST_$test_num ($test_name) PASSED."
|
||||
return 0 # Return success
|
||||
fi
|
||||
}
|
||||
|
||||
# Helper function to call run_test and update the overall script exit code
|
||||
run_and_track_test() {
|
||||
local test_num_arg="$1"
|
||||
local test_name_arg="$2"
|
||||
local test_command_arg="$3"
|
||||
|
||||
# Run the test
|
||||
run_test "$test_num_arg" "$test_name_arg" "$test_command_arg"
|
||||
local test_specific_exit_code=$?
|
||||
|
||||
# If the test failed, set the overall script exit code to 1
|
||||
if [ "$test_specific_exit_code" -ne 0 ]; then
|
||||
# No need for extra echo here, run_test already logged the failure.
|
||||
overall_script_exit_code=1
|
||||
fi
|
||||
}
|
||||
|
||||
# --- Actual Test Execution ---
|
||||
run_and_track_test 1 "test_struct_output_generate.py" \
|
||||
"HF_HUB_DISABLE_XET=1 python3 -m pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py -k \"not test_structured_output_with_reasoning_matrices\""
|
||||
run_and_track_test 2 "test_moe_pallas.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py"
|
||||
run_and_track_test 3 "test_lora.py" \
|
||||
"VLLM_XLA_CHECK_RECOMPILATION=0 python3 -m pytest -s -v /workspace/vllm/tests/tpu/lora/test_lora.py"
|
||||
run_and_track_test 4 "test_tpu_qkv_linear.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_tpu_qkv_linear.py"
|
||||
run_and_track_test 5 "test_spmd_model_weight_loading.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_spmd_model_weight_loading.py"
|
||||
run_and_track_test 6 "test_kv_cache_update_kernel.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_kv_cache_update_kernel.py"
|
||||
|
||||
# After all tests have been attempted, exit with the overall status.
|
||||
if [ "$overall_script_exit_code" -ne 0 ]; then
|
||||
echo "--- One or more tests FAILED. Overall script exiting with failure code 1. ---"
|
||||
else
|
||||
echo "--- All tests have completed and PASSED. Overall script exiting with success code 0. ---"
|
||||
fi
|
||||
exit "$overall_script_exit_code"
|
||||
' # IMPORTANT: This is the closing single quote for the bash -c "..." command. Ensure it is present and correct.
|
||||
|
||||
# Capture the exit code of the docker run command
|
||||
DOCKER_RUN_EXIT_CODE=$?
|
||||
|
||||
# The trap will run for cleanup.
|
||||
# Exit the main script with the Docker run command's exit code.
|
||||
if [ "$DOCKER_RUN_EXIT_CODE" -ne 0 ]; then
|
||||
echo "Docker run command failed with exit code $DOCKER_RUN_EXIT_CODE."
|
||||
exit "$DOCKER_RUN_EXIT_CODE"
|
||||
else
|
||||
echo "Docker run command completed successfully."
|
||||
exit 0
|
||||
fi
|
||||
# TODO: This test fails because it uses RANDOM_SEED sampling
|
||||
# pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py \
|
||||
@ -62,7 +62,8 @@ echo "Results will be stored in: $RESULTS_DIR"
|
||||
echo "--- Installing Python dependencies ---"
|
||||
python3 -m pip install --progress-bar off git+https://github.com/thuml/depyf.git \
|
||||
&& python3 -m pip install --progress-bar off pytest pytest-asyncio tpu-info \
|
||||
&& python3 -m pip install --progress-bar off lm_eval[api]==0.4.4
|
||||
&& python3 -m pip install --progress-bar off lm_eval[api]==0.4.4 \
|
||||
&& python3 -m pip install --progress-bar off hf-transfer
|
||||
echo "--- Python dependencies installed ---"
|
||||
export VLLM_USE_V1=1
|
||||
export VLLM_XLA_CHECK_RECOMPILATION=1
|
||||
@ -70,7 +71,7 @@ export VLLM_XLA_CACHE_PATH=
|
||||
echo "Using VLLM V1"
|
||||
|
||||
echo "--- Hardware Information ---"
|
||||
tpu-info
|
||||
# tpu-info
|
||||
echo "--- Starting Tests ---"
|
||||
set +e
|
||||
overall_script_exit_code=0
|
||||
@ -134,7 +135,7 @@ run_and_track_test 1 "test_compilation.py" \
|
||||
run_and_track_test 2 "test_basic.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_basic.py"
|
||||
run_and_track_test 3 "test_accuracy.py::test_lm_eval_accuracy_v1_engine" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine"
|
||||
"HF_HUB_DISABLE_XET=1 python3 -m pytest -s -v /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine"
|
||||
run_and_track_test 4 "test_quantization_accuracy.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py"
|
||||
run_and_track_test 5 "examples/offline_inference/tpu.py" \
|
||||
@ -149,18 +150,6 @@ run_and_track_test 9 "test_multimodal.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_multimodal.py"
|
||||
run_and_track_test 10 "test_pallas.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py"
|
||||
run_and_track_test 11 "test_struct_output_generate.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py -k \"not test_structured_output_with_reasoning_matrices\""
|
||||
run_and_track_test 12 "test_moe_pallas.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py"
|
||||
run_and_track_test 13 "test_lora.py" \
|
||||
"VLLM_XLA_CHECK_RECOMPILATION=0 python3 -m pytest -s -v /workspace/vllm/tests/tpu/lora/test_lora.py"
|
||||
run_and_track_test 14 "test_tpu_qkv_linear.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_tpu_qkv_linear.py"
|
||||
run_and_track_test 15 "test_spmd_model_weight_loading.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_spmd_model_weight_loading.py"
|
||||
run_and_track_test 16 "test_kv_cache_update_kernel.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_kv_cache_update_kernel.py"
|
||||
|
||||
# After all tests have been attempted, exit with the overall status.
|
||||
if [ "$overall_script_exit_code" -ne 0 ]; then
|
||||
|
||||
@ -11,8 +11,8 @@ container_name="xpu_${BUILDKITE_COMMIT}_$(tr -dc A-Za-z0-9 < /dev/urandom | head
|
||||
docker build -t ${image_name} -f docker/Dockerfile.xpu .
|
||||
|
||||
# Setup cleanup
|
||||
remove_docker_container() {
|
||||
docker rm -f "${container_name}" || true;
|
||||
remove_docker_container() {
|
||||
docker rm -f "${container_name}" || true;
|
||||
docker image rm -f "${image_name}" || true;
|
||||
docker system prune -f || true;
|
||||
}
|
||||
@ -26,7 +26,18 @@ docker run \
|
||||
--name "${container_name}" \
|
||||
"${image_name}" \
|
||||
sh -c '
|
||||
VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m
|
||||
VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m -tp 2
|
||||
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
|
||||
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
|
||||
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
|
||||
cd tests
|
||||
pytest -v -s v1/core
|
||||
pytest -v -s v1/engine
|
||||
pytest -v -s v1/sample --ignore=v1/sample/test_logprobs.py --ignore=v1/sample/test_logprobs_e2e.py
|
||||
pytest -v -s v1/worker --ignore=v1/worker/test_gpu_model_runner.py
|
||||
pytest -v -s v1/structured_output
|
||||
pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_eagle.py
|
||||
pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py
|
||||
pytest -v -s v1/test_serial_utils.py
|
||||
pytest -v -s v1/test_utils.py
|
||||
pytest -v -s v1/test_metrics_reader.py
|
||||
'
|
||||
|
||||
@ -11,10 +11,10 @@ cd "$(dirname "${BASH_SOURCE[0]}")/../.."
|
||||
(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
|
||||
|
||||
# run python-based benchmarks and upload the result to buildkite
|
||||
python3 benchmarks/benchmark_latency.py --output-json latency_results.json 2>&1 | tee benchmark_latency.txt
|
||||
vllm bench latency --output-json latency_results.json 2>&1 | tee benchmark_latency.txt
|
||||
bench_latency_exit_code=$?
|
||||
|
||||
python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 --output-json throughput_results.json 2>&1 | tee benchmark_throughput.txt
|
||||
vllm bench throughput --input-len 256 --output-len 256 --output-json throughput_results.json 2>&1 | tee benchmark_throughput.txt
|
||||
bench_throughput_exit_code=$?
|
||||
|
||||
# run server-based benchmarks and upload the result to buildkite
|
||||
@ -24,7 +24,7 @@ wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/r
|
||||
|
||||
# wait for server to start, timeout after 600 seconds
|
||||
timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
|
||||
python3 benchmarks/benchmark_serving.py \
|
||||
vllm bench serve \
|
||||
--backend vllm \
|
||||
--dataset-name sharegpt \
|
||||
--dataset-path ./ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||
|
||||
@ -22,16 +22,6 @@ trap remove_docker_container EXIT
|
||||
# Remove the container that might not be cleaned up in the previous run.
|
||||
remove_docker_container
|
||||
|
||||
# Build docker image.
|
||||
# TODO: build the image outside the script and share the image with other
|
||||
# tpu test if building time is too long.
|
||||
DOCKER_BUILDKIT=1 docker build \
|
||||
--build-arg max_jobs=16 \
|
||||
--build-arg USE_SCCACHE=1 \
|
||||
--build-arg GIT_REPO_CHECK=0 \
|
||||
--tag vllm/vllm-tpu-bm \
|
||||
--progress plain -f docker/Dockerfile.tpu .
|
||||
|
||||
LOG_ROOT=$(mktemp -d)
|
||||
# If mktemp fails, set -e will cause the script to exit.
|
||||
echo "Results will be stored in: $LOG_ROOT"
|
||||
|
||||
14
.buildkite/scripts/tpu/quantized_v6e_1.env
Normal file
14
.buildkite/scripts/tpu/quantized_v6e_1.env
Normal file
@ -0,0 +1,14 @@
|
||||
# Environment config
|
||||
TEST_NAME=llama8bw8a8
|
||||
CONTAINER_NAME=vllm-tpu
|
||||
|
||||
# vllm config
|
||||
MODEL=RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8
|
||||
MAX_NUM_SEQS=128
|
||||
MAX_NUM_BATCHED_TOKENS=1024
|
||||
TENSOR_PARALLEL_SIZE=1
|
||||
MAX_MODEL_LEN=2048
|
||||
DOWNLOAD_DIR=/mnt/disks/persist
|
||||
EXPECTED_THROUGHPUT=10.0
|
||||
INPUT_LEN=1800
|
||||
OUTPUT_LEN=128
|
||||
@ -77,7 +77,7 @@ done
|
||||
echo "run benchmark test..."
|
||||
echo "logging to $BM_LOG"
|
||||
echo
|
||||
python benchmarks/benchmark_serving.py \
|
||||
vllm bench serve \
|
||||
--backend vllm \
|
||||
--model $MODEL \
|
||||
--dataset-name sonnet \
|
||||
|
||||
@ -117,7 +117,7 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s core
|
||||
|
||||
- label: Entrypoints Test # 40min
|
||||
- label: Entrypoints Test (LLM) # 40min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
fast_check: true
|
||||
@ -125,8 +125,6 @@ steps:
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/entrypoints/llm
|
||||
- tests/entrypoints/openai
|
||||
- tests/entrypoints/test_chat_utils
|
||||
- tests/entrypoints/offline_mode
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
@ -135,9 +133,21 @@ steps:
|
||||
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
|
||||
- VLLM_USE_V1=0 pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process
|
||||
- VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
|
||||
|
||||
- label: Entrypoints Test (API Server) # 40min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
fast_check: true
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/entrypoints/openai
|
||||
- tests/entrypoints/test_chat_utils
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/
|
||||
- pytest -v -s entrypoints/test_chat_utils.py
|
||||
- VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
|
||||
|
||||
- label: Distributed Tests (4 GPUs) # 10min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
@ -149,12 +159,14 @@ steps:
|
||||
- tests/distributed/test_utils
|
||||
- tests/distributed/test_pynccl
|
||||
- tests/distributed/test_events
|
||||
- tests/spec_decode/e2e/test_integration_dist_tp4
|
||||
- tests/compile/test_basic_correctness
|
||||
- examples/offline_inference/rlhf.py
|
||||
- examples/offline_inference/rlhf_colocate.py
|
||||
- tests/examples/offline_inference/data_parallel.py
|
||||
- tests/v1/test_async_llm_dp.py
|
||||
- tests/v1/test_external_lb_dp.py
|
||||
- tests/v1/test_internal_lb_dp.py
|
||||
- tests/v1/test_hybrid_lb_dp.py
|
||||
- tests/v1/engine/test_engine_core_client.py
|
||||
commands:
|
||||
# test with tp=2 and external_dp=2
|
||||
@ -163,14 +175,16 @@ steps:
|
||||
# test with tp=2 and pp=2
|
||||
- PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
|
||||
# test with internal dp
|
||||
- python3 ../examples/offline_inference/data_parallel.py
|
||||
- python3 ../examples/offline_inference/data_parallel.py --enforce-eager
|
||||
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
||||
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py
|
||||
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/test_internal_lb_dp.py
|
||||
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/test_hybrid_lb_dp.py
|
||||
- pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp
|
||||
- pytest -v -s distributed/test_utils.py
|
||||
- pytest -v -s compile/test_basic_correctness.py
|
||||
- pytest -v -s distributed/test_pynccl.py
|
||||
- pytest -v -s distributed/test_events.py
|
||||
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
|
||||
# TODO: create a dedicated test section for multi-GPU example tests
|
||||
# when we have multiple distributed example tests
|
||||
- pushd ../examples/offline_inference
|
||||
@ -225,7 +239,7 @@ steps:
|
||||
working_dir: "/vllm-workspace/tests" # optional
|
||||
|
||||
- label: Engine Test # 10min
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/engine
|
||||
@ -254,6 +268,7 @@ steps:
|
||||
- pytest -v -s v1/structured_output
|
||||
- pytest -v -s v1/spec_decode
|
||||
- pytest -v -s v1/kv_connector/unit
|
||||
- pytest -v -s v1/metrics
|
||||
- pytest -v -s v1/test_serial_utils.py
|
||||
- pytest -v -s v1/test_utils.py
|
||||
- pytest -v -s v1/test_oracle.py
|
||||
@ -262,11 +277,11 @@ steps:
|
||||
# VLLM_USE_FLASHINFER_SAMPLER or not on H100.
|
||||
- pytest -v -s v1/e2e
|
||||
# Integration test for streaming correctness (requires special branch).
|
||||
- pip install -U git+https://github.com/robertgshaw2-neuralmagic/lm-evaluation-harness.git@streaming-api
|
||||
- pip install -U git+https://github.com/robertgshaw2-redhat/lm-evaluation-harness.git@streaming-api
|
||||
- pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine
|
||||
|
||||
- label: Examples Test # 25min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
working_dir: "/vllm-workspace/examples"
|
||||
source_file_dependencies:
|
||||
- vllm/entrypoints
|
||||
@ -280,7 +295,7 @@ steps:
|
||||
- python3 offline_inference/llm_engine_example.py
|
||||
- python3 offline_inference/audio_language.py --seed 0
|
||||
- python3 offline_inference/vision_language.py --seed 0
|
||||
- python3 offline_inference/vision_language_embedding.py --seed 0
|
||||
- python3 offline_inference/vision_language_pooling.py --seed 0
|
||||
- python3 offline_inference/vision_language_multi_image.py --seed 0
|
||||
- VLLM_USE_V1=0 python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
|
||||
- python3 offline_inference/encoder_decoder.py
|
||||
@ -300,7 +315,7 @@ steps:
|
||||
|
||||
|
||||
- label: Platform Tests (CUDA)
|
||||
mirror_hardwares: [amdexperimental]
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/cuda
|
||||
@ -318,19 +333,8 @@ steps:
|
||||
- pytest -v -s samplers
|
||||
- VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers
|
||||
|
||||
- label: Speculative decoding tests # 40min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- vllm/spec_decode
|
||||
- tests/spec_decode
|
||||
- vllm/model_executor/models/eagle.py
|
||||
commands:
|
||||
- pytest -v -s spec_decode/e2e/test_multistep_correctness.py
|
||||
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py --ignore=spec_decode/e2e/test_mtp_correctness.py
|
||||
- pytest -v -s spec_decode/e2e/test_eagle_correctness.py
|
||||
|
||||
- label: LoRA Test %N # 15min each
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- vllm/lora
|
||||
- tests/lora
|
||||
@ -338,7 +342,7 @@ steps:
|
||||
parallelism: 4
|
||||
|
||||
- label: PyTorch Compilation Unit Tests
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -382,7 +386,7 @@ steps:
|
||||
- pytest -v -s kernels/core
|
||||
|
||||
- label: Kernels Attention Test %N
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- csrc/attention/
|
||||
- vllm/attention
|
||||
@ -393,7 +397,7 @@ steps:
|
||||
parallelism: 2
|
||||
|
||||
- label: Kernels Quantization Test %N
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- csrc/quantization/
|
||||
- vllm/model_executor/layers/quantization
|
||||
@ -412,7 +416,7 @@ steps:
|
||||
- pytest -v -s kernels/moe
|
||||
|
||||
- label: Kernels Mamba Test
|
||||
mirror_hardwares: [amdexperimental]
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
source_file_dependencies:
|
||||
- csrc/mamba/
|
||||
- tests/kernels/mamba
|
||||
@ -434,7 +438,6 @@ steps:
|
||||
|
||||
- label: Model Executor Test
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
soft_fail: true
|
||||
source_file_dependencies:
|
||||
- vllm/model_executor
|
||||
- tests/model_executor
|
||||
@ -491,7 +494,7 @@ steps:
|
||||
- pytest -s entrypoints/openai/correctness/
|
||||
|
||||
- label: Encoder Decoder tests # 5min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/encoder_decoder
|
||||
@ -499,7 +502,7 @@ steps:
|
||||
- pytest -v -s encoder_decoder
|
||||
|
||||
- label: OpenAI-Compatible Tool Use # 20 min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
fast_check: false
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -512,7 +515,7 @@ steps:
|
||||
##### models test #####
|
||||
|
||||
- label: Basic Models Test # 24min
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -536,6 +539,17 @@ steps:
|
||||
- pip freeze | grep -E 'torch'
|
||||
- pytest -v -s models/language -m core_model
|
||||
|
||||
- label: Language Models Test (Hybrid) # 35 min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/models/language/generation
|
||||
commands:
|
||||
# Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
|
||||
- pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8'
|
||||
- pytest -v -s models/language/generation -m hybrid_model
|
||||
|
||||
- label: Language Models Test (Extended Generation) # 1hr20min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
optional: true
|
||||
@ -545,7 +559,7 @@ steps:
|
||||
commands:
|
||||
# Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
|
||||
- pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8'
|
||||
- pytest -v -s models/language/generation -m 'not core_model'
|
||||
- pytest -v -s models/language/generation -m '(not core_model) and (not hybrid_model)'
|
||||
|
||||
- label: Language Models Test (Extended Pooling) # 36min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
@ -590,7 +604,7 @@ steps:
|
||||
- pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=0) and not core_model'
|
||||
|
||||
- label: Multi-Modal Models Test (Extended) 3
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
optional: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -600,7 +614,7 @@ steps:
|
||||
- pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=1) and not core_model'
|
||||
|
||||
- label: Quantized Models Test
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- vllm/model_executor/layers/quantization
|
||||
- tests/models/quantization
|
||||
@ -617,6 +631,18 @@ steps:
|
||||
# e.g. pytest -v -s models/encoder_decoder/vision_language/test_mllama.py
|
||||
# *To avoid merge conflicts, remember to REMOVE (not just comment out) them before merging the PR*
|
||||
|
||||
- label: Transformers Nightly Models Test
|
||||
working_dir: "/vllm-workspace/"
|
||||
optional: true
|
||||
commands:
|
||||
- pip install --upgrade git+https://github.com/huggingface/transformers
|
||||
- pytest -v -s tests/models/test_initialization.py
|
||||
- pytest -v -s tests/models/multimodal/processing/
|
||||
- pytest -v -s tests/models/multimodal/test_mapping.py
|
||||
- python3 examples/offline_inference/basic/chat.py
|
||||
- python3 examples/offline_inference/audio_language.py --model-type whisper
|
||||
- python3 examples/offline_inference/vision_language.py --model-type qwen2_5_vl
|
||||
|
||||
##### 1 GPU test #####
|
||||
##### multi gpus test #####
|
||||
|
||||
@ -671,10 +697,12 @@ steps:
|
||||
- vllm/worker/model_runner.py
|
||||
- entrypoints/llm/test_collective_rpc.py
|
||||
- tests/v1/test_async_llm_dp.py
|
||||
- tests/v1/test_external_lb_dp.py
|
||||
- tests/v1/entrypoints/openai/test_multi_api_servers.py
|
||||
- vllm/v1/engine/
|
||||
commands:
|
||||
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
||||
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py
|
||||
- DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py
|
||||
- pytest -v -s entrypoints/llm/test_collective_rpc.py
|
||||
- pytest -v -s ./compile/test_basic_correctness.py
|
||||
@ -689,10 +717,10 @@ steps:
|
||||
- pytest -v -s distributed/test_sequence_parallel.py
|
||||
# this test fails consistently.
|
||||
# TODO: investigate and fix
|
||||
# - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
|
||||
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
||||
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/test_disagg.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown
|
||||
- pytest -v -s models/multimodal/generation/test_maverick.py
|
||||
|
||||
- label: Plugin Tests (2 GPUs) # 40min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
|
||||
6
.gemini/config.yaml
Normal file
6
.gemini/config.yaml
Normal file
@ -0,0 +1,6 @@
|
||||
# https://developers.google.com/gemini-code-assist/docs/customize-gemini-behavior-github
|
||||
have_fun: false # Just review the code
|
||||
code_review:
|
||||
comment_severity_threshold: HIGH # Reduce quantity of comments
|
||||
pull_request_opened:
|
||||
summary: false # Don't summarize the PR in a separate comment
|
||||
16
.github/CODEOWNERS
vendored
16
.github/CODEOWNERS
vendored
@ -16,7 +16,8 @@
|
||||
/vllm/lora @jeejeelee
|
||||
/vllm/reasoning @aarnphm
|
||||
/vllm/entrypoints @aarnphm
|
||||
CMakeLists.txt @tlrmchlsmth
|
||||
/vllm/compilation @zou3519 @youkaichao @ProExpertProg
|
||||
CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
||||
|
||||
# Any change to the VllmConfig changes can have a large user-facing impact,
|
||||
# so spam a lot of people
|
||||
@ -42,7 +43,6 @@ CMakeLists.txt @tlrmchlsmth
|
||||
/tests/multimodal @DarkLight1337 @ywang96
|
||||
/tests/prefix_caching @comaniac @KuntaiDu
|
||||
/tests/quantization @mgoin @robertgshaw2-redhat
|
||||
/tests/spec_decode @njhill @LiuXiaoxuanPKU
|
||||
/tests/test_inputs.py @DarkLight1337 @ywang96
|
||||
/tests/v1/entrypoints/llm/test_struct_output_generate.py @mgoin @russellb @aarnphm
|
||||
/tests/v1/structured_output @mgoin @russellb @aarnphm
|
||||
@ -52,3 +52,15 @@ CMakeLists.txt @tlrmchlsmth
|
||||
# Docs
|
||||
/docs @hmellor
|
||||
mkdocs.yaml @hmellor
|
||||
|
||||
# CPU
|
||||
/vllm/v1/worker/^cpu @bigPYJ1151
|
||||
/csrc/cpu @bigPYJ1151
|
||||
/vllm/platforms/cpu.py @bigPYJ1151
|
||||
/cmake/cpu_extension.cmake @bigPYJ1151
|
||||
/docker/Dockerfile.cpu @bigPYJ1151
|
||||
|
||||
# Intel GPU
|
||||
/vllm/v1/worker/^xpu @jikunshang
|
||||
/vllm/platforms/xpu.py @jikunshang
|
||||
/docker/Dockerfile.xpu @jikunshang
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/750-RFC.yml
vendored
2
.github/ISSUE_TEMPLATE/750-RFC.yml
vendored
@ -46,7 +46,7 @@ body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
Thanks for contributing 🎉!
|
||||
Thanks for contributing 🎉! The vLLM core team hosts a biweekly RFC review session at 9:30AM Pacific Time, while most RFCs can be discussed online, you can optionally sign up for a slot to discuss your RFC online [here](https://docs.google.com/document/d/1CiLVBZeIVfR7_PNAKVSusxpceywkoOOB78qoWqHvSZc/edit).
|
||||
- type: checkboxes
|
||||
id: askllm
|
||||
attributes:
|
||||
|
||||
38
.github/mergify.yml
vendored
38
.github/mergify.yml
vendored
@ -27,6 +27,22 @@ pull_request_rules:
|
||||
add:
|
||||
- ci/build
|
||||
|
||||
- name: label-deepseek
|
||||
description: Automatically apply deepseek label
|
||||
conditions:
|
||||
- or:
|
||||
- files~=^examples/.*deepseek.*\.py
|
||||
- files~=^tests/.*deepseek.*\.py
|
||||
- files~=^vllm/entrypoints/openai/tool_parsers/.*deepseek.*\.py
|
||||
- files~=^vllm/model_executor/models/.*deepseek.*\.py
|
||||
- files~=^vllm/reasoning/.*deepseek.*\.py
|
||||
- files~=^vllm/transformers_utils/.*deepseek.*\.py
|
||||
- title~=(?i)DeepSeek
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
- deepseek
|
||||
|
||||
- name: label-frontend
|
||||
description: Automatically apply frontend label
|
||||
conditions:
|
||||
@ -58,14 +74,23 @@ pull_request_rules:
|
||||
- files~=^vllm/multimodal/
|
||||
- files~=^tests/multimodal/
|
||||
- files~=^tests/models/multimodal/
|
||||
- files~=^tests/models/*/audio_language/
|
||||
- files~=^tests/models/*/vision_language/
|
||||
- files=tests/models/test_vision.py
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
- multi-modality
|
||||
|
||||
- name: label-new-model
|
||||
description: Automatically apply new-model label
|
||||
conditions:
|
||||
- and:
|
||||
- files~=^vllm/model_executor/models/
|
||||
- files=vllm/model_executor/models/registry.py
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
- new-model
|
||||
|
||||
- name: label-performance
|
||||
description: Automatically apply performance label
|
||||
conditions:
|
||||
@ -139,9 +164,12 @@ pull_request_rules:
|
||||
description: Automatically apply speculative-decoding label
|
||||
conditions:
|
||||
- or:
|
||||
- files~=^vllm/spec_decode/
|
||||
- files=vllm/model_executor/layers/spec_decode_base_sampler.py
|
||||
- files~=^tests/spec_decode/
|
||||
- files~=^vllm/v1/spec_decode/
|
||||
- files~=^tests/v1/spec_decode/
|
||||
- files~=^examples/.*(spec_decode|mlpspeculator|eagle|speculation).*\.py
|
||||
- files~=^vllm/model_executor/models/.*eagle.*\.py
|
||||
- files=vllm/model_executor/models/mlp_speculator.py
|
||||
- files~=^vllm/transformers_utils/configs/(eagle|medusa|mlp_speculator)\.py
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
|
||||
4
.github/workflows/lint-and-deploy.yaml
vendored
4
.github/workflows/lint-and-deploy.yaml
vendored
@ -7,7 +7,7 @@ permissions:
|
||||
|
||||
jobs:
|
||||
lint-and-deploy:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-24.04-arm
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
@ -68,7 +68,7 @@ jobs:
|
||||
export AWS_ACCESS_KEY_ID=minioadmin
|
||||
export AWS_SECRET_ACCESS_KEY=minioadmin
|
||||
sleep 30 && kubectl -n ns-vllm logs -f "$(kubectl -n ns-vllm get pods | awk '/deployment/ {print $1;exit}')" &
|
||||
helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env"
|
||||
helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set image.env[2].name=VLLM_CPU_CI_ENV --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string image.env[2].value="1" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env"
|
||||
|
||||
- name: curl test
|
||||
run: |
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@ -146,6 +146,7 @@ venv.bak/
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
docs/argparse
|
||||
docs/examples
|
||||
|
||||
# mypy
|
||||
|
||||
@ -21,7 +21,7 @@ repos:
|
||||
- id: ruff-format
|
||||
files: ^(.buildkite|benchmarks|examples)/.*
|
||||
- repo: https://github.com/crate-ci/typos
|
||||
rev: v1.32.0
|
||||
rev: v1.34.0
|
||||
hooks:
|
||||
- id: typos
|
||||
- repo: https://github.com/PyCQA/isort
|
||||
@ -160,10 +160,17 @@ repos:
|
||||
types: [python]
|
||||
pass_filenames: false
|
||||
additional_dependencies: [pathspec, regex]
|
||||
- id: validate-config
|
||||
name: Validate configuration has default values and that each field has a docstring
|
||||
entry: python tools/validate_config.py
|
||||
language: python
|
||||
types: [python]
|
||||
pass_filenames: true
|
||||
files: vllm/config.py|tests/test_config.py|vllm/entrypoints/openai/cli_args.py
|
||||
# Keep `suggestion` last
|
||||
- id: suggestion
|
||||
name: Suggestion
|
||||
entry: bash -c 'echo "To bypass pre-commit hooks, add --no-verify to git commit."'
|
||||
entry: bash -c 'echo "To bypass all the pre-commit hooks, add --no-verify to git commit. To skip a specific hook, prefix the commit command with SKIP=<hook-id>."'
|
||||
language: system
|
||||
verbose: true
|
||||
pass_filenames: false
|
||||
|
||||
137
CMakeLists.txt
137
CMakeLists.txt
@ -45,7 +45,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1
|
||||
# requirements.txt files and should be kept consistent. The ROCm torch
|
||||
# versions are derived from docker/Dockerfile.rocm
|
||||
#
|
||||
set(TORCH_SUPPORTED_VERSION_CUDA "2.7.0")
|
||||
set(TORCH_SUPPORTED_VERSION_CUDA "2.7.1")
|
||||
set(TORCH_SUPPORTED_VERSION_ROCM "2.7.0")
|
||||
|
||||
#
|
||||
@ -171,7 +171,6 @@ if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}")
|
||||
endif()
|
||||
|
||||
|
||||
#
|
||||
# Use FetchContent for C++ dependencies that are compiled as part of vLLM's build process.
|
||||
# setup.py will override FETCHCONTENT_BASE_DIR to play nicely with sccache.
|
||||
@ -232,7 +231,6 @@ endif()
|
||||
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/mamba/mamba_ssm/selective_scan_fwd.cu"
|
||||
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
|
||||
"csrc/cache_kernels.cu"
|
||||
"csrc/attention/paged_attention_v1.cu"
|
||||
"csrc/attention/paged_attention_v2.cu"
|
||||
@ -259,7 +257,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
|
||||
|
||||
# Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building.
|
||||
set(CUTLASS_REVISION "v3.9.2" CACHE STRING "CUTLASS revision to use")
|
||||
set(CUTLASS_REVISION "v4.0.0" CACHE STRING "CUTLASS revision to use")
|
||||
|
||||
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
|
||||
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
|
||||
@ -298,7 +296,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu"
|
||||
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
|
||||
"csrc/cutlass_extensions/common.cpp"
|
||||
"csrc/attention/mla/cutlass_mla_entry.cu")
|
||||
"csrc/attention/mla/cutlass_mla_entry.cu"
|
||||
"csrc/quantization/fp8/per_token_group_quant.cu")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${VLLM_EXT_SRC}"
|
||||
@ -393,7 +392,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
|
||||
# CUDA 12.0 or later
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_ARCHS)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu"
|
||||
@ -409,7 +408,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
|
||||
message(STATUS "Building scaled_mm_c3x_sm90 for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_ARCHS)
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c3x_sm90 as CUDA Compiler version is "
|
||||
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
|
||||
"later if you intend on running FP8 quantized models on "
|
||||
@ -420,10 +419,40 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
||||
# The cutlass_scaled_mm kernels for Geforce Blackwell SM120 (c3x, i.e. CUTLASS 3.x) require
|
||||
# CUDA 12.8 or later
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0;12.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu"
|
||||
)
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM120=1")
|
||||
# Let scaled_mm_c2x know it doesn't need to build these arches
|
||||
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
|
||||
message(STATUS "Building scaled_mm_c3x_sm120 for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c3x_sm120 as CUDA Compiler version is "
|
||||
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
|
||||
"later if you intend on running FP8 quantized models on "
|
||||
"Blackwell.")
|
||||
else()
|
||||
message(STATUS "Not building scaled_mm_c3x_120 as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
||||
# The cutlass_scaled_mm kernels for Blackwell SM100 (c3x, i.e. CUTLASS 3.x)
|
||||
# require CUDA 12.8 or later
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu"
|
||||
@ -438,7 +467,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
|
||||
message(STATUS "Building scaled_mm_c3x_sm100 for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS)
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c3x_sm100 as CUDA Compiler version is "
|
||||
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
|
||||
"later if you intend on running FP8 quantized models on "
|
||||
@ -481,7 +510,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
|
||||
# require CUDA 12.2 or later (and only work on Hopper).
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_ARCHS)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.2 AND SCALED_MM_ARCHS)
|
||||
set(SRCS "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
@ -490,7 +519,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SPARSE_SCALED_MM_C3X=1")
|
||||
message(STATUS "Building sparse_scaled_mm_c3x for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_ARCHS)
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.2 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building sparse_scaled_mm_c3x kernels as CUDA Compiler version is "
|
||||
"not >= 12.2, we recommend upgrading to CUDA 12.2 or later "
|
||||
"if you intend on running FP8 sparse quantized models on Hopper.")
|
||||
@ -502,7 +531,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
# FP4 Archs and flags
|
||||
cuda_archs_loose_intersection(FP4_ARCHS "10.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND FP4_ARCHS)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
|
||||
"csrc/quantization/fp4/nvfp4_experts_quant.cu"
|
||||
@ -523,9 +552,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
# CUTLASS MLA Archs and flags
|
||||
cuda_archs_loose_intersection(MLA_ARCHS "10.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND MLA_ARCHS)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND MLA_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/attention/mla/cutlass_mla_kernels.cu")
|
||||
"csrc/attention/mla/cutlass_mla_kernels.cu"
|
||||
"csrc/attention/mla/sm100_cutlass_mla_kernel.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${MLA_ARCHS}")
|
||||
@ -548,7 +578,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# if it's possible to compile MoE kernels that use its output.
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
|
||||
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu")
|
||||
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm90.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
@ -562,7 +592,27 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"if you intend on running FP8 quantized MoE models on Hopper.")
|
||||
else()
|
||||
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
"in CUDA target architectures.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
|
||||
message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is "
|
||||
"not >= 12.8, we recommend upgrading to CUDA 12.8 or later "
|
||||
"if you intend on running FP8 quantized MoE models on Blackwell.")
|
||||
else()
|
||||
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
|
||||
"in CUDA target architectures.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@ -574,7 +624,37 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${CUTLASS_MOE_DATA_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
endif()
|
||||
message(STATUS "Building moe_data for archs: ${CUTLASS_MOE_DATA_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
|
||||
message(STATUS "Not building moe_data as CUDA Compiler version is "
|
||||
"not >= 12.3, we recommend upgrading to CUDA 12.3 or later "
|
||||
"if you intend on running FP8 quantized MoE models on Hopper or Blackwell.")
|
||||
else()
|
||||
message(STATUS "Not building moe_data as no compatible archs found "
|
||||
"in CUDA target architectures.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
set(SRCS "csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
|
||||
message(STATUS "Building blockwise_scaled_group_mm_sm100 for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building blockwise_scaled_group_mm_sm100 kernels as CUDA Compiler version is "
|
||||
"not >= 12.8, we recommend upgrading to CUDA 12.8 or later "
|
||||
"if you intend on running FP8 quantized MoE models on Blackwell.")
|
||||
else()
|
||||
message(STATUS "Not building blockwise_scaled_group_mm_sm100 as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
#
|
||||
# Machete kernels
|
||||
@ -582,7 +662,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# The machete kernels only work on hopper and require CUDA 12.0 or later.
|
||||
# Only build Machete kernels if we are building for something compatible with sm90a
|
||||
cuda_archs_loose_intersection(MACHETE_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND MACHETE_ARCHS)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND MACHETE_ARCHS)
|
||||
#
|
||||
# For the Machete kernels we automatically generate sources for various
|
||||
# preselected input type pairs and schedules.
|
||||
@ -634,7 +714,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
message(STATUS "Building Machete kernels for archs: ${MACHETE_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0
|
||||
AND MACHETE_ARCHS)
|
||||
message(STATUS "Not building Machete kernels as CUDA Compiler version is "
|
||||
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
|
||||
@ -688,6 +768,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_MOE_EXT_SRC "csrc/moe/moe_wna16.cu")
|
||||
endif()
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
set(MOE_PERMUTE_SRC
|
||||
"csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu"
|
||||
"csrc/moe/moe_permute_unpermute_op.cu")
|
||||
|
||||
list(APPEND VLLM_MOE_EXT_SRC "${MOE_PERMUTE_SRC}")
|
||||
endif()
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${VLLM_MOE_EXT_SRC}"
|
||||
CUDA_ARCHS "${CUDA_ARCHS}")
|
||||
@ -756,17 +844,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
set(MOE_PERMUTE_SRC
|
||||
"csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu"
|
||||
"csrc/moe/moe_permute_unpermute_op.cu")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_PERMUTE_SRC}"
|
||||
CUDA_ARCHS "${MOE_PERMUTE_ARCHS}")
|
||||
|
||||
list(APPEND VLLM_MOE_EXT_SRC "${MOE_PERMUTE_SRC}")
|
||||
endif()
|
||||
message(STATUS "Enabling moe extension.")
|
||||
define_gpu_extension_target(
|
||||
_moe_C
|
||||
|
||||
@ -63,13 +63,11 @@ vLLM is fast with:
|
||||
- Speculative decoding
|
||||
- Chunked prefill
|
||||
|
||||
**Performance benchmark**: We include a performance benchmark at the end of [our blog post](https://blog.vllm.ai/2024/09/05/perf-update.html). It compares the performance of vLLM against other LLM serving engines ([TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM), [SGLang](https://github.com/sgl-project/sglang) and [LMDeploy](https://github.com/InternLM/lmdeploy)). The implementation is under [nightly-benchmarks folder](.buildkite/nightly-benchmarks/) and you can [reproduce](https://github.com/vllm-project/vllm/issues/8176) this benchmark using our one-click runnable script.
|
||||
|
||||
vLLM is flexible and easy to use with:
|
||||
|
||||
- Seamless integration with popular Hugging Face models
|
||||
- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more
|
||||
- Tensor parallelism and pipeline parallelism support for distributed inference
|
||||
- Tensor, pipeline, data and expert parallelism support for distributed inference
|
||||
- Streaming outputs
|
||||
- OpenAI-compatible API server
|
||||
- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, TPU, and AWS Neuron
|
||||
|
||||
33
RELEASE.md
33
RELEASE.md
@ -52,3 +52,36 @@ After branch cut, we approach finalizing the release branch with clear criteria
|
||||
* Release branch specific changes (e.g. change version identifiers or CI fixes)
|
||||
|
||||
Please note: **No feature work allowed for cherry picks**. All PRs that are considered for cherry-picks need to be merged on trunk, the only exception are Release branch specific changes.
|
||||
|
||||
## Manual validations
|
||||
|
||||
### E2E Performance Validation
|
||||
|
||||
Before each release, we perform end-to-end performance validation to ensure no regressions are introduced. This validation uses the [vllm-benchmark workflow](https://github.com/pytorch/pytorch-integration-testing/actions/workflows/vllm-benchmark.yml) on PyTorch CI.
|
||||
|
||||
**Current Coverage:**
|
||||
* Models: Llama3, Llama4, and Mixtral
|
||||
* Hardware: NVIDIA H100 and AMD MI300x
|
||||
* *Note: Coverage may change based on new model releases and hardware availability*
|
||||
|
||||
**Performance Validation Process:**
|
||||
|
||||
**Step 1: Get Access**
|
||||
Request write access to the [pytorch/pytorch-integration-testing](https://github.com/pytorch/pytorch-integration-testing) repository to run the benchmark workflow.
|
||||
|
||||
**Step 2: Review Benchmark Setup**
|
||||
Familiarize yourself with the benchmark configurations:
|
||||
* [CUDA setup](https://github.com/pytorch/pytorch-integration-testing/tree/main/vllm-benchmarks/benchmarks/cuda)
|
||||
* [ROCm setup](https://github.com/pytorch/pytorch-integration-testing/tree/main/vllm-benchmarks/benchmarks/rocm)
|
||||
|
||||
**Step 3: Run the Benchmark**
|
||||
Navigate to the [vllm-benchmark workflow](https://github.com/pytorch/pytorch-integration-testing/actions/workflows/vllm-benchmark.yml) and configure:
|
||||
* **vLLM branch**: Set to the release branch (e.g., `releases/v0.9.2`)
|
||||
* **vLLM commit**: Set to the RC commit hash
|
||||
|
||||
**Step 4: Review Results**
|
||||
Once the workflow completes, benchmark results will be available on the [vLLM benchmark dashboard](https://hud.pytorch.org/benchmark/llms?repoName=vllm-project%2Fvllm) under the corresponding branch and commit.
|
||||
|
||||
**Step 5: Performance Comparison**
|
||||
Compare the current results against the previous release to verify no performance regressions have occurred. Here is an
|
||||
example of [v0.9.1 vs v0.9.2](https://hud.pytorch.org/benchmark/llms?startTime=Thu%2C%2017%20Apr%202025%2021%3A43%3A50%20GMT&stopTime=Wed%2C%2016%20Jul%202025%2021%3A43%3A50%20GMT&granularity=week&lBranch=releases/v0.9.1&lCommit=b6553be1bc75f046b00046a4ad7576364d03c835&rBranch=releases/v0.9.2&rCommit=a5dd03c1ebc5e4f56f3c9d3dc0436e9c582c978f&repoName=vllm-project%2Fvllm&benchmarkName=&modelName=All%20Models&backendName=All%20Backends&modeName=All%20Modes&dtypeName=All%20DType&deviceName=All%20Devices&archName=All%20Platforms).
|
||||
|
||||
@ -98,7 +98,7 @@ Then run the benchmarking script
|
||||
```bash
|
||||
# download dataset
|
||||
# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
python3 vllm/benchmarks/benchmark_serving.py \
|
||||
vllm bench serve \
|
||||
--backend vllm \
|
||||
--model NousResearch/Hermes-3-Llama-3.1-8B \
|
||||
--endpoint /v1/completions \
|
||||
@ -111,25 +111,25 @@ If successful, you will see the following output
|
||||
|
||||
```
|
||||
============ Serving Benchmark Result ============
|
||||
Successful requests: 10
|
||||
Benchmark duration (s): 5.78
|
||||
Total input tokens: 1369
|
||||
Total generated tokens: 2212
|
||||
Request throughput (req/s): 1.73
|
||||
Output token throughput (tok/s): 382.89
|
||||
Total Token throughput (tok/s): 619.85
|
||||
Successful requests: 10
|
||||
Benchmark duration (s): 5.78
|
||||
Total input tokens: 1369
|
||||
Total generated tokens: 2212
|
||||
Request throughput (req/s): 1.73
|
||||
Output token throughput (tok/s): 382.89
|
||||
Total Token throughput (tok/s): 619.85
|
||||
---------------Time to First Token----------------
|
||||
Mean TTFT (ms): 71.54
|
||||
Median TTFT (ms): 73.88
|
||||
P99 TTFT (ms): 79.49
|
||||
Mean TTFT (ms): 71.54
|
||||
Median TTFT (ms): 73.88
|
||||
P99 TTFT (ms): 79.49
|
||||
-----Time per Output Token (excl. 1st token)------
|
||||
Mean TPOT (ms): 7.91
|
||||
Median TPOT (ms): 7.96
|
||||
P99 TPOT (ms): 8.03
|
||||
Mean TPOT (ms): 7.91
|
||||
Median TPOT (ms): 7.96
|
||||
P99 TPOT (ms): 8.03
|
||||
---------------Inter-token Latency----------------
|
||||
Mean ITL (ms): 7.74
|
||||
Median ITL (ms): 7.70
|
||||
P99 ITL (ms): 8.39
|
||||
Mean ITL (ms): 7.74
|
||||
Median ITL (ms): 7.70
|
||||
P99 ITL (ms): 8.39
|
||||
==================================================
|
||||
```
|
||||
|
||||
@ -141,7 +141,7 @@ If the dataset you want to benchmark is not supported yet in vLLM, even then you
|
||||
{"prompt": "What is the capital of India?"}
|
||||
{"prompt": "What is the capital of Iran?"}
|
||||
{"prompt": "What is the capital of China?"}
|
||||
```
|
||||
```
|
||||
|
||||
```bash
|
||||
# start server
|
||||
@ -150,7 +150,7 @@ VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.1-8B-Instruct --disable-log-requests
|
||||
|
||||
```bash
|
||||
# run benchmarking script
|
||||
python3 benchmarks/benchmark_serving.py --port 9001 --save-result --save-detailed \
|
||||
vllm bench serve --port 9001 --save-result --save-detailed \
|
||||
--backend vllm \
|
||||
--model meta-llama/Llama-3.1-8B-Instruct \
|
||||
--endpoint /v1/completions \
|
||||
@ -174,7 +174,7 @@ vllm serve Qwen/Qwen2-VL-7B-Instruct --disable-log-requests
|
||||
```
|
||||
|
||||
```bash
|
||||
python3 vllm/benchmarks/benchmark_serving.py \
|
||||
vllm bench serve \
|
||||
--backend openai-chat \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--endpoint /v1/chat/completions \
|
||||
@ -194,7 +194,7 @@ VLLM_USE_V1=1 vllm serve meta-llama/Meta-Llama-3-8B-Instruct \
|
||||
```
|
||||
|
||||
``` bash
|
||||
python3 benchmarks/benchmark_serving.py \
|
||||
vllm bench serve \
|
||||
--model meta-llama/Meta-Llama-3-8B-Instruct \
|
||||
--dataset-name hf \
|
||||
--dataset-path likaixin/InstructCoder \
|
||||
@ -210,7 +210,7 @@ vllm serve Qwen/Qwen2-VL-7B-Instruct --disable-log-requests
|
||||
**`lmms-lab/LLaVA-OneVision-Data`**
|
||||
|
||||
```bash
|
||||
python3 vllm/benchmarks/benchmark_serving.py \
|
||||
vllm bench serve \
|
||||
--backend openai-chat \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--endpoint /v1/chat/completions \
|
||||
@ -224,7 +224,7 @@ python3 vllm/benchmarks/benchmark_serving.py \
|
||||
**`Aeala/ShareGPT_Vicuna_unfiltered`**
|
||||
|
||||
```bash
|
||||
python3 vllm/benchmarks/benchmark_serving.py \
|
||||
vllm bench serve \
|
||||
--backend openai-chat \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--endpoint /v1/chat/completions \
|
||||
@ -237,7 +237,7 @@ python3 vllm/benchmarks/benchmark_serving.py \
|
||||
**`AI-MO/aimo-validation-aime`**
|
||||
|
||||
``` bash
|
||||
python3 vllm/benchmarks/benchmark_serving.py \
|
||||
vllm bench serve \
|
||||
--model Qwen/QwQ-32B \
|
||||
--dataset-name hf \
|
||||
--dataset-path AI-MO/aimo-validation-aime \
|
||||
@ -248,7 +248,7 @@ python3 vllm/benchmarks/benchmark_serving.py \
|
||||
**`philschmid/mt-bench`**
|
||||
|
||||
``` bash
|
||||
python3 vllm/benchmarks/benchmark_serving.py \
|
||||
vllm bench serve \
|
||||
--model Qwen/QwQ-32B \
|
||||
--dataset-name hf \
|
||||
--dataset-path philschmid/mt-bench \
|
||||
@ -261,7 +261,7 @@ When using OpenAI-compatible backends such as `vllm`, optional sampling
|
||||
parameters can be specified. Example client command:
|
||||
|
||||
```bash
|
||||
python3 vllm/benchmarks/benchmark_serving.py \
|
||||
vllm bench serve \
|
||||
--backend vllm \
|
||||
--model NousResearch/Hermes-3-Llama-3.1-8B \
|
||||
--endpoint /v1/completions \
|
||||
@ -296,7 +296,7 @@ The following arguments can be used to control the ramp-up:
|
||||
<br/>
|
||||
|
||||
```bash
|
||||
python3 vllm/benchmarks/benchmark_throughput.py \
|
||||
vllm bench throughput \
|
||||
--model NousResearch/Hermes-3-Llama-3.1-8B \
|
||||
--dataset-name sonnet \
|
||||
--dataset-path vllm/benchmarks/sonnet.txt \
|
||||
@ -314,7 +314,7 @@ Total num output tokens: 1500
|
||||
**VisionArena Benchmark for Vision Language Models**
|
||||
|
||||
``` bash
|
||||
python3 vllm/benchmarks/benchmark_throughput.py \
|
||||
vllm bench throughput \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--backend vllm-chat \
|
||||
--dataset-name hf \
|
||||
@ -336,7 +336,7 @@ Total num output tokens: 1280
|
||||
``` bash
|
||||
VLLM_WORKER_MULTIPROC_METHOD=spawn \
|
||||
VLLM_USE_V1=1 \
|
||||
python3 vllm/benchmarks/benchmark_throughput.py \
|
||||
vllm bench throughput \
|
||||
--dataset-name=hf \
|
||||
--dataset-path=likaixin/InstructCoder \
|
||||
--model=meta-llama/Meta-Llama-3-8B-Instruct \
|
||||
@ -360,7 +360,7 @@ Total num output tokens: 204800
|
||||
**`lmms-lab/LLaVA-OneVision-Data`**
|
||||
|
||||
```bash
|
||||
python3 vllm/benchmarks/benchmark_throughput.py \
|
||||
vllm bench throughput \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--backend vllm-chat \
|
||||
--dataset-name hf \
|
||||
@ -373,7 +373,7 @@ python3 vllm/benchmarks/benchmark_throughput.py \
|
||||
**`Aeala/ShareGPT_Vicuna_unfiltered`**
|
||||
|
||||
```bash
|
||||
python3 vllm/benchmarks/benchmark_throughput.py \
|
||||
vllm bench throughput \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--backend vllm-chat \
|
||||
--dataset-name hf \
|
||||
@ -385,7 +385,7 @@ python3 vllm/benchmarks/benchmark_throughput.py \
|
||||
**`AI-MO/aimo-validation-aime`**
|
||||
|
||||
```bash
|
||||
python3 benchmarks/benchmark_throughput.py \
|
||||
vllm bench throughput \
|
||||
--model Qwen/QwQ-32B \
|
||||
--backend vllm \
|
||||
--dataset-name hf \
|
||||
@ -399,7 +399,7 @@ python3 benchmarks/benchmark_throughput.py \
|
||||
``` bash
|
||||
# download dataset
|
||||
# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
python3 vllm/benchmarks/benchmark_throughput.py \
|
||||
vllm bench throughput \
|
||||
--model meta-llama/Llama-2-7b-hf \
|
||||
--backend vllm \
|
||||
--dataset_path <your data path>/ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||
|
||||
141
benchmarks/auto_tune/README.md
Normal file
141
benchmarks/auto_tune/README.md
Normal file
@ -0,0 +1,141 @@
|
||||
# Automated vLLM Server Parameter Tuning
|
||||
|
||||
This script automates the process of finding the optimal server parameter combination (`max-num-seqs` and `max-num-batched-tokens`) to maximize throughput for a vLLM server. It also supports additional constraints such as E2E latency and prefix cache hit rate.
|
||||
|
||||
## Table of Contents
|
||||
- [Prerequisites](#prerequisites)
|
||||
- [Configuration](#configuration)
|
||||
- [How to Run](#how-to-run)
|
||||
- [Example Use Cases](#example-use-cases)
|
||||
- [Output](#output)
|
||||
- [How It Works](#how-it-works)
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Before running the script, please ensure the following steps are completed:
|
||||
|
||||
1. **Clone vLLM & Set Up Branch**: Clone the vLLM repository and check out to your desired branch.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/vllm-project/vllm.git
|
||||
cd vllm
|
||||
# git checkout <your-branch>
|
||||
```
|
||||
|
||||
1. **Install Environment**: Install or update the correct running environment. For TPU usage, activate your `conda` environment and install the corresponding `torch` and `torch_xla` versions.
|
||||
|
||||
2. **Model Configuration**: If you are using a customized model, ensure its configuration files are correctly placed and accessible.
|
||||
|
||||
## Configuration
|
||||
|
||||
You must set the following variables at the top of the script before execution.
|
||||
|
||||
| Variable | Description | Example Value |
|
||||
| --- | --- | --- |
|
||||
| `BASE` | **Required.** The absolute path to the parent directory of your vLLM repository directory. | `"$HOME"` |
|
||||
| `MODEL` | **Required.** The Hugging Face model identifier to be served by vllm. | `"meta-llama/Llama-3.1-8B-Instruct"` |
|
||||
| `SYSTEM`| **Required.** The hardware you are running on. Choices: `TPU` or `GPU`. (For other systems, it might not support saving profiles) | `"TPU"` |
|
||||
| `TP` | **Required.** The tensor-parallelism size. | `1` |
|
||||
| `DOWNLOAD_DIR` | **Required.** Directory to download and load model weights from. | `""` (default download path) |
|
||||
| `INPUT_LEN` | **Required.** Request input length. | `4000` |
|
||||
| `OUTPUT_LEN` | **Required.** Request output length. | `16` |
|
||||
| `MAX_MODEL_LEN` | **Required.** Max model length. | `4096` |
|
||||
| `MIN_CACHE_HIT_PCT` | Prefix cache hit rate in percentage (0-100). Set to `0` to disable. | `60` |
|
||||
| `MAX_LATENCY_ALLOWED_MS` | The maximum allowed P99 end-to-end latency in milliseconds. Set to a very large number (e.g., `100000000000`) to effectively ignore the latency constraint. | `500` |
|
||||
| `NUM_SEQS_LIST` | A space-separated string of `max-num-seqs` values to test. | `"128 256"` |
|
||||
| `NUM_BATCHED_TOKENS_LIST` | A space-separated string of `max-num-batched-tokens` values to test. | `"1024 2048 4096"` |
|
||||
|
||||
**Note**: The default `NUM_SEQS_LIST` and `NUM_BATCHED_TOKENS_LIST` are set for medium-sized inputs/outputs. For very short contexts (e.g., 20 input, 20 output tokens), you may need to test larger values for `max-num-seqs`.
|
||||
|
||||
## How to Run
|
||||
|
||||
1. **Configure**: Edit the script and set the variables in the [Configuration](#configuration) section.
|
||||
2. **Execute**: Run the script. Since the process can take a long time, it is highly recommended to use a terminal multiplexer like `tmux` or `screen` to prevent the script from stopping if your connection is lost.
|
||||
|
||||
```
|
||||
cd <FOLDER_OF_THIS_SCRIPT>
|
||||
bash auto_tune.sh
|
||||
```
|
||||
|
||||
Please note that the `bash auto_tune.sh` command cannot contain full or partial path with keyword `vllm`, otherwise `pkill -f vllm` command will also kill this script itself.
|
||||
|
||||
## Example Use Cases
|
||||
|
||||
Here are a few examples of how to configure the script for different goals:
|
||||
|
||||
### 1. Maximize Throughput (No Latency Constraint)
|
||||
- **Goal**: Find the best `max-num-seqs` and `max-num-batched-tokens` to get the highest possible throughput for 1800 input tokens and 20 output tokens.
|
||||
- **Configuration**:
|
||||
|
||||
```bash
|
||||
INPUT_LEN=1800
|
||||
OUTPUT_LEN=20
|
||||
MAX_MODEL_LEN=2048
|
||||
MIN_CACHE_HIT_PCT=0
|
||||
MAX_LATENCY_ALLOWED_MS=100000000000 # A very large number
|
||||
```
|
||||
|
||||
#### 2. Maximize Throughput with a Latency Requirement
|
||||
- **Goal**: Find the best server parameters when P99 end-to-end latency must be below 500ms.
|
||||
- **Configuration**:
|
||||
|
||||
```bash
|
||||
INPUT_LEN=1800
|
||||
OUTPUT_LEN=20
|
||||
MAX_MODEL_LEN=2048
|
||||
MIN_CACHE_HIT_PCT=0
|
||||
MAX_LATENCY_ALLOWED_MS=500
|
||||
```
|
||||
|
||||
#### 3. Maximize Throughput with Prefix Caching and Latency Requirements
|
||||
- **Goal**: Find the best server parameters assuming a 60% prefix cache hit rate and a latency requirement of 500ms.
|
||||
- **Configuration**:
|
||||
|
||||
```bash
|
||||
INPUT_LEN=1800
|
||||
OUTPUT_LEN=20
|
||||
MAX_MODEL_LEN=2048
|
||||
MIN_CACHE_HIT_PCT=60
|
||||
MAX_LATENCY_ALLOWED_MS=500
|
||||
```
|
||||
|
||||
## Output
|
||||
|
||||
After the script finishes, you will find the results in a new, timestamped directory created inside `$BASE/auto-benchmark/`.
|
||||
|
||||
- **Log Files**: The directory (`$BASE/auto-benchmark/YYYY_MM_DD_HH_MM/`) contains detailed logs for each run:
|
||||
- `vllm_log_...txt`: The log output from the vLLM server for each parameter combination.
|
||||
- `bm_log_...txt`: The log output from the `vllm bench serve` command for each benchmark run.
|
||||
|
||||
- **Final Result Summary**: A file named `result.txt` is created in the log directory. It contains a summary of each tested combination and concludes with the overall best parameters found.
|
||||
|
||||
```
|
||||
# Example result.txt content
|
||||
hash:a1b2c3d4...
|
||||
max_num_seqs: 128, max_num_batched_tokens: 2048, request_rate: 10.0, e2el: 450.5, throughput: 9.8, goodput: 9.8
|
||||
max_num_seqs: 128, max_num_batched_tokens: 4096 does not meet latency requirement 500
|
||||
...
|
||||
best_max_num_seqs: 256, best_num_batched_tokens: 2048, best_throughput: 12.5, profile saved in: /home/user/vllm/auto-benchmark/2024_08_01_10_30/profile
|
||||
```
|
||||
|
||||
If it cannot find the best parameters, the final row will be `best_max_num_seqs: 0, best_num_batched_tokens: 0, best_throughput: 0`. This can be due to either the server not starting properly, or the latency requirement being too strict.
|
||||
|
||||
- **Profiler Trace**: A directory named `profile` is created inside the log directory. It contains the profiler trace file (e.g., `.xplane.pb` for TPU or a `.json` trace for GPU) from the single best-performing run.
|
||||
|
||||
## How It Works
|
||||
|
||||
The script follows a systematic process to find the optimal parameters:
|
||||
|
||||
1. **Find Max GPU Memory Utilization**: The script first determines the highest safe `gpu-memory-utilization` (starting from 0.98 and decreasing) that does not cause an Out-Of-Memory (OOM) error when launching the server. This ensures the benchmark runs use the maximum available memory without crashing.
|
||||
|
||||
2. **Iterate and Benchmark**: It then enters a nested loop, iterating through every combination of `max-num-seqs` and `max-num-batched-tokens` provided in the configuration lists.
|
||||
|
||||
3. **Latency-Aware Throughput Search**: For each parameter combination:
|
||||
- The vLLM server is started.
|
||||
- A benchmark is first run with an infinite request rate (`--request-rate inf`).
|
||||
- If the resulting P99 E2E latency is within the `MAX_LATENCY_ALLOWED_MS` limit, this throughput is considered the maximum for this configuration.
|
||||
- If the latency is too high, the script performs a search by iteratively decreasing the request rate until the latency constraint is met. This finds the highest sustainable throughput for the given parameters and latency requirement.
|
||||
|
||||
4. **Track Best Result**: Throughout the process, the script tracks the parameter combination that has yielded the highest valid throughput so far.
|
||||
|
||||
5. **Profile Collection**: For the best-performing run, the script saves the vLLM profiler output, which can be used for deep-dive performance analysis with tools like TensorBoard.
|
||||
@ -1,45 +1,18 @@
|
||||
#!/bin/bash
|
||||
|
||||
# This script aims to tune the best server parameter combinations to maximize throughput for given requirement.
|
||||
# The current server parameter combination is max_num_seqs and max_num_batched_tokens
|
||||
# It also supports additional requirement: e2e latency and prefix cache.
|
||||
|
||||
# Pre-requisite:
|
||||
# 1. Checkout to your branch, install/ update the correct running env. For TPU, activate conda env and install the corresponding torch, xla version.
|
||||
# 2. If the model is customized, replace the MODEL's config with the customized config.
|
||||
# 3. Set variables (ALL REQUIRED)
|
||||
# BASE: your directory for vllm repo
|
||||
# MODEL: the model served by vllm
|
||||
# SYSTEM: the hardware, choice TPU or GPU, for other systems, "get best profile" might not support.
|
||||
# TP: ways of tensor parallelism
|
||||
# DOWNLOAD_DIR: directory to download and load model weights.
|
||||
# INPUT_LEN: request input len
|
||||
# OUTPUT_LEN: request output len
|
||||
# MIN_CACHE_HIT_PCT: prefix cache rate
|
||||
# MAX_LATENCY_ALLOWED_MS: (e2e) latency requirement. If there's no latency requirement, set it to a large number like 1000000000
|
||||
# NUM_SEQS_LIST: a list of `max-num-seqs` you want to loop with.
|
||||
# NUM_BATCHED_TOKENS_LIST: a list of `max-num-batched-tokens` you want to loop with.
|
||||
# Note that the default NUM_SEQS_LIST and NUM_BATCHED_TOKENS_LIST are set for medium size input/output len, for extra short context (such as 20:20), you might need to include larger numbers in NUM_SEQS_LIST.
|
||||
# 4. Run the script, it might take a long time, you can use tmux to avoid the script stop if disconnection happens.
|
||||
# 5. The final result will be saved in RESULT file.
|
||||
|
||||
|
||||
# Example use cases
|
||||
# 1. Given input_len=1800, output_len=20, what's the best max_num_seqs and max_num_batched_tokens to get highest throughput?
|
||||
# Use INPUT_LEN=1800, OUTPUT_LEN=20, MIN_CACHE_HIT_PCT=0, MAX_LATENCY_ALLOWED_MS=100000000000
|
||||
# 2. If we have latency requirement to be lower than 500ms, what's the best server parameter?
|
||||
# Use INPUT_LEN=1800, OUTPUT_LEN=20, MIN_CACHE_HIT_PCT=0, MAX_LATENCY_ALLOWED_MS=500
|
||||
# 3. If we want to reach 60% prefix cache, what's the best server parameter?
|
||||
# Use INPUT_LEN=1800, OUTPUT_LEN=20, MIN_CACHE_HIT_PCT=60, MAX_LATENCY_ALLOWED_MS=500
|
||||
# This script aims to tune the best server parameter combinations to maximize throughput for given requirement.
|
||||
# See details in README (benchmarks/auto_tune/README.md).
|
||||
|
||||
TAG=$(date +"%Y_%m_%d_%H_%M")
|
||||
BASE=""
|
||||
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
||||
BASE="$SCRIPT_DIR/../../.."
|
||||
MODEL="meta-llama/Llama-3.1-8B-Instruct"
|
||||
SYSTEM="TPU"
|
||||
TP=1
|
||||
DOWNLOAD_DIR=""
|
||||
INPUT_LEN=4000
|
||||
OUTPUT_LEN=16
|
||||
MAX_MODEL_LEN=4096
|
||||
MIN_CACHE_HIT_PCT=0
|
||||
MAX_LATENCY_ALLOWED_MS=100000000000
|
||||
NUM_SEQS_LIST="128 256"
|
||||
@ -65,6 +38,13 @@ current_hash=$(git rev-parse HEAD)
|
||||
echo "hash:$current_hash" >> "$RESULT"
|
||||
echo "current_hash: $current_hash"
|
||||
|
||||
TOTAL_LEN=$((INPUT_LEN + OUTPUT_LEN))
|
||||
RED='\033[0;31m'
|
||||
if (( TOTAL_LEN > MAX_MODEL_LEN )); then
|
||||
echo -e "${RED}FAILED: INPUT_LEN($INPUT_LEN) + OUTPUT_LEN($OUTPUT_LEN) = $TOTAL_LEN, which is > MAX_MODEL_LEN = $MAX_MODEL_LEN.\033[0m" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
best_throughput=0
|
||||
best_max_num_seqs=0
|
||||
best_num_batched_tokens=0
|
||||
@ -76,7 +56,7 @@ start_server() {
|
||||
local max_num_batched_tokens=$3
|
||||
local vllm_log=$4
|
||||
local profile_dir=$5
|
||||
|
||||
|
||||
pkill -f vllm
|
||||
|
||||
VLLM_USE_V1=1 VLLM_SERVER_DEV_MODE=1 VLLM_TORCH_PROFILER_DIR=$profile_dir vllm serve $MODEL \
|
||||
@ -89,13 +69,13 @@ start_server() {
|
||||
--enable-prefix-caching \
|
||||
--load-format dummy \
|
||||
--download-dir "$DOWNLOAD_DIR" \
|
||||
--max-model-len $(( INPUT_LEN+OUTPUT_LEN )) > "$vllm_log" 2>&1 &
|
||||
--max-model-len $MAX_MODEL_LEN > "$vllm_log" 2>&1 &
|
||||
|
||||
# wait for 10 minutes...
|
||||
server_started=0
|
||||
for i in {1..60}; do
|
||||
for i in {1..60}; do
|
||||
RESPONSE=$(curl -s -X GET "http://0.0.0.0:8004/health" -w "%{http_code}" -o /dev/stdout)
|
||||
STATUS_CODE=$(echo "$RESPONSE" | tail -n 1)
|
||||
STATUS_CODE=$(echo "$RESPONSE" | tail -n 1)
|
||||
if [[ "$STATUS_CODE" -eq 200 ]]; then
|
||||
server_started=1
|
||||
break
|
||||
@ -118,10 +98,10 @@ update_best_profile() {
|
||||
selected_profile_file=
|
||||
if [[ "$SYSTEM" == "TPU" ]]; then
|
||||
selected_profile_file="${sorted_paths[$profile_index]}/*.xplane.pb"
|
||||
fi
|
||||
fi
|
||||
if [[ "$SYSTEM" == "GPU" ]]; then
|
||||
selected_profile_file="${sorted_paths[$profile_index]}"
|
||||
fi
|
||||
fi
|
||||
rm -f $PROFILE_PATH/*
|
||||
cp $selected_profile_file $PROFILE_PATH
|
||||
}
|
||||
@ -149,17 +129,18 @@ run_benchmark() {
|
||||
echo "server started."
|
||||
fi
|
||||
echo
|
||||
|
||||
|
||||
echo "run benchmark test..."
|
||||
meet_latency_requirement=0
|
||||
# get a basic qps by using request-rate inf
|
||||
bm_log="$LOG_FOLDER/bm_log_${max_num_seqs}_${max_num_batched_tokens}_requestrate_inf.txt"
|
||||
prefix_len=$(( INPUT_LEN * MIN_CACHE_HIT_PCT / 100 ))
|
||||
python benchmarks/benchmark_serving.py \
|
||||
adjusted_input_len=$(( INPUT_LEN - prefix_len ))
|
||||
vllm bench serve \
|
||||
--backend vllm \
|
||||
--model $MODEL \
|
||||
--dataset-name random \
|
||||
--random-input-len $INPUT_LEN \
|
||||
--random-input-len $adjusted_input_len \
|
||||
--random-output-len $OUTPUT_LEN \
|
||||
--ignore-eos \
|
||||
--disable-tqdm \
|
||||
@ -188,11 +169,11 @@ run_benchmark() {
|
||||
curl -X POST http://0.0.0.0:8004/reset_prefix_cache
|
||||
sleep 5
|
||||
bm_log="$LOG_FOLDER/bm_log_${max_num_seqs}_${max_num_batched_tokens}_requestrate_${request_rate}.txt"
|
||||
python benchmarks/benchmark_serving.py \
|
||||
vllm bench serve \
|
||||
--backend vllm \
|
||||
--model $MODEL \
|
||||
--dataset-name random \
|
||||
--random-input-len $INPUT_LEN \
|
||||
--random-input-len $adjusted_input_len \
|
||||
--random-output-len $OUTPUT_LEN \
|
||||
--ignore-eos \
|
||||
--disable-tqdm \
|
||||
@ -273,4 +254,3 @@ done
|
||||
echo "finish permutations"
|
||||
echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput, profile saved in: $PROFILE_PATH"
|
||||
echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput, profile saved in: $PROFILE_PATH" >> "$RESULT"
|
||||
|
||||
@ -324,6 +324,9 @@ class RandomDataset(BenchmarkDataset):
|
||||
input_low = int(real_input_len * (1 - range_ratio))
|
||||
input_high = int(real_input_len * (1 + range_ratio))
|
||||
output_low = int(output_len * (1 - range_ratio))
|
||||
# Ensure the lower bound for output length is at least 1 to prevent
|
||||
# sampling 0 tokens, which can cause request failures.
|
||||
output_low = max(output_low, 1)
|
||||
output_high = int(output_len * (1 + range_ratio))
|
||||
|
||||
# Add logging for debugging
|
||||
@ -701,6 +704,7 @@ class HuggingFaceDataset(BenchmarkDataset):
|
||||
self,
|
||||
dataset_path: str,
|
||||
dataset_split: str,
|
||||
no_stream: bool = False,
|
||||
dataset_subset: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
@ -708,6 +712,7 @@ class HuggingFaceDataset(BenchmarkDataset):
|
||||
|
||||
self.dataset_split = dataset_split
|
||||
self.dataset_subset = dataset_subset
|
||||
self.load_stream = not no_stream
|
||||
self.load_data()
|
||||
|
||||
def load_data(self) -> None:
|
||||
@ -716,7 +721,7 @@ class HuggingFaceDataset(BenchmarkDataset):
|
||||
self.dataset_path,
|
||||
name=self.dataset_subset,
|
||||
split=self.dataset_split,
|
||||
streaming=True,
|
||||
streaming=self.load_stream,
|
||||
)
|
||||
self.data = self.data.shuffle(seed=self.random_seed)
|
||||
|
||||
|
||||
@ -11,6 +11,7 @@ from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import vllm.envs as envs
|
||||
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||
@ -34,6 +35,10 @@ def save_to_pytorch_benchmark_format(
|
||||
write_to_json(pt_file, pt_records)
|
||||
|
||||
|
||||
@deprecated(
|
||||
"benchmark_latency.py is deprecated and will be removed in a "
|
||||
"future version. Please use 'vllm bench latency' instead.",
|
||||
)
|
||||
def main(args: argparse.Namespace):
|
||||
print(args)
|
||||
|
||||
|
||||
@ -30,7 +30,7 @@ import os
|
||||
import random
|
||||
import time
|
||||
import warnings
|
||||
from collections.abc import AsyncGenerator, Iterable
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal, Optional
|
||||
@ -38,6 +38,7 @@ from typing import Any, Literal, Optional
|
||||
import numpy as np
|
||||
from tqdm.asyncio import tqdm
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from backend_request_func import (
|
||||
ASYNC_REQUEST_FUNCS,
|
||||
@ -73,6 +74,7 @@ from benchmark_dataset import (
|
||||
VisionArenaDataset,
|
||||
)
|
||||
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||
from vllm.benchmarks.serve import get_request
|
||||
|
||||
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
|
||||
|
||||
@ -107,101 +109,6 @@ class BenchmarkMetrics:
|
||||
percentiles_e2el_ms: list[tuple[float, float]]
|
||||
|
||||
|
||||
def _get_current_request_rate(
|
||||
ramp_up_strategy: Optional[Literal["linear", "exponential"]],
|
||||
ramp_up_start_rps: Optional[int],
|
||||
ramp_up_end_rps: Optional[int],
|
||||
request_index: int,
|
||||
total_requests: int,
|
||||
request_rate: float,
|
||||
) -> float:
|
||||
if (
|
||||
ramp_up_strategy
|
||||
and ramp_up_start_rps is not None
|
||||
and ramp_up_end_rps is not None
|
||||
):
|
||||
progress = request_index / max(total_requests - 1, 1)
|
||||
if ramp_up_strategy == "linear":
|
||||
increase = (ramp_up_end_rps - ramp_up_start_rps) * progress
|
||||
return ramp_up_start_rps + increase
|
||||
elif ramp_up_strategy == "exponential":
|
||||
ratio = ramp_up_end_rps / ramp_up_start_rps
|
||||
return ramp_up_start_rps * (ratio**progress)
|
||||
else:
|
||||
raise ValueError(f"Unknown ramp-up strategy: {ramp_up_strategy}")
|
||||
return request_rate
|
||||
|
||||
|
||||
async def get_request(
|
||||
input_requests: list[SampleRequest],
|
||||
request_rate: float,
|
||||
burstiness: float = 1.0,
|
||||
ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None,
|
||||
ramp_up_start_rps: Optional[int] = None,
|
||||
ramp_up_end_rps: Optional[int] = None,
|
||||
) -> AsyncGenerator[tuple[SampleRequest, float], None]:
|
||||
"""
|
||||
Asynchronously generates requests at a specified rate
|
||||
with OPTIONAL burstiness and OPTIONAL ramp-up strategy.
|
||||
|
||||
Args:
|
||||
input_requests:
|
||||
A list of input requests, each represented as a SampleRequest.
|
||||
request_rate:
|
||||
The rate at which requests are generated (requests/s).
|
||||
burstiness (optional):
|
||||
The burstiness factor of the request generation.
|
||||
Only takes effect when request_rate is not inf.
|
||||
Default value is 1, which follows a Poisson process.
|
||||
Otherwise, the request intervals follow a gamma distribution.
|
||||
A lower burstiness value (0 < burstiness < 1) results
|
||||
in more bursty requests, while a higher burstiness value
|
||||
(burstiness > 1) results in a more uniform arrival of requests.
|
||||
ramp_up_strategy (optional):
|
||||
The ramp-up strategy. Can be "linear" or "exponential".
|
||||
If None, uses constant request rate (specified by request_rate).
|
||||
ramp_up_start_rps (optional):
|
||||
The starting request rate for ramp-up.
|
||||
ramp_up_end_rps (optional):
|
||||
The ending request rate for ramp-up.
|
||||
"""
|
||||
assert burstiness > 0, (
|
||||
f"A positive burstiness factor is expected, but given {burstiness}."
|
||||
)
|
||||
# Convert to list to get length for ramp-up calculations
|
||||
if isinstance(input_requests, Iterable) and not isinstance(input_requests, list):
|
||||
input_requests = list(input_requests)
|
||||
|
||||
total_requests = len(input_requests)
|
||||
request_index = 0
|
||||
|
||||
for request in input_requests:
|
||||
current_request_rate = _get_current_request_rate(
|
||||
ramp_up_strategy,
|
||||
ramp_up_start_rps,
|
||||
ramp_up_end_rps,
|
||||
request_index,
|
||||
total_requests,
|
||||
request_rate,
|
||||
)
|
||||
|
||||
yield request, current_request_rate
|
||||
|
||||
request_index += 1
|
||||
|
||||
if current_request_rate == float("inf"):
|
||||
# If the request rate is infinity, then we don't need to wait.
|
||||
continue
|
||||
|
||||
theta = 1.0 / (current_request_rate * burstiness)
|
||||
|
||||
# Sample the request interval from the gamma distribution.
|
||||
# If burstiness is 1, it follows exponential distribution.
|
||||
interval = np.random.gamma(shape=burstiness, scale=theta)
|
||||
# The next request will be sent after the interval.
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
|
||||
def calculate_metrics(
|
||||
input_requests: list[SampleRequest],
|
||||
outputs: list[RequestFuncOutput],
|
||||
@ -551,7 +458,7 @@ async def benchmark(
|
||||
"total_input_tokens": metrics.total_input,
|
||||
"total_output_tokens": metrics.total_output,
|
||||
"request_throughput": metrics.request_throughput,
|
||||
"request_goodput:": metrics.request_goodput if goodput_config_dict else None,
|
||||
"request_goodput": metrics.request_goodput if goodput_config_dict else None,
|
||||
"output_throughput": metrics.output_throughput,
|
||||
"total_token_throughput": metrics.total_token_throughput,
|
||||
"input_lens": [output.prompt_len for output in outputs],
|
||||
@ -687,6 +594,10 @@ def save_to_pytorch_benchmark_format(
|
||||
write_to_json(pt_file, pt_records)
|
||||
|
||||
|
||||
@deprecated(
|
||||
"benchmark_serving.py is deprecated and will be removed in a future "
|
||||
"version. Please use 'vllm bench serve' instead.",
|
||||
)
|
||||
def main(args: argparse.Namespace):
|
||||
print(args)
|
||||
random.seed(args.seed)
|
||||
@ -825,6 +736,7 @@ def main(args: argparse.Namespace):
|
||||
dataset_subset=args.hf_subset,
|
||||
dataset_split=args.hf_split,
|
||||
random_seed=args.seed,
|
||||
no_stream=args.no_stream,
|
||||
).sample(
|
||||
num_requests=args.num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
@ -1033,6 +945,11 @@ def create_argument_parser():
|
||||
help="Path to the sharegpt/sonnet dataset. "
|
||||
"Or the huggingface dataset ID if using HF dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-stream",
|
||||
action="store_true",
|
||||
help="Do not load the dataset in streaming mode.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-concurrency",
|
||||
type=int,
|
||||
|
||||
@ -15,6 +15,7 @@ import torch
|
||||
import uvloop
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from benchmark_dataset import (
|
||||
AIMODataset,
|
||||
@ -167,7 +168,8 @@ async def run_vllm_async(
|
||||
from vllm import SamplingParams
|
||||
|
||||
async with build_async_engine_client_from_engine_args(
|
||||
engine_args, disable_frontend_multiprocessing
|
||||
engine_args,
|
||||
disable_frontend_multiprocessing=disable_frontend_multiprocessing,
|
||||
) as llm:
|
||||
model_config = await llm.get_model_config()
|
||||
assert all(
|
||||
@ -356,6 +358,7 @@ def get_requests(args, tokenizer):
|
||||
elif args.dataset_name == "burstgpt":
|
||||
dataset_cls = BurstGPTDataset
|
||||
elif args.dataset_name == "hf":
|
||||
common_kwargs["no_stream"] = args.no_stream
|
||||
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = VisionArenaDataset
|
||||
common_kwargs["dataset_subset"] = None
|
||||
@ -380,6 +383,10 @@ def get_requests(args, tokenizer):
|
||||
return dataset_cls(**common_kwargs).sample(**sample_kwargs)
|
||||
|
||||
|
||||
@deprecated(
|
||||
"benchmark_throughput.py is deprecated and will be removed in a "
|
||||
"future version. Please use 'vllm bench throughput' instead.",
|
||||
)
|
||||
def main(args: argparse.Namespace):
|
||||
if args.seed is None:
|
||||
args.seed = 0
|
||||
@ -610,6 +617,11 @@ def create_argument_parser():
|
||||
help="Name of the dataset to benchmark on.",
|
||||
default="sharegpt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-stream",
|
||||
action="store_true",
|
||||
help="Do not load the dataset in streaming mode.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
# benchmark the overhead of disaggregated prefill.
|
||||
# methodology:
|
||||
# - send all request to prefill vLLM instance. It will buffer KV cache.
|
||||
# - then send all request to decode instance.
|
||||
# - then send all request to decode instance.
|
||||
# - The TTFT of decode instance is the overhead.
|
||||
|
||||
set -ex
|
||||
@ -12,6 +12,8 @@ kill_gpu_processes() {
|
||||
# kill all processes on GPU.
|
||||
pgrep pt_main_thread | xargs -r kill -9
|
||||
pgrep python3 | xargs -r kill -9
|
||||
# vLLM now names the process with VLLM prefix after https://github.com/vllm-project/vllm/pull/21445
|
||||
pgrep VLLM | xargs -r kill -9
|
||||
sleep 10
|
||||
|
||||
# remove vllm config file
|
||||
@ -61,7 +63,7 @@ benchmark() {
|
||||
--gpu-memory-utilization 0.6 \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' &
|
||||
|
||||
|
||||
|
||||
CUDA_VISIBLE_DEVICES=1 python3 \
|
||||
-m vllm.entrypoints.openai.api_server \
|
||||
@ -76,38 +78,38 @@ benchmark() {
|
||||
wait_for_server 8200
|
||||
|
||||
# let the prefill instance finish prefill
|
||||
python3 ../benchmark_serving.py \
|
||||
--backend vllm \
|
||||
--model $model \
|
||||
--dataset-name $dataset_name \
|
||||
--dataset-path $dataset_path \
|
||||
--sonnet-input-len $input_len \
|
||||
--sonnet-output-len "$output_len" \
|
||||
--sonnet-prefix-len $prefix_len \
|
||||
--num-prompts $num_prompts \
|
||||
--port 8100 \
|
||||
--save-result \
|
||||
--result-dir $results_folder \
|
||||
--result-filename disagg_prefill_tp1.json \
|
||||
--request-rate "inf"
|
||||
vllm bench serve \
|
||||
--backend vllm \
|
||||
--model $model \
|
||||
--dataset-name $dataset_name \
|
||||
--dataset-path $dataset_path \
|
||||
--sonnet-input-len $input_len \
|
||||
--sonnet-output-len "$output_len" \
|
||||
--sonnet-prefix-len $prefix_len \
|
||||
--num-prompts $num_prompts \
|
||||
--port 8100 \
|
||||
--save-result \
|
||||
--result-dir $results_folder \
|
||||
--result-filename disagg_prefill_tp1.json \
|
||||
--request-rate "inf"
|
||||
|
||||
|
||||
# send the request to decode.
|
||||
# The TTFT of this command will be the overhead of disagg prefill impl.
|
||||
python3 ../benchmark_serving.py \
|
||||
--backend vllm \
|
||||
--model $model \
|
||||
--dataset-name $dataset_name \
|
||||
--dataset-path $dataset_path \
|
||||
--sonnet-input-len $input_len \
|
||||
--sonnet-output-len "$output_len" \
|
||||
--sonnet-prefix-len $prefix_len \
|
||||
--num-prompts $num_prompts \
|
||||
--port 8200 \
|
||||
--save-result \
|
||||
--result-dir $results_folder \
|
||||
--result-filename disagg_prefill_tp1_overhead.json \
|
||||
--request-rate "$qps"
|
||||
vllm bench serve \
|
||||
--backend vllm \
|
||||
--model $model \
|
||||
--dataset-name $dataset_name \
|
||||
--dataset-path $dataset_path \
|
||||
--sonnet-input-len $input_len \
|
||||
--sonnet-output-len "$output_len" \
|
||||
--sonnet-prefix-len $prefix_len \
|
||||
--num-prompts $num_prompts \
|
||||
--port 8200 \
|
||||
--save-result \
|
||||
--result-dir $results_folder \
|
||||
--result-filename disagg_prefill_tp1_overhead.json \
|
||||
--request-rate "$qps"
|
||||
kill_gpu_processes
|
||||
|
||||
}
|
||||
|
||||
@ -18,6 +18,8 @@ kill_gpu_processes() {
|
||||
# kill all processes on GPU.
|
||||
pgrep pt_main_thread | xargs -r kill -9
|
||||
pgrep python3 | xargs -r kill -9
|
||||
# vLLM now names the process with VLLM prefix after https://github.com/vllm-project/vllm/pull/21445
|
||||
pgrep VLLM | xargs -r kill -9
|
||||
for port in 8000 8100 8200; do lsof -t -i:$port | xargs -r kill -9; done
|
||||
sleep 1
|
||||
}
|
||||
@ -58,7 +60,7 @@ launch_chunked_prefill() {
|
||||
|
||||
|
||||
launch_disagg_prefill() {
|
||||
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||||
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||||
# disagg prefill
|
||||
CUDA_VISIBLE_DEVICES=0 python3 \
|
||||
-m vllm.entrypoints.openai.api_server \
|
||||
@ -97,20 +99,20 @@ benchmark() {
|
||||
output_len=$2
|
||||
tag=$3
|
||||
|
||||
python3 ../benchmark_serving.py \
|
||||
--backend vllm \
|
||||
--model $model \
|
||||
--dataset-name $dataset_name \
|
||||
--dataset-path $dataset_path \
|
||||
--sonnet-input-len $input_len \
|
||||
--sonnet-output-len "$output_len" \
|
||||
--sonnet-prefix-len $prefix_len \
|
||||
--num-prompts $num_prompts \
|
||||
--port 8000 \
|
||||
--save-result \
|
||||
--result-dir $results_folder \
|
||||
--result-filename "$tag"-qps-"$qps".json \
|
||||
--request-rate "$qps"
|
||||
vllm bench serve \
|
||||
--backend vllm \
|
||||
--model $model \
|
||||
--dataset-name $dataset_name \
|
||||
--dataset-path $dataset_path \
|
||||
--sonnet-input-len $input_len \
|
||||
--sonnet-output-len "$output_len" \
|
||||
--sonnet-prefix-len $prefix_len \
|
||||
--num-prompts $num_prompts \
|
||||
--port 8000 \
|
||||
--save-result \
|
||||
--result-dir $results_folder \
|
||||
--result-filename "$tag"-qps-"$qps".json \
|
||||
--request-rate "$qps"
|
||||
|
||||
sleep 2
|
||||
}
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
|
||||
141
benchmarks/kernels/bench_nvfp4_gemm.py
Normal file
141
benchmarks/kernels/bench_nvfp4_gemm.py
Normal file
@ -0,0 +1,141 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
from weight_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
if not current_platform.has_device_capability(100):
|
||||
raise RuntimeError("NVFP4 requires compute capability of 10.0 (Blackwell)")
|
||||
|
||||
|
||||
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
||||
PROVIDER_CFGS = {
|
||||
"torch-bf16": dict(enabled=True),
|
||||
"nvfp4": dict(no_a_quant=False, enabled=True),
|
||||
"nvfp4-noquant": dict(no_a_quant=True, enabled=True),
|
||||
}
|
||||
|
||||
_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]]
|
||||
|
||||
|
||||
def _quant_weight_nvfp4(b: torch.Tensor, device: str):
|
||||
# Compute global scale for weight
|
||||
b_amax = torch.abs(b).max().to(torch.float32)
|
||||
b_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
|
||||
b_fp4, scale_b_fp4 = ops.scaled_fp4_quant(b, b_global_scale)
|
||||
return b_fp4, scale_b_fp4, b_global_scale
|
||||
|
||||
|
||||
def build_nvfp4_runner(cfg, a, b, dtype, device):
|
||||
b_fp4, scale_b_fp4, b_global_scale = _quant_weight_nvfp4(b, device)
|
||||
|
||||
# Compute global scale for activation
|
||||
# NOTE: This is generally provided ahead-of-time by the model checkpoint.
|
||||
a_amax = torch.abs(a).max().to(torch.float32)
|
||||
a_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax
|
||||
|
||||
# Alpha for the GEMM operation
|
||||
alpha = 1.0 / (a_global_scale * b_global_scale)
|
||||
|
||||
if cfg["no_a_quant"]:
|
||||
# Pre-quantize activation
|
||||
a_fp4, scale_a_fp4 = ops.scaled_fp4_quant(a, a_global_scale)
|
||||
|
||||
def run():
|
||||
return ops.cutlass_scaled_fp4_mm(
|
||||
a_fp4, b_fp4, scale_a_fp4, scale_b_fp4, alpha, dtype
|
||||
)
|
||||
|
||||
return run
|
||||
|
||||
# Quantize activation on-the-fly
|
||||
def run():
|
||||
a_fp4, scale_a_fp4 = ops.scaled_fp4_quant(a, a_global_scale)
|
||||
return ops.cutlass_scaled_fp4_mm(
|
||||
a_fp4, b_fp4, scale_a_fp4, scale_b_fp4, alpha, dtype
|
||||
)
|
||||
|
||||
return run
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384],
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=_enabled,
|
||||
line_names=_enabled,
|
||||
ylabel="TFLOP/s (larger is better)",
|
||||
plot_name="BF16 vs NVFP4 GEMMs",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider, N, K):
|
||||
M = batch_size
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
a = torch.randn((M, K), device=device, dtype=dtype)
|
||||
b = torch.randn((N, K), device=device, dtype=dtype)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "torch-bf16":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: torch.nn.functional.linear(a, b), quantiles=quantiles
|
||||
)
|
||||
else:
|
||||
cfg = PROVIDER_CFGS[provider]
|
||||
run_quant = build_nvfp4_runner(cfg, a, b, dtype, device)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: run_quant(), quantiles=quantiles
|
||||
)
|
||||
|
||||
to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3)
|
||||
return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms)
|
||||
|
||||
|
||||
def prepare_shapes(args):
|
||||
out = []
|
||||
for model, tp_size in itertools.product(args.models, args.tp_sizes):
|
||||
for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
|
||||
KN[tp_dim] //= tp_size
|
||||
KN.append(model)
|
||||
out.append(KN)
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=["meta-llama/Llama-3.1-8B-Instruct"],
|
||||
choices=list(WEIGHT_SHAPES.keys()),
|
||||
)
|
||||
parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1])
|
||||
args = parser.parse_args()
|
||||
|
||||
for K, N, model in prepare_shapes(args):
|
||||
print(f"{model}, N={N} K={K}, BF16 vs NVFP4 GEMMs TFLOP/s:")
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
show_plots=True,
|
||||
save_path=f"bench_nvfp4_res_n{N}_k{K}",
|
||||
N=N,
|
||||
K=K,
|
||||
)
|
||||
|
||||
print("Benchmark finished!")
|
||||
98
benchmarks/kernels/bench_per_token_quant_fp8.py
Normal file
98
benchmarks/kernels/bench_per_token_quant_fp8.py
Normal file
@ -0,0 +1,98 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import itertools
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
|
||||
# TODO(luka): use standalone_compile utility
|
||||
def with_dyn_arg(fn: Callable, arg_index: int, dim_index: int):
|
||||
def inner(*args):
|
||||
torch._dynamo.mark_dynamic(args[arg_index], dim_index)
|
||||
return fn(*args)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
torch._dynamo.config.recompile_limit = 8888
|
||||
compilation_config = CompilationConfig(custom_ops=["none"])
|
||||
with set_current_vllm_config(VllmConfig(compilation_config=compilation_config)):
|
||||
torch_per_token_quant_fp8 = torch.compile(
|
||||
QuantFP8(False, GroupShape.PER_TOKEN),
|
||||
fullgraph=True,
|
||||
dynamic=False, # recompile for different shapes
|
||||
)
|
||||
|
||||
# First dim is explicitly dynamic to simulate vLLM usage
|
||||
torch_per_token_quant_fp8 = with_dyn_arg(torch_per_token_quant_fp8, 0, 0)
|
||||
|
||||
|
||||
def cuda_per_token_quant_fp8(
|
||||
input: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return ops.scaled_fp8_quant(input)
|
||||
|
||||
|
||||
def calculate_diff(batch_size: int, seq_len: int):
|
||||
"""Calculate difference between Triton and CUDA implementations."""
|
||||
device = torch.device("cuda")
|
||||
x = torch.rand((batch_size * seq_len, 4096), dtype=torch.float16, device=device)
|
||||
|
||||
torch_out, torch_scale = torch_per_token_quant_fp8(x)
|
||||
cuda_out, cuda_scale = cuda_per_token_quant_fp8(x)
|
||||
|
||||
if torch.allclose(
|
||||
cuda_out.to(torch.float32), torch_out.to(torch.float32), rtol=1e-3, atol=1e-5
|
||||
) and torch.allclose(cuda_scale, torch_scale, rtol=1e-3, atol=1e-5):
|
||||
print("✅ All implementations match")
|
||||
else:
|
||||
print("❌ Implementations differ")
|
||||
|
||||
|
||||
batch_size_range = [1, 16, 32, 64, 128]
|
||||
seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]
|
||||
|
||||
configs = list(itertools.product(batch_size_range, seq_len_range))
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size", "seq_len"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["torch", "cuda"],
|
||||
line_names=["Torch", "CUDA"],
|
||||
styles=[("blue", "-"), ("green", "-")],
|
||||
ylabel="us",
|
||||
plot_name="per-token-dynamic-quant-fp8-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark_quantization(batch_size, seq_len, provider):
|
||||
dtype = torch.float16
|
||||
device = torch.device("cuda")
|
||||
|
||||
x = torch.randn(batch_size * seq_len, 4096, device=device, dtype=dtype)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "torch":
|
||||
fn = lambda: torch_per_token_quant_fp8(x.clone())
|
||||
elif provider == "cuda":
|
||||
fn = lambda: cuda_per_token_quant_fp8(x.clone())
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
calculate_diff(batch_size=4, seq_len=4096)
|
||||
benchmark_quantization.run(print_data=True)
|
||||
@ -113,6 +113,7 @@ def bench_run(
|
||||
w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
per_act_token: bool,
|
||||
num_repeats: int,
|
||||
):
|
||||
for _ in range(num_repeats):
|
||||
@ -124,7 +125,8 @@ def bench_run(
|
||||
topk_ids,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
a1_scale=a_scale,
|
||||
per_act_token,
|
||||
a1_scale=None,
|
||||
)
|
||||
|
||||
def run_cutlass_from_graph(
|
||||
@ -148,7 +150,8 @@ def bench_run(
|
||||
topk_ids,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
a1_scale=a_scale,
|
||||
per_act_token,
|
||||
a1_scale=None,
|
||||
)
|
||||
|
||||
def run_triton_from_graph(
|
||||
@ -227,6 +230,7 @@ def bench_run(
|
||||
"w2_q": w2_q,
|
||||
"w1_scale": w1_scale,
|
||||
"w2_scale": w2_scale,
|
||||
"per_act_token": per_act_token,
|
||||
# cuda graph params
|
||||
"cutlass_graph": cutlass_graph,
|
||||
"triton_graph": triton_graph,
|
||||
@ -287,12 +291,13 @@ def bench_run(
|
||||
w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
per_act_token,
|
||||
num_warmup,
|
||||
)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, num_runs)", # noqa: E501
|
||||
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
|
||||
@ -234,8 +234,10 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
|
||||
|
||||
fn = lambda: ops.gptq_marlin_gemm(
|
||||
a=bt.a,
|
||||
c=None,
|
||||
b_q_weight=w_q,
|
||||
b_scales=w_s,
|
||||
global_scale=None,
|
||||
b_zeros=w_zp,
|
||||
g_idx=g_idx,
|
||||
perm=sort_indices,
|
||||
|
||||
@ -86,6 +86,9 @@ def benchmark_config(
|
||||
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32
|
||||
)
|
||||
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
|
||||
if use_deep_gemm:
|
||||
# we use the default block shape for deepgemm
|
||||
block_quant_shape = [128, 128]
|
||||
if use_fp8_w8a8:
|
||||
if block_quant_shape:
|
||||
block_n, block_k = block_quant_shape[0], block_quant_shape[1]
|
||||
@ -573,7 +576,11 @@ def main(args: argparse.Namespace):
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
elif config.architectures[0] in ("DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"):
|
||||
elif config.architectures[0] in (
|
||||
"DeepseekV3ForCausalLM",
|
||||
"DeepseekV2ForCausalLM",
|
||||
"Glm4MoeForCausalLM",
|
||||
):
|
||||
E = config.n_routed_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
@ -583,6 +590,11 @@ def main(args: argparse.Namespace):
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
elif config.architectures[0] in ("HunYuanMoEV1ForCausalLM"):
|
||||
E = config.num_experts
|
||||
topk = config.moe_topk[0]
|
||||
intermediate_size = config.moe_intermediate_size[0]
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
else:
|
||||
# Support for llama4
|
||||
config = config.get_text_config()
|
||||
@ -620,7 +632,7 @@ def main(args: argparse.Namespace):
|
||||
4096,
|
||||
]
|
||||
else:
|
||||
batch_sizes = [args.batch_size]
|
||||
batch_sizes = args.batch_size
|
||||
|
||||
use_deep_gemm = bool(args.use_deep_gemm)
|
||||
|
||||
@ -728,7 +740,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
parser.add_argument("--use-deep-gemm", action="store_true")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--batch-size", type=int, required=False)
|
||||
parser.add_argument("--batch-size", type=int, nargs="+", required=False)
|
||||
parser.add_argument("--tune", action="store_true")
|
||||
parser.add_argument("--trust-remote-code", action="store_true")
|
||||
parser.add_argument("--model-prefix", type=str, required=False)
|
||||
|
||||
@ -5,9 +5,8 @@ import itertools
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||
moe_align_block_size_triton,
|
||||
moe_align_block_size,
|
||||
)
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
@ -21,62 +20,6 @@ def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
|
||||
)
|
||||
|
||||
|
||||
def check_correctness(num_tokens, num_experts=256, block_size=256, topk=8):
|
||||
"""
|
||||
Verifies vllm vs. Triton
|
||||
"""
|
||||
topk_ids = get_topk_ids(num_tokens, num_experts, topk)
|
||||
|
||||
# 1. malloc space for triton and vllm
|
||||
# malloc enough space (max_num_tokens_padded) for the sorted ids
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
sorted_ids_triton = torch.empty(
|
||||
(max_num_tokens_padded,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
sorted_ids_triton.fill_(topk_ids.numel()) # fill with sentinel value
|
||||
expert_ids_triton = torch.zeros(
|
||||
(max_num_tokens_padded // block_size,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
num_tokens_post_pad_triton = torch.empty((1,), dtype=torch.int32, device="cuda")
|
||||
|
||||
sorted_ids_vllm = torch.empty_like(sorted_ids_triton)
|
||||
sorted_ids_vllm.fill_(topk_ids.numel())
|
||||
expert_ids_vllm = torch.zeros_like(expert_ids_triton)
|
||||
num_tokens_post_pad_vllm = torch.empty_like(num_tokens_post_pad_triton)
|
||||
|
||||
# 2. run implementations
|
||||
moe_align_block_size_triton(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids_triton,
|
||||
expert_ids_triton,
|
||||
num_tokens_post_pad_triton,
|
||||
)
|
||||
|
||||
ops.moe_align_block_size(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids_vllm,
|
||||
expert_ids_vllm,
|
||||
num_tokens_post_pad_vllm,
|
||||
)
|
||||
print(f"✅ VLLM implementation works with {num_experts} experts!")
|
||||
|
||||
# 3. compare results
|
||||
if torch.allclose(expert_ids_triton, expert_ids_vllm) and torch.allclose(
|
||||
num_tokens_post_pad_triton, num_tokens_post_pad_vllm
|
||||
):
|
||||
print("✅ Triton and VLLM implementations match.")
|
||||
else:
|
||||
print("❌ Triton and VLLM implementations DO NOT match.")
|
||||
print("Triton expert_ids:", expert_ids_triton)
|
||||
print("VLLM expert_ids:", expert_ids_vllm)
|
||||
print("Triton num_tokens_post_pad:", num_tokens_post_pad_triton)
|
||||
print("VLLM num_tokens_post_pad:", num_tokens_post_pad_vllm)
|
||||
|
||||
|
||||
# test configurations
|
||||
num_tokens_range = [1, 16, 256, 4096]
|
||||
num_experts_range = [16, 64, 224, 256, 280, 512]
|
||||
@ -89,8 +32,8 @@ configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range
|
||||
x_names=["num_tokens", "num_experts", "topk"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["vllm", "triton"], # "triton"
|
||||
line_names=["VLLM", "Triton"], # "Triton"
|
||||
line_vals=["vllm"],
|
||||
line_names=["vLLM"],
|
||||
plot_name="moe-align-block-size-performance",
|
||||
args={},
|
||||
)
|
||||
@ -100,37 +43,11 @@ def benchmark(num_tokens, num_experts, topk, provider):
|
||||
block_size = 256
|
||||
topk_ids = get_topk_ids(num_tokens, num_experts, topk)
|
||||
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda")
|
||||
sorted_ids.fill_(topk_ids.numel())
|
||||
max_num_m_blocks = max_num_tokens_padded // block_size
|
||||
expert_ids = torch.empty((max_num_m_blocks,), dtype=torch.int32, device="cuda")
|
||||
num_tokens_post_pad = torch.empty((1,), dtype=torch.int32, device="cuda")
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "vllm":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: ops.moe_align_block_size(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids.clone(),
|
||||
expert_ids.clone(),
|
||||
num_tokens_post_pad.clone(),
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif provider == "triton":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: moe_align_block_size_triton(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids.clone(),
|
||||
expert_ids.clone(),
|
||||
num_tokens_post_pad.clone(),
|
||||
),
|
||||
lambda: moe_align_block_size(topk_ids, block_size, num_experts),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
@ -154,6 +71,4 @@ if __name__ == "__main__":
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("Running correctness check...")
|
||||
check_correctness(num_tokens=1024, num_experts=args.num_experts, topk=args.topk)
|
||||
benchmark.run(print_data=True, show_plots=True)
|
||||
|
||||
@ -8,12 +8,13 @@ import ray
|
||||
import torch
|
||||
from transformers import AutoConfig
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import *
|
||||
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
|
||||
_moe_permute,
|
||||
_moe_unpermute_and_reduce,
|
||||
moe_permute,
|
||||
moe_unpermute,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import *
|
||||
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import *
|
||||
from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
@ -63,18 +64,19 @@ def benchmark_permute(
|
||||
|
||||
def run():
|
||||
if use_customized_permute:
|
||||
(permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = (
|
||||
moe_permute(
|
||||
qhidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
token_expert_indices=token_expert_indices,
|
||||
topk=topk,
|
||||
n_expert=num_experts,
|
||||
n_local_expert=num_experts,
|
||||
expert_map=None,
|
||||
align_block_size=align_block_size,
|
||||
)
|
||||
(
|
||||
permuted_hidden_states,
|
||||
a1q_scale,
|
||||
first_token_off,
|
||||
inv_perm_idx,
|
||||
m_indices,
|
||||
) = moe_permute(
|
||||
qhidden_states,
|
||||
a1q_scale=None,
|
||||
topk_ids=topk_ids,
|
||||
n_expert=num_experts,
|
||||
expert_map=None,
|
||||
align_block_size=align_block_size,
|
||||
)
|
||||
else:
|
||||
(
|
||||
@ -150,18 +152,19 @@ def benchmark_unpermute(
|
||||
|
||||
def prepare():
|
||||
if use_customized_permute:
|
||||
(permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = (
|
||||
moe_permute(
|
||||
qhidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
token_expert_indices=token_expert_indices,
|
||||
topk=topk,
|
||||
n_expert=num_experts,
|
||||
n_local_expert=num_experts,
|
||||
expert_map=None,
|
||||
align_block_size=align_block_size,
|
||||
)
|
||||
(
|
||||
permuted_hidden_states,
|
||||
a1q_scale,
|
||||
first_token_off,
|
||||
inv_perm_idx,
|
||||
m_indices,
|
||||
) = moe_permute(
|
||||
qhidden_states,
|
||||
a1q_scale=None,
|
||||
topk_ids=topk_ids,
|
||||
n_expert=num_experts,
|
||||
expert_map=None,
|
||||
align_block_size=align_block_size,
|
||||
)
|
||||
# convert to fp16/bf16 as gemm output
|
||||
return (
|
||||
@ -191,16 +194,19 @@ def benchmark_unpermute(
|
||||
|
||||
def run(input: tuple):
|
||||
if use_customized_permute:
|
||||
(permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = input
|
||||
(
|
||||
permuted_hidden_states,
|
||||
first_token_off,
|
||||
inv_perm_idx,
|
||||
m_indices,
|
||||
) = input
|
||||
output = torch.empty_like(hidden_states)
|
||||
moe_unpermute(
|
||||
output,
|
||||
permuted_hidden_states,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inv_perm_idx,
|
||||
first_token_off,
|
||||
topk,
|
||||
num_experts,
|
||||
num_experts,
|
||||
)
|
||||
else:
|
||||
(
|
||||
@ -211,7 +217,11 @@ def benchmark_unpermute(
|
||||
inv_perm,
|
||||
) = input
|
||||
_moe_unpermute_and_reduce(
|
||||
output_hidden_states, permuted_hidden_states, inv_perm, topk_weights
|
||||
output_hidden_states,
|
||||
permuted_hidden_states,
|
||||
inv_perm,
|
||||
topk_weights,
|
||||
True,
|
||||
)
|
||||
|
||||
# JIT compilation & warmup
|
||||
@ -318,6 +328,7 @@ def main(args: argparse.Namespace):
|
||||
elif (
|
||||
config.architectures[0] == "DeepseekV3ForCausalLM"
|
||||
or config.architectures[0] == "DeepseekV2ForCausalLM"
|
||||
or config.architectures[0] == "Glm4MoeForCausalLM"
|
||||
):
|
||||
E = config.n_routed_experts
|
||||
topk = config.num_experts_per_tok
|
||||
|
||||
240
benchmarks/kernels/benchmark_trtllm_attention.py
Normal file
240
benchmarks/kernels/benchmark_trtllm_attention.py
Normal file
@ -0,0 +1,240 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import csv
|
||||
import os
|
||||
import random
|
||||
from datetime import datetime
|
||||
|
||||
import flashinfer
|
||||
import torch
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
|
||||
# KV Cache Layout for TRT-LLM
|
||||
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
|
||||
|
||||
|
||||
def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||
finfo = torch.finfo(dtype)
|
||||
min_val, max_val = x.aminmax()
|
||||
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
||||
scale = finfo.max / amax * 0.1
|
||||
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
|
||||
return x_scl_sat.to(dtype), scale.float().reciprocal()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def benchmark_decode(
|
||||
num_seqs,
|
||||
max_seq_len,
|
||||
page_size=16,
|
||||
dtype=torch.bfloat16,
|
||||
kv_layout="HND",
|
||||
num_kv_heads=8,
|
||||
kv_cache_dtype="auto",
|
||||
head_dim=128,
|
||||
warmup=10,
|
||||
trials=20,
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
device = "cuda"
|
||||
torch.manual_seed(0)
|
||||
|
||||
# Currently only HEAD_GRP_SIZE == 8 is supported
|
||||
HEAD_GRP_SIZE = 8
|
||||
MAX_SEQ_LEN = max_seq_len
|
||||
|
||||
# large number to reduce kv_cache reuse
|
||||
NUM_BLOCKS = int(256000 / page_size)
|
||||
|
||||
workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8, device=device)
|
||||
|
||||
# For decode, batch_size is num_decode_token
|
||||
num_qo_heads = num_kv_heads * HEAD_GRP_SIZE
|
||||
sm_scale = float(1.0 / (head_dim**0.5))
|
||||
q = torch.randn(num_seqs, num_qo_heads, head_dim, device=device, dtype=dtype)
|
||||
kv_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
||||
|
||||
max_kv_len = max(kv_lens)
|
||||
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int, device=device)
|
||||
max_num_blocks_per_seq = (max_kv_len + page_size - 1) // page_size
|
||||
|
||||
block_tables = torch.randint(
|
||||
0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
|
||||
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim)
|
||||
kv_cache = torch.randn(size=kv_cache_shape, device=device, dtype=dtype)
|
||||
k_scale = v_scale = 1.0
|
||||
|
||||
if kv_cache_dtype.startswith("fp8"):
|
||||
kv_cache, _ = to_float8(kv_cache)
|
||||
|
||||
# Benchmark TRT decode
|
||||
def trt_decode():
|
||||
return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
|
||||
q,
|
||||
kv_cache,
|
||||
workspace_buffer,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
sm_scale,
|
||||
block_tables,
|
||||
kv_lens_tensor,
|
||||
page_size,
|
||||
max_kv_len,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
def time_fn(fn, warmup=10, trials=20):
|
||||
torch.cuda.synchronize()
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
times = []
|
||||
for i in range(warmup):
|
||||
fn()
|
||||
for i in range(trials):
|
||||
start.record()
|
||||
fn()
|
||||
end.record()
|
||||
torch.cuda.synchronize()
|
||||
times.append(start.elapsed_time(end)) # ms
|
||||
return sum(times) / len(times), torch.std(torch.tensor(times))
|
||||
|
||||
# TRT Decode
|
||||
trt_mean, trt_std = time_fn(trt_decode)
|
||||
|
||||
kv_indptr = [0]
|
||||
kv_indices = []
|
||||
kv_last_page_lens = []
|
||||
for i in range(num_seqs):
|
||||
seq_len = kv_lens[i]
|
||||
assert seq_len > 0
|
||||
num_blocks = (seq_len + page_size - 1) // page_size
|
||||
kv_indices.extend(block_tables[i, :num_blocks])
|
||||
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
||||
kv_last_page_len = seq_len % page_size
|
||||
if kv_last_page_len == 0:
|
||||
kv_last_page_len = page_size
|
||||
kv_last_page_lens.append(kv_last_page_len)
|
||||
|
||||
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
||||
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||
|
||||
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer,
|
||||
kv_layout,
|
||||
use_tensor_cores=((num_qo_heads // num_kv_heads) > 4),
|
||||
)
|
||||
|
||||
wrapper.plan(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
page_size,
|
||||
"NONE",
|
||||
q_data_type=dtype,
|
||||
kv_data_type=torch.float8_e4m3fn if kv_cache_dtype.startswith("fp8") else dtype,
|
||||
)
|
||||
|
||||
def baseline_decode():
|
||||
return wrapper.run(q, kv_cache, sm_scale, k_scale, v_scale)
|
||||
|
||||
baseline_mean, baseline_std = time_fn(baseline_decode)
|
||||
|
||||
# Calculate percentage speedup (positive means TRT is faster)
|
||||
speedup_percent = (baseline_mean - trt_mean) / baseline_mean
|
||||
|
||||
print(
|
||||
f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.3f}\t{trt_std.item():.3f}"
|
||||
f"\t{baseline_mean:.3f}\t{baseline_std.item():.3f}\t{speedup_percent:.3f}"
|
||||
)
|
||||
|
||||
# Return results for CSV writing
|
||||
return {
|
||||
"num_seqs": num_seqs,
|
||||
"trt_mean": trt_mean,
|
||||
"trt_std": trt_std.item(),
|
||||
"baseline_mean": baseline_mean,
|
||||
"baseline_std": baseline_std.item(),
|
||||
"speedup_percent": speedup_percent,
|
||||
"q_dtype": str(dtype),
|
||||
"kv_cache_dtype": kv_cache_dtype,
|
||||
"page_size": page_size,
|
||||
"num_kv_heads": num_kv_heads,
|
||||
"head_dim": head_dim,
|
||||
"max_seq_len": max_seq_len,
|
||||
}
|
||||
|
||||
|
||||
def write_results_to_csv(results, filename=None):
|
||||
"""Write benchmark results to CSV file."""
|
||||
if filename is None:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv"
|
||||
|
||||
fieldnames = [
|
||||
"num_seqs",
|
||||
"trt_mean",
|
||||
"trt_std",
|
||||
"baseline_mean",
|
||||
"baseline_std",
|
||||
"speedup_percent",
|
||||
"q_dtype",
|
||||
"kv_cache_dtype",
|
||||
"page_size",
|
||||
"num_kv_heads",
|
||||
"head_dim",
|
||||
"max_seq_len",
|
||||
]
|
||||
|
||||
file_exists = os.path.exists(filename)
|
||||
|
||||
with open(filename, "a", newline="") as csvfile:
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||
|
||||
if not file_exists:
|
||||
writer.writeheader()
|
||||
|
||||
for result in results:
|
||||
writer.writerow(result)
|
||||
|
||||
print(f"Results written to {filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
num_seqs = [1, 4, 8, 16, 32, 64, 128, 256]
|
||||
max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
|
||||
all_results = []
|
||||
|
||||
print("Running benchmark for kv_cache_dtype: bfloat16")
|
||||
print(
|
||||
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\tbaseline_std\tspeedup_percent"
|
||||
)
|
||||
for max_seq_len in max_seq_lens:
|
||||
for bs in num_seqs:
|
||||
result = benchmark_decode(
|
||||
bs, max_seq_len, dtype=torch.bfloat16, kv_cache_dtype="auto"
|
||||
)
|
||||
all_results.append(result)
|
||||
|
||||
print("Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8")
|
||||
print(
|
||||
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\tbaseline_std\tspeedup_percent"
|
||||
)
|
||||
for max_seq_len in max_seq_lens:
|
||||
for bs in num_seqs:
|
||||
result = benchmark_decode(
|
||||
bs, max_seq_len, dtype=torch.bfloat16, kv_cache_dtype="fp8"
|
||||
)
|
||||
all_results.append(result)
|
||||
|
||||
# Write all results to CSV
|
||||
write_results_to_csv(all_results)
|
||||
108
benchmarks/kv_cache/benchmark_block_pool.py
Normal file
108
benchmarks/kv_cache/benchmark_block_pool.py
Normal file
@ -0,0 +1,108 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import gc
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from tabulate import tabulate
|
||||
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
|
||||
|
||||
class Metric:
|
||||
def __init__(self) -> None:
|
||||
self.cnt: int = 0
|
||||
self.sum_v: int = 0
|
||||
self.max_v: Optional[int] = None
|
||||
|
||||
def update(self, v: int) -> None:
|
||||
self.cnt += 1
|
||||
self.sum_v += v
|
||||
if self.max_v is None:
|
||||
self.max_v = v
|
||||
else:
|
||||
self.max_v = max(self.max_v, v)
|
||||
|
||||
def avg_v(self) -> float:
|
||||
return self.sum_v * 1.0 / self.cnt
|
||||
|
||||
|
||||
def main(args):
|
||||
rows = []
|
||||
for allocate_block in args.allocate_blocks:
|
||||
# Enforce a GC collect ahead to minimize the impact among runs
|
||||
gc.collect()
|
||||
block_pool = BlockPool(num_gpu_blocks=args.num_gpu_blocks, enable_caching=True)
|
||||
|
||||
get_blocks_metric: Metric = Metric()
|
||||
free_blocks_metric: Metric = Metric()
|
||||
for _ in range(args.num_iteration):
|
||||
t1 = time.monotonic_ns()
|
||||
blocks = block_pool.get_new_blocks(allocate_block)
|
||||
t2 = time.monotonic_ns()
|
||||
block_pool.free_blocks(blocks)
|
||||
t3 = time.monotonic_ns()
|
||||
get_blocks_metric.update(t2 - t1)
|
||||
free_blocks_metric.update(t3 - t2)
|
||||
|
||||
if get_blocks_metric.max_v is not None and free_blocks_metric.max_v is not None:
|
||||
rows.append(
|
||||
[
|
||||
get_blocks_metric.cnt,
|
||||
args.num_gpu_blocks,
|
||||
allocate_block,
|
||||
get_blocks_metric.avg_v() / 1000000,
|
||||
get_blocks_metric.max_v / 1000000.0,
|
||||
free_blocks_metric.avg_v() / 1000000,
|
||||
free_blocks_metric.max_v / 1000000.0,
|
||||
]
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"No valid metrics found."
|
||||
f" {get_blocks_metric.max_v=} {free_blocks_metric.max_v=}"
|
||||
)
|
||||
|
||||
print(
|
||||
tabulate(
|
||||
rows,
|
||||
headers=[
|
||||
"Iterations",
|
||||
"Total\nBlocks",
|
||||
"Allocated\nBlocks",
|
||||
"Get Blocks\nAvg (ms)",
|
||||
"Get Blocks\nMax (ms)",
|
||||
"Free Blocks\nAvg (ms)",
|
||||
"Free Blocks\nMax (ms)",
|
||||
],
|
||||
tablefmt="grid",
|
||||
floatfmt=".6f",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def invoke_main() -> None:
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark the performance of BlockPool for KV Cache."
|
||||
)
|
||||
parser.add_argument("--num-gpu-blocks", type=int, default=100000)
|
||||
parser.add_argument(
|
||||
"--num-iteration",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of iterations to run to stablize final data readings",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--allocate-blocks",
|
||||
type=int,
|
||||
nargs="*",
|
||||
default=[10, 50, 100, 500, 1000],
|
||||
help="Number of blocks to allocate",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
invoke_main() # pragma: no cover
|
||||
@ -12,9 +12,8 @@ endif()
|
||||
#
|
||||
# Define environment variables for special configurations
|
||||
#
|
||||
if(DEFINED ENV{VLLM_CPU_AVX512BF16})
|
||||
set(ENABLE_AVX512BF16 ON)
|
||||
endif()
|
||||
set(ENABLE_AVX512BF16 $ENV{VLLM_CPU_AVX512BF16})
|
||||
set(ENABLE_AVX512VNNI $ENV{VLLM_CPU_AVX512VNNI})
|
||||
|
||||
include_directories("${CMAKE_SOURCE_DIR}/csrc")
|
||||
|
||||
@ -59,6 +58,22 @@ function (find_isa CPUINFO TARGET OUT)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
|
||||
function(check_sysctl TARGET OUT)
|
||||
execute_process(COMMAND sysctl -n "${TARGET}"
|
||||
RESULT_VARIABLE SYSCTL_RET
|
||||
OUTPUT_VARIABLE SYSCTL_INFO
|
||||
ERROR_QUIET
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||
if(SYSCTL_RET EQUAL 0 AND
|
||||
(SYSCTL_INFO STREQUAL "1" OR SYSCTL_INFO GREATER 0))
|
||||
set(${OUT} ON PARENT_SCOPE)
|
||||
else()
|
||||
set(${OUT} OFF PARENT_SCOPE)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
|
||||
function (is_avx512_disabled OUT)
|
||||
set(DISABLE_AVX512 $ENV{VLLM_CPU_DISABLE_AVX512})
|
||||
if(DISABLE_AVX512 AND DISABLE_AVX512 STREQUAL "true")
|
||||
@ -71,7 +86,10 @@ endfunction()
|
||||
is_avx512_disabled(AVX512_DISABLED)
|
||||
|
||||
if (MACOSX_FOUND AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
|
||||
set(APPLE_SILICON_FOUND TRUE)
|
||||
message(STATUS "Apple Silicon Detected")
|
||||
set(ENABLE_NUMA OFF)
|
||||
check_sysctl(hw.optional.neon ASIMD_FOUND)
|
||||
check_sysctl(hw.optional.arm.FEAT_BF16 ARM_BF16_FOUND)
|
||||
else()
|
||||
find_isa(${CPUINFO} "avx2" AVX2_FOUND)
|
||||
find_isa(${CPUINFO} "avx512f" AVX512_FOUND)
|
||||
@ -83,7 +101,6 @@ else()
|
||||
find_isa(${CPUINFO} "S390" S390_FOUND)
|
||||
endif()
|
||||
|
||||
|
||||
if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||
list(APPEND CXX_COMPILE_FLAGS
|
||||
"-mavx512f"
|
||||
@ -96,12 +113,30 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
|
||||
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
|
||||
list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16")
|
||||
set(ENABLE_AVX512BF16 ON)
|
||||
else()
|
||||
set(ENABLE_AVX512BF16 OFF)
|
||||
message(WARNING "Disable AVX512-BF16 ISA support, requires gcc/g++ >= 12.3")
|
||||
endif()
|
||||
else()
|
||||
set(ENABLE_AVX512BF16 OFF)
|
||||
message(WARNING "Disable AVX512-BF16 ISA support, no avx512_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512BF16=1.")
|
||||
endif()
|
||||
|
||||
find_isa(${CPUINFO} "avx512_vnni" AVX512VNNI_FOUND)
|
||||
if (AVX512VNNI_FOUND OR ENABLE_AVX512VNNI)
|
||||
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
|
||||
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
|
||||
list(APPEND CXX_COMPILE_FLAGS "-mavx512vnni")
|
||||
set(ENABLE_AVX512VNNI ON)
|
||||
else()
|
||||
set(ENABLE_AVX512VNNI OFF)
|
||||
message(WARNING "Disable AVX512-VNNI ISA support, requires gcc/g++ >= 12.3")
|
||||
endif()
|
||||
else()
|
||||
set(ENABLE_AVX512VNNI OFF)
|
||||
message(WARNING "Disable AVX512-VNNI ISA support, no avx512_vnni found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512VNNI=1.")
|
||||
endif()
|
||||
|
||||
elseif (AVX2_FOUND)
|
||||
list(APPEND CXX_COMPILE_FLAGS "-mavx2")
|
||||
@ -132,9 +167,6 @@ elseif (ASIMD_FOUND)
|
||||
set(MARCH_FLAGS "-march=armv8.2-a+dotprod+fp16")
|
||||
endif()
|
||||
list(APPEND CXX_COMPILE_FLAGS ${MARCH_FLAGS})
|
||||
elseif(APPLE_SILICON_FOUND)
|
||||
message(STATUS "Apple Silicon Detected")
|
||||
set(ENABLE_NUMA OFF)
|
||||
elseif (S390_FOUND)
|
||||
message(STATUS "S390 detected")
|
||||
# Check for S390 VXE support
|
||||
@ -148,17 +180,32 @@ else()
|
||||
endif()
|
||||
|
||||
#
|
||||
# Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 platforms)
|
||||
#
|
||||
if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||
# Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 /ARM platforms)
|
||||
# Flag to enable ACL kernels for AARCH64 platforms
|
||||
if ( VLLM_BUILD_ACL STREQUAL "ON")
|
||||
set(USE_ACL ON)
|
||||
else()
|
||||
set(USE_ACL OFF)
|
||||
endif()
|
||||
|
||||
if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND)
|
||||
FetchContent_Declare(
|
||||
oneDNN
|
||||
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
|
||||
GIT_TAG v3.7.1
|
||||
GIT_TAG v3.8.1
|
||||
GIT_PROGRESS TRUE
|
||||
GIT_SHALLOW TRUE
|
||||
)
|
||||
|
||||
if(USE_ACL)
|
||||
find_library(ARM_COMPUTE_LIBRARY NAMES arm_compute PATHS $ENV{ACL_ROOT_DIR}/build/)
|
||||
if(NOT ARM_COMPUTE_LIBRARY)
|
||||
message(FATAL_ERROR "Could not find ARM Compute Library: please set ACL_ROOT_DIR")
|
||||
endif()
|
||||
set(ONEDNN_AARCH64_USE_ACL "ON")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ENV{ACL_ROOT_DIR}/build/")
|
||||
endif()
|
||||
|
||||
set(ONEDNN_LIBRARY_TYPE "STATIC")
|
||||
set(ONEDNN_BUILD_DOC "OFF")
|
||||
set(ONEDNN_BUILD_EXAMPLES "OFF")
|
||||
@ -231,11 +278,29 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||
"csrc/cpu/quant.cpp"
|
||||
"csrc/cpu/shm.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI)
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cpu/sgl-kernels/gemm.cpp"
|
||||
"csrc/cpu/sgl-kernels/gemm_int8.cpp"
|
||||
"csrc/cpu/sgl-kernels/gemm_fp8.cpp"
|
||||
"csrc/cpu/sgl-kernels/moe.cpp"
|
||||
"csrc/cpu/sgl-kernels/moe_int8.cpp"
|
||||
"csrc/cpu/sgl-kernels/moe_fp8.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
add_compile_definitions(-DCPU_CAPABILITY_AVX512)
|
||||
endif()
|
||||
elseif(POWER10_FOUND)
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cpu/quant.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
endif()
|
||||
if (ASIMD_FOUND)
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cpu/quant.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
endif()
|
||||
|
||||
message(STATUS "CPU extension source files: ${VLLM_EXT_SRC}")
|
||||
|
||||
#
|
||||
# Define extension targets
|
||||
|
||||
@ -38,7 +38,7 @@ else()
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn
|
||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||
GIT_TAG 5f3644181c7a15345ce20bfc65af117d3601b524
|
||||
GIT_TAG 1c2624e53c078854e0637ee566c72fe2107e75f4
|
||||
GIT_PROGRESS TRUE
|
||||
# Don't share the vllm-flash-attn build between build types
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
|
||||
@ -265,8 +265,8 @@ macro(set_gencode_flags_for_srcs)
|
||||
endmacro()
|
||||
|
||||
#
|
||||
# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
|
||||
# `<major>.<minor>[letter]` compute the "loose intersection" with the
|
||||
# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
|
||||
# `<major>.<minor>[letter]` compute the "loose intersection" with the
|
||||
# `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in
|
||||
# `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there
|
||||
# is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the
|
||||
@ -278,7 +278,7 @@ endmacro()
|
||||
# in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`.
|
||||
# We have special handling for x.0a, if x.0a is in `SRC_CUDA_ARCHS` and x.0 is
|
||||
# in `TGT_CUDA_ARCHS` then we should remove x.0a from `SRC_CUDA_ARCHS` and add
|
||||
# x.0a to the result (and remove x.0 from TGT_CUDA_ARCHS).
|
||||
# x.0a to the result (and remove x.0 from TGT_CUDA_ARCHS).
|
||||
# The result is stored in `OUT_CUDA_ARCHS`.
|
||||
#
|
||||
# Example:
|
||||
@ -313,21 +313,16 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
|
||||
# if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
|
||||
# remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS
|
||||
set(_CUDA_ARCHS)
|
||||
if ("9.0a" IN_LIST _SRC_CUDA_ARCHS)
|
||||
list(REMOVE_ITEM _SRC_CUDA_ARCHS "9.0a")
|
||||
if ("9.0" IN_LIST TGT_CUDA_ARCHS)
|
||||
list(REMOVE_ITEM _TGT_CUDA_ARCHS "9.0")
|
||||
set(_CUDA_ARCHS "9.0a")
|
||||
foreach(_arch ${_SRC_CUDA_ARCHS})
|
||||
if(_arch MATCHES "\\a$")
|
||||
list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}")
|
||||
string(REPLACE "a" "" _base "${_arch}")
|
||||
if ("${_base}" IN_LIST TGT_CUDA_ARCHS)
|
||||
list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_base}")
|
||||
list(APPEND _CUDA_ARCHS "${_arch}")
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if ("10.0a" IN_LIST _SRC_CUDA_ARCHS)
|
||||
list(REMOVE_ITEM _SRC_CUDA_ARCHS "10.0a")
|
||||
if ("10.0" IN_LIST TGT_CUDA_ARCHS)
|
||||
list(REMOVE_ITEM _TGT_CUDA_ARCHS "10.0")
|
||||
set(_CUDA_ARCHS "10.0a")
|
||||
endif()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
list(SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
|
||||
|
||||
@ -359,7 +354,7 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
|
||||
endforeach()
|
||||
|
||||
list(REMOVE_DUPLICATES _CUDA_ARCHS)
|
||||
|
||||
|
||||
# reapply +PTX suffix to architectures that requested PTX
|
||||
set(_FINAL_ARCHS)
|
||||
foreach(_arch ${_CUDA_ARCHS})
|
||||
@ -370,7 +365,7 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
|
||||
endif()
|
||||
endforeach()
|
||||
set(_CUDA_ARCHS ${_FINAL_ARCHS})
|
||||
|
||||
|
||||
set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
|
||||
@ -24,6 +24,7 @@
|
||||
|
||||
#include "attention_dtypes.h"
|
||||
#include "attention_utils.cuh"
|
||||
#include "../cuda_compat.h"
|
||||
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_bf16.h>
|
||||
@ -33,12 +34,6 @@ typedef __hip_bfloat16 __nv_bfloat16;
|
||||
#include "../quantization/fp8/nvidia/quant_utils.cuh"
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define WARP_SIZE 32
|
||||
#else
|
||||
#define WARP_SIZE warpSize
|
||||
#endif
|
||||
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
|
||||
@ -670,7 +665,6 @@ __global__ void paged_attention_v2_reduce_kernel(
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
#undef WARP_SIZE
|
||||
#undef MAX
|
||||
#undef MIN
|
||||
#undef DIVIDE_ROUND_UP
|
||||
|
||||
372
csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp
Normal file
372
csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp
Normal file
@ -0,0 +1,372 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*
|
||||
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
|
||||
* by Alcanderian JieXin Liang
|
||||
*/
|
||||
|
||||
/*!
|
||||
\file
|
||||
\brief An universal device layer for cutlass 3.x-style kernels.
|
||||
*/
|
||||
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
// common
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/device_kernel.h"
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
#include "cutlass/cluster_launch.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
#endif // !defined(__CUDACC_RTC__)
|
||||
|
||||
#include "../kernel/sm100_fmha_mla_tma_warpspecialized.hpp"
|
||||
#include "../kernel/sm100_fmha_mla_reduction.hpp"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::fmha::device {
|
||||
|
||||
using namespace cute;
|
||||
using namespace cutlass::fmha::kernel;
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class Kernel_
|
||||
>
|
||||
class MLA {
|
||||
public:
|
||||
|
||||
using Kernel = Kernel_;
|
||||
|
||||
using ReductionKernel = cutlass::fmha::kernel::Sm100FmhaMlaReductionKernel<
|
||||
typename Kernel::ElementOut,
|
||||
typename Kernel::ElementAcc,
|
||||
typename Kernel::ElementAcc,
|
||||
Kernel::TileShapeH::value,
|
||||
Kernel::TileShapeL::value,
|
||||
256 /*Max split*/
|
||||
>;
|
||||
|
||||
/// Argument structure: User API
|
||||
using KernelArguments = typename Kernel::Arguments;
|
||||
using ReductionArguments = typename ReductionKernel::Arguments;
|
||||
|
||||
using Arguments = KernelArguments;
|
||||
|
||||
/// Argument structure: Kernel API
|
||||
using KernelParams = typename Kernel::Params;
|
||||
using ReductionParams = typename ReductionKernel::Params;
|
||||
struct Params {
|
||||
KernelParams fmha_params;
|
||||
ReductionParams reduction_params;
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
/// Kernel API parameters object
|
||||
Params params_;
|
||||
|
||||
bool is_initialized(bool set = false) {
|
||||
static bool initialized = false;
|
||||
if (set) initialized = true;
|
||||
return initialized;
|
||||
}
|
||||
|
||||
static ReductionArguments to_reduction_args(Arguments const& args) {
|
||||
auto [H, K, D, B] = args.problem_shape;
|
||||
return ReductionArguments{
|
||||
nullptr, args.epilogue.ptr_o, nullptr, args.epilogue.ptr_lse,
|
||||
args.mainloop.softmax_scale, B, args.split_kv, K, args.mainloop.ptr_seq,
|
||||
args.ptr_split_kv, Kernel::TileShapeS::value
|
||||
};
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
/// Access the Params structure
|
||||
Params const& params() const {
|
||||
return params_;
|
||||
}
|
||||
|
||||
static void set_split_kv (KernelArguments& args) {
|
||||
// printf("set_split_kv start");
|
||||
if (args.split_kv >= 1) return;
|
||||
auto [H, K, D, B] = args.problem_shape;
|
||||
// std::cout << H << " " << K << " " << D << " " << B << "\n";
|
||||
int sm_count = args.hw_info.sm_count;
|
||||
// printf(" sm_count = %d\n", sm_count);
|
||||
int max_splits = ceil_div(K, 128);
|
||||
max_splits = min(16, max_splits);
|
||||
// printf(" max_splits = %d\n", max_splits);
|
||||
int sms_per_batch = max(1, sm_count / B);
|
||||
// printf(" sms_per_batch = %d\n", sms_per_batch);
|
||||
int split_heur = min(max_splits, sms_per_batch);
|
||||
int waves = ceil_div(B * split_heur, sm_count);
|
||||
int k_waves = ceil_div(max_splits, split_heur);
|
||||
int split_wave_aware = ceil_div(max_splits, k_waves);
|
||||
args.split_kv = split_wave_aware;
|
||||
// printf(" args.split_kv = %d\n", args.split_kv);
|
||||
|
||||
}
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status
|
||||
can_implement(Arguments const& args) {
|
||||
if (! Kernel::can_implement(args)) {
|
||||
return Status::kInvalid;
|
||||
}
|
||||
if (! ReductionKernel::can_implement(to_reduction_args(args))) {
|
||||
return Status::kInvalid;
|
||||
}
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t
|
||||
get_workspace_size(Arguments const& args) {
|
||||
size_t workspace_bytes = 0;
|
||||
workspace_bytes += Kernel::get_workspace_size(args);
|
||||
workspace_bytes += ReductionKernel::get_workspace_size(to_reduction_args(args));
|
||||
return workspace_bytes;
|
||||
}
|
||||
|
||||
/// Computes the maximum number of active blocks per multiprocessor
|
||||
static int maximum_active_blocks(int /* smem_capacity */ = -1) {
|
||||
CUTLASS_TRACE_HOST("MLA::maximum_active_blocks()");
|
||||
int max_active_blocks = -1;
|
||||
int smem_size = Kernel::SharedStorageSize;
|
||||
|
||||
// first, account for dynamic smem capacity if needed
|
||||
cudaError_t result;
|
||||
if (smem_size >= (48 << 10)) {
|
||||
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
|
||||
result = cudaFuncSetAttribute(
|
||||
device_kernel<Kernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaFuncSetAttribute() returned error: "
|
||||
<< cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
// query occupancy after setting smem size
|
||||
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks,
|
||||
device_kernel<Kernel>,
|
||||
Kernel::MaxThreadsPerBlock,
|
||||
smem_size);
|
||||
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: "
|
||||
<< cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
|
||||
return max_active_blocks;
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status
|
||||
initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("MLA::initialize() - workspace "
|
||||
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
// Initialize the workspace
|
||||
Status status = Kernel::initialize_workspace(args, workspace, stream);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
status = ReductionKernel::initialize_workspace(to_reduction_args(args), workspace, stream);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
KernelParams kernel_params = Kernel::to_underlying_arguments(args, workspace);
|
||||
|
||||
ReductionArguments reduction_args = to_reduction_args(args);
|
||||
if (reduction_args.split_kv > 1) {
|
||||
reduction_args.ptr_oaccum = kernel_params.epilogue.ptr_o_acc;
|
||||
reduction_args.ptr_lseaccum = kernel_params.epilogue.ptr_lse_acc;
|
||||
}
|
||||
ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace);
|
||||
// Initialize the Params structure
|
||||
params_ = Params {kernel_params, reduction_params};
|
||||
|
||||
if (is_initialized()) return Status::kSuccess;
|
||||
|
||||
// account for dynamic smem capacity if needed
|
||||
// no dynamic smem is needed for reduction kernel
|
||||
int smem_size = Kernel::SharedStorageSize;
|
||||
if (smem_size >= (48 << 10)) {
|
||||
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
|
||||
cudaError_t result = cudaFuncSetAttribute(
|
||||
device_kernel<Kernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
is_initialized(true);
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Update API is preserved in 3.0, but does not guarantee a lightweight update of params.
|
||||
Status
|
||||
update(Arguments const& args, void* workspace = nullptr) {
|
||||
CUTLASS_TRACE_HOST("MLA()::update() - workspace: " << workspace);
|
||||
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
if (workspace_bytes > 0 && nullptr == workspace) {
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
auto fmha_params = Kernel::to_underlying_arguments(args, workspace);
|
||||
|
||||
ReductionArguments reduction_args = to_reduction_args(args);
|
||||
if (reduction_args.split_kv > 1) {
|
||||
reduction_args.ptr_oaccum = fmha_params.epilogue.ptr_o_acc;
|
||||
reduction_args.ptr_lseaccum = fmha_params.epilogue.ptr_lse_acc;
|
||||
}
|
||||
ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace);
|
||||
// Initialize the Params structure
|
||||
params_ = Params {fmha_params, reduction_params};
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Primary run() entry point API that is static allowing users to create and manage their own params.
|
||||
/// Supplied params struct must be construct by calling Kernel::to_underling_arguments()
|
||||
static Status
|
||||
run(Params& params, cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("MLA::run()");
|
||||
dim3 const block = Kernel::get_block_shape();
|
||||
dim3 const grid = Kernel::get_grid_shape(params.fmha_params);
|
||||
|
||||
// configure smem size and carveout
|
||||
int smem_size = Kernel::SharedStorageSize;
|
||||
|
||||
Status launch_result;
|
||||
// Use extended launch API only for mainloops that use it
|
||||
if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) {
|
||||
dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}),
|
||||
cute::size<1>(typename Kernel::ClusterShape{}),
|
||||
cute::size<2>(typename Kernel::ClusterShape{}));
|
||||
void const* kernel = (void const*) device_kernel<Kernel>;
|
||||
void* kernel_params[] = {¶ms.fmha_params};
|
||||
launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params);
|
||||
}
|
||||
else {
|
||||
launch_result = Status::kSuccess;
|
||||
device_kernel<Kernel><<<grid, block, smem_size, stream>>>(params.fmha_params);
|
||||
}
|
||||
|
||||
cudaError_t result = cudaGetLastError();
|
||||
if (cudaSuccess != result or Status::kSuccess != launch_result) {
|
||||
//return Status::kSuccess;
|
||||
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
if (params.reduction_params.split_kv > 1) {
|
||||
// launch reduction kernel
|
||||
dim3 const block = ReductionKernel::get_block_shape();
|
||||
dim3 const grid = ReductionKernel::get_grid_shape(params.reduction_params);
|
||||
device_kernel<ReductionKernel><<<grid, block, 0, stream>>>(params.reduction_params);
|
||||
cudaError_t result = cudaGetLastError();
|
||||
if (cudaSuccess == result) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
else {
|
||||
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
else {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
|
||||
//
|
||||
|
||||
/// Launches the kernel after first constructing Params internal state from supplied arguments.
|
||||
Status
|
||||
run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
Status status = initialize(args, workspace, stream);
|
||||
if (Status::kSuccess == status) {
|
||||
status = run(params_, stream);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
/// Launches the kernel after first constructing Params internal state from supplied arguments.
|
||||
Status
|
||||
operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
return run(args, workspace, stream);
|
||||
}
|
||||
|
||||
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
|
||||
Status
|
||||
run(cudaStream_t stream = nullptr) {
|
||||
return run(params_, stream);
|
||||
}
|
||||
|
||||
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
|
||||
Status
|
||||
operator()(cudaStream_t stream = nullptr) {
|
||||
return run(params_, stream);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::fmha::device
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,203 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*
|
||||
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
|
||||
* by Alcanderian JieXin Liang
|
||||
*/
|
||||
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
using namespace cute;
|
||||
template<
|
||||
class ElementOut,
|
||||
class ElementAcc,
|
||||
class ElementScale,
|
||||
size_t kNumHeads,
|
||||
size_t kHeadDimLatent,
|
||||
int kMaxSplits
|
||||
>
|
||||
struct Sm100FmhaMlaReductionKernel {
|
||||
|
||||
static const int SharedStorageSize = 0;
|
||||
static const int MaxThreadsPerBlock = 128;
|
||||
static const int MinBlocksPerMultiprocessor = 1;
|
||||
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
|
||||
static_assert(kHeadDimLatent % MaxThreadsPerBlock == 0);
|
||||
struct Arguments {
|
||||
ElementAcc* ptr_oaccum = nullptr;
|
||||
ElementOut* ptr_o = nullptr;
|
||||
ElementAcc* ptr_lseaccum = nullptr;
|
||||
ElementAcc* ptr_lse = nullptr;
|
||||
ElementScale scale = 1.f;
|
||||
int num_batches = 0;
|
||||
int split_kv = -1;
|
||||
int dim_k = -1;
|
||||
int* ptr_seq = nullptr;
|
||||
int* ptr_split_kv = nullptr;
|
||||
int tile_shape_s = 128;
|
||||
};
|
||||
using Params = Arguments;
|
||||
|
||||
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
|
||||
return {args.ptr_oaccum, args.ptr_o, args.ptr_lseaccum, args.ptr_lse,
|
||||
args.scale, args.num_batches, args.split_kv, args.dim_k, args.ptr_seq,
|
||||
args.ptr_split_kv, args.tile_shape_s};
|
||||
}
|
||||
|
||||
static size_t get_workspace_size(Arguments const& /*args*/) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
static Status initialize_workspace(
|
||||
Arguments const& /*args*/, void* /*ws*/, cudaStream_t /*stream*/) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
return dim3(kNumHeads, 1, params.num_batches);
|
||||
}
|
||||
|
||||
static dim3 get_block_shape() {
|
||||
return dim3(MaxThreadsPerBlock, 1, 1);
|
||||
}
|
||||
|
||||
static bool can_implement(Arguments const& args) {
|
||||
if (args.num_batches <= 0) return false;
|
||||
if (args.split_kv <= 0) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void operator() (Params const& params, char* smem_raw) {
|
||||
if (params.split_kv <= 1) return;
|
||||
auto blk_coord = make_coord(blockIdx.x, _0{}, blockIdx.z);
|
||||
|
||||
__shared__ ElementAcc sLseScale[kMaxSplits];
|
||||
const size_t offset_lseaccum = get<0>(blk_coord) + kNumHeads * params.split_kv * get<2>(blk_coord);
|
||||
const size_t offset_lse = get<0>(blk_coord) + kNumHeads * get<2>(blk_coord);
|
||||
|
||||
Tensor gLSEaccum = make_tensor(make_gmem_ptr(params.ptr_lseaccum + offset_lseaccum),
|
||||
make_shape(params.split_kv), Stride<Int<kNumHeads>>{});
|
||||
|
||||
Tensor gLSE = make_tensor(make_gmem_ptr(params.ptr_lse + offset_lse),
|
||||
Shape<_1>{}, Stride<_1>{});
|
||||
|
||||
auto dim_k = params.ptr_seq == nullptr ? params.dim_k : params.ptr_seq[get<2>(blk_coord)];
|
||||
auto local_split_kv = params.ptr_split_kv == nullptr ? params.split_kv : params.ptr_split_kv[get<2>(blk_coord)];
|
||||
auto k_tile_total = ceil_div(dim_k, params.tile_shape_s);
|
||||
auto k_tile_per_cta = ceil_div(k_tile_total, local_split_kv);
|
||||
local_split_kv = ceil_div(k_tile_total, k_tile_per_cta);
|
||||
|
||||
int warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
if (warp_idx == 0) {
|
||||
constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32);
|
||||
|
||||
ElementAcc local_lse[kNLsePerThread];
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||
const int split = i * 32 + threadIdx.x;
|
||||
local_lse[i] = split < local_split_kv ? gLSEaccum(split) : -std::numeric_limits<ElementAcc>::infinity();
|
||||
}
|
||||
|
||||
ElementAcc lse_max = -std::numeric_limits<ElementAcc>::infinity();
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||
lse_max = max(lse_max, local_lse[i]);
|
||||
}
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int offset = 16; offset >= 1; offset /= 2) {
|
||||
lse_max = max(lse_max, __shfl_xor_sync(0xffffffff, lse_max, offset));
|
||||
}
|
||||
lse_max = lse_max == -std::numeric_limits<ElementAcc>::infinity() ? 0.0f : lse_max; // In case all local LSEs are -inf
|
||||
lse_max = __shfl_sync(0xffffffff, lse_max, 0);
|
||||
|
||||
ElementAcc sum_lse = 0;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||
sum_lse = sum_lse + expf(local_lse[i] - lse_max);
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int offset = 16; offset >= 1; offset /= 2) {
|
||||
sum_lse = sum_lse + __shfl_xor_sync(0xffffffff, sum_lse, offset);
|
||||
}
|
||||
|
||||
sum_lse = __shfl_sync(0xffffffff, sum_lse, 0);
|
||||
|
||||
ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits<ElementAcc>::infinity() : logf(sum_lse) + lse_max;
|
||||
if (threadIdx.x == 0 and params.ptr_lse != nullptr) {
|
||||
gLSE(0) = global_lse;
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||
const int split = i * 32 + threadIdx.x;
|
||||
if (split < local_split_kv) {
|
||||
sLseScale[split] = expf(local_lse[i] - global_lse);
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
constexpr int Elements = kHeadDimLatent / MaxThreadsPerBlock;
|
||||
const size_t offset_oaccum = kHeadDimLatent * params.split_kv * (get<0>(blk_coord) + kNumHeads * get<2>(blk_coord));
|
||||
Tensor gOaccum = make_tensor(make_gmem_ptr(params.ptr_oaccum + offset_oaccum),
|
||||
Shape<Int<kHeadDimLatent>>{}, Stride<_1>{});
|
||||
ElementAcc local_val[Elements] = {0};
|
||||
for (int split = 0; split < local_split_kv; ++split) {
|
||||
ElementAcc lse_scale = sLseScale[split];
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for(int i = 0; i < Elements; ++i) {
|
||||
local_val[i] += lse_scale * gOaccum(threadIdx.x + MaxThreadsPerBlock * i);
|
||||
}
|
||||
gOaccum.data() = gOaccum.data() + kHeadDimLatent;
|
||||
}
|
||||
auto ptr_o_local = params.ptr_o + (get<0>(blk_coord) + get<2>(blk_coord) * kNumHeads) * kHeadDimLatent;
|
||||
Tensor gO = make_tensor(make_gmem_ptr(ptr_o_local), Shape<Int<kHeadDimLatent>>{}, Stride<_1>{});
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for(int i = 0; i < Elements; ++i) {
|
||||
gO(threadIdx.x + MaxThreadsPerBlock * i) = static_cast<ElementOut>(local_val[i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,165 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*
|
||||
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
|
||||
* by Alcanderian JieXin Liang
|
||||
*/
|
||||
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/kernel_hardware_info.h"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Sm100MlaIndividualTileScheduler {
|
||||
|
||||
struct Params {
|
||||
dim3 grid;
|
||||
};
|
||||
|
||||
bool valid_ = true;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Sm100MlaIndividualTileScheduler(Params const&) {}
|
||||
|
||||
template<class ProblemShape, class ClusterShape>
|
||||
static Params to_underlying_arguments(
|
||||
ProblemShape const& problem_shape, KernelHardwareInfo hw_info,
|
||||
ClusterShape const& cluster_shape, int const& split_kv) {
|
||||
using namespace cute;
|
||||
dim3 grid(get<0>(cluster_shape), get<3>(problem_shape) /* Batch */, split_kv /*Maximum Split KV*/);
|
||||
return Params{ grid };
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
return params.grid;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool is_valid() {
|
||||
return valid_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
auto get_block_coord() {
|
||||
using namespace cute;
|
||||
return make_coord(blockIdx.x, _0{}, blockIdx.y, blockIdx.z);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Sm100MlaIndividualTileScheduler& operator++() {
|
||||
valid_ = false;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Sm100MlaPersistentTileScheduler {
|
||||
|
||||
struct Params {
|
||||
int num_blocks;
|
||||
FastDivmod divmod_m_block;
|
||||
FastDivmod divmod_b;
|
||||
FastDivmod divmod_split_kv;
|
||||
KernelHardwareInfo hw_info;
|
||||
};
|
||||
|
||||
int block_idx = 0;
|
||||
Params params;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Sm100MlaPersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {}
|
||||
|
||||
template<class ProblemShape, class ClusterShape>
|
||||
static Params to_underlying_arguments(
|
||||
ProblemShape const& problem_shape, KernelHardwareInfo hw_info,
|
||||
ClusterShape const& cluster_shape, int const& split_kv) {
|
||||
using namespace cute;
|
||||
// Get SM count if needed, otherwise use user supplied SM count
|
||||
int sm_count = hw_info.sm_count;
|
||||
if (sm_count <= 1 || sm_count % size<0>(cluster_shape) != 0) {
|
||||
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
|
||||
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
|
||||
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
|
||||
hw_info.sm_count = sm_count;
|
||||
|
||||
int num_m_blocks = size<0>(cluster_shape);
|
||||
int num_blocks = num_m_blocks * get<3>(problem_shape) /* Batch */;
|
||||
num_blocks *= split_kv; /* Maximum Split KV*/
|
||||
|
||||
return Params {
|
||||
num_blocks,
|
||||
{ num_m_blocks}, { get<3>(problem_shape) }, {split_kv},
|
||||
hw_info
|
||||
};
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1);
|
||||
return grid;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool is_valid() {
|
||||
return block_idx < params.num_blocks;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
auto get_block_coord() {
|
||||
using namespace cute;
|
||||
int block_decode = block_idx;
|
||||
int m_block, bidb, n_split_kv;
|
||||
params.divmod_m_block(block_decode, m_block, block_decode);
|
||||
params.divmod_b(block_decode, bidb, block_decode);
|
||||
params.divmod_split_kv(block_decode, n_split_kv, block_decode);
|
||||
return make_coord(m_block, _0{}, bidb, n_split_kv);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Sm100MlaPersistentTileScheduler& operator++() {
|
||||
block_idx += gridDim.x;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
||||
283
csrc/attention/mla/sm100_cutlass_mla_kernel.cu
Normal file
283
csrc/attention/mla/sm100_cutlass_mla_kernel.cu
Normal file
@ -0,0 +1,283 @@
|
||||
/*
|
||||
Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
/*
|
||||
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
|
||||
* by Alcanderian JieXin Liang
|
||||
*/
|
||||
#include "core/registration.h"
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/kernel_hardware_info.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass_sm100_mla/device/sm100_mla.hpp"
|
||||
#include "cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp"
|
||||
|
||||
// clang-format off
|
||||
#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040
|
||||
void sm100_cutlass_mla_decode(
|
||||
torch::Tensor const& out,
|
||||
torch::Tensor const& q_nope,
|
||||
torch::Tensor const& q_pe,
|
||||
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||
torch::Tensor const& seq_lens,
|
||||
torch::Tensor const& page_table,
|
||||
torch::Tensor const& workspace,
|
||||
int64_t num_kv_splits) {
|
||||
TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_decode");
|
||||
}
|
||||
int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) {
|
||||
TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_get_workspace_size");
|
||||
}
|
||||
#else
|
||||
|
||||
#define CUTLASS_CHECK(status) \
|
||||
{ \
|
||||
cutlass::Status error = status; \
|
||||
TORCH_CHECK(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \
|
||||
}
|
||||
|
||||
using namespace cute;
|
||||
using namespace cutlass::fmha::kernel;
|
||||
|
||||
template <bool v>
|
||||
struct IsPersistent {
|
||||
static const bool value = v;
|
||||
};
|
||||
|
||||
template <typename T, bool IsPaged128, typename PersistenceOption = IsPersistent<true>>
|
||||
struct MlaSm100 {
|
||||
using Element = T;
|
||||
using ElementAcc = float;
|
||||
using ElementOut = T;
|
||||
|
||||
using TileShape = Shape<_128, _128, Shape<_512, _64>>;
|
||||
using TileShapeH = cute::tuple_element_t<0, TileShape>;
|
||||
using TileShapeD = cute::tuple_element_t<2, TileShape>;
|
||||
|
||||
// H K (D_latent D_rope) B
|
||||
using ProblemShape = cute::tuple<TileShapeH, int, TileShapeD, int>;
|
||||
|
||||
using StrideQ = cute::tuple<int64_t, _1, int64_t>; // H D B
|
||||
using StrideK = cute::tuple<int64_t, _1, int64_t>; // K D B
|
||||
using StrideO = StrideK; // H D B
|
||||
using StrideLSE = cute::tuple<_1, int>; // H B
|
||||
|
||||
using TileScheduler =
|
||||
std::conditional_t<PersistenceOption::value, Sm100MlaPersistentTileScheduler, Sm100MlaIndividualTileScheduler>;
|
||||
|
||||
using FmhaKernel = cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized<
|
||||
TileShape,
|
||||
Element,
|
||||
ElementAcc,
|
||||
ElementOut,
|
||||
ElementAcc,
|
||||
TileScheduler,
|
||||
/*kIsCpAsync=*/!IsPaged128>;
|
||||
using Fmha = cutlass::fmha::device::MLA<FmhaKernel>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
typename T::Fmha::Arguments args_from_options(
|
||||
at::Tensor const& out,
|
||||
at::Tensor const& q_nope,
|
||||
at::Tensor const& q_pe,
|
||||
at::Tensor const& kv_c_and_k_pe_cache,
|
||||
at::Tensor const& seq_lens,
|
||||
at::Tensor const& page_table,
|
||||
double sm_scale,
|
||||
int64_t num_kv_splits) {
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = q_nope.device().index();
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
|
||||
int batches = q_nope.sizes()[0];
|
||||
int page_count_per_seq = page_table.sizes()[1];
|
||||
int page_count_total = kv_c_and_k_pe_cache.sizes()[0];
|
||||
int page_size = kv_c_and_k_pe_cache.sizes()[1];
|
||||
int max_seq_len = page_size * page_count_per_seq;
|
||||
using TileShapeH = typename T::TileShapeH;
|
||||
using TileShapeD = typename T::TileShapeD;
|
||||
auto problem_shape = cute::make_tuple(TileShapeH{}, max_seq_len, TileShapeD{}, batches);
|
||||
|
||||
auto [H, K, D, B] = problem_shape;
|
||||
auto [D_latent, D_rope] = D;
|
||||
|
||||
float scale = float(sm_scale);
|
||||
|
||||
using StrideQ = typename T::StrideQ;
|
||||
using StrideK = typename T::StrideK;
|
||||
using StrideO = typename T::StrideO;
|
||||
using StrideLSE = typename T::StrideLSE;
|
||||
|
||||
StrideQ stride_Q_nope = cute::make_tuple(
|
||||
static_cast<int64_t>(q_nope.stride(1)), _1{}, static_cast<int64_t>(q_nope.stride(0)));
|
||||
StrideQ stride_Q_pe = cute::make_tuple(
|
||||
static_cast<int64_t>(q_pe.stride(1)), _1{}, static_cast<int64_t>(q_pe.stride(0)));
|
||||
|
||||
StrideK stride_C = cute::make_tuple(
|
||||
static_cast<int64_t>(0 + D_latent + D_rope), _1{}, static_cast<int64_t>(page_size * (D_latent + D_rope)));
|
||||
StrideLSE stride_PT = cute::make_stride(_1{}, page_count_per_seq);
|
||||
StrideLSE stride_LSE = cute::make_tuple(_1{}, 0 + H);
|
||||
StrideO stride_O = cute::make_tuple(static_cast<int64_t>(0 + D_latent), _1{}, static_cast<int64_t>(0 + H * D_latent));
|
||||
|
||||
using Element = typename T::Element;
|
||||
using ElementOut = typename T::ElementOut;
|
||||
using ElementAcc = typename T::ElementAcc;
|
||||
auto Q_nope_ptr = static_cast<Element*>(q_nope.data_ptr());
|
||||
auto Q_pe_ptr = static_cast<Element*>(q_pe.data_ptr());
|
||||
auto C_ptr = static_cast<Element*>(kv_c_and_k_pe_cache.data_ptr());
|
||||
typename T::Fmha::Arguments arguments{
|
||||
problem_shape,
|
||||
{scale,
|
||||
Q_nope_ptr,
|
||||
stride_Q_nope,
|
||||
Q_pe_ptr,
|
||||
stride_Q_pe,
|
||||
C_ptr,
|
||||
stride_C,
|
||||
C_ptr + D_latent,
|
||||
stride_C,
|
||||
static_cast<int*>(seq_lens.data_ptr()),
|
||||
static_cast<int*>(page_table.data_ptr()),
|
||||
stride_PT,
|
||||
page_count_total,
|
||||
page_size},
|
||||
{static_cast<ElementOut*>(out.data_ptr()), stride_O, static_cast<ElementAcc*>(nullptr), stride_LSE},
|
||||
hw_info,
|
||||
// TODO(trevor-m): Change split_kv back to -1 when
|
||||
// https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will
|
||||
// perform worse with larger context length and smaller batch sizes.
|
||||
num_kv_splits, // split_kv
|
||||
nullptr, // is_var_split_kv
|
||||
};
|
||||
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
|
||||
// split_kv automatically based on batch size and sequence length to balance
|
||||
// workload across available SMs. Consider using var_split_kv for manual
|
||||
// control if needed.
|
||||
T::Fmha::set_split_kv(arguments);
|
||||
return arguments;
|
||||
}
|
||||
|
||||
template <typename Element, bool IsPaged128, typename PersistenceOption>
|
||||
void runMla(
|
||||
at::Tensor const& out,
|
||||
at::Tensor const& q_nope,
|
||||
at::Tensor const& q_pe,
|
||||
at::Tensor const& kv_c_and_k_pe_cache,
|
||||
at::Tensor const& seq_lens,
|
||||
at::Tensor const& page_table,
|
||||
at::Tensor const& workspace,
|
||||
double sm_scale,
|
||||
int64_t num_kv_splits,
|
||||
cudaStream_t stream) {
|
||||
using MlaSm100Type = MlaSm100<Element, IsPaged128, PersistenceOption>;
|
||||
typename MlaSm100Type::Fmha fmha;
|
||||
auto arguments = args_from_options<MlaSm100Type>(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, sm_scale, num_kv_splits);
|
||||
|
||||
CUTLASS_CHECK(fmha.can_implement(arguments));
|
||||
|
||||
CUTLASS_CHECK(fmha.initialize(arguments, workspace.data_ptr(), stream));
|
||||
|
||||
CUTLASS_CHECK(fmha.run(arguments, workspace.data_ptr(), stream));
|
||||
}
|
||||
|
||||
#define DISPATCH_BOOL(expr, const_expr, ...) \
|
||||
[&]() -> bool { \
|
||||
if (expr) { \
|
||||
constexpr bool const_expr = true; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
constexpr bool const_expr = false; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
||||
|
||||
void sm100_cutlass_mla_decode(
|
||||
torch::Tensor const& out,
|
||||
torch::Tensor const& q_nope,
|
||||
torch::Tensor const& q_pe,
|
||||
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||
torch::Tensor const& seq_lens,
|
||||
torch::Tensor const& page_table,
|
||||
torch::Tensor const& workspace,
|
||||
double sm_scale,
|
||||
int64_t num_kv_splits) {
|
||||
auto in_dtype = q_nope.dtype();
|
||||
at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()};
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(q_nope.get_device());
|
||||
const int page_size = kv_c_and_k_pe_cache.sizes()[1];
|
||||
|
||||
// NOTE(alcanderian): IsPersistent has bug with manual split_kv.
|
||||
// Kernel will hang if batch is too large with large num_kv_splits. (for example bs=8, num_kv_splits=8)
|
||||
// Maybe per batch split kv will fix this.
|
||||
DISPATCH_BOOL(page_size == 128, IsPaged128, [&] {
|
||||
DISPATCH_BOOL(num_kv_splits <= 1, NotManualSplitKV, [&] {
|
||||
if (in_dtype == at::ScalarType::Half) {
|
||||
runMla<cutlass::half_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
|
||||
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
|
||||
} else if (in_dtype == at::ScalarType::BFloat16) {
|
||||
runMla<cutlass::bfloat16_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
|
||||
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
|
||||
} else if (in_dtype == at::ScalarType::Float8_e4m3fn) {
|
||||
runMla<cutlass::float_e4m3_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
|
||||
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported input data type of MLA");
|
||||
}
|
||||
return true;
|
||||
});
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) {
|
||||
// Workspace size depends on ElementAcc and ElementLSE (same as ElementAcc)
|
||||
// which are float, so Element type here doesn't matter.
|
||||
using MlaSm100Type = MlaSm100<cutlass::half_t, true>;
|
||||
|
||||
// Get split kv. Requires problem shape and sm_count only.
|
||||
typename MlaSm100Type::Fmha::Arguments arguments;
|
||||
using TileShapeH = typename MlaSm100Type::TileShapeH;
|
||||
using TileShapeD = typename MlaSm100Type::TileShapeD;
|
||||
arguments.problem_shape =
|
||||
cute::make_tuple(TileShapeH{}, static_cast<int>(max_seq_len), TileShapeD{}, static_cast<int>(num_batches));
|
||||
// Assumes device 0 when getting sm_count.
|
||||
arguments.hw_info.sm_count =
|
||||
sm_count <= 0 ? cutlass::KernelHardwareInfo::query_device_multiprocessor_count(/*device_id=*/0) : sm_count;
|
||||
arguments.split_kv = num_kv_splits;
|
||||
MlaSm100Type::Fmha::set_split_kv(arguments);
|
||||
|
||||
return MlaSm100Type::Fmha::get_workspace_size(arguments);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("sm100_cutlass_mla_decode", &sm100_cutlass_mla_decode);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CatchAll, m) {
|
||||
m.impl("sm100_cutlass_mla_get_workspace_size", &sm100_cutlass_mla_get_workspace_size);
|
||||
}
|
||||
|
||||
// clang-format on
|
||||
@ -16,14 +16,8 @@
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "attention_kernels.cuh"
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define WARP_SIZE 32
|
||||
#else
|
||||
#define WARP_SIZE warpSize
|
||||
#endif
|
||||
#include "../cuda_compat.h"
|
||||
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
@ -80,7 +74,7 @@ void paged_attention_v1_launcher(
|
||||
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
|
||||
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
|
||||
|
||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
const int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
int padded_max_seq_len =
|
||||
DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
|
||||
int logits_size = padded_max_seq_len * sizeof(float);
|
||||
@ -187,7 +181,6 @@ void paged_attention_v1(
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE)
|
||||
}
|
||||
|
||||
#undef WARP_SIZE
|
||||
#undef MAX
|
||||
#undef MIN
|
||||
#undef DIVIDE_ROUND_UP
|
||||
|
||||
@ -16,14 +16,8 @@
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "attention_kernels.cuh"
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define WARP_SIZE 32
|
||||
#else
|
||||
#define WARP_SIZE warpSize
|
||||
#endif
|
||||
#include "../cuda_compat.h"
|
||||
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
@ -84,7 +78,7 @@ void paged_attention_v2_launcher(
|
||||
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
|
||||
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
|
||||
|
||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
const int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
|
||||
int logits_size = PARTITION_SIZE * sizeof(float);
|
||||
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
|
||||
@ -197,7 +191,6 @@ void paged_attention_v2(
|
||||
CALL_V2_LAUNCHER_BLOCK_SIZE)
|
||||
}
|
||||
|
||||
#undef WARP_SIZE
|
||||
#undef MAX
|
||||
#undef MIN
|
||||
#undef DIVIDE_ROUND_UP
|
||||
|
||||
@ -33,6 +33,8 @@ namespace vec_op {
|
||||
#endif
|
||||
|
||||
#define FORCE_INLINE __attribute__((always_inline)) inline
|
||||
// Number of elements in single ASIMD vector of given Datatype
|
||||
#define NUM_ELEMENTS_REG(vec) (sizeof(vec) / sizeof(vec[0]))
|
||||
|
||||
namespace {
|
||||
template <typename T, T... indexes, typename F>
|
||||
@ -86,8 +88,8 @@ struct FP16Vec16 : public Vec<FP16Vec16> {
|
||||
}
|
||||
|
||||
void save(void* ptr, const int elem_num) const {
|
||||
int full_blocks = elem_num / 8;
|
||||
int remainder = elem_num % 8;
|
||||
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
|
||||
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
|
||||
|
||||
if (full_blocks > 0) {
|
||||
vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]);
|
||||
@ -197,6 +199,25 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||
vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[2]), v.val[3])}) {};
|
||||
|
||||
void save(void* ptr) const { *reinterpret_cast<bfloat16x8x2_t*>(ptr) = reg; };
|
||||
void save(void* ptr, const int elem_num) const {
|
||||
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
|
||||
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
|
||||
for (int i = 0; i < full_blocks; i++)
|
||||
vst1q_bf16(
|
||||
reinterpret_cast<__bf16*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i,
|
||||
reg.val[i]);
|
||||
if (remainder > 0) {
|
||||
bfloat16x8_t temp = reg.val[full_blocks];
|
||||
bfloat16_t* base = reinterpret_cast<bfloat16_t*>(ptr) + full_blocks * 8;
|
||||
if (remainder > 0) base[0] = vgetq_lane_bf16(temp, 0);
|
||||
if (remainder > 1) base[1] = vgetq_lane_bf16(temp, 1);
|
||||
if (remainder > 2) base[2] = vgetq_lane_bf16(temp, 2);
|
||||
if (remainder > 3) base[3] = vgetq_lane_bf16(temp, 3);
|
||||
if (remainder > 4) base[4] = vgetq_lane_bf16(temp, 4);
|
||||
if (remainder > 5) base[5] = vgetq_lane_bf16(temp, 5);
|
||||
if (remainder > 6) base[6] = vgetq_lane_bf16(temp, 6);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
@ -213,6 +234,25 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
: reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {};
|
||||
|
||||
void save(void* ptr) const { *reinterpret_cast<bfloat16x8x4_t*>(ptr) = reg; };
|
||||
void save(void* ptr, const int elem_num) const {
|
||||
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
|
||||
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
|
||||
for (int i = 0; i < full_blocks; i++)
|
||||
vst1q_bf16(
|
||||
reinterpret_cast<__bf16*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i,
|
||||
reg.val[i]);
|
||||
if (remainder > 0) {
|
||||
bfloat16x8_t temp = reg.val[full_blocks];
|
||||
bfloat16_t* base = reinterpret_cast<bfloat16_t*>(ptr) + full_blocks * 8;
|
||||
base[0] = vgetq_lane_bf16(temp, 0);
|
||||
if (remainder > 1) base[1] = vgetq_lane_bf16(temp, 1);
|
||||
if (remainder > 2) base[2] = vgetq_lane_bf16(temp, 2);
|
||||
if (remainder > 3) base[3] = vgetq_lane_bf16(temp, 3);
|
||||
if (remainder > 4) base[4] = vgetq_lane_bf16(temp, 4);
|
||||
if (remainder > 5) base[5] = vgetq_lane_bf16(temp, 5);
|
||||
if (remainder > 6) base[6] = vgetq_lane_bf16(temp, 6);
|
||||
}
|
||||
};
|
||||
};
|
||||
#endif
|
||||
|
||||
@ -372,6 +412,48 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
}
|
||||
};
|
||||
|
||||
struct INT32Vec16 : public Vec<INT32Vec16> {
|
||||
constexpr static int VEC_ELEM_NUM = 16;
|
||||
union AliasReg {
|
||||
int32x4x4_t reg;
|
||||
int32_t values[VEC_ELEM_NUM];
|
||||
};
|
||||
int32x4x4_t reg;
|
||||
|
||||
explicit INT32Vec16(const void* ptr) {
|
||||
reg.val[0] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr));
|
||||
reg.val[1] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr) + 4);
|
||||
reg.val[2] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr) + 8);
|
||||
reg.val[3] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr) + 12);
|
||||
}
|
||||
|
||||
void save(int32_t* ptr) const {
|
||||
vst1q_s32(ptr, reg.val[0]);
|
||||
vst1q_s32(ptr + 4, reg.val[1]);
|
||||
vst1q_s32(ptr + 8, reg.val[2]);
|
||||
vst1q_s32(ptr + 12, reg.val[3]);
|
||||
};
|
||||
|
||||
void save(int32_t* ptr, const int elem_num) const {
|
||||
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
|
||||
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
|
||||
|
||||
for (int i = 0; i < full_blocks; i++)
|
||||
vst1q_s32(
|
||||
reinterpret_cast<__int32_t*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i,
|
||||
reg.val[i]);
|
||||
|
||||
if (remainder > 0) {
|
||||
int32x4_t temp = reg.val[full_blocks];
|
||||
int32_t* base = reinterpret_cast<int32_t*>(ptr) + full_blocks * 4;
|
||||
if (remainder > 0) base[0] = vgetq_lane_s32(temp, 0);
|
||||
if (remainder > 1) base[1] = vgetq_lane_s32(temp, 1);
|
||||
if (remainder > 2) base[2] = vgetq_lane_s32(temp, 2);
|
||||
if (remainder > 3) base[3] = vgetq_lane_s32(temp, 3);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
constexpr static int VEC_ELEM_NUM = 16;
|
||||
union AliasReg {
|
||||
@ -434,7 +516,12 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
reg.val[2] = vcvt_f32_f16(vget_low_f16(v.reg.val[1]));
|
||||
reg.val[3] = vcvt_f32_f16(vget_high_f16(v.reg.val[1]));
|
||||
};
|
||||
|
||||
explicit FP32Vec16(const INT32Vec16& v) {
|
||||
reg.val[0] = vcvtq_f32_s32(v.reg.val[0]);
|
||||
reg.val[1] = vcvtq_f32_s32(v.reg.val[1]);
|
||||
reg.val[2] = vcvtq_f32_s32(v.reg.val[2]);
|
||||
reg.val[3] = vcvtq_f32_s32(v.reg.val[3]);
|
||||
};
|
||||
FP32Vec16 operator+(const FP32Vec16& b) const {
|
||||
return FP32Vec16(float32x4x4_t({vaddq_f32(reg.val[0], b.reg.val[0]),
|
||||
vaddq_f32(reg.val[1], b.reg.val[1]),
|
||||
@ -463,6 +550,85 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
vdivq_f32(reg.val[3], b.reg.val[3])}));
|
||||
};
|
||||
|
||||
FP32Vec16 clamp(const FP32Vec16& min, const FP32Vec16& max) const {
|
||||
return FP32Vec16(float32x4x4_t(
|
||||
{vminq_f32(max.reg.val[0], vmaxq_f32(min.reg.val[0], reg.val[0])),
|
||||
vminq_f32(max.reg.val[1], vmaxq_f32(min.reg.val[1], reg.val[1])),
|
||||
vminq_f32(max.reg.val[2], vmaxq_f32(min.reg.val[2], reg.val[2])),
|
||||
vminq_f32(max.reg.val[3], vmaxq_f32(min.reg.val[3], reg.val[3]))}));
|
||||
};
|
||||
|
||||
FP32Vec16 max(const FP32Vec16& b) const {
|
||||
return FP32Vec16(float32x4x4_t({vmaxq_f32(b.reg.val[0], reg.val[0]),
|
||||
vmaxq_f32(b.reg.val[1], reg.val[1]),
|
||||
vmaxq_f32(b.reg.val[2], reg.val[2]),
|
||||
vmaxq_f32(b.reg.val[3], reg.val[3])}));
|
||||
};
|
||||
|
||||
FP32Vec16 max(const FP32Vec16& b, const int elem_num) const {
|
||||
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
|
||||
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
|
||||
float32x4x4_t temp;
|
||||
|
||||
for (int i = 0; i < full_blocks; i++)
|
||||
temp.val[i] = vmaxq_f32(b.reg.val[i], reg.val[i]);
|
||||
|
||||
if (remainder > 0) {
|
||||
float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 0),
|
||||
vgetq_lane_f32(b.reg.val[full_blocks], 0));
|
||||
temp.val[full_blocks] = vsetq_lane_f32(max_v, temp.val[full_blocks], 0);
|
||||
}
|
||||
if (remainder > 1) {
|
||||
float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 1),
|
||||
vgetq_lane_f32(b.reg.val[full_blocks], 1));
|
||||
temp.val[full_blocks] = vsetq_lane_f32(max_v, temp.val[full_blocks], 1);
|
||||
}
|
||||
if (remainder > 2) {
|
||||
float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 2),
|
||||
vgetq_lane_f32(b.reg.val[full_blocks], 2));
|
||||
temp.val[full_blocks] = vsetq_lane_f32(max_v, temp.val[full_blocks], 2);
|
||||
}
|
||||
return FP32Vec16(temp);
|
||||
};
|
||||
|
||||
FP32Vec16 min(const FP32Vec16& b) const {
|
||||
return FP32Vec16(float32x4x4_t({
|
||||
vminq_f32(b.reg.val[0], reg.val[0]),
|
||||
vminq_f32(b.reg.val[1], reg.val[1]),
|
||||
vminq_f32(b.reg.val[2], reg.val[2]),
|
||||
vminq_f32(b.reg.val[3], reg.val[3]),
|
||||
}));
|
||||
};
|
||||
FP32Vec16 min(const FP32Vec16& b, const int elem_num) const {
|
||||
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
|
||||
const int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
|
||||
float32x4x4_t temp;
|
||||
for (int i = 0; i < full_blocks; i++)
|
||||
temp.val[i] = vminq_f32(b.reg.val[i], reg.val[i]);
|
||||
|
||||
if (remainder > 0) {
|
||||
float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 0),
|
||||
vgetq_lane_f32(b.reg.val[full_blocks], 0));
|
||||
temp.val[full_blocks] = vsetq_lane_f32(min_v, temp.val[full_blocks], 0);
|
||||
}
|
||||
if (remainder > 1) {
|
||||
float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 1),
|
||||
vgetq_lane_f32(b.reg.val[full_blocks], 1));
|
||||
temp.val[full_blocks] = vsetq_lane_f32(min_v, temp.val[full_blocks], 1);
|
||||
}
|
||||
if (remainder > 2) {
|
||||
float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 2),
|
||||
vgetq_lane_f32(b.reg.val[full_blocks], 2));
|
||||
temp.val[full_blocks] = vsetq_lane_f32(min_v, temp.val[full_blocks], 2);
|
||||
}
|
||||
|
||||
return FP32Vec16(temp);
|
||||
};
|
||||
FP32Vec16 abs() const {
|
||||
return FP32Vec16(
|
||||
float32x4x4_t({vabsq_f32(reg.val[0]), vabsq_f32(reg.val[1]),
|
||||
vabsq_f32(reg.val[2]), vabsq_f32(reg.val[3])}));
|
||||
}
|
||||
float reduce_sum() const {
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
@ -473,6 +639,24 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
return answer;
|
||||
};
|
||||
|
||||
float reduce_max() const {
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
float max_v = std::numeric_limits<float>::lowest();
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&max_v, &ar](int i) { max_v = std::max(max_v, ar.values[i]); });
|
||||
return max_v;
|
||||
}
|
||||
|
||||
float reduce_min() const {
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
float min_v = std::numeric_limits<float>::max();
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&min_v, &ar](int i) { min_v = std::min(min_v, ar.values[i]); });
|
||||
return min_v;
|
||||
}
|
||||
|
||||
template <int group_size>
|
||||
float reduce_sub_sum(int idx) {
|
||||
static_assert(VEC_ELEM_NUM % group_size == 0);
|
||||
@ -493,6 +677,83 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
vst1q_f32(ptr + 8, reg.val[2]);
|
||||
vst1q_f32(ptr + 12, reg.val[3]);
|
||||
};
|
||||
|
||||
void save(float* ptr, const int elem_num) const {
|
||||
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
|
||||
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
|
||||
|
||||
for (int i = 0; i < full_blocks; i++)
|
||||
vst1q_f32(
|
||||
reinterpret_cast<float32_t*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i,
|
||||
reg.val[i]);
|
||||
|
||||
if (remainder > 0) {
|
||||
float32x4_t temp = reg.val[full_blocks];
|
||||
float* base = reinterpret_cast<float32_t*>(ptr) +
|
||||
full_blocks * NUM_ELEMENTS_REG(reg.val[0]);
|
||||
if (remainder > 0) base[0] = vgetq_lane_f32(temp, 0);
|
||||
if (remainder > 1) base[1] = vgetq_lane_f32(temp, 1);
|
||||
if (remainder > 2) base[2] = vgetq_lane_f32(temp, 2);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct INT8Vec16 : public Vec<INT8Vec16> {
|
||||
constexpr static int VEC_ELEM_NUM = 16;
|
||||
union AliasReg {
|
||||
int8x16_t reg;
|
||||
int8_t values[VEC_ELEM_NUM];
|
||||
};
|
||||
int8x16_t reg;
|
||||
|
||||
explicit INT8Vec16(const FP32Vec16& vec) {
|
||||
// Convert each 128-bit float32 vector to int32
|
||||
int32x4_t part0 =
|
||||
vcvtq_s32_f32(vec.reg.val[0]); // Convert first 128-bit block
|
||||
int32x4_t part1 =
|
||||
vcvtq_s32_f32(vec.reg.val[1]); // Convert second 128-bit block
|
||||
int32x4_t part2 =
|
||||
vcvtq_s32_f32(vec.reg.val[2]); // Convert third 128-bit block
|
||||
int32x4_t part3 =
|
||||
vcvtq_s32_f32(vec.reg.val[3]); // Convert fourth 128-bit block
|
||||
|
||||
// Narrow each 32-bit vector to 8 bits and combine
|
||||
int8x8_t lower =
|
||||
vqmovn_s16(vcombine_s16(vqmovn_s32(part0), vqmovn_s32(part1)));
|
||||
int8x8_t upper =
|
||||
vqmovn_s16(vcombine_s16(vqmovn_s32(part2), vqmovn_s32(part3)));
|
||||
reg = vcombine_s8(lower, upper); // Combine to form a single 128-bit vector
|
||||
}
|
||||
|
||||
void save(int8_t* ptr) const { vst1q_s8(ptr, reg); };
|
||||
|
||||
void save(int8_t* ptr, const int elem_num) const {
|
||||
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg);
|
||||
int remainder = elem_num % NUM_ELEMENTS_REG(reg);
|
||||
|
||||
for (int i = 0; i < full_blocks; i++)
|
||||
vst1q_s8(reinterpret_cast<int8_t*>(ptr) + NUM_ELEMENTS_REG(reg) * i, reg);
|
||||
if (remainder > 0) {
|
||||
int8x16_t temp = reg;
|
||||
int8_t* base =
|
||||
reinterpret_cast<int8_t*>(ptr) + full_blocks * NUM_ELEMENTS_REG(reg);
|
||||
if (remainder > 0) base[0] = vgetq_lane_s8(temp, 0);
|
||||
if (remainder > 1) base[1] = vgetq_lane_s8(temp, 1);
|
||||
if (remainder > 2) base[2] = vgetq_lane_s8(temp, 2);
|
||||
if (remainder > 3) base[3] = vgetq_lane_s8(temp, 3);
|
||||
if (remainder > 4) base[4] = vgetq_lane_s8(temp, 4);
|
||||
if (remainder > 5) base[5] = vgetq_lane_s8(temp, 5);
|
||||
if (remainder > 6) base[6] = vgetq_lane_s8(temp, 6);
|
||||
if (remainder > 7) base[7] = vgetq_lane_s8(temp, 7);
|
||||
if (remainder > 8) base[8] = vgetq_lane_s8(temp, 8);
|
||||
if (remainder > 9) base[9] = vgetq_lane_s8(temp, 9);
|
||||
if (remainder > 10) base[10] = vgetq_lane_s8(temp, 10);
|
||||
if (remainder > 11) base[11] = vgetq_lane_s8(temp, 11);
|
||||
if (remainder > 12) base[12] = vgetq_lane_s8(temp, 12);
|
||||
if (remainder > 13) base[13] = vgetq_lane_s8(temp, 13);
|
||||
if (remainder > 14) base[14] = vgetq_lane_s8(temp, 14);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
||||
@ -57,6 +57,7 @@ class DNNLPrimitiveHelper {
|
||||
// Note: Due to the limitation of oneDNN
|
||||
// (https://github.com/oneapi-src/oneDNN/issues/1636), the quantized bias is
|
||||
// not supported.
|
||||
|
||||
template <typename OutputT, typename BiasT>
|
||||
static void gemm_s8s8_jit(const int8_t* a, const int8_t* b, OutputT* c,
|
||||
const BiasT* bias, dnnl_dim_t M, dnnl_dim_t N,
|
||||
@ -90,6 +91,27 @@ class DNNLPrimitiveHelper {
|
||||
}
|
||||
|
||||
dnnl::matmul::primitive_desc matmul_pd;
|
||||
// Create memory descriptors with format_tag::any for the primitive. This
|
||||
// enables the matmul primitive to choose memory layouts for an
|
||||
// optimized primitive implementation, and these layouts may differ from the
|
||||
// ones provided by the user.
|
||||
#ifdef __aarch64__
|
||||
auto mat_src_md = dnnl::memory::desc({M, K}, dnnl::memory::data_type::s8,
|
||||
dnnl::memory::format_tag::any);
|
||||
auto mat_weights_md = dnnl::memory::desc(
|
||||
{K, N}, dnnl::memory::data_type::s8, dnnl::memory::format_tag::any);
|
||||
auto mat_dst_md =
|
||||
dnnl::memory::desc({M, N}, OutputType, dnnl::memory::format_tag::any);
|
||||
if (bias) {
|
||||
dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1});
|
||||
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), mat_src_md,
|
||||
mat_weights_md, bias_md,
|
||||
mat_dst_md, attr);
|
||||
} else {
|
||||
matmul_pd = dnnl::matmul::primitive_desc(
|
||||
default_engine(), mat_src_md, mat_weights_md, mat_dst_md, attr);
|
||||
}
|
||||
#else
|
||||
if (bias) {
|
||||
dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1});
|
||||
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
|
||||
@ -98,6 +120,7 @@ class DNNLPrimitiveHelper {
|
||||
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
|
||||
c_md, attr);
|
||||
}
|
||||
#endif
|
||||
dnnl::matmul matmul(matmul_pd);
|
||||
|
||||
auto& engine = default_engine();
|
||||
@ -111,24 +134,34 @@ class DNNLPrimitiveHelper {
|
||||
(void*)b_scales);
|
||||
|
||||
auto& stream = default_stream();
|
||||
|
||||
auto mat_src_mem = a_m;
|
||||
auto mat_weights_mem = b_m;
|
||||
auto mat_dst_mem = c_m;
|
||||
#ifdef __aarch64__
|
||||
if (matmul_pd.weights_desc() != b_m.get_desc()) {
|
||||
mat_weights_mem = dnnl::memory(matmul_pd.weights_desc(), engine);
|
||||
dnnl::reorder(b_m, mat_weights_mem).execute(stream, b_m, mat_weights_mem);
|
||||
}
|
||||
#endif
|
||||
if constexpr (InputNoScale) {
|
||||
if (bias) {
|
||||
dnnl::memory::desc bias_md({N}, BiasType, {1});
|
||||
dnnl::memory bias_m(bias_md, engine, (void*)bias);
|
||||
matmul.execute(
|
||||
stream, {
|
||||
{DNNL_ARG_SRC, a_m},
|
||||
{DNNL_ARG_WEIGHTS, b_m},
|
||||
{DNNL_ARG_SRC, mat_src_mem},
|
||||
{DNNL_ARG_WEIGHTS, mat_weights_mem},
|
||||
{DNNL_ARG_BIAS, bias_m},
|
||||
{DNNL_ARG_DST, c_m},
|
||||
{DNNL_ARG_DST, mat_dst_mem},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||
});
|
||||
} else {
|
||||
matmul.execute(
|
||||
stream, {
|
||||
{DNNL_ARG_SRC, a_m},
|
||||
{DNNL_ARG_WEIGHTS, b_m},
|
||||
{DNNL_ARG_DST, c_m},
|
||||
{DNNL_ARG_SRC, mat_src_mem},
|
||||
{DNNL_ARG_WEIGHTS, mat_weights_mem},
|
||||
{DNNL_ARG_DST, mat_dst_mem},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||
});
|
||||
}
|
||||
@ -138,19 +171,19 @@ class DNNLPrimitiveHelper {
|
||||
dnnl::memory bias_m(bias_md, engine, (void*)bias);
|
||||
matmul.execute(
|
||||
stream, {
|
||||
{DNNL_ARG_SRC, a_m},
|
||||
{DNNL_ARG_WEIGHTS, b_m},
|
||||
{DNNL_ARG_SRC, mat_src_mem},
|
||||
{DNNL_ARG_WEIGHTS, mat_weights_mem},
|
||||
{DNNL_ARG_BIAS, bias_m},
|
||||
{DNNL_ARG_DST, c_m},
|
||||
{DNNL_ARG_DST, mat_dst_mem},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||
});
|
||||
} else {
|
||||
matmul.execute(
|
||||
stream, {
|
||||
{DNNL_ARG_SRC, a_m},
|
||||
{DNNL_ARG_WEIGHTS, b_m},
|
||||
{DNNL_ARG_DST, c_m},
|
||||
{DNNL_ARG_SRC, mat_src_mem},
|
||||
{DNNL_ARG_WEIGHTS, mat_weights_mem},
|
||||
{DNNL_ARG_DST, mat_dst_mem},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||
});
|
||||
@ -170,5 +203,4 @@ class DNNLPrimitiveHelper {
|
||||
return stream;
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
@ -36,7 +36,7 @@ struct KernelVecType<c10::Half> {
|
||||
using cvt_vec_type = vec_op::FP32Vec16;
|
||||
};
|
||||
|
||||
#ifdef __AVX512F__
|
||||
#if defined(__AVX512F__) || defined(__aarch64__)
|
||||
template <bool AZP, typename scalar_t>
|
||||
void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
const float* scale, const int32_t* azp,
|
||||
@ -598,8 +598,9 @@ void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
const float* scale, const int32_t* azp,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
TORCH_CHECK(
|
||||
false, "static_scaled_int8_quant_impl requires AVX512/powerpc64 support.")
|
||||
TORCH_CHECK(false,
|
||||
"static_scaled_int8_quant_impl requires AVX512/powerpc64/AArch64 "
|
||||
"support.")
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
@ -607,9 +608,9 @@ void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
float* scale, int32_t* azp,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"dynamic_scaled_int8_quant_impl requires AVX512/powerpc64 support.")
|
||||
TORCH_CHECK(false,
|
||||
"dynamic_scaled_int8_quant_impl requires "
|
||||
"AVX512/powerpc64/AArch64 support.")
|
||||
}
|
||||
|
||||
template <bool PerChannel, typename scalar_t>
|
||||
@ -617,7 +618,8 @@ void static_quant_epilogue(const float* input, scalar_t* output,
|
||||
const float a_scale, const float* b_scale,
|
||||
const int32_t* azp_with_adj, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
TORCH_CHECK(false, "static_quant_epilogue requires AVX512/powerpc64 support.")
|
||||
TORCH_CHECK(
|
||||
false, "static_quant_epilogue requires AVX512/powerpc64/AArch64 support.")
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
@ -626,8 +628,9 @@ void dynamic_quant_epilogue(const float* input, scalar_t* output,
|
||||
const int32_t* azp, const int32_t* azp_with_adj,
|
||||
const scalar_t* bias, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
TORCH_CHECK(false,
|
||||
"dynamic_quant_epilogue requires AVX512/powerpc64 support.")
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"dynamic_quant_epilogue requires AVX512/powerpc64/AArch64 support.")
|
||||
}
|
||||
#endif
|
||||
} // namespace
|
||||
|
||||
238
csrc/cpu/sgl-kernels/common.h
Normal file
238
csrc/cpu/sgl-kernels/common.h
Normal file
@ -0,0 +1,238 @@
|
||||
// Adapted from
|
||||
// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/record_function.h>
|
||||
|
||||
// clang-format off
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
|
||||
// dispatch bool
|
||||
#define AT_DISPATCH_BOOL(BOOL_V, BOOL_NAME, ...) \
|
||||
[&] { \
|
||||
if (BOOL_V) { \
|
||||
constexpr bool BOOL_NAME = true; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
constexpr bool BOOL_NAME = false; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
||||
|
||||
// dispatch: bfloat16, float16, int8_t, fp8_e4m3
|
||||
#define CPU_DISPATCH_PACKED_TYPES(TYPE, ...) \
|
||||
[&] { \
|
||||
switch (TYPE) { \
|
||||
case at::ScalarType::BFloat16 : { \
|
||||
using packed_t = at::BFloat16; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
case at::ScalarType::Half: { \
|
||||
using packed_t = at::Half; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
case at::ScalarType::Char : { \
|
||||
using packed_t = int8_t; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
case at::ScalarType::Float8_e4m3fn : { \
|
||||
using packed_t = at::Float8_e4m3fn; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported floating data type.\n"); \
|
||||
} \
|
||||
}()
|
||||
|
||||
#define UNUSED(x) (void)(x)
|
||||
|
||||
#define CHECK_CPU(x) TORCH_CHECK(x.device().type() == at::kCPU, #x " must be a CPU tensor")
|
||||
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_LAST_DIM_CONTIGUOUS(x) \
|
||||
TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimension")
|
||||
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CPU(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \
|
||||
CHECK_CPU(x); \
|
||||
CHECK_LAST_DIM_CONTIGUOUS(x)
|
||||
|
||||
#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
|
||||
|
||||
#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
|
||||
|
||||
// parallel routines
|
||||
constexpr int GRAIN_SIZE = 1024;
|
||||
|
||||
template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
|
||||
inline T div_up(T x, T y) { return (x + y - 1) / y; }
|
||||
|
||||
template <typename T>
|
||||
inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) {
|
||||
#if 0
|
||||
// onednn partition pattern
|
||||
T& n_my = n_end;
|
||||
if (nth <= 1 || n == 0) {
|
||||
n_start = 0;
|
||||
n_my = n;
|
||||
} else {
|
||||
T n1 = div_up(n, nth);
|
||||
T n2 = n1 - 1;
|
||||
T T1 = n - n2 * nth;
|
||||
n_my = ith < T1 ? n1 : n2;
|
||||
n_start = ith <= T1 ? ith*n1 : T1 * n1 + (ith - T1) * n2;
|
||||
}
|
||||
n_end += n_start;
|
||||
#else
|
||||
// pytorch aten partition pattern
|
||||
T n_my = div_up(n, nth);
|
||||
n_start = ith * n_my;
|
||||
n_end = std::min(n_start + n_my, n);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename func_t>
|
||||
inline void parallel_for(int n, const func_t& f) {
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp parallel
|
||||
{
|
||||
int nth = omp_get_num_threads();
|
||||
int ith = omp_get_thread_num();
|
||||
int tbegin, tend;
|
||||
balance211(n, nth, ith, tbegin, tend);
|
||||
f(tbegin, tend);
|
||||
}
|
||||
#else
|
||||
f(0, n);
|
||||
#endif
|
||||
}
|
||||
|
||||
// for 1d parallel, use `actual_nth`
|
||||
// for 2d parallel, use even nths, e.g. 43->42
|
||||
int inline adjust_num_threads(int m) {
|
||||
int actual_nth = at::get_num_threads();
|
||||
if (m == 1) {
|
||||
return actual_nth;
|
||||
}
|
||||
return std::max(1, (actual_nth >> 1) * 2);
|
||||
}
|
||||
|
||||
template <typename func_t>
|
||||
inline void parallel_2d(int m, int n, const func_t& f) {
|
||||
|
||||
// make sure we have even num_threads
|
||||
int nth = adjust_num_threads(m);
|
||||
|
||||
// [NOTE] thread blocking:
|
||||
//
|
||||
// 1) prefer square block per thread
|
||||
// 2) use even number of CPU cores
|
||||
// 3) use all `num_threads` cores
|
||||
//
|
||||
// we have:
|
||||
// TM * TN = T
|
||||
// BM / TM = BN / TN
|
||||
// then:
|
||||
// TM = ((BM / BN) * T) ^ 0.5
|
||||
//
|
||||
float r = float(m) / n;
|
||||
int nth_m = std::ceil(std::sqrt(r * nth));
|
||||
int nth_n = 1;
|
||||
for (; nth_m > 0; --nth_m) {
|
||||
nth_n = nth / nth_m;
|
||||
if (nth_m * nth_n == nth) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp parallel num_threads(nth)
|
||||
{
|
||||
int ith = omp_get_thread_num();
|
||||
int ith_m = ith / nth_n;
|
||||
int ith_n = ith % nth_n;
|
||||
|
||||
int thread_block_m = div_up(m, nth_m);
|
||||
int thread_block_n = div_up(n, nth_n);
|
||||
|
||||
int begin_m = ith_m * thread_block_m;
|
||||
int end_m = std::min(m, begin_m + thread_block_m);
|
||||
int begin_n = ith_n * thread_block_n;
|
||||
int end_n = std::min(n, begin_n + thread_block_n);
|
||||
|
||||
f(begin_m, end_m, begin_n, end_n);
|
||||
}
|
||||
#else
|
||||
f(0, m, 0, n);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int get_cache_blocks(int BLOCK_SIZE, int K) {
|
||||
// L2 2MB and ratio of 50%
|
||||
const int L2_size = 2048 * 1024 >> 1;
|
||||
return std::max(1, int(L2_size / (BLOCK_SIZE * K * sizeof(T))));
|
||||
}
|
||||
|
||||
// data indexing for dimension collapse
|
||||
template <typename T>
|
||||
inline T data_index_init(T offset) {
|
||||
return offset;
|
||||
}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
inline T data_index_init(T offset, T& x, const T& X, Args&&... args) {
|
||||
offset = data_index_init(offset, std::forward<Args>(args)...);
|
||||
x = offset % X;
|
||||
return offset / X;
|
||||
}
|
||||
|
||||
inline bool data_index_step() {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
inline bool data_index_step(T& x, const T& X, Args&&... args) {
|
||||
if (data_index_step(std::forward<Args>(args)...)) {
|
||||
x = ((x + 1) == X) ? 0 : (x + 1);
|
||||
return x == 0;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// forced unroll for perf critical path
|
||||
|
||||
#if __has_attribute(always_inline)
|
||||
#define ALWAYS_INLINE __attribute__((__always_inline__)) inline
|
||||
#else
|
||||
#define ALWAYS_INLINE inline
|
||||
#endif
|
||||
|
||||
template <int n>
|
||||
struct Unroll {
|
||||
template <typename Func, typename... Args>
|
||||
ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
|
||||
Unroll<n - 1>{}(f, args...);
|
||||
f(std::integral_constant<int, n - 1>{}, args...);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Unroll<1> {
|
||||
template <typename Func, typename... Args>
|
||||
ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
|
||||
f(std::integral_constant<int, 0>{}, args...);
|
||||
}
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
464
csrc/cpu/sgl-kernels/gemm.cpp
Normal file
464
csrc/cpu/sgl-kernels/gemm.cpp
Normal file
@ -0,0 +1,464 @@
|
||||
// Adapted from
|
||||
// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu
|
||||
|
||||
#include "common.h"
|
||||
#include "vec.h"
|
||||
#include "gemm.h"
|
||||
|
||||
// clang-format off
|
||||
|
||||
namespace {
|
||||
|
||||
// packed layout:
|
||||
// quants {N, K} int8_t
|
||||
// comp {N} int32_t
|
||||
template <int BLOCK_N>
|
||||
inline void s8s8_compensation(int8_t* __restrict__ packed, int K) {
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
__m512i vcomp[COLS];
|
||||
|
||||
for (int col = 0; col < COLS; ++col) {
|
||||
vcomp[col] = _mm512_setzero_si512();
|
||||
}
|
||||
|
||||
const int64_t offset = BLOCK_N * K;
|
||||
const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80));
|
||||
for (int k = 0; k < K / 4; ++k) {
|
||||
for (int col = 0; col < COLS; ++col) {
|
||||
__m512i vb = _mm512_loadu_si512((const __m512i *)(packed + k * BLOCK_N * 4 + col * 64));
|
||||
vcomp[col] = _mm512_dpbusd_epi32(vcomp[col], off, vb);
|
||||
}
|
||||
}
|
||||
|
||||
for (int col = 0; col < COLS; ++col) {
|
||||
_mm512_storeu_si512((__m512i *)(packed + offset + col * 64), vcomp[col]);
|
||||
}
|
||||
#else
|
||||
TORCH_CHECK(false, "s8s8_compensation not implemented!");
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert to vnni format
|
||||
// from [N, K] to [K/2, N, 2] for bfloat16 and float16
|
||||
template <typename packed_t>
|
||||
inline void pack_vnni(packed_t* __restrict__ packed, const packed_t* __restrict__ weight, int N, int K) {
|
||||
const int VNNI_BLK = 2;
|
||||
for (int n = 0; n < N; ++n) {
|
||||
for (int k = 0; k < K / VNNI_BLK; ++k) {
|
||||
for (int d = 0; d < VNNI_BLK; ++d) {
|
||||
packed[k * N * VNNI_BLK + n * VNNI_BLK + d] = weight[n * K + k * VNNI_BLK + d];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void pack_vnni<int8_t>(int8_t* __restrict__ packed, const int8_t* __restrict__ weight, int N, int K) {
|
||||
constexpr int BLOCK_N = block_size_n();
|
||||
TORCH_CHECK(N == BLOCK_N);
|
||||
|
||||
const int VNNI_BLK = 4;
|
||||
for (int n = 0; n < N; ++n) {
|
||||
for (int k = 0; k < K / VNNI_BLK; ++k) {
|
||||
for (int d = 0; d < VNNI_BLK; ++d) {
|
||||
packed[k * N * VNNI_BLK + n * VNNI_BLK + d] = weight[n * K + k * VNNI_BLK + d];
|
||||
}
|
||||
}
|
||||
}
|
||||
s8s8_compensation<BLOCK_N>(packed, K);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
fVec data0 = fVec::loadu(input + d);
|
||||
fVec data1 = fVec::loadu(input + d + fVec::size());
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_add_stub(scalar_t* __restrict__ out, const float* __restrict__ input, const float* __restrict__ bias, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
fVec data0 = fVec::loadu(input + d) + fVec::loadu(bias + d);
|
||||
fVec data1 = fVec::loadu(input + d + fVec::size()) + fVec::loadu(bias + d + fVec::size());
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d] + bias[d]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, bool has_bias, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_nn {
|
||||
static inline void apply(
|
||||
const scalar_t* __restrict__ A, const scalar_t* __restrict__ B, scalar_t* __restrict__ C,
|
||||
const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
|
||||
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
template <bool has_bias, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_nn<at::BFloat16, has_bias, BLOCK_M, BLOCK_N> {
|
||||
static inline void apply(
|
||||
const at::BFloat16* __restrict__ A, const at::BFloat16* __restrict__ B, at::BFloat16* __restrict__ C,
|
||||
const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
|
||||
|
||||
constexpr int ROWS = BLOCK_M;
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
|
||||
// prefetch distance
|
||||
constexpr int PREFETCH_SIZE_K = 0;
|
||||
|
||||
__m512bh va;
|
||||
__m512bh vb[COLS];
|
||||
__m512 vc[ROWS * COLS];
|
||||
|
||||
auto loadc = [&](auto i) {
|
||||
constexpr int col = i % COLS;
|
||||
if constexpr (has_bias) {
|
||||
vc[i] = _mm512_loadu_ps(bias + col * 16);
|
||||
} else {
|
||||
vc[i] = _mm512_set1_ps(0.f);
|
||||
}
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(loadc);
|
||||
|
||||
const int64_t K2 = K >> 1;
|
||||
const int64_t lda2 = lda >> 1;
|
||||
const int64_t ldb2 = ldb; // ldb * 2 >> 1;
|
||||
const float* a_ptr = reinterpret_cast<const float*>(A);
|
||||
const float* b_ptr = reinterpret_cast<const float*>(B);
|
||||
|
||||
auto compute = [&](auto i, int64_t k) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
if constexpr (col == 0) {
|
||||
va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k]));
|
||||
}
|
||||
if constexpr (row == 0) {
|
||||
vb[col] = (__m512bh)(_mm512_loadu_si512(b_ptr + k * ldb2 + col * 16));
|
||||
if constexpr (PREFETCH_SIZE_K > 0) {
|
||||
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0);
|
||||
}
|
||||
}
|
||||
vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]);
|
||||
};
|
||||
for (int64_t k = 0; k < K2; ++k) {
|
||||
Unroll<ROWS * COLS>{}(compute, k);
|
||||
}
|
||||
|
||||
auto storec = [&](auto i) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
// for COLS = 2, 4 use 512bit store
|
||||
// for COLS = 1, 3 use 256bit store
|
||||
if constexpr (COLS % 2 == 0) {
|
||||
if constexpr (col % 2 == 0) {
|
||||
_mm512_storeu_si512(
|
||||
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)),
|
||||
(__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col])));
|
||||
}
|
||||
} else {
|
||||
_mm256_storeu_si256(
|
||||
reinterpret_cast<__m256i*>(C + row * ldc + col * 16),
|
||||
(__m256i)(_mm512_cvtneps_pbh(vc[i])));
|
||||
}
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(storec);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
|
||||
tinygemm_kernel_nn<scalar_t, has_bias, MB_SIZE, NB_SIZE>::apply( \
|
||||
A + mb_start * lda, B + nb_start * 2, C + mb_start * ldc + nb_start, \
|
||||
has_bias ? bias + nb_start : nullptr, K, lda, ldb, ldc);
|
||||
|
||||
template <typename scalar_t, bool has_bias>
|
||||
struct brgemm {
|
||||
static inline void apply(
|
||||
const scalar_t* __restrict__ A, const scalar_t* __restrict__ B, scalar_t* __restrict__ C,
|
||||
float* __restrict__ Ctmp, const float* __restrict__ bias,
|
||||
int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
|
||||
|
||||
constexpr int BLOCK_N = block_size_n();
|
||||
at::native::cpublas::brgemm(
|
||||
M, N, K, lda, ldb, BLOCK_N, /* add_C */false,
|
||||
A, B, Ctmp);
|
||||
|
||||
// copy from Ctmp to C
|
||||
for (int64_t m = 0; m < M; ++m) {
|
||||
if constexpr (has_bias) {
|
||||
copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N);
|
||||
} else {
|
||||
copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename scalar_t, bool has_bias>
|
||||
void tinygemm_kernel(
|
||||
const scalar_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ bias,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg) {
|
||||
|
||||
if (brg) {
|
||||
brgemm<scalar_t, has_bias>::apply(
|
||||
A, B, C, Ctmp, bias,
|
||||
M, N, K, lda, ldb, ldc);
|
||||
return;
|
||||
}
|
||||
|
||||
// pattern: 1-4-16
|
||||
constexpr int64_t BLOCK_M = 4;
|
||||
constexpr int64_t BLOCK_N = 64;
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
for (int mb = 0; mb < MB; ++mb) {
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
for (int64_t nb = 0; nb < NB; ++nb) {
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
|
||||
|
||||
switch(mb_size << 4 | nb_size >> 4) {
|
||||
// mb_size = 1
|
||||
case 0x12: LAUNCH_TINYGEMM_KERNEL_NN(1, 32); break;
|
||||
case 0x14: LAUNCH_TINYGEMM_KERNEL_NN(1, 64); break;
|
||||
// mb_size = 2
|
||||
case 0x22: LAUNCH_TINYGEMM_KERNEL_NN(2, 32); break;
|
||||
case 0x24: LAUNCH_TINYGEMM_KERNEL_NN(2, 64); break;
|
||||
// mb_size = 3
|
||||
case 0x32: LAUNCH_TINYGEMM_KERNEL_NN(3, 32); break;
|
||||
case 0x34: LAUNCH_TINYGEMM_KERNEL_NN(3, 64); break;
|
||||
// mb_size = 4
|
||||
case 0x42: LAUNCH_TINYGEMM_KERNEL_NN(4, 32); break;
|
||||
case 0x44: LAUNCH_TINYGEMM_KERNEL_NN(4, 64); break;
|
||||
default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void weight_packed_linear_kernel_impl(
|
||||
scalar_t* __restrict__ out,
|
||||
const scalar_t* __restrict__ mat1,
|
||||
const scalar_t* __restrict__ mat2,
|
||||
const float* __restrict__ bias,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t mat1_strideM,
|
||||
int64_t out_strideM) {
|
||||
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
|
||||
// use avx512-bf16 when a) M is small; b) dtype is bfloat16, otherwise use amx
|
||||
const bool use_brgemm = (M > 4) || (!std::is_same_v<scalar_t, at::BFloat16>);
|
||||
|
||||
// l2 cache block for n
|
||||
int64_t cache_blocks_nb = get_cache_blocks<scalar_t>(BLOCK_N, K);
|
||||
|
||||
// parallel on [MB, NB]
|
||||
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
|
||||
parallel_2d(MB, NB, [&](int64_t begin_mb, int64_t end_mb, int64_t begin_nb, int64_t end_nb) {
|
||||
|
||||
// for brgemm, use float32 for accumulate
|
||||
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
|
||||
|
||||
for (int64_t nbb = begin_nb; nbb < end_nb; nbb += cache_blocks_nb) {
|
||||
for (int64_t mb = begin_mb; mb < end_mb; ++mb) {
|
||||
for (int64_t nb = nbb; nb < std::min(nbb + cache_blocks_nb, end_nb); ++nb) {
|
||||
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(M - mb_start, BLOCK_M);
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(N - nb_start, BLOCK_N);
|
||||
|
||||
tinygemm_kernel<scalar_t, has_bias>(
|
||||
/* A */ mat1 + mb_start * mat1_strideM,
|
||||
/* B */ mat2 + nb_start * K /* nb * BLOCK_N * K */,
|
||||
/* C */ out + mb_start * out_strideM + nb_start,
|
||||
/* Ctmp*/ Ctmp,
|
||||
/* bias*/ bias + nb_start,
|
||||
/* M */ mb_size,
|
||||
/* N */ nb_size,
|
||||
/* K */ K,
|
||||
/* lda */ mat1_strideM,
|
||||
/* ldb */ nb_size,
|
||||
/* ldc */ out_strideM,
|
||||
/* brg */ use_brgemm);
|
||||
}}}
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// tinygemm interface
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(const scalar_t* __restrict__ A, const scalar_t* __restrict__ B, scalar_t* __restrict__ C,
|
||||
float* __restrict__ Ctmp, int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool brg) {
|
||||
tinygemm_kernel<scalar_t, false>(A, B, C, Ctmp, nullptr, M, N, K, lda, ldb, ldc, brg);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
|
||||
template void tinygemm_kernel<TYPE>( \
|
||||
const TYPE* __restrict__ A, const TYPE* __restrict__ B, TYPE* __restrict__ C, \
|
||||
float* __restrict__ Ctmp, int64_t M, int64_t N, int64_t K, int64_t lda, \
|
||||
int64_t ldb, int64_t ldc, bool brg)
|
||||
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
|
||||
|
||||
at::Tensor convert_weight_packed(at::Tensor& weight) {
|
||||
// for 3d moe weights
|
||||
// weight : [E, OC, IC]
|
||||
// w1 : [E, 2N, K]
|
||||
// w2 : [E, K, N]
|
||||
CHECK_INPUT(weight);
|
||||
|
||||
const int64_t ndim = weight.ndimension();
|
||||
TORCH_CHECK(ndim == 2 || ndim == 3, "expect weight to be 2d or 3d, got ", ndim, "d tensor.");
|
||||
const auto st = weight.scalar_type();
|
||||
const int64_t E = ndim == 3 ? weight.size(0) : 1;
|
||||
const int64_t OC = ndim == 3 ? weight.size(1) : weight.size(0);
|
||||
const int64_t IC = ndim == 3 ? weight.size(2) : weight.size(1);
|
||||
|
||||
// we handle 2 TILE_N at a time.
|
||||
TORCH_CHECK(OC % TILE_N == 0, "invalid weight out features ", OC);
|
||||
TORCH_CHECK(IC % TILE_K == 0, "invalid weight input features ", IC);
|
||||
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
const int64_t NB = div_up(OC, BLOCK_N);
|
||||
|
||||
// use phony sizes here [E, OC, IC], for each [E], [OC, IC] -> [IC / 2, OC, 2]
|
||||
auto packed_weight = at::empty({}, weight.options());
|
||||
const int64_t stride = OC * IC;
|
||||
|
||||
TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf || st == at::kChar || st == at::kFloat8_e4m3fn,
|
||||
"expect weight to be bfloat16, float16, int8 or fp8_e4m3.");
|
||||
|
||||
CPU_DISPATCH_PACKED_TYPES(st, [&] {
|
||||
// adjust most inner dimension size
|
||||
const int packed_row_size = get_row_size<packed_t>(IC);
|
||||
auto sizes = weight.sizes().vec();
|
||||
sizes[ndim - 1] = packed_row_size;
|
||||
packed_weight.resize_(sizes);
|
||||
|
||||
const packed_t* w_data = weight.data_ptr<packed_t>();
|
||||
packed_t* packed_data = packed_weight.data_ptr<packed_t>();
|
||||
|
||||
// parallel on {E, NB}
|
||||
at::parallel_for(0, E * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
int64_t e{0}, nb{0};
|
||||
data_index_init(begin, e, E, nb, NB);
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
UNUSED(i);
|
||||
|
||||
int64_t n = nb * BLOCK_N;
|
||||
int64_t n_size = std::min(BLOCK_N, OC - n);
|
||||
pack_vnni<packed_t>(
|
||||
packed_data + e * OC * packed_row_size + n * packed_row_size,
|
||||
w_data + e * stride + n * IC,
|
||||
n_size,
|
||||
IC);
|
||||
|
||||
// move to the next index
|
||||
data_index_step(e, E, nb, NB);
|
||||
}
|
||||
});
|
||||
});
|
||||
return packed_weight;
|
||||
}
|
||||
|
||||
// mat1 : [M, K]
|
||||
// mat2 : [N, K]
|
||||
// bias : [N]
|
||||
// out : [M, N]
|
||||
//
|
||||
at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2,
|
||||
const std::optional<at::Tensor>& bias, bool is_vnni) {
|
||||
RECORD_FUNCTION(
|
||||
"sgl-kernel::weight_packed_linear", std::vector<c10::IValue>({mat1, mat2, bias}));
|
||||
|
||||
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
|
||||
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
|
||||
CHECK_INPUT(mat2);
|
||||
|
||||
int64_t M = mat1.size(0);
|
||||
int64_t N = mat2.size(0);
|
||||
int64_t K = mat2.size(1);
|
||||
CHECK_EQ(mat1.size(1), K);
|
||||
CHECK_DIM(2, mat1);
|
||||
CHECK_DIM(2, mat2);
|
||||
|
||||
auto out = at::empty({M, N}, mat1.options());
|
||||
|
||||
// strides
|
||||
int64_t mat1_strideM = mat1.stride(0);
|
||||
int64_t out_strideM = out.stride(0);
|
||||
|
||||
const bool has_bias = bias.has_value();
|
||||
const float* bias_data = nullptr;
|
||||
if (has_bias) {
|
||||
CHECK_EQ(bias.value().size(0), N);
|
||||
bias_data = bias.value().data_ptr<float>();
|
||||
}
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(mat1.scalar_type(), "weight_packed_linear_kernel_impl", [&] {
|
||||
weight_packed_linear_kernel_impl<scalar_t>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
mat1.data_ptr<scalar_t>(),
|
||||
packed_w.data_ptr<scalar_t>(),
|
||||
bias_data,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
mat1_strideM,
|
||||
out_strideM);
|
||||
});
|
||||
|
||||
return out;
|
||||
}
|
||||
266
csrc/cpu/sgl-kernels/gemm.h
Normal file
266
csrc/cpu/sgl-kernels/gemm.h
Normal file
@ -0,0 +1,266 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/native/CPUBlas.h>
|
||||
|
||||
// clang-format off
|
||||
|
||||
// amx-bf16
|
||||
#define TILE_M 16
|
||||
#define TILE_N 16
|
||||
#define TILE_K 32
|
||||
|
||||
// block size for AMX gemm
|
||||
constexpr int block_size_m() { return 2 * TILE_M; }
|
||||
constexpr int block_size_n() { return 2 * TILE_N; }
|
||||
|
||||
// define threshold using brgemm (intel AMX)
|
||||
template <typename T> inline bool can_use_brgemm(int M);
|
||||
template <> inline bool can_use_brgemm<at::BFloat16>(int M) { return M > 4; }
|
||||
template <> inline bool can_use_brgemm<at::Half>(int M) { return true; }
|
||||
// TODO: add u8s8 brgemm, this requires PyTorch 2.7
|
||||
template <> inline bool can_use_brgemm<int8_t>(int M) { return false; }
|
||||
template <> inline bool can_use_brgemm<at::Float8_e4m3fn>(int M) { return M > 4; }
|
||||
template <> inline bool can_use_brgemm<at::quint4x2>(int M) { return M > 4; }
|
||||
|
||||
// work around compiler internal error
|
||||
#define BLOCK_K 128 // 4 * TILE_K
|
||||
|
||||
// adjust leading dimension size for K
|
||||
template <typename T>
|
||||
inline int64_t get_row_size(int64_t K) {
|
||||
return K;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline int64_t get_row_size<int8_t>(int64_t K) {
|
||||
return K + sizeof(int32_t);
|
||||
}
|
||||
|
||||
inline int64_t get_row_size(int64_t K, bool use_int8_w8a8) {
|
||||
return use_int8_w8a8 ? K + sizeof(int32_t) : K;
|
||||
}
|
||||
|
||||
// pack weight to vnni format
|
||||
at::Tensor convert_weight_packed(at::Tensor& weight);
|
||||
|
||||
// moe implementations for int8 w8a8
|
||||
template <typename scalar_t>
|
||||
void fused_experts_int8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic1,
|
||||
scalar_t* __restrict__ ic2,
|
||||
uint8_t* __restrict__ A_tmp,
|
||||
float* __restrict__ C_tmp,
|
||||
uint8_t* __restrict__ Aq_tmp,
|
||||
float* __restrict__ As_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const int8_t* __restrict__ packed_w1,
|
||||
const int8_t* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
const float* __restrict__ topk_weights,
|
||||
const int32_t* __restrict__ sorted_ids,
|
||||
const int32_t* __restrict__ expert_ids,
|
||||
const int32_t* __restrict__ offsets,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t E,
|
||||
int64_t topk,
|
||||
int64_t num_tokens_post_pad);
|
||||
|
||||
// moe implementations for fp8 w8a16
|
||||
template <typename scalar_t>
|
||||
void fused_experts_fp8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic0,
|
||||
scalar_t* __restrict__ ic1,
|
||||
scalar_t* __restrict__ ic2,
|
||||
scalar_t* __restrict__ A_tmp,
|
||||
scalar_t* __restrict__ B_tmp,
|
||||
float* __restrict__ C_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w1,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
int64_t block_size_N,
|
||||
int64_t block_size_K,
|
||||
const float* __restrict__ topk_weights,
|
||||
const int32_t* __restrict__ sorted_ids,
|
||||
const int32_t* __restrict__ expert_ids,
|
||||
const int32_t* __restrict__ offsets,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t E,
|
||||
int64_t topk,
|
||||
int64_t num_tokens_post_pad);
|
||||
|
||||
// moe implementations for int4 w4a16
|
||||
template <typename scalar_t>
|
||||
void fused_experts_int4_w4a16_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic0,
|
||||
scalar_t* __restrict__ ic1,
|
||||
scalar_t* __restrict__ ic2,
|
||||
scalar_t* __restrict__ A_tmp,
|
||||
scalar_t* __restrict__ B_tmp,
|
||||
float* __restrict__ C_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const at::quint4x2* __restrict__ packed_w1,
|
||||
const at::quint4x2* __restrict__ packed_w2,
|
||||
const uint8_t* __restrict__ w1z,
|
||||
const uint8_t* __restrict__ w2z,
|
||||
const scalar_t* __restrict__ w1s,
|
||||
const scalar_t* __restrict__ w2s,
|
||||
int group_size,
|
||||
const float* __restrict__ topk_weights,
|
||||
const int32_t* __restrict__ sorted_ids,
|
||||
const int32_t* __restrict__ expert_ids,
|
||||
const int32_t* __restrict__ offsets,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t E,
|
||||
int64_t topk,
|
||||
int64_t num_tokens_post_pad);
|
||||
|
||||
// shared expert implementation for int8 w8a8
|
||||
template <typename scalar_t>
|
||||
void shared_expert_int8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic1,
|
||||
float* __restrict__ C_tmp,
|
||||
uint8_t* __restrict__ Aq_tmp,
|
||||
float* __restrict__ As_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const int8_t* __restrict__ packed_w1,
|
||||
const int8_t* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
const scalar_t* __restrict__ fused_experts_out,
|
||||
float routed_scaling_factor,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K);
|
||||
|
||||
template <typename scalar_t>
|
||||
void shared_expert_fp8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic0,
|
||||
scalar_t* __restrict__ ic1,
|
||||
scalar_t* __restrict__ B_tmp,
|
||||
float* __restrict__ C_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w1,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
int64_t block_size_N,
|
||||
int64_t block_size_K,
|
||||
const scalar_t* __restrict__ fused_experts_out,
|
||||
float routed_scaling_factor,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K);
|
||||
|
||||
// tinygemm interface
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const scalar_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
float* __restrict__ Ctmp,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg);
|
||||
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const uint8_t* __restrict__ A,
|
||||
const int8_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
int32_t* __restrict__ Ctmp,
|
||||
const float* __restrict__ As,
|
||||
const float* __restrict__ Bs,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg);
|
||||
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const scalar_t* __restrict__ A,
|
||||
const at::Float8_e4m3fn* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
scalar_t* __restrict__ Btmp,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ scale,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg,
|
||||
int64_t block_size_K);
|
||||
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const scalar_t* __restrict__ A,
|
||||
const at::quint4x2* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
const uint8_t* __restrict__ Bz,
|
||||
const scalar_t* __restrict__ Bs,
|
||||
scalar_t* __restrict__ Btmp,
|
||||
float* __restrict__ Ctmp,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int group_size,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
int64_t strideBz,
|
||||
int64_t strideBs,
|
||||
bool brg);
|
||||
|
||||
// TODO: debug print, remove me later
|
||||
inline void print_16x32i(const __m512i x) {
|
||||
int32_t a[16];
|
||||
_mm512_storeu_si512((__m512i *)a, x);
|
||||
|
||||
for (int i = 0; i < 16; i++){
|
||||
std::cout << a[i] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
inline void print_16x32(const __m512 x) {
|
||||
float a[16];
|
||||
_mm512_storeu_ps((__m512 *)a, x);
|
||||
|
||||
for (int i = 0; i < 16; i++){
|
||||
std::cout << a[i] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
|
||||
inline void print_32x8u(const __m256i x) {
|
||||
uint8_t a[32];
|
||||
_mm256_storeu_si256((__m256i *)a, x);
|
||||
|
||||
for (int i = 0; i < 32; ++i) {
|
||||
std::cout << int32_t(a[i]) << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
530
csrc/cpu/sgl-kernels/gemm_fp8.cpp
Normal file
530
csrc/cpu/sgl-kernels/gemm_fp8.cpp
Normal file
@ -0,0 +1,530 @@
|
||||
// Adapted from
|
||||
// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu
|
||||
|
||||
#include "common.h"
|
||||
#include "vec.h"
|
||||
#include "gemm.h"
|
||||
|
||||
// clang-format off
|
||||
|
||||
// we use 4x32 for BLOCK_M
|
||||
#define BLOCK_SIZE_M_SCALE 4
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
fVec data0 = fVec::loadu(input + d);
|
||||
fVec data1 = fVec::loadu(input + d + fVec::size());
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_add_stub(scalar_t* __restrict__ out, const float* __restrict__ input, const float* __restrict__ bias, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
fVec data0 = fVec::loadu(input + d) + fVec::loadu(bias + d);
|
||||
fVec data1 = fVec::loadu(input + d + fVec::size()) + fVec::loadu(bias + d + fVec::size());
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d] + bias[d]);
|
||||
}
|
||||
}
|
||||
|
||||
inline void unpack_B(
|
||||
at::BFloat16* __restrict__ Btmp,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_B,
|
||||
int N,
|
||||
int K,
|
||||
int ldb,
|
||||
int ldb_tmp,
|
||||
float scale) {
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
// [K/2, N, 2]
|
||||
const int K2 = K >> 1;
|
||||
const int ldb2 = ldb; // ldb * 2 >> 1;
|
||||
const uint16_t* b_ptr = reinterpret_cast<const uint16_t*>(packed_B);
|
||||
const __m512 vd = _mm512_set1_ps(scale);
|
||||
|
||||
constexpr int BLOCK_N = block_size_n();
|
||||
static_assert(BLOCK_N == 32);
|
||||
|
||||
// prefetch distance
|
||||
constexpr int PREFETCH_SIZE_K = 64;
|
||||
|
||||
#pragma GCC unroll 4
|
||||
for (int k = 0; k < K2; ++k) {
|
||||
__m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2);
|
||||
if constexpr (PREFETCH_SIZE_K > 0) {
|
||||
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2, _MM_HINT_T0);
|
||||
}
|
||||
|
||||
__m256i b8_0 = _mm512_extracti32x8_epi32(b8, 0);
|
||||
__m256i b8_1 = _mm512_extracti32x8_epi32(b8, 1);
|
||||
|
||||
__m512bh bf16_0 = CVT_FP8_TO_BF16(b8_0);
|
||||
__m512bh bf16_1 = CVT_FP8_TO_BF16(b8_1);
|
||||
|
||||
// Apply scale
|
||||
__m512 f0_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 0));
|
||||
__m512 f0_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 1));
|
||||
__m512 f1_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 0));
|
||||
__m512 f1_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 1));
|
||||
|
||||
f0_lo = _mm512_mul_ps(f0_lo, vd);
|
||||
f0_hi = _mm512_mul_ps(f0_hi, vd);
|
||||
f1_lo = _mm512_mul_ps(f1_lo, vd);
|
||||
f1_hi = _mm512_mul_ps(f1_hi, vd);
|
||||
|
||||
bf16_0 = _mm512_cvtne2ps_pbh(f0_hi, f0_lo);
|
||||
bf16_1 = _mm512_cvtne2ps_pbh(f1_hi, f1_lo);
|
||||
|
||||
_mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + 0, (__m512i)bf16_0);
|
||||
_mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + 32, (__m512i)bf16_1);
|
||||
}
|
||||
#else
|
||||
TORCH_CHECK(false, "unpack_B: scalar path not implemented!");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename packed_t, bool has_bias, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_nn {
|
||||
static inline void apply(
|
||||
const scalar_t* __restrict__ A, const packed_t* __restrict__ B, scalar_t* __restrict__ C,
|
||||
const float* __restrict__ bias, const float* __restrict__ scale, int K, int lda, int ldb, int ldc, int64_t block_size_K) {
|
||||
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
template <bool has_bias, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_nn<at::BFloat16, at::Float8_e4m3fn, has_bias, BLOCK_M, BLOCK_N> {
|
||||
static inline void apply(
|
||||
const at::BFloat16* __restrict__ A, const at::Float8_e4m3fn* __restrict__ B, at::BFloat16* __restrict__ C,
|
||||
const float* __restrict__ bias, const float* __restrict__ scale, int K, int lda, int ldb, int ldc, int64_t block_size_K) {
|
||||
|
||||
constexpr int ROWS = BLOCK_M;
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
|
||||
const int KB = div_up(K, BLOCK_K);
|
||||
|
||||
// prefetch distance
|
||||
constexpr int PREFETCH_SIZE_K = 64;
|
||||
constexpr int PREFETCH_SIZE_KB = 1;
|
||||
|
||||
__m512bh va;
|
||||
__m512bh vb[COLS];
|
||||
__m512 vc[ROWS * COLS];
|
||||
__m512 vsum[ROWS * COLS];
|
||||
|
||||
// block quant scale
|
||||
__m512 vscale;
|
||||
|
||||
auto loadc = [&](auto i) {
|
||||
constexpr int col = i % COLS;
|
||||
if constexpr (has_bias) {
|
||||
vc[i] = _mm512_loadu_ps(bias + col * 16);
|
||||
} else {
|
||||
vc[i] = _mm512_setzero_ps();
|
||||
}
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(loadc);
|
||||
|
||||
const int lda2 = lda >> 1;
|
||||
const int ldb2 = ldb; // ldb * 2 >> 1;
|
||||
const float* a_ptr = reinterpret_cast<const float*>(A);
|
||||
const uint16_t* b_ptr = reinterpret_cast<const uint16_t*>(B);
|
||||
|
||||
auto compute = [&](auto i, int k) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
if constexpr (col == 0) {
|
||||
va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k]));
|
||||
if constexpr (PREFETCH_SIZE_K > 0) {
|
||||
_mm_prefetch(a_ptr + row * lda2 + k + PREFETCH_SIZE_K, _MM_HINT_T0);
|
||||
}
|
||||
}
|
||||
if constexpr (row == 0) {
|
||||
if constexpr (col % 2 == 0) {
|
||||
__m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2 + col * 16);
|
||||
if constexpr (PREFETCH_SIZE_K > 0) {
|
||||
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0);
|
||||
}
|
||||
vb[col + 0] = CVT_FP8_TO_BF16(_mm512_extracti32x8_epi32(b8, 0));
|
||||
vb[col + 1] = CVT_FP8_TO_BF16(_mm512_extracti32x8_epi32(b8, 1));
|
||||
}
|
||||
}
|
||||
vsum[i] = _mm512_dpbf16_ps(vsum[i], va, vb[col]);
|
||||
};
|
||||
|
||||
constexpr int BLOCK_K2 = BLOCK_K >> 1;
|
||||
for (int kb = 0; kb < KB; ++kb) {
|
||||
int kb_start = kb * BLOCK_K2;
|
||||
int kb_end = std::min(K, kb_start + BLOCK_K2);
|
||||
// 1. load scale vector
|
||||
vscale = _mm512_set1_ps(scale[kb]);
|
||||
if constexpr (PREFETCH_SIZE_KB > 0) {
|
||||
_mm_prefetch(scale + kb + PREFETCH_SIZE_KB, _MM_HINT_T0);
|
||||
}
|
||||
// 2. zero vsum for each block
|
||||
Unroll<ROWS * COLS>{}([&](auto i) {
|
||||
vsum[i] = _mm512_setzero_ps();
|
||||
});
|
||||
// 3. accumulate across each block
|
||||
for (int k = kb_start; k < kb_end; ++k) {
|
||||
Unroll<ROWS * COLS>{}(compute, k);
|
||||
}
|
||||
// 4. apply scale
|
||||
Unroll<ROWS * COLS>{}([&](auto i) {
|
||||
vc[i] = _mm512_fmadd_ps(vsum[i], vscale, vc[i]);
|
||||
});
|
||||
}
|
||||
|
||||
auto storec = [&](auto i) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
// for COLS = 2,4 use 512bit store
|
||||
if constexpr (col % 2 == 0) {
|
||||
_mm512_storeu_si512(
|
||||
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)),
|
||||
(__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col])));
|
||||
}
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(storec);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
|
||||
tinygemm_kernel_nn<scalar_t, at::Float8_e4m3fn, has_bias, MB_SIZE, NB_SIZE>::apply( \
|
||||
A + mb_start * lda, B + nb_start * 2, C + mb_start * ldc + nb_start, \
|
||||
has_bias ? bias + nb_start : nullptr, scale, K, lda, ldb, ldc, block_size_K);
|
||||
|
||||
template <typename scalar_t, typename packed_t, bool has_bias>
|
||||
struct brgemm {
|
||||
static inline void apply(
|
||||
const scalar_t* __restrict__ A,
|
||||
const packed_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
scalar_t* __restrict__ Btmp,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ bias,
|
||||
const float* __restrict__ scale,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldc) {
|
||||
TORCH_CHECK(false, "struct brgemm: primary template not implemented!");
|
||||
}
|
||||
};
|
||||
|
||||
template <bool has_bias>
|
||||
struct brgemm<at::BFloat16, at::Float8_e4m3fn, has_bias> {
|
||||
static inline void apply(
|
||||
const at::BFloat16* __restrict__ A,
|
||||
const at::Float8_e4m3fn* __restrict__ B,
|
||||
at::BFloat16* __restrict__ C,
|
||||
at::BFloat16* __restrict__ Btmp,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ bias,
|
||||
const float* __restrict__ scale,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldc) {
|
||||
|
||||
constexpr int BLOCK_N = block_size_n();
|
||||
|
||||
// [K, BLOCK_N] -> [K / 2, BLOCK_N * 2]
|
||||
const int ldb_tmp = BLOCK_N;
|
||||
|
||||
for (int k = 0; k < K; k += BLOCK_K) {
|
||||
int kb_size = std::min(BLOCK_K, K - k);
|
||||
|
||||
int idx = k >> 7; // k / BLOCK_K where BLOCK_K = 128
|
||||
unpack_B(Btmp + k * ldb_tmp, B + k * ldb, N, kb_size, ldb, ldb_tmp, scale[idx]);
|
||||
}
|
||||
|
||||
at::native::cpublas::brgemm(
|
||||
M, N, K, lda, ldb_tmp, BLOCK_N, /* add_C */ false, A, Btmp, Ctmp);
|
||||
|
||||
// copy from Ctmp to C
|
||||
for (int m = 0; m < M; ++m) {
|
||||
if constexpr (has_bias) {
|
||||
copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N);
|
||||
} else {
|
||||
copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename scalar_t, bool has_bias>
|
||||
void tinygemm_kernel(
|
||||
const scalar_t* __restrict__ A,
|
||||
const at::Float8_e4m3fn* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
scalar_t* __restrict__ Btmp,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ scale,
|
||||
const float* __restrict__ bias,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg,
|
||||
int64_t block_size_K) {
|
||||
|
||||
if (brg) {
|
||||
brgemm<scalar_t, at::Float8_e4m3fn, has_bias>::apply(
|
||||
A, B, C, Btmp, Ctmp, bias, scale, M, N, K, lda, ldb, ldc);
|
||||
return;
|
||||
}
|
||||
|
||||
// pattern: 1-4-16
|
||||
constexpr int64_t BLOCK_M = 4;
|
||||
constexpr int64_t BLOCK_N = 64;
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
for (int mb = 0; mb < MB; ++mb) {
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
for (int64_t nb = 0; nb < NB; ++nb) {
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
|
||||
|
||||
switch(mb_size << 4 | nb_size >> 4) {
|
||||
case 0x12: LAUNCH_TINYGEMM_KERNEL_NN(1, 32); break;
|
||||
case 0x22: LAUNCH_TINYGEMM_KERNEL_NN(2, 32); break;
|
||||
case 0x32: LAUNCH_TINYGEMM_KERNEL_NN(3, 32); break;
|
||||
case 0x42: LAUNCH_TINYGEMM_KERNEL_NN(4, 32); break;
|
||||
default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void fp8_scaled_mm_kernel_impl(
|
||||
scalar_t* __restrict__ out,
|
||||
const scalar_t* __restrict__ mat1,
|
||||
const at::Float8_e4m3fn* __restrict__ mat2,
|
||||
const float* __restrict__ scales2,
|
||||
const float* __restrict__ bias,
|
||||
scalar_t* __restrict__ buffer,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t mat1_strideM,
|
||||
int64_t out_strideM,
|
||||
int64_t block_size_N,
|
||||
int64_t block_size_K,
|
||||
int64_t buffer_size_per_thread) {
|
||||
|
||||
constexpr int64_t BLOCK_M = block_size_m() * BLOCK_SIZE_M_SCALE;
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
|
||||
const int64_t scale_size_K = div_up(K, block_size_K);
|
||||
const int64_t blocks_n_per_group = block_size_N / BLOCK_N;
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(M);
|
||||
|
||||
// parallel on [MB, NB]
|
||||
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
int64_t mb{0}, nb{0};
|
||||
data_index_init(begin, mb, MB, nb, NB);
|
||||
|
||||
int tid = at::get_thread_num();
|
||||
scalar_t* __restrict__ Btmp = buffer + tid * buffer_size_per_thread;
|
||||
float* __restrict__ Ctmp = (float*)((void*)(Btmp + BLOCK_N * K));
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
UNUSED(i);
|
||||
const float* scale_ptr = scales2 + (nb / blocks_n_per_group) * scale_size_K;
|
||||
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(M - mb_start, BLOCK_M);
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(N - nb_start, BLOCK_N);
|
||||
|
||||
tinygemm_kernel<scalar_t, has_bias>(
|
||||
/* A */ mat1 + mb_start * mat1_strideM,
|
||||
/* B */ mat2 + nb_start * K, // nb * BLOCK_N * K
|
||||
/* C */ out + mb_start * out_strideM + nb_start,
|
||||
/* Btmp */ Btmp,
|
||||
/* Ctmp */ Ctmp,
|
||||
/* scale */ scale_ptr,
|
||||
/* bias */ bias + nb_start,
|
||||
/* M */ mb_size,
|
||||
/* N */ nb_size,
|
||||
/* K */ K,
|
||||
/* lda */ mat1_strideM,
|
||||
/* ldb */ nb_size,
|
||||
/* ldc */ out_strideM,
|
||||
/* brg */ use_brgemm,
|
||||
/* block_size_K */ block_size_K);
|
||||
|
||||
// move to the next index
|
||||
data_index_step(mb, MB, nb, NB);
|
||||
}
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// tinygemm interface
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const scalar_t* __restrict__ A,
|
||||
const at::Float8_e4m3fn* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
scalar_t* __restrict__ Btmp,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ scale,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg,
|
||||
int64_t block_size_K) {
|
||||
tinygemm_kernel<scalar_t, false>(A, B, C, Btmp, Ctmp, scale, nullptr, M, N, K, lda, ldb, ldc, brg, block_size_K);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
|
||||
template void tinygemm_kernel<TYPE>( \
|
||||
const TYPE* __restrict__ A, \
|
||||
const at::Float8_e4m3fn* __restrict__ B, \
|
||||
TYPE* __restrict__ C, \
|
||||
TYPE* __restrict__ Btmp, \
|
||||
float* __restrict__ Ctmp, \
|
||||
const float* __restrict__ scale, \
|
||||
int64_t M, \
|
||||
int64_t N, \
|
||||
int64_t K, \
|
||||
int64_t lda, \
|
||||
int64_t ldb, \
|
||||
int64_t ldc, \
|
||||
bool brg, \
|
||||
int64_t block_size_K)
|
||||
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
|
||||
|
||||
at::Tensor fp8_scaled_mm_cpu(at::Tensor& mat1, at::Tensor& mat2, at::Tensor& scales2,
|
||||
std::vector<int64_t> block_size, std::optional<at::Tensor>& bias,
|
||||
at::ScalarType out_dtype, bool is_vnni) {
|
||||
RECORD_FUNCTION("sgl-kernel::fp8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales2, block_size, bias}));
|
||||
|
||||
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
|
||||
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
|
||||
CHECK_INPUT(mat2);
|
||||
CHECK_INPUT(scales2);
|
||||
TORCH_CHECK(scales2.scalar_type() == at::kFloat,
|
||||
"fp8_scaled_mm_cpu: expect scales2 to be float32.");
|
||||
|
||||
int64_t M = mat1.size(0);
|
||||
int64_t N = mat2.size(0);
|
||||
int64_t K = mat2.size(1);
|
||||
|
||||
CHECK_EQ(mat1.size(1), K);
|
||||
CHECK_DIM(2, mat1);
|
||||
CHECK_DIM(2, mat2);
|
||||
|
||||
TORCH_CHECK(block_size.size() == 2,
|
||||
"fp8_scaled_mm_cpu: expect block_size.size() to be 2.");
|
||||
|
||||
int64_t block_size_N = block_size[0];
|
||||
int64_t block_size_K = block_size[1];
|
||||
|
||||
constexpr int64_t BLOCK_M = block_size_m() * BLOCK_SIZE_M_SCALE;
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
TORCH_CHECK(block_size_N % BLOCK_N == 0, "fp8_scaled_mm_cpu: expect block_size_N to be multiples of BLOCK_N");
|
||||
TORCH_CHECK(block_size_K == BLOCK_K, "fp8_scaled_mm_cpu: expect block_size_K equals to BLOCK_K");
|
||||
CHECK_EQ(scales2.size(0), div_up(N, block_size_N));
|
||||
CHECK_EQ(scales2.size(1), div_up(K, block_size_K));
|
||||
|
||||
const auto st = mat1.scalar_type();
|
||||
TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf,
|
||||
"fp8_scaled_mm_cpu: expect A to be bfloat16 or half.");
|
||||
TORCH_CHECK(st == out_dtype,
|
||||
"fp8_scaled_mm_cpu: expect A has same dtype with out_dtype.");
|
||||
TORCH_CHECK(mat2.scalar_type() == at::kFloat8_e4m3fn,
|
||||
"fp8_scaled_mm_cpu: expect mat2 to be fp8_e4m3.");
|
||||
TORCH_CHECK(scales2.scalar_type() == at::kFloat,
|
||||
"fp8_scaled_mm_cpu: expect scales to be float32.");
|
||||
auto out = at::empty({M, N}, mat1.options().dtype(out_dtype));
|
||||
|
||||
// strides
|
||||
int64_t mat1_strideM = mat1.stride(0);
|
||||
int64_t out_strideM = out.stride(0);
|
||||
|
||||
const bool has_bias = bias.has_value();
|
||||
const float* bias_data = nullptr;
|
||||
if (has_bias) {
|
||||
CHECK_EQ(bias.value().size(0), N);
|
||||
bias_data = bias.value().data_ptr<float>();
|
||||
}
|
||||
|
||||
// Btmp : [T, BLOCK_N * K]
|
||||
// Ctmp : [T, BLOCK_M * BLOCK_N]
|
||||
int num_threads = at::get_num_threads();
|
||||
int64_t size_per_thread = BLOCK_N * K + BLOCK_M * BLOCK_N * 2;
|
||||
auto buffer = at::empty({num_threads, size_per_thread}, mat1.options());
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "fp8_scaled_mm_kernel_impl", [&] {
|
||||
fp8_scaled_mm_kernel_impl<scalar_t>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
mat1.data_ptr<scalar_t>(),
|
||||
packed_w.data_ptr<at::Float8_e4m3fn>(),
|
||||
scales2.data_ptr<float>(),
|
||||
bias_data,
|
||||
buffer.data_ptr<scalar_t>(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
mat1_strideM,
|
||||
out_strideM,
|
||||
block_size_N,
|
||||
block_size_K,
|
||||
size_per_thread);
|
||||
});
|
||||
|
||||
return out;
|
||||
}
|
||||
440
csrc/cpu/sgl-kernels/gemm_int8.cpp
Normal file
440
csrc/cpu/sgl-kernels/gemm_int8.cpp
Normal file
@ -0,0 +1,440 @@
|
||||
// Adapted from
|
||||
// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu
|
||||
|
||||
#include "common.h"
|
||||
#include "vec.h"
|
||||
#include "gemm.h"
|
||||
|
||||
// clang-format off
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t, bool has_bias, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_nn {
|
||||
static inline void apply(
|
||||
const uint8_t* __restrict__ A, const int8_t* __restrict__ B, scalar_t* __restrict__ C,
|
||||
const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp,
|
||||
const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
|
||||
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
template <bool has_bias, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_nn<at::BFloat16, has_bias, BLOCK_M, BLOCK_N> {
|
||||
static inline void apply(
|
||||
const uint8_t* __restrict__ A, const int8_t* __restrict__ B, at::BFloat16* __restrict__ C,
|
||||
const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp,
|
||||
const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
|
||||
|
||||
constexpr int ROWS = BLOCK_M;
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
static_assert(COLS % 2 == 0);
|
||||
|
||||
// prefetch distance
|
||||
constexpr int PREFETCH_SIZE_K = 0;
|
||||
|
||||
__m512i va;
|
||||
__m512i vb[COLS];
|
||||
__m512i vc[ROWS * COLS];
|
||||
__m512i vcomp[COLS];
|
||||
__m512 vd0;
|
||||
__m512 vd1[COLS];
|
||||
|
||||
// oops! 4x4 spills but luckily we use 4x2
|
||||
__m512 vbias[COLS];
|
||||
|
||||
// [NOTE]: s8s8 igemm compensation in avx512-vnni
|
||||
//
|
||||
// avx512-vnni has no s8s8, so we need to change s8s8 to u8s8 with compensate:
|
||||
//
|
||||
// a * b = (a + 128) * b - 128 * b
|
||||
// s s u s u s
|
||||
//
|
||||
// 1) 128 * b is pre-computed when packing B to vnni formats
|
||||
// 2) a + 128 is fused when dynamically quantize A
|
||||
//
|
||||
auto loadc = [&](auto i) {
|
||||
vc[i] = _mm512_set1_epi32(0);
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(loadc);
|
||||
|
||||
const int64_t K4 = K >> 2;
|
||||
const int64_t lda4 = lda >> 2;
|
||||
const int64_t ldb4 = ldb; // ldb * 4 >> 2;
|
||||
const int32_t* a_ptr = reinterpret_cast<const int32_t*>(A);
|
||||
const int32_t* b_ptr = reinterpret_cast<const int32_t*>(B);
|
||||
|
||||
auto compute = [&](auto i, int64_t k) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
if constexpr (col == 0) {
|
||||
va = _mm512_set1_epi32(a_ptr[row * lda4 + k]);
|
||||
}
|
||||
if constexpr (row == 0) {
|
||||
vb[col] = _mm512_loadu_si512(b_ptr + k * ldb4 + col * 16);
|
||||
if constexpr (PREFETCH_SIZE_K > 0) {
|
||||
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb4 + col * 16, _MM_HINT_T0);
|
||||
}
|
||||
}
|
||||
vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]);
|
||||
};
|
||||
for (int64_t k = 0; k < K4; ++k) {
|
||||
Unroll<ROWS * COLS>{}(compute, k);
|
||||
}
|
||||
|
||||
auto storec = [&](auto i) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
// load a scale
|
||||
if constexpr(col == 0) {
|
||||
vd0 = _mm512_set1_ps(As[row]);
|
||||
}
|
||||
// load b scale and vcomp per 2 vectors
|
||||
// also load bias if any
|
||||
if constexpr (row == 0) {
|
||||
if constexpr (col % 2 == 0) {
|
||||
vd1[col + 0] = _mm512_loadu_ps(Bs + col * 16);
|
||||
vd1[col + 1] = _mm512_loadu_ps(Bs + col * 16 + 16);
|
||||
vcomp[col + 0] = _mm512_loadu_si512(Bcomp + col * 16);
|
||||
vcomp[col + 1] = _mm512_loadu_si512(Bcomp + col * 16 + 16);
|
||||
if constexpr (has_bias) {
|
||||
vbias[col + 0] = _mm512_loadu_ps(bias + col * 16);
|
||||
vbias[col + 1] = _mm512_loadu_ps(bias + col * 16 + 16);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// for COLS = 2, 4 use 512bit store
|
||||
if constexpr (col % 2 == 0) {
|
||||
__m512 vc0 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[row * COLS + col + 0], vcomp[col + 0]));
|
||||
__m512 vc1 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[row * COLS + col + 1], vcomp[col + 1]));
|
||||
if constexpr (has_bias) {
|
||||
vc0 = _mm512_fmadd_ps(_mm512_mul_ps(vc0, vd0), vd1[col + 0], vbias[col + 0]);
|
||||
vc1 = _mm512_fmadd_ps(_mm512_mul_ps(vc1, vd0), vd1[col + 1], vbias[col + 1]);
|
||||
} else {
|
||||
vc0 = _mm512_mul_ps(_mm512_mul_ps(vc0, vd0), vd1[col + 0]);
|
||||
vc1 = _mm512_mul_ps(_mm512_mul_ps(vc1, vd0), vd1[col + 1]);
|
||||
}
|
||||
|
||||
_mm512_storeu_si512(
|
||||
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)),
|
||||
(__m512i)(_mm512_cvtne2ps_pbh(vc1, vc0)));
|
||||
}
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(storec);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
|
||||
tinygemm_kernel_nn<scalar_t, has_bias, MB_SIZE, NB_SIZE>::apply( \
|
||||
A + mb_start * lda, B + nb_start * 4, C + mb_start * ldc + nb_start, \
|
||||
As + mb_start, Bs + nb_start, Bcomp + nb_start, \
|
||||
has_bias ? bias + nb_start : nullptr, K, lda, ldb, ldc);
|
||||
|
||||
template <typename scalar_t, bool has_bias>
|
||||
void tinygemm_kernel(
|
||||
const uint8_t* __restrict__ A,
|
||||
const int8_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
int32_t* __restrict__ Ctmp,
|
||||
const float* __restrict__ As,
|
||||
const float* __restrict__ Bs,
|
||||
const float* __restrict__ bias,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg) {
|
||||
|
||||
// B compensation
|
||||
const int32_t* Bcomp = reinterpret_cast<const int32_t*>(B + block_size_n() * K);
|
||||
|
||||
// pattern: 1-4-16
|
||||
constexpr int64_t BLOCK_M = 4;
|
||||
constexpr int64_t BLOCK_N = 64;
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
for (int64_t mb = 0; mb < MB; ++mb) {
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
for (int64_t nb = 0; nb < NB; ++nb) {
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
|
||||
|
||||
switch(mb_size << 4 | nb_size >> 4) {
|
||||
// mb_size = 1
|
||||
case 0x12: LAUNCH_TINYGEMM_KERNEL_NN(1, 32); break;
|
||||
case 0x14: LAUNCH_TINYGEMM_KERNEL_NN(1, 64); break;
|
||||
// mb_size = 2
|
||||
case 0x22: LAUNCH_TINYGEMM_KERNEL_NN(2, 32); break;
|
||||
case 0x24: LAUNCH_TINYGEMM_KERNEL_NN(2, 64); break;
|
||||
// mb_size = 3
|
||||
case 0x32: LAUNCH_TINYGEMM_KERNEL_NN(3, 32); break;
|
||||
case 0x34: LAUNCH_TINYGEMM_KERNEL_NN(3, 64); break;
|
||||
// mb_size = 4
|
||||
case 0x42: LAUNCH_TINYGEMM_KERNEL_NN(4, 32); break;
|
||||
case 0x44: LAUNCH_TINYGEMM_KERNEL_NN(4, 64); break;
|
||||
default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
void int8_scaled_mm_kernel_impl(
|
||||
scalar_t* __restrict__ out,
|
||||
const uint8_t* __restrict__ mat1,
|
||||
const int8_t* __restrict__ mat2,
|
||||
const float* __restrict__ scales1,
|
||||
const float* __restrict__ scales2,
|
||||
const float* __restrict__ bias,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K) {
|
||||
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
|
||||
// TODO: brgemm u8s8 depends on PyTorch 2.7 release.
|
||||
const bool use_brgemm = false;
|
||||
|
||||
// K + 4 after compensation
|
||||
const int64_t packed_row_size = get_row_size<int8_t>(K);
|
||||
|
||||
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
int64_t mb{0}, nb{0};
|
||||
data_index_init(begin, mb, MB, nb, NB);
|
||||
|
||||
// for brgemm, use int32_t for accumulate
|
||||
alignas(64) int32_t Ctmp[BLOCK_M * BLOCK_N];
|
||||
|
||||
for (int i = begin; i < end; ++i) {
|
||||
UNUSED(i);
|
||||
int mb_start = mb * BLOCK_M;
|
||||
int mb_size = std::min(M - mb_start, BLOCK_M);
|
||||
int nb_start = nb * BLOCK_N;
|
||||
int nb_size = std::min(N - nb_start, BLOCK_N);
|
||||
|
||||
tinygemm_kernel<scalar_t, has_bias>(
|
||||
/* A */ mat1 + mb_start * K,
|
||||
/* B */ mat2 + nb_start * packed_row_size /* nb * BLOCK_N * (K + 4) */,
|
||||
/* C */ out + mb_start * N + nb_start,
|
||||
/* Ctmp*/ Ctmp,
|
||||
/* As */ scales1 + mb_start,
|
||||
/* Bs */ scales2 + nb_start,
|
||||
/* bias*/ bias + nb_start,
|
||||
/* M */ mb_size,
|
||||
/* N */ nb_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ nb_size,
|
||||
/* ldc */ N,
|
||||
/* brg */ use_brgemm);
|
||||
|
||||
// move to the next index
|
||||
data_index_step(mb, MB, nb, NB);
|
||||
}
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// tinygemm interface
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(const uint8_t* __restrict__ A, const int8_t* __restrict__ B, scalar_t* __restrict__ C,
|
||||
int32_t* __restrict__ Ctmp, const float* __restrict__ As, const float* __restrict__ Bs,
|
||||
int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool brg) {
|
||||
tinygemm_kernel<scalar_t, false>(A, B, C, Ctmp, As, Bs, nullptr, M, N, K, lda, ldb, ldc, brg);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
|
||||
template void tinygemm_kernel<TYPE>( \
|
||||
const uint8_t* __restrict__ A, const int8_t* __restrict__ B, TYPE* __restrict__ C, \
|
||||
int32_t* __restrict__ Ctmp, const float* __restrict__ As, const float* __restrict__ Bs, \
|
||||
int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool brg)
|
||||
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> per_token_quant_int8_cpu(at::Tensor& A) {
|
||||
RECORD_FUNCTION("sgl-kernel::per_token_quant_int8_cpu", std::vector<c10::IValue>({A}));
|
||||
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(A);
|
||||
CHECK_DIM(2, A);
|
||||
|
||||
int64_t M = A.size(0);
|
||||
int64_t K = A.size(1);
|
||||
int64_t lda = A.stride(0);
|
||||
|
||||
const auto st = A.scalar_type();
|
||||
TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf,
|
||||
"per_token_quant_int8: expect A to be bfloat16 or half.");
|
||||
|
||||
auto Aq = at::empty({M, K}, A.options().dtype(at::kByte));
|
||||
auto As = at::empty({M}, A.options().dtype(at::kFloat));
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "per_token_quant_int8", [&] {
|
||||
uint8_t* __restrict__ Aq_data = Aq.data_ptr<uint8_t>();
|
||||
float* __restrict__ As_data = As.data_ptr<float>();
|
||||
const scalar_t* __restrict__ A_data = A.data_ptr<scalar_t>();
|
||||
|
||||
at::parallel_for(0, M, 0, [&] (int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
quantize_row_int8<scalar_t>(
|
||||
Aq_data + m * K,
|
||||
As_data[m],
|
||||
A_data + m * lda,
|
||||
K);
|
||||
}
|
||||
});
|
||||
});
|
||||
return std::make_tuple(Aq, As);
|
||||
}
|
||||
|
||||
// weight : static, per-channel, symmetric
|
||||
// activation : dynamic, per-token, symmetric
|
||||
//
|
||||
// mat1 : [M, K]
|
||||
// mat2 : [N, K]
|
||||
// scales1 : [M]
|
||||
// scales2 : [N]
|
||||
// bias : [N]
|
||||
// out : [M, N]
|
||||
//
|
||||
at::Tensor int8_scaled_mm_cpu(at::Tensor& mat1, at::Tensor& mat2,
|
||||
at::Tensor& scales1, at::Tensor& scales2,
|
||||
std::optional<at::Tensor>& bias, at::ScalarType out_dtype, bool is_vnni) {
|
||||
RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales1, scales2, bias}));
|
||||
|
||||
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
|
||||
|
||||
CHECK_INPUT(mat1);
|
||||
CHECK_INPUT(mat2);
|
||||
CHECK_INPUT(scales1);
|
||||
CHECK_INPUT(scales2);
|
||||
CHECK_DIM(2, mat1);
|
||||
CHECK_DIM(2, mat2);
|
||||
|
||||
int64_t M = mat1.size(0);
|
||||
int64_t N = mat2.size(0);
|
||||
int64_t K = mat1.size(1);
|
||||
|
||||
// see [NOTE]: s8s8 igemm compensation in avx512-vnni
|
||||
CHECK_EQ(mat2.size(1), (int64_t)(is_vnni ? K + sizeof(int32_t) : K));
|
||||
CHECK_EQ(scales1.numel(), M);
|
||||
CHECK_EQ(scales2.numel(), N);
|
||||
|
||||
TORCH_CHECK(mat1.scalar_type() == at::kByte, "int8_scaled_mm: expect mat1 to be uint8.");
|
||||
TORCH_CHECK(mat2.scalar_type() == at::kChar, "int8_scaled_mm: expect mat2 to be int8.");
|
||||
TORCH_CHECK(scales1.scalar_type() == at::kFloat && scales2.scalar_type() == at::kFloat,
|
||||
"int8_scaled_mm: expect scales to be float32.");
|
||||
|
||||
auto out = at::empty({M, N}, mat1.options().dtype(out_dtype));
|
||||
|
||||
const bool has_bias = bias.has_value();
|
||||
const float* bias_data = nullptr;
|
||||
if (has_bias) {
|
||||
CHECK_EQ(bias.value().size(0), N);
|
||||
bias_data = bias.value().data_ptr<float>();
|
||||
}
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "int8_scaled_mm_kernel_impl", [&] {
|
||||
int8_scaled_mm_kernel_impl<scalar_t>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
mat1.data_ptr<uint8_t>(),
|
||||
packed_w.data_ptr<int8_t>(),
|
||||
scales1.data_ptr<float>(),
|
||||
scales2.data_ptr<float>(),
|
||||
bias_data,
|
||||
M,
|
||||
N,
|
||||
K);
|
||||
});
|
||||
return out;
|
||||
}
|
||||
|
||||
// fused `per_token_quant_int8_cpu` and `int8_scaled_mm_cpu`
|
||||
at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2, at::Tensor& scales2,
|
||||
const std::optional<at::Tensor>& bias, at::ScalarType out_dtype, bool is_vnni) {
|
||||
RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales2, bias}));
|
||||
|
||||
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
|
||||
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
|
||||
CHECK_INPUT(mat2);
|
||||
CHECK_INPUT(scales2);
|
||||
CHECK_DIM(2, mat1);
|
||||
CHECK_DIM(2, mat2);
|
||||
|
||||
int64_t M = mat1.size(0);
|
||||
int64_t N = mat2.size(0);
|
||||
int64_t K = mat1.size(1);
|
||||
int64_t lda = mat1.stride(0);
|
||||
|
||||
// see [NOTE]: s8s8 igemm compensation in avx512-vnni
|
||||
CHECK_EQ(mat2.size(1), (int64_t)(is_vnni ? K + sizeof(int32_t) : K));
|
||||
CHECK_EQ(scales2.numel(), N);
|
||||
|
||||
const auto st = mat1.scalar_type();
|
||||
TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf,
|
||||
"int8_scaled_mm_with_quant: expect A to be bfloat16 or half.");
|
||||
TORCH_CHECK(st == out_dtype,
|
||||
"int8_scaled_mm_with_quant: expect A has same dtype with out_dtype.");
|
||||
TORCH_CHECK(mat2.scalar_type() == at::kChar,
|
||||
"int8_scaled_mm_with_quant: expect mat2 to be int8.");
|
||||
TORCH_CHECK(scales2.scalar_type() == at::kFloat,
|
||||
"int8_scaled_mm_with_quant: expect scales to be float32.");
|
||||
|
||||
const int64_t buffer_size = M * K + M * sizeof(float);
|
||||
auto buffer = at::empty({buffer_size}, mat1.options().dtype(at::kByte));
|
||||
auto out = at::empty({M, N}, mat1.options().dtype(out_dtype));
|
||||
|
||||
const bool has_bias = bias.has_value();
|
||||
const float* bias_data = nullptr;
|
||||
if (has_bias) {
|
||||
CHECK_EQ(bias.value().size(0), N);
|
||||
bias_data = bias.value().data_ptr<float>();
|
||||
}
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "int8_scaled_mm_with_quant_kernel_impl", [&] {
|
||||
uint8_t* __restrict__ Aq_data = buffer.data_ptr<uint8_t>();
|
||||
float* __restrict__ As_data = (float*)((void*)(Aq_data + M * K));
|
||||
const scalar_t* __restrict__ A_data = mat1.data_ptr<scalar_t>();
|
||||
|
||||
at::parallel_for(0, M, 0, [&] (int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
quantize_row_int8<scalar_t>(
|
||||
Aq_data + m * K,
|
||||
As_data[m],
|
||||
A_data + m * lda,
|
||||
K);
|
||||
}
|
||||
});
|
||||
|
||||
int8_scaled_mm_kernel_impl<scalar_t>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
Aq_data,
|
||||
packed_w.data_ptr<int8_t>(),
|
||||
As_data,
|
||||
scales2.data_ptr<float>(),
|
||||
bias_data,
|
||||
M,
|
||||
N,
|
||||
K);
|
||||
});
|
||||
return out;
|
||||
}
|
||||
1330
csrc/cpu/sgl-kernels/moe.cpp
Normal file
1330
csrc/cpu/sgl-kernels/moe.cpp
Normal file
File diff suppressed because it is too large
Load Diff
502
csrc/cpu/sgl-kernels/moe_fp8.cpp
Normal file
502
csrc/cpu/sgl-kernels/moe_fp8.cpp
Normal file
@ -0,0 +1,502 @@
|
||||
// Adapted from
|
||||
// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu
|
||||
|
||||
#include "common.h"
|
||||
#include "gemm.h"
|
||||
#include "vec.h"
|
||||
|
||||
// clang-format off
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t size) {
|
||||
using Vec = at::vec::Vectorized<scalar_t>;
|
||||
// no remainder
|
||||
#pragma GCC unroll 4
|
||||
for (int64_t d = 0; d < size; d += Vec::size()) {
|
||||
Vec data = Vec::loadu(input + d);
|
||||
data.store(out + d);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_mul_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, float weight, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
const fVec weight_vec = fVec(weight);
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
bVec x = bVec::loadu(input + d);
|
||||
fVec x0, x1;
|
||||
std::tie(x0, x1) = at::vec::convert_to_float(x);
|
||||
x0 = x0 * weight_vec;
|
||||
x1 = x1 * weight_vec;
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d] * weight);
|
||||
}
|
||||
}
|
||||
|
||||
// acc from [topk, K] to [K]
|
||||
template <typename scalar_t>
|
||||
inline void sum_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t topk, int64_t K) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
if (topk == 1) {
|
||||
// do copy for topk = 1
|
||||
copy_stub(out, input, K);
|
||||
} else {
|
||||
// do sum for topk != 1
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= K - kVecSize; d += kVecSize) {
|
||||
fVec sum_fvec0 = fVec(0.f);
|
||||
fVec sum_fvec1 = fVec(0.f);
|
||||
for (int t = 0; t < topk; ++t) {
|
||||
bVec x_bvec = bVec::loadu(input + t * K + d);
|
||||
fVec x_fvec0, x_fvec1;
|
||||
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
||||
|
||||
sum_fvec0 += x_fvec0;
|
||||
sum_fvec1 += x_fvec1;
|
||||
}
|
||||
bVec out_bvec = convert_from_float_ext<scalar_t>(sum_fvec0, sum_fvec1);
|
||||
out_bvec.store(out + d);
|
||||
}
|
||||
for (; d < K; ++d) {
|
||||
float sum_val = 0.f;
|
||||
for (int t = 0; t < topk; ++t) {
|
||||
sum_val += static_cast<float>(input[t * K + d]);
|
||||
}
|
||||
out[d] = static_cast<scalar_t>(sum_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// out = input + input2 * scale
|
||||
template <typename scalar_t>
|
||||
inline void add_mul_stub(
|
||||
scalar_t* __restrict__ out,
|
||||
const scalar_t* __restrict__ input,
|
||||
const scalar_t* __restrict__ input2,
|
||||
float scale,
|
||||
int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
const fVec s_vec = fVec(scale);
|
||||
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
bVec x_bvec = bVec::loadu(input + d);
|
||||
fVec x0, x1;
|
||||
std::tie(x0, x1) = at::vec::convert_to_float(x_bvec);
|
||||
|
||||
bVec y_bvec = bVec::loadu(input2 + d);
|
||||
fVec y0, y1;
|
||||
std::tie(y0, y1) = at::vec::convert_to_float(y_bvec);
|
||||
|
||||
x0 = x0 + y0 * s_vec;
|
||||
x1 = x1 + y1 * s_vec;
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d] + float(input2[d]) * scale);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void silu_and_mul_stub(
|
||||
scalar_t* __restrict__ out,
|
||||
const scalar_t* __restrict__ input,
|
||||
const scalar_t* __restrict__ input2,
|
||||
int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
const fVec one = fVec(1.f);
|
||||
|
||||
// no remainder
|
||||
#pragma GCC unroll 4
|
||||
for (int64_t d = 0; d < size; d += bVec::size()) {
|
||||
bVec x = bVec::loadu(input + d);
|
||||
fVec x0, x1;
|
||||
std::tie(x0, x1) = at::vec::convert_to_float(x);
|
||||
bVec y = bVec::loadu(input2 + d);
|
||||
fVec y0, y1;
|
||||
std::tie(y0, y1) = at::vec::convert_to_float(y);
|
||||
x0 = x0 / (one + x0.neg().exp_u20());
|
||||
x1 = x1 / (one + x1.neg().exp_u20());
|
||||
x0 = x0 * y0;
|
||||
x1 = x1 * y1;
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
template <typename scalar_t>
|
||||
void fused_experts_fp8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic0,
|
||||
scalar_t* __restrict__ ic1,
|
||||
scalar_t* __restrict__ ic2,
|
||||
scalar_t* __restrict__ A_tmp,
|
||||
scalar_t* __restrict__ B_tmp,
|
||||
float* __restrict__ C_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w1,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
int64_t block_size_N,
|
||||
int64_t block_size_K,
|
||||
const float* __restrict__ topk_weights,
|
||||
const int32_t* __restrict__ sorted_ids,
|
||||
const int32_t* __restrict__ expert_ids,
|
||||
const int32_t* __restrict__ offsets,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t E,
|
||||
int64_t topk,
|
||||
int64_t num_tokens_post_pad) {
|
||||
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
|
||||
// stage 1: intermediate_cache0 = hidden_states @ w1
|
||||
const int64_t MB = div_up(num_tokens_post_pad, BLOCK_M);
|
||||
const int64_t NB = div_up(2 * N, BLOCK_N);
|
||||
int64_t scale_size_N = div_up(2 * N, block_size_N);
|
||||
int64_t scale_size_K = div_up(K, block_size_K);
|
||||
int64_t blocks_n_per_group = block_size_N / BLOCK_N;
|
||||
|
||||
const int64_t stride_e = 2 * N * K;
|
||||
const int64_t stride_n = K;
|
||||
|
||||
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
// get local pointers
|
||||
int tid = at::get_thread_num();
|
||||
scalar_t* __restrict__ A = A_tmp + tid * BLOCK_M * K;
|
||||
|
||||
bool is_brgemm_used = false;
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB;
|
||||
int64_t nb = i % NB;
|
||||
|
||||
int64_t n_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
// B shape [K, n_size] in vnni format
|
||||
int32_t expert_id = expert_ids[mb];
|
||||
const at::Float8_e4m3fn* __restrict__ B = packed_w1 + expert_id * stride_e + nb * BLOCK_N * stride_n;
|
||||
const float* __restrict__ Bs = w1s + expert_id * scale_size_N * scale_size_K + (nb / blocks_n_per_group) * scale_size_K;
|
||||
|
||||
// 1.a load A
|
||||
const int32_t* A_ids = sorted_ids + mb * BLOCK_M;
|
||||
int64_t m_size = offsets[mb + 1] - offsets[mb];
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(m_size);
|
||||
is_brgemm_used = is_brgemm_used || use_brgemm;
|
||||
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
int32_t index = A_ids[m] / topk;
|
||||
copy_stub(A + m * K, input + index * K, K);
|
||||
}
|
||||
|
||||
const int64_t offset = offsets[mb];
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ A,
|
||||
/* B */ B,
|
||||
/* C */ ic0 + offset * 2 * N + nb * BLOCK_N,
|
||||
/* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N),
|
||||
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
|
||||
/* scale */ Bs,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ 2 * N,
|
||||
/* brg */ use_brgemm,
|
||||
/* block_size_K */ block_size_K);
|
||||
}
|
||||
|
||||
if (is_brgemm_used) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
|
||||
// stage 1.5: intermediate_cache1 = silu(intermediate_cache0)
|
||||
at::parallel_for(0, M * topk, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
silu_and_mul_stub(
|
||||
ic1 + m * N,
|
||||
ic0 + m * 2 * N,
|
||||
ic0 + m * 2 * N + N,
|
||||
N);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 2: intermediate_cache2 = intermediate_cache1 @ w2
|
||||
// w2 : [E, K, N] as [E, OC, IC]
|
||||
const int64_t OC = K; // rename K as OC
|
||||
const int64_t IC = N; // rename N as IC
|
||||
const int64_t MB2 = MB;
|
||||
const int64_t NB2 = div_up(OC, BLOCK_N);
|
||||
scale_size_N = div_up(K, block_size_N);
|
||||
scale_size_K = div_up(N, block_size_K);
|
||||
const int64_t stride_e2 = OC * IC;
|
||||
const int64_t stride_oc = IC;
|
||||
|
||||
// parallel on [MB2, NB2]
|
||||
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
|
||||
int tid = at::get_thread_num();
|
||||
alignas(64) scalar_t C[BLOCK_M * BLOCK_K];
|
||||
|
||||
bool is_brgemm_used = false;
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB2;
|
||||
int64_t nb = i % NB2;
|
||||
|
||||
int64_t m_size = offsets[mb + 1] - offsets[mb];
|
||||
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(m_size);
|
||||
is_brgemm_used = is_brgemm_used || use_brgemm;
|
||||
|
||||
// A ptr from ic1 of [M * topk, N] in sorted order
|
||||
// so as to avoid copy A to tmp buffer again
|
||||
const scalar_t* __restrict__ A = ic1 + offsets[mb] * N;
|
||||
const int32_t* A_ids = sorted_ids + mb * BLOCK_M;
|
||||
|
||||
// B shape [IC, n_size] in vnni format
|
||||
int32_t expert_id = expert_ids[mb];
|
||||
const at::Float8_e4m3fn* __restrict__ B = packed_w2 + expert_id * stride_e2 + nb * BLOCK_N * stride_oc;
|
||||
const float* __restrict__ Bs = w2s + expert_id * scale_size_N * scale_size_K + (nb / blocks_n_per_group) * scale_size_K;
|
||||
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ A,
|
||||
/* B */ B,
|
||||
/* C */ C,
|
||||
/* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N),
|
||||
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
|
||||
/* scale */ Bs,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ IC,
|
||||
/* lda */ IC,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N,
|
||||
/* brg */ use_brgemm,
|
||||
/* block_size_K */ block_size_K);
|
||||
|
||||
// 2.b copy from C to ic2 in original order
|
||||
// and also mul topk_weights in float32
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
int32_t index = A_ids[m];
|
||||
float weight = topk_weights[index];
|
||||
copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size);
|
||||
}
|
||||
}
|
||||
|
||||
if (is_brgemm_used) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
|
||||
// stage 3: out = intermediate_cache2.sum(dim=1)
|
||||
// from [M, topk, K] to [M, K]
|
||||
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
sum_stub(output + m * K, ic2 + m * topk * K, topk, K);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#define INSTANTIATE_MOE_FP8_TEMPLATE(TYPE) \
|
||||
template void fused_experts_fp8_kernel_impl<TYPE>( \
|
||||
TYPE* __restrict__ output, \
|
||||
TYPE* __restrict__ ic0, \
|
||||
TYPE* __restrict__ ic1, \
|
||||
TYPE* __restrict__ ic2, \
|
||||
TYPE* __restrict__ A_tmp, \
|
||||
TYPE* __restrict__ B_tmp, \
|
||||
float* __restrict__ C_tmp, \
|
||||
const TYPE* __restrict__ input, \
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w1, \
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w2, \
|
||||
const float* __restrict__ w1s, \
|
||||
const float* __restrict__ w2s, \
|
||||
int64_t block_size_N, \
|
||||
int64_t block_size_K, \
|
||||
const float* __restrict__ topk_weights, \
|
||||
const int32_t* __restrict__ sorted_ids, \
|
||||
const int32_t* __restrict__ expert_ids, \
|
||||
const int32_t* __restrict__ offsets, \
|
||||
int64_t M, \
|
||||
int64_t N, \
|
||||
int64_t K, \
|
||||
int64_t E, \
|
||||
int64_t topk, \
|
||||
int64_t num_tokens_post_pad)
|
||||
|
||||
INSTANTIATE_MOE_FP8_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_MOE_FP8_TEMPLATE(at::Half);
|
||||
|
||||
template <typename scalar_t>
|
||||
void shared_expert_fp8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic0,
|
||||
scalar_t* __restrict__ ic1,
|
||||
scalar_t* __restrict__ B_tmp,
|
||||
float* __restrict__ C_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w1,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
int64_t block_size_N,
|
||||
int64_t block_size_K,
|
||||
const scalar_t* __restrict__ fused_experts_out,
|
||||
float routed_scaling_factor,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K) {
|
||||
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
|
||||
// stage 1: intermediate_cache0 = hidden_states @ w1
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(2 * N, BLOCK_N);
|
||||
int64_t scale_size_K = div_up(K, block_size_K);
|
||||
int64_t blocks_n_per_group = block_size_N / BLOCK_N;
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(M);
|
||||
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
int tid = at::get_thread_num();
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB;
|
||||
int64_t nb = i % NB;
|
||||
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
|
||||
int64_t n_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ input + mb * BLOCK_M * K,
|
||||
/* B */ packed_w1 + nb * BLOCK_N * K,
|
||||
/* C */ ic0 + mb * BLOCK_M * 2 * N + nb * BLOCK_N,
|
||||
/* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N),
|
||||
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
|
||||
/* scale */ w1s + (nb / blocks_n_per_group) * scale_size_K,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ 2 * N,
|
||||
/* brg */ use_brgemm,
|
||||
/* block_size_K */ block_size_K);
|
||||
}
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
|
||||
// stage 1.5: intermediate_cache1 = silu(intermediate_cache0)
|
||||
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
silu_and_mul_stub(
|
||||
ic1 + m * N,
|
||||
ic0 + m * 2 * N,
|
||||
ic0 + m * 2 * N + N,
|
||||
N);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 2: intermediate_cache2 = intermediate_cache1 @ w2
|
||||
// w2 : [K, N] as [OC, IC]
|
||||
const int64_t OC = K; // rename K as OC
|
||||
const int64_t IC = N; // rename N as IC
|
||||
const int64_t MB2 = MB;
|
||||
const int64_t NB2 = div_up(K, BLOCK_N);
|
||||
scale_size_K = div_up(N, block_size_K);
|
||||
|
||||
// parallel on [MB2, NB2]
|
||||
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
|
||||
int tid = at::get_thread_num();
|
||||
alignas(64) scalar_t C[BLOCK_M * BLOCK_K];
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB2;
|
||||
int64_t nb = i % NB2;
|
||||
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
|
||||
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
// 2.a gemm: C = A @ B
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ ic1 + mb * BLOCK_M * N,
|
||||
/* B */ packed_w2 + nb * BLOCK_N * N,
|
||||
/* C */ C,
|
||||
/* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N),
|
||||
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
|
||||
/* scale */ w2s + (nb / blocks_n_per_group) * scale_size_K,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ IC,
|
||||
/* lda */ IC,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N,
|
||||
/* brg */ use_brgemm,
|
||||
/* block_size_K */ block_size_K);
|
||||
|
||||
// 2.b copy from C to output and add fused_experts_out
|
||||
scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N;
|
||||
const scalar_t* __restrict__ fused_out = fused_experts_out + mb * BLOCK_M * K + nb * BLOCK_N;
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
}
|
||||
|
||||
#define INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(TYPE) \
|
||||
template void shared_expert_fp8_kernel_impl<TYPE>( \
|
||||
TYPE* __restrict__ output, \
|
||||
TYPE* __restrict__ ic0, \
|
||||
TYPE* __restrict__ ic1, \
|
||||
TYPE* __restrict__ B_tmp, \
|
||||
float* __restrict__ C_tmp, \
|
||||
const TYPE* __restrict__ input, \
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w1, \
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w2, \
|
||||
const float* __restrict__ w1s, \
|
||||
const float* __restrict__ w2s, \
|
||||
int64_t block_size_N, \
|
||||
int64_t block_size_K, \
|
||||
const TYPE* __restrict__ fused_experts_out, \
|
||||
float routed_scaling_factor, \
|
||||
int64_t M, \
|
||||
int64_t N, \
|
||||
int64_t K)
|
||||
|
||||
INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(at::Half);
|
||||
769
csrc/cpu/sgl-kernels/moe_int8.cpp
Normal file
769
csrc/cpu/sgl-kernels/moe_int8.cpp
Normal file
@ -0,0 +1,769 @@
|
||||
// Adapted from
|
||||
// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu
|
||||
|
||||
#include "common.h"
|
||||
#include "vec.h"
|
||||
#include "gemm.h"
|
||||
|
||||
// clang-format off
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t size) {
|
||||
using Vec = at::vec::Vectorized<scalar_t>;
|
||||
// no remainder
|
||||
#pragma GCC unroll 4
|
||||
for (int64_t d = 0; d < size; d += Vec::size()) {
|
||||
Vec data = Vec::loadu(input + d);
|
||||
data.store(out + d);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void copy_stub<uint8_t>(uint8_t* __restrict__ out, const uint8_t* __restrict__ input, int64_t size) {
|
||||
// size might be 64x + 32
|
||||
std::memcpy(out, input, size * sizeof(uint8_t));
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_mul_stub(scalar_t* __restrict__ out, const float* __restrict__ input, float weight, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
const fVec weight_vec = fVec(weight);
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
fVec data0 = fVec::loadu(input + d) * weight_vec;
|
||||
fVec data1 = fVec::loadu(input + d + fVec::size()) * weight_vec;
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d] * weight);
|
||||
}
|
||||
}
|
||||
|
||||
// acc from [topk, K] to [K]
|
||||
template <typename scalar_t>
|
||||
inline void sum_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t topk, int64_t K) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
if (topk == 1) {
|
||||
// do copy for topk = 1
|
||||
copy_stub(out, input, K);
|
||||
} else {
|
||||
// do sum for topk != 1
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= K - kVecSize; d += kVecSize) {
|
||||
fVec sum_fvec0 = fVec(0.f);
|
||||
fVec sum_fvec1 = fVec(0.f);
|
||||
for (int t = 0; t < topk; ++t) {
|
||||
bVec x_bvec = bVec::loadu(input + t * K + d);
|
||||
fVec x_fvec0, x_fvec1;
|
||||
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
||||
|
||||
sum_fvec0 += x_fvec0;
|
||||
sum_fvec1 += x_fvec1;
|
||||
}
|
||||
bVec out_bvec = convert_from_float_ext<scalar_t>(sum_fvec0, sum_fvec1);
|
||||
out_bvec.store(out + d);
|
||||
}
|
||||
for (; d < K; ++d) {
|
||||
float sum_val = 0.f;
|
||||
for (int t = 0; t < topk; ++t) {
|
||||
sum_val += static_cast<float>(input[t * K + d]);
|
||||
}
|
||||
out[d] = static_cast<scalar_t>(sum_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// out = input + input2 * scale
|
||||
template <typename scalar_t>
|
||||
inline void add_mul_stub(scalar_t* __restrict__ out, const float* __restrict__ input,
|
||||
const scalar_t* __restrict__ input2, float scale, int64_t size) {
|
||||
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
const fVec s_vec = fVec(scale);
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
fVec x0 = fVec::loadu(input + d);
|
||||
fVec x1 = fVec::loadu(input + d + fVec::size());
|
||||
|
||||
bVec y_bvec = bVec::loadu(input2 + d);
|
||||
fVec y0, y1;
|
||||
std::tie(y0, y1) = at::vec::convert_to_float(y_bvec);
|
||||
|
||||
x0 = x0 + y0 * s_vec;
|
||||
x1 = x1 + y1 * s_vec;
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d] + float(input2[d]) * scale);
|
||||
}
|
||||
}
|
||||
|
||||
/// gemm for w13
|
||||
template <typename scalar_t, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_vnni {
|
||||
static inline void apply(
|
||||
const uint8_t* __restrict__ A, const int8_t* __restrict__ B0, const int8_t* __restrict__ B1, scalar_t* __restrict__ C,
|
||||
const float* __restrict__ As, const float* __restrict__ Bs0, const float* __restrict__ Bs1,
|
||||
const int32_t* __restrict__ Bcomp0, const int32_t* __restrict__ Bcomp1,
|
||||
int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
|
||||
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
template <int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_vnni<at::BFloat16, BLOCK_M, BLOCK_N> {
|
||||
static inline void apply(
|
||||
const uint8_t* __restrict__ A, const int8_t* __restrict__ B0, const int8_t* __restrict__ B1, at::BFloat16* __restrict__ C,
|
||||
const float* __restrict__ As, const float* __restrict__ Bs0, const float* __restrict__ Bs1,
|
||||
const int32_t* __restrict__ Bcomp0, const int32_t* __restrict__ Bcomp1,
|
||||
int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
|
||||
|
||||
constexpr int ROWS = BLOCK_M;
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
static_assert(COLS % 2 == 0);
|
||||
|
||||
__m512i va;
|
||||
__m512i vb0[COLS];
|
||||
__m512i vb1[COLS];
|
||||
__m512i vc0[ROWS * COLS];
|
||||
__m512i vc1[ROWS * COLS];
|
||||
__m512i vcomp0[COLS];
|
||||
__m512i vcomp1[COLS];
|
||||
__m512 was;
|
||||
__m512 vbs0[COLS];
|
||||
__m512 vbs1[COLS];
|
||||
|
||||
auto loadc = [&](auto i) {
|
||||
vc0[i] = _mm512_set1_epi32(0);
|
||||
vc1[i] = _mm512_set1_epi32(0);
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(loadc);
|
||||
|
||||
const int64_t K4 = K >> 2;
|
||||
const int64_t lda4 = lda >> 2;
|
||||
const int64_t ldb4 = ldb; // ldb * 4 >> 2;
|
||||
const int32_t* a_ptr = reinterpret_cast<const int32_t*>(A);
|
||||
const int32_t* b0_ptr = reinterpret_cast<const int32_t*>(B0);
|
||||
const int32_t* b1_ptr = reinterpret_cast<const int32_t*>(B1);
|
||||
|
||||
auto compute = [&](auto i, int64_t k) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
if constexpr (col == 0) {
|
||||
va = _mm512_set1_epi32(a_ptr[row * lda4 + k]);
|
||||
}
|
||||
if constexpr (row == 0) {
|
||||
vb0[col] = _mm512_loadu_si512(b0_ptr + k * ldb4 + col * 16);
|
||||
vb1[col] = _mm512_loadu_si512(b1_ptr + k * ldb4 + col * 16);
|
||||
}
|
||||
vc0[i] = _mm512_dpbusd_epi32(vc0[i], va, vb0[col]);
|
||||
vc1[i] = _mm512_dpbusd_epi32(vc1[i], va, vb1[col]);
|
||||
};
|
||||
for (int64_t k = 0; k < K4; ++k) {
|
||||
Unroll<ROWS * COLS>{}(compute, k);
|
||||
}
|
||||
|
||||
auto scalec = [&](auto i) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
// load a scale
|
||||
if constexpr(col == 0) {
|
||||
was = _mm512_set1_ps(As[row]);
|
||||
}
|
||||
// load b scale and vcomp
|
||||
if constexpr (row == 0) {
|
||||
vbs0[col] = _mm512_loadu_ps(Bs0 + col * 16);
|
||||
vbs1[col] = _mm512_loadu_ps(Bs1 + col * 16);
|
||||
vcomp0[col] = _mm512_loadu_si512(Bcomp0 + col * 16);
|
||||
vcomp1[col] = _mm512_loadu_si512(Bcomp1 + col * 16);
|
||||
}
|
||||
__m512 c0 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc0[i], vcomp0[col]));
|
||||
__m512 c1 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc1[i], vcomp1[col]));
|
||||
vc0[i] = _mm512_castps_si512(_mm512_mul_ps(_mm512_mul_ps(c0, was), vbs0[col]));
|
||||
vc1[i] = _mm512_castps_si512(_mm512_mul_ps(_mm512_mul_ps(c1, was), vbs1[col]));
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(scalec);
|
||||
|
||||
using Vec = at::vec::Vectorized<float>;
|
||||
const Vec one = Vec(1.f);
|
||||
auto storec = [&](auto i) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
// for COLS = 2, 4 use 512bit store
|
||||
if constexpr (col % 2 == 0) {
|
||||
Vec x0 = _mm512_castsi512_ps(vc0[row * COLS + col + 0]);
|
||||
Vec x1 = _mm512_castsi512_ps(vc0[row * COLS + col + 1]);
|
||||
Vec y0 = _mm512_castsi512_ps(vc1[row * COLS + col + 0]);
|
||||
Vec y1 = _mm512_castsi512_ps(vc1[row * COLS + col + 1]);
|
||||
// silu
|
||||
x0 = x0 / (one + x0.neg().exp_u20());
|
||||
x1 = x1 / (one + x1.neg().exp_u20());
|
||||
// mul
|
||||
x0 = x0 * y0;
|
||||
x1 = x1 * y1;
|
||||
|
||||
_mm512_storeu_si512(
|
||||
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)),
|
||||
(__m512i)(_mm512_cvtne2ps_pbh(__m512(x1), __m512(x0))));
|
||||
}
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(storec);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#define LAUNCH_TINYGEMM_KERNEL_VNNI(MB_SIZE, NB_SIZE) \
|
||||
tinygemm_kernel_vnni<scalar_t, MB_SIZE, NB_SIZE>::apply( \
|
||||
A + mb_start * lda, B0 + nb_start * 4, B1 + nb_start * 4, \
|
||||
C + mb_start * ldc + nb_start, As + mb_start, \
|
||||
Bs0 + nb_start, Bs1 + nb_start, Bcomp0 + nb_start, Bcomp1 + nb_start,\
|
||||
K, lda, ldb, ldc);
|
||||
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const uint8_t* __restrict__ A,
|
||||
const int8_t* __restrict__ B0,
|
||||
const int8_t* __restrict__ B1,
|
||||
scalar_t* __restrict__ C,
|
||||
const float* __restrict__ As,
|
||||
const float* __restrict__ Bs0,
|
||||
const float* __restrict__ Bs1,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc) {
|
||||
|
||||
const int32_t* Bcomp0 = reinterpret_cast<const int32_t*>(B0 + block_size_n() * K);
|
||||
const int32_t* Bcomp1 = reinterpret_cast<const int32_t*>(B1 + block_size_n() * K);
|
||||
|
||||
// pattern: 1-(2+2)-(8+8)
|
||||
constexpr int64_t BLOCK_M = 4;
|
||||
constexpr int64_t BLOCK_N = 32;
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
for (int mb = 0; mb < MB; ++mb) {
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
for (int64_t nb = 0; nb < NB; ++nb) {
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
|
||||
|
||||
switch(mb_size << 4 | nb_size >> 4) {
|
||||
case 0x12: LAUNCH_TINYGEMM_KERNEL_VNNI(1, 32); break;
|
||||
case 0x22: LAUNCH_TINYGEMM_KERNEL_VNNI(2, 32); break;
|
||||
case 0x32: LAUNCH_TINYGEMM_KERNEL_VNNI(3, 32); break;
|
||||
case 0x42: LAUNCH_TINYGEMM_KERNEL_VNNI(4, 32); break;
|
||||
default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// gemm for w2
|
||||
template <typename scalar_t, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_vnni2 {
|
||||
static inline void apply(
|
||||
const uint8_t* __restrict__ A, const int8_t* __restrict__ B, float* __restrict__ C,
|
||||
const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp,
|
||||
int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
|
||||
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
template <int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_vnni2<at::BFloat16, BLOCK_M, BLOCK_N> {
|
||||
static inline void apply(
|
||||
const uint8_t* __restrict__ A, const int8_t* __restrict__ B, float* __restrict__ C,
|
||||
const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp,
|
||||
int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
|
||||
|
||||
constexpr int ROWS = BLOCK_M;
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
static_assert(COLS % 2 == 0);
|
||||
|
||||
__m512i va;
|
||||
__m512i vb[COLS];
|
||||
__m512i vc[ROWS * COLS];
|
||||
__m512i vcomp[COLS];
|
||||
__m512 was;
|
||||
__m512 vbs[COLS];
|
||||
|
||||
auto loadc = [&](auto i) {
|
||||
vc[i] = _mm512_set1_epi32(0);
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(loadc);
|
||||
|
||||
const int64_t K4 = K >> 2;
|
||||
const int64_t lda4 = lda >> 2;
|
||||
const int64_t ldb4 = ldb; // ldb * 4 >> 2;
|
||||
const int32_t* a_ptr = reinterpret_cast<const int32_t*>(A);
|
||||
const int32_t* b_ptr = reinterpret_cast<const int32_t*>(B);
|
||||
|
||||
auto compute = [&](auto i, int64_t k) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
if constexpr (col == 0) {
|
||||
va = _mm512_set1_epi32(a_ptr[row * lda4 + k]);
|
||||
}
|
||||
if constexpr (row == 0) {
|
||||
vb[col] = _mm512_loadu_si512(b_ptr + k * ldb4 + col * 16);
|
||||
}
|
||||
vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]);
|
||||
};
|
||||
for (int64_t k = 0; k < K4; ++k) {
|
||||
Unroll<ROWS * COLS>{}(compute, k);
|
||||
}
|
||||
|
||||
auto storec = [&](auto i) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
// load a scale
|
||||
if constexpr(col == 0) {
|
||||
was = _mm512_set1_ps(As[row]);
|
||||
}
|
||||
// load b scale and vcomp per 2 vectors
|
||||
// also load bias if any
|
||||
if constexpr (row == 0) {
|
||||
if constexpr (col % 2 == 0) {
|
||||
vbs[col + 0] = _mm512_loadu_ps(Bs + col * 16);
|
||||
vbs[col + 1] = _mm512_loadu_ps(Bs + col * 16 + 16);
|
||||
vcomp[col + 0] = _mm512_loadu_si512(Bcomp + col * 16);
|
||||
vcomp[col + 1] = _mm512_loadu_si512(Bcomp + col * 16 + 16);
|
||||
}
|
||||
}
|
||||
__m512 x = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[i], vcomp[col]));
|
||||
x = _mm512_mul_ps(_mm512_mul_ps(x, was), vbs[col]);
|
||||
_mm512_storeu_ps(reinterpret_cast<__m512*>(C + row * ldc + col * 16), x);
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(storec);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#define LAUNCH_TINYGEMM_KERNEL_VNNI2(MB_SIZE, NB_SIZE) \
|
||||
tinygemm_kernel_vnni2<scalar_t, MB_SIZE, NB_SIZE>::apply( \
|
||||
A + mb_start * lda, B + nb_start * 4, C + mb_start * ldc + nb_start, \
|
||||
As + mb_start, Bs + nb_start, Bcomp + nb_start, \
|
||||
K, lda, ldb, ldc);
|
||||
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const uint8_t* __restrict__ A,
|
||||
const int8_t* __restrict__ B,
|
||||
float* __restrict__ C,
|
||||
const float* __restrict__ As,
|
||||
const float* __restrict__ Bs,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc) {
|
||||
|
||||
// B compensation
|
||||
const int32_t* Bcomp = reinterpret_cast<const int32_t*>(B + block_size_n() * K);
|
||||
|
||||
// pattern: 1-4-16
|
||||
constexpr int64_t BLOCK_M = 4;
|
||||
constexpr int64_t BLOCK_N = 64;
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
for (int64_t mb = 0; mb < MB; ++mb) {
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
for (int64_t nb = 0; nb < NB; ++nb) {
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
|
||||
|
||||
switch(mb_size << 4 | nb_size >> 4) {
|
||||
case 0x12: LAUNCH_TINYGEMM_KERNEL_VNNI2(1, 32); break;
|
||||
case 0x22: LAUNCH_TINYGEMM_KERNEL_VNNI2(2, 32); break;
|
||||
case 0x32: LAUNCH_TINYGEMM_KERNEL_VNNI2(3, 32); break;
|
||||
case 0x42: LAUNCH_TINYGEMM_KERNEL_VNNI2(4, 32); break;
|
||||
default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
template <typename scalar_t>
|
||||
void fused_experts_int8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic1,
|
||||
scalar_t* __restrict__ ic2,
|
||||
uint8_t* __restrict__ A_tmp,
|
||||
float* __restrict__ C_tmp,
|
||||
uint8_t* __restrict__ Aq_tmp,
|
||||
float* __restrict__ As_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const int8_t* __restrict__ packed_w1,
|
||||
const int8_t* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
const float* __restrict__ topk_weights,
|
||||
const int32_t* __restrict__ sorted_ids,
|
||||
const int32_t* __restrict__ expert_ids,
|
||||
const int32_t* __restrict__ offsets,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t E,
|
||||
int64_t topk,
|
||||
int64_t num_tokens_post_pad) {
|
||||
|
||||
// handle 2 tiles per block
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
|
||||
// stage 0: quantize input to uint8, [M, K]
|
||||
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
quantize_row_int8<scalar_t>(
|
||||
Aq_tmp + m * K,
|
||||
As_tmp[m],
|
||||
input + m * K,
|
||||
K);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 1: intermediate_cache1 = silu(hidden_states @ w1)
|
||||
const int64_t MB = div_up(num_tokens_post_pad, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
|
||||
// strides for w1: [E, 2N, K]
|
||||
TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N);
|
||||
|
||||
// K and N are packed for int8
|
||||
const int64_t packed_K = get_row_size<int8_t>(K);
|
||||
const int64_t packed_N = get_row_size<int8_t>(N);
|
||||
|
||||
const int64_t stride_e = 2 * N * packed_K;
|
||||
const int64_t stride_n = packed_K;
|
||||
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
// get local pointers
|
||||
int tid = at::get_thread_num();
|
||||
uint8_t* __restrict__ A = A_tmp + tid * BLOCK_M * K;
|
||||
|
||||
alignas(64) float As[BLOCK_M];
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB;
|
||||
int64_t nb = i % NB;
|
||||
|
||||
// nb0 from top half and nb1 from bottom half
|
||||
int64_t nb0 = nb, nb1 = nb + NB;
|
||||
int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N);
|
||||
|
||||
// B shape [K, n_size] in vnni format
|
||||
int32_t expert_id = expert_ids[mb];
|
||||
const int8_t* __restrict__ B0 = packed_w1 + expert_id * stride_e + nb0 * BLOCK_N * stride_n;
|
||||
const int8_t* __restrict__ B1 = packed_w1 + expert_id * stride_e + nb1 * BLOCK_N * stride_n;
|
||||
const float* __restrict__ Bs0 = w1s + expert_id * 2 * N + nb0 * BLOCK_N;
|
||||
const float* __restrict__ Bs1 = w1s + expert_id * 2 * N + nb1 * BLOCK_N;
|
||||
|
||||
// 1.a load A
|
||||
const int32_t* A_ids = sorted_ids + mb * BLOCK_M;
|
||||
int64_t m_size = offsets[mb + 1] - offsets[mb];
|
||||
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
int32_t index = A_ids[m] / topk;
|
||||
copy_stub(A + m * K, Aq_tmp + index * K, K);
|
||||
As[m] = As_tmp[index];
|
||||
}
|
||||
|
||||
// fused 1.b: silu_and_mul(A @ B0, A @ B1)
|
||||
const int64_t offset = offsets[mb];
|
||||
tinygemm_kernel(
|
||||
/* A */ A,
|
||||
/* B0 */ B0,
|
||||
/* B1 */ B1,
|
||||
/* C */ ic1 + offset * N + nb * BLOCK_N,
|
||||
/* As */ As,
|
||||
/* Bs0 */ Bs0,
|
||||
/* Bs1 */ Bs1,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ N);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 1.5: quantize ic1 to uint8, [M * topk, N]
|
||||
at::parallel_for(0, M * topk, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
quantize_row_int8<scalar_t>(
|
||||
Aq_tmp + m * N,
|
||||
As_tmp[m],
|
||||
ic1 + m * N,
|
||||
N);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 2: intermediate_cache2 = intermediate_cache1 @ w2
|
||||
// w2 : [E, K, N] as [E, OC, IC]
|
||||
const int64_t OC = K; // rename K as OC
|
||||
const int64_t IC = N; // rename N as IC
|
||||
const int64_t MB2 = MB;
|
||||
const int64_t NB2 = div_up(OC, BLOCK_N);
|
||||
const int64_t stride_e2 = OC * packed_N;
|
||||
const int64_t stride_oc = packed_N;
|
||||
|
||||
// parallel on [MB2, NB2]
|
||||
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
|
||||
// get local pointers
|
||||
int tid = at::get_thread_num();
|
||||
// we won't be using C1 for gemm2
|
||||
float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N;
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB2;
|
||||
int64_t nb = i % NB2;
|
||||
|
||||
int64_t m_size = offsets[mb + 1] - offsets[mb];
|
||||
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
// A ptr from ic1 of [M * topk, N] in sorted order
|
||||
// so as to avoid copy A to tmp buffer again
|
||||
const uint8_t* __restrict__ A = Aq_tmp + offsets[mb] * N;
|
||||
const float* __restrict__ As = As_tmp + offsets[mb];
|
||||
const int32_t* A_ids = sorted_ids + mb * BLOCK_M;
|
||||
|
||||
// B shape [IC, n_size] in vnni format
|
||||
int32_t expert_id = expert_ids[mb];
|
||||
const int8_t* __restrict__ B = packed_w2 + expert_id * stride_e2 + nb * BLOCK_N * stride_oc;
|
||||
const float* __restrict__ Bs = w2s + expert_id * K + nb * BLOCK_N;
|
||||
|
||||
// 2.a gemm: C = A @ B
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ A,
|
||||
/* B */ B,
|
||||
/* C */ C,
|
||||
/* As */ As,
|
||||
/* Bs */ Bs,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ IC,
|
||||
/* lda */ IC,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N);
|
||||
|
||||
// 2.b copy from C to ic2 in original order
|
||||
// and also mul topk_weights in float32
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
int32_t index = A_ids[m];
|
||||
float weight = topk_weights[index];
|
||||
copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// stage 3: out = intermediate_cache2.sum(dim=1)
|
||||
// from [M, topk, K] to [M, K]
|
||||
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
sum_stub(output + m * K, ic2 + m * topk * K, topk, K);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#define INSTANTIATE_MOE_INT8_TEMPLATE(TYPE) \
|
||||
template void fused_experts_int8_kernel_impl<TYPE> ( \
|
||||
TYPE* __restrict__ output, TYPE* __restrict__ ic1, \
|
||||
TYPE* __restrict__ ic2, uint8_t* __restrict__ A_tmp, \
|
||||
float* __restrict__ C_tmp, uint8_t* __restrict__ Aq_tmp, \
|
||||
float* __restrict__ As_tmp, const TYPE* __restrict__ input, \
|
||||
const int8_t* __restrict__ packed_w1, const int8_t* __restrict__ packed_w2, \
|
||||
const float* __restrict__ w1s, const float* __restrict__ w2s, \
|
||||
const float* __restrict__ topk_weights, const int32_t* __restrict__ sorted_ids, \
|
||||
const int32_t* __restrict__ expert_ids, const int32_t* __restrict__ offsets, \
|
||||
int64_t M, int64_t N, int64_t K, int64_t E, int64_t topk, int64_t num_tokens_post_pad)
|
||||
|
||||
INSTANTIATE_MOE_INT8_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_MOE_INT8_TEMPLATE(at::Half);
|
||||
|
||||
template <typename scalar_t>
|
||||
void shared_expert_int8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic1,
|
||||
float* __restrict__ C_tmp,
|
||||
uint8_t* __restrict__ Aq_tmp,
|
||||
float* __restrict__ As_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const int8_t* __restrict__ packed_w1,
|
||||
const int8_t* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
const scalar_t* __restrict__ fused_experts_out,
|
||||
float routed_scaling_factor,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K) {
|
||||
|
||||
// handle 2 tiles per block
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
|
||||
// stage 0: quantize input to uint8, [M, K]
|
||||
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
quantize_row_int8<scalar_t>(
|
||||
Aq_tmp + m * K,
|
||||
As_tmp[m],
|
||||
input + m * K,
|
||||
K);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 1: intermediate_cache1 = silu(hidden_states @ w1)
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
|
||||
TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N);
|
||||
|
||||
// K and N are packed for int8
|
||||
const int64_t packed_K = get_row_size<int8_t>(K);
|
||||
const int64_t packed_N = get_row_size<int8_t>(N);
|
||||
const int64_t stride_n = packed_K;
|
||||
|
||||
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB;
|
||||
int64_t nb = i % NB;
|
||||
|
||||
// nb0 from top half and nb1 from bottom half
|
||||
int64_t nb0 = nb, nb1 = nb + NB;
|
||||
int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N);
|
||||
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
|
||||
|
||||
// A shape [m_size, K]
|
||||
const uint8_t* A = Aq_tmp + mb * BLOCK_M * K;
|
||||
const float* As = As_tmp + mb * BLOCK_M;
|
||||
|
||||
// B shape [K, n_size] in vnni format
|
||||
const int8_t* __restrict__ B0 = packed_w1 + nb0 * BLOCK_N * stride_n;
|
||||
const int8_t* __restrict__ B1 = packed_w1 + nb1 * BLOCK_N * stride_n;
|
||||
const float* __restrict__ Bs0 = w1s + nb0 * BLOCK_N;
|
||||
const float* __restrict__ Bs1 = w1s + nb1 * BLOCK_N;
|
||||
|
||||
// fused 1.b: silu_and_mul(A @ B0, A @ B1)
|
||||
tinygemm_kernel(
|
||||
/* A */ A,
|
||||
/* B0 */ B0,
|
||||
/* B1 */ B1,
|
||||
/* C */ ic1 + mb * BLOCK_M * N + nb * BLOCK_N,
|
||||
/* As */ As,
|
||||
/* Bs0 */ Bs0,
|
||||
/* Bs1 */ Bs1,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ N);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 1.5: quantize ic1 to uint8, [M * topk, N]
|
||||
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
quantize_row_int8<scalar_t>(
|
||||
Aq_tmp + m * N,
|
||||
As_tmp[m],
|
||||
ic1 + m * N,
|
||||
N);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 2: intermediate_cache2 = intermediate_cache1 @ w2
|
||||
// w2 : [K, N] as [OC, IC]
|
||||
const int64_t OC = K; // rename K as OC
|
||||
const int64_t IC = N; // rename N as IC
|
||||
const int64_t MB2 = MB;
|
||||
const int64_t NB2 = div_up(OC, BLOCK_N);
|
||||
const int64_t stride_oc = packed_N;
|
||||
|
||||
// parallel on [MB2, NB2]
|
||||
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
|
||||
// get local pointers
|
||||
int tid = at::get_thread_num();
|
||||
// we won't be using C1 for gemm2
|
||||
float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N;
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB2;
|
||||
int64_t nb = i % NB2;
|
||||
|
||||
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
|
||||
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
// A shape [m_size, IC]
|
||||
const uint8_t* __restrict__ A = Aq_tmp + mb * BLOCK_M * N;
|
||||
const float* __restrict__ As = As_tmp + mb * BLOCK_M;
|
||||
|
||||
// B shape [IC, n_size] in vnni format
|
||||
const int8_t* __restrict__ B = packed_w2 + nb * BLOCK_N * stride_oc;
|
||||
const float* __restrict__ Bs = w2s + nb * BLOCK_N;
|
||||
|
||||
// 2.a gemm: C = A @ B
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ A,
|
||||
/* B */ B,
|
||||
/* C */ C,
|
||||
/* As */ As,
|
||||
/* Bs */ Bs,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ IC,
|
||||
/* lda */ IC,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N);
|
||||
|
||||
// 2.b copy from C to output and add fused_experts_out
|
||||
scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N;
|
||||
const scalar_t* __restrict__ fused_out = fused_experts_out + mb * BLOCK_M * K + nb * BLOCK_N;
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#define INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(TYPE) \
|
||||
template void shared_expert_int8_kernel_impl<TYPE> ( \
|
||||
TYPE* __restrict__ output, TYPE* __restrict__ ic1, \
|
||||
float* __restrict__ C_tmp, uint8_t* __restrict__ Aq_tmp, \
|
||||
float* __restrict__ As_tmp, const TYPE* __restrict__ input, \
|
||||
const int8_t* __restrict__ packed_w1, const int8_t* __restrict__ packed_w2, \
|
||||
const float* __restrict__ w1s, const float* __restrict__ w2s, \
|
||||
const TYPE* __restrict__ fused_experts_out, float routed_scaling_factor, \
|
||||
int64_t M, int64_t N, int64_t K)
|
||||
|
||||
INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(at::Half);
|
||||
308
csrc/cpu/sgl-kernels/vec.h
Normal file
308
csrc/cpu/sgl-kernels/vec.h
Normal file
@ -0,0 +1,308 @@
|
||||
// Adapted from
|
||||
// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu
|
||||
|
||||
#pragma once
|
||||
|
||||
// clang-format off
|
||||
|
||||
#if defined(__AVX512F__) && defined(__AVX512BF16__) && defined(__AMX_BF16__)
|
||||
#define CPU_CAPABILITY_AVX512
|
||||
#endif
|
||||
|
||||
#include <ATen/cpu/vec/functional.h>
|
||||
#include <ATen/cpu/vec/vec.h>
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace at::vec;
|
||||
|
||||
template <typename scalar_t,
|
||||
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
||||
inline Vectorized<scalar_t> convert_from_float_ext(const Vectorized<float>& a, const Vectorized<float>& b) {
|
||||
return at::vec::convert_from_float<scalar_t>(a, b);
|
||||
}
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
|
||||
// `at::vec::convert_from_float<>` from PyTorch doesn't have avx512-bf16 intrinsics
|
||||
// use native instruction for bfloat16->float32 conversion
|
||||
template <>
|
||||
inline Vectorized<at::BFloat16> convert_from_float_ext<at::BFloat16>(const Vectorized<float>& a, const Vectorized<float>& b) {
|
||||
return (__m512i)(_mm512_cvtne2ps_pbh(__m512(b), __m512(a)));
|
||||
}
|
||||
|
||||
#define CVT_BF16_TO_FP32(a) \
|
||||
_mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16))
|
||||
|
||||
#define CVT_FP16_TO_FP32(a) \
|
||||
_mm512_cvtps_ph(a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC))
|
||||
|
||||
// this doesn't handle NaN.
|
||||
inline __m512bh cvt_e4m3_bf16_intrinsic_no_nan(__m256i fp8_vec) {
|
||||
const __m512i x = _mm512_cvtepu8_epi16(fp8_vec);
|
||||
|
||||
const __m512i mant = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x07)), 4);
|
||||
const __m512i raw_exp = _mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x78)), 3);
|
||||
const __m512i exp = _mm512_slli_epi16(_mm512_add_epi16(raw_exp, _mm512_set1_epi16(120)), 7);
|
||||
const __m512i nonsign = _mm512_or_si512(exp, mant);
|
||||
|
||||
const __m512i sign = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x80)), 8);
|
||||
const __m512i combined = _mm512_or_si512(nonsign, sign);
|
||||
|
||||
const __mmask32 is_nonzero = _mm512_cmpneq_epi16_mask(x, _mm512_setzero_si512());
|
||||
return (__m512bh)_mm512_maskz_mov_epi16(is_nonzero, combined);
|
||||
}
|
||||
|
||||
inline __m512bh cvt_e4m3_bf16_intrinsic_without_denorm(__m256i fp8_vec) {
|
||||
// The following conversion is without denorm behavior, that is to say,
|
||||
// Max subnorm : S.0000.111 = 0.875 ∗ 2**(−6)
|
||||
// Min subnorm : S.0000.001 = 2**(−9)
|
||||
// 0.0019 ~ 0.0137 cannot be converted correctly.
|
||||
__m512i x = _mm512_cvtepu8_epi16(fp8_vec);
|
||||
auto mask = _mm512_cmpneq_epi16_mask(
|
||||
_mm512_and_si512(x, _mm512_set1_epi16(127)),
|
||||
_mm512_setzero_si512()); // mask = x & 0x7f
|
||||
auto mask_nan = _mm512_cmpneq_epi16_mask(
|
||||
_mm512_and_si512(x, _mm512_set1_epi16(127)),
|
||||
_mm512_set1_epi16(127)); // mask_nan = x & 0x7f
|
||||
auto mantissa = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(7)), 4); // mantissa = (x & 7) << 4
|
||||
auto exponent = _mm512_add_epi16(
|
||||
_mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(120)), 3),
|
||||
_mm512_set1_epi16(120)); // exponent = (((x >> 3) & 15) + 120)
|
||||
auto nonsign = _mm512_maskz_mov_epi16(mask, _mm512_or_si512(mantissa, _mm512_slli_epi16(exponent, 7)));
|
||||
nonsign = _mm512_mask_mov_epi16(_mm512_set1_epi16(0x7fff), mask_nan, nonsign); // deal with Nan
|
||||
return (__m512bh)(_mm512_or_si512(
|
||||
nonsign,
|
||||
_mm512_slli_epi16(
|
||||
_mm512_and_si512(x, _mm512_set1_epi16(128)),
|
||||
8))); // add sign (x & 128) << 8
|
||||
}
|
||||
|
||||
inline __m512bh cvt_e4m3_bf16_intrinsic_with_denorm(__m256i fp8_vec) {
|
||||
__m512i x = _mm512_cvtepu8_epi16(fp8_vec);
|
||||
__m512i lg2mant = _mm512_mask_mov_epi16(
|
||||
_mm512_mask_mov_epi16(
|
||||
_mm512_setzero_si512(), _mm512_test_epi16_mask(x, _mm512_set1_epi16(2)), _mm512_set1_epi16(1)),
|
||||
_mm512_test_epi16_mask(x, _mm512_set1_epi16(4)),
|
||||
_mm512_set1_epi16(2));
|
||||
return (__m512bh)(_mm512_or_si512(
|
||||
_mm512_maskz_mov_epi16(
|
||||
_mm512_cmpneq_epi16_mask(_mm512_and_si512(x, _mm512_set1_epi16(127)), _mm512_setzero_si512()),
|
||||
_mm512_mask_blend_epi16(
|
||||
_mm512_test_epi16_mask(x, _mm512_set1_epi16(120)),
|
||||
_mm512_or_si512(
|
||||
_mm512_and_si512(
|
||||
_mm512_sllv_epi16(
|
||||
_mm512_and_si512(x, _mm512_set1_epi16(3)), _mm512_sub_epi16(_mm512_set1_epi16(7), lg2mant)),
|
||||
_mm512_set1_epi16(0x007f)),
|
||||
_mm512_slli_epi16(_mm512_add_epi16(lg2mant, _mm512_set1_epi16(118)), 7)),
|
||||
_mm512_or_si512(
|
||||
_mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(7)), 4),
|
||||
_mm512_slli_epi16(
|
||||
_mm512_add_epi16(
|
||||
_mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(120)), 3), _mm512_set1_epi16(120)),
|
||||
7)))),
|
||||
_mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(128)), 8)));
|
||||
}
|
||||
|
||||
inline __m512bh CVT_FP8_TO_BF16(__m256i a) {
|
||||
#ifdef SGLANG_CPU_FP8_CVT_FTZ
|
||||
return cvt_e4m3_bf16_intrinsic_no_nan(a);
|
||||
#else
|
||||
return cvt_e4m3_bf16_intrinsic_with_denorm(a);
|
||||
#endif
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
// vector to scalar reduction
|
||||
#if defined(CPU_CAPABILITY_AVX512) && 0
|
||||
inline float vec_reduce_sum(const Vectorized<float>& a) {
|
||||
return _mm512_reduce_add_ps(__m512(a));
|
||||
}
|
||||
|
||||
inline float vec_reduce_max(const Vectorized<float>& a) {
|
||||
return _mm512_reduce_max_ps(__m512(a));
|
||||
}
|
||||
#else
|
||||
inline float vec_reduce_sum(const Vectorized<float>& a) {
|
||||
return vec_reduce_all([](Vectorized<float>& x, Vectorized<float>& y) { return x + y; }, a);
|
||||
}
|
||||
|
||||
inline float vec_reduce_max(const Vectorized<float>& a) {
|
||||
return vec_reduce_all([](Vectorized<float>& x, Vectorized<float>& y) { return maximum(x, y); }, a);
|
||||
}
|
||||
#endif
|
||||
|
||||
// https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282
|
||||
template <typename scalar_t>
|
||||
inline void quantize_row_int8(uint8_t* __restrict__ Aq, float& As,
|
||||
const scalar_t* __restrict__ A, int64_t K, float eps = 1e-7) {
|
||||
|
||||
float amax = 0.f; // absolute max
|
||||
for (int64_t k = 0; k < K; ++k) {
|
||||
const float val = static_cast<float>(A[k]);
|
||||
amax = std::max(amax, std::abs(val));
|
||||
}
|
||||
|
||||
amax = std::max(amax, eps);
|
||||
const float scale = amax / 127;
|
||||
const float inv_scale = 127 / amax;
|
||||
|
||||
for (int64_t k = 0; k < K; ++k) {
|
||||
const float val = static_cast<float>(A[k]) * inv_scale;
|
||||
Aq[k] = (uint8_t)(std::round(val)) + 128;
|
||||
}
|
||||
As = scale;
|
||||
}
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
template <>
|
||||
inline void quantize_row_int8<at::BFloat16>(uint8_t* __restrict__ Aq, float& As,
|
||||
const at::BFloat16* __restrict__ A, int64_t K, float eps) {
|
||||
|
||||
const __m512 signBit = _mm512_set1_ps(-0.0f);
|
||||
const __m512i off = _mm512_set1_epi32(128);
|
||||
|
||||
// K is 32x, no remainder
|
||||
float amax = 0.f;
|
||||
__m512 vamax0 = _mm512_set1_ps(0.f);
|
||||
__m512 vamax1 = _mm512_set1_ps(0.f);
|
||||
for (int64_t k = 0; k < K; k += 32) {
|
||||
__m512i va = _mm512_loadu_si512((void*)(A + k));
|
||||
__m512 va0 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 0));
|
||||
__m512 va1 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 1));
|
||||
vamax0 = _mm512_max_ps(vamax0, _mm512_andnot_ps(signBit, va0));
|
||||
vamax1 = _mm512_max_ps(vamax1, _mm512_andnot_ps(signBit, va1));
|
||||
}
|
||||
amax = _mm512_reduce_max_ps(_mm512_max_ps(vamax0, vamax1));
|
||||
amax = std::max(amax, eps);
|
||||
const float scale = amax / 127;
|
||||
const float inv_scale = 127 / amax;
|
||||
const __m512 vd = _mm512_set1_ps(inv_scale);
|
||||
|
||||
for (int64_t k = 0; k < K; k += 32) {
|
||||
__m512i va = _mm512_loadu_si512((void*)(A + k));
|
||||
__m512 va0 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 0));
|
||||
__m512 va1 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 1));
|
||||
va0 = _mm512_mul_ps(va0, vd);
|
||||
va1 = _mm512_mul_ps(va1, vd);
|
||||
va0 = _mm512_roundscale_ps(va0, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
|
||||
va1 = _mm512_roundscale_ps(va1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
|
||||
__m128i i0 = _mm512_cvtepi32_epi8(_mm512_add_epi32(_mm512_cvtps_epi32(va0), off));
|
||||
__m128i i1 = _mm512_cvtepi32_epi8(_mm512_add_epi32(_mm512_cvtps_epi32(va1), off));
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(Aq + k), _mm256_set_m128i(i1, i0));
|
||||
}
|
||||
As = scale;
|
||||
}
|
||||
#endif
|
||||
|
||||
// transpose utils
|
||||
// taken from my PR in ggml: https://github.com/ggml-org/llama.cpp/pull/8998
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
inline void transpose_16x16_32bit(__m512i * v) {
|
||||
__m512i v1[16];
|
||||
v1[0] = _mm512_unpacklo_epi32(v[0], v[1]);
|
||||
v1[1] = _mm512_unpackhi_epi32(v[0], v[1]);
|
||||
v1[2] = _mm512_unpacklo_epi32(v[2], v[3]);
|
||||
v1[3] = _mm512_unpackhi_epi32(v[2], v[3]);
|
||||
v1[4] = _mm512_unpacklo_epi32(v[4], v[5]);
|
||||
v1[5] = _mm512_unpackhi_epi32(v[4], v[5]);
|
||||
v1[6] = _mm512_unpacklo_epi32(v[6], v[7]);
|
||||
v1[7] = _mm512_unpackhi_epi32(v[6], v[7]);
|
||||
v1[8] = _mm512_unpacklo_epi32(v[8], v[9]);
|
||||
v1[9] = _mm512_unpackhi_epi32(v[8], v[9]);
|
||||
v1[10] = _mm512_unpacklo_epi32(v[10], v[11]);
|
||||
v1[11] = _mm512_unpackhi_epi32(v[10], v[11]);
|
||||
v1[12] = _mm512_unpacklo_epi32(v[12], v[13]);
|
||||
v1[13] = _mm512_unpackhi_epi32(v[12], v[13]);
|
||||
v1[14] = _mm512_unpacklo_epi32(v[14], v[15]);
|
||||
v1[15] = _mm512_unpackhi_epi32(v[14], v[15]);
|
||||
|
||||
v[0] = _mm512_unpacklo_epi64(v1[0], v1[2]);
|
||||
v[1] = _mm512_unpackhi_epi64(v1[0], v1[2]);
|
||||
v[2] = _mm512_unpacklo_epi64(v1[1], v1[3]);
|
||||
v[3] = _mm512_unpackhi_epi64(v1[1], v1[3]);
|
||||
v[4] = _mm512_unpacklo_epi64(v1[4], v1[6]);
|
||||
v[5] = _mm512_unpackhi_epi64(v1[4], v1[6]);
|
||||
v[6] = _mm512_unpacklo_epi64(v1[5], v1[7]);
|
||||
v[7] = _mm512_unpackhi_epi64(v1[5], v1[7]);
|
||||
v[8] = _mm512_unpacklo_epi64(v1[8], v1[10]);
|
||||
v[9] = _mm512_unpackhi_epi64(v1[8], v1[10]);
|
||||
v[10] = _mm512_unpacklo_epi64(v1[9], v1[11]);
|
||||
v[11] = _mm512_unpackhi_epi64(v1[9], v1[11]);
|
||||
v[12] = _mm512_unpacklo_epi64(v1[12], v1[14]);
|
||||
v[13] = _mm512_unpackhi_epi64(v1[12], v1[14]);
|
||||
v[14] = _mm512_unpacklo_epi64(v1[13], v1[15]);
|
||||
v[15] = _mm512_unpackhi_epi64(v1[13], v1[15]);
|
||||
|
||||
v1[0] = _mm512_shuffle_i32x4(v[0], v[4], 0x88);
|
||||
v1[1] = _mm512_shuffle_i32x4(v[1], v[5], 0x88);
|
||||
v1[2] = _mm512_shuffle_i32x4(v[2], v[6], 0x88);
|
||||
v1[3] = _mm512_shuffle_i32x4(v[3], v[7], 0x88);
|
||||
v1[4] = _mm512_shuffle_i32x4(v[0], v[4], 0xdd);
|
||||
v1[5] = _mm512_shuffle_i32x4(v[1], v[5], 0xdd);
|
||||
v1[6] = _mm512_shuffle_i32x4(v[2], v[6], 0xdd);
|
||||
v1[7] = _mm512_shuffle_i32x4(v[3], v[7], 0xdd);
|
||||
v1[8] = _mm512_shuffle_i32x4(v[8], v[12], 0x88);
|
||||
v1[9] = _mm512_shuffle_i32x4(v[9], v[13], 0x88);
|
||||
v1[10] = _mm512_shuffle_i32x4(v[10], v[14], 0x88);
|
||||
v1[11] = _mm512_shuffle_i32x4(v[11], v[15], 0x88);
|
||||
v1[12] = _mm512_shuffle_i32x4(v[8], v[12], 0xdd);
|
||||
v1[13] = _mm512_shuffle_i32x4(v[9], v[13], 0xdd);
|
||||
v1[14] = _mm512_shuffle_i32x4(v[10], v[14], 0xdd);
|
||||
v1[15] = _mm512_shuffle_i32x4(v[11], v[15], 0xdd);
|
||||
|
||||
v[0] = _mm512_shuffle_i32x4(v1[0], v1[8], 0x88);
|
||||
v[1] = _mm512_shuffle_i32x4(v1[1], v1[9], 0x88);
|
||||
v[2] = _mm512_shuffle_i32x4(v1[2], v1[10], 0x88);
|
||||
v[3] = _mm512_shuffle_i32x4(v1[3], v1[11], 0x88);
|
||||
v[4] = _mm512_shuffle_i32x4(v1[4], v1[12], 0x88);
|
||||
v[5] = _mm512_shuffle_i32x4(v1[5], v1[13], 0x88);
|
||||
v[6] = _mm512_shuffle_i32x4(v1[6], v1[14], 0x88);
|
||||
v[7] = _mm512_shuffle_i32x4(v1[7], v1[15], 0x88);
|
||||
v[8] = _mm512_shuffle_i32x4(v1[0], v1[8], 0xdd);
|
||||
v[9] = _mm512_shuffle_i32x4(v1[1], v1[9], 0xdd);
|
||||
v[10] = _mm512_shuffle_i32x4(v1[2], v1[10], 0xdd);
|
||||
v[11] = _mm512_shuffle_i32x4(v1[3], v1[11], 0xdd);
|
||||
v[12] = _mm512_shuffle_i32x4(v1[4], v1[12], 0xdd);
|
||||
v[13] = _mm512_shuffle_i32x4(v1[5], v1[13], 0xdd);
|
||||
v[14] = _mm512_shuffle_i32x4(v1[6], v1[14], 0xdd);
|
||||
v[15] = _mm512_shuffle_i32x4(v1[7], v1[15], 0xdd);
|
||||
}
|
||||
|
||||
// remove warning : ignoring attributes on template argument ‘__m512i’ [-Wignored-attributes]
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wignored-attributes"
|
||||
|
||||
// transpose from [2, 32] to [32, 2]
|
||||
inline std::tuple<__m512i, __m512i> transpose_2x32_16bit(__m512i r0, __m512i r1) {
|
||||
// r0: {a0, a1, ..., a31}
|
||||
// r1: {b0, b1, ..., b31}
|
||||
//
|
||||
// d0: {a0, b0, ..., a15, b15}
|
||||
// d1: {a16, b16, ..., a31, b31}
|
||||
//
|
||||
__m512i d0 = _mm512_unpacklo_epi16(r0, r1);
|
||||
__m512i d1 = _mm512_unpackhi_epi16(r0, r1);
|
||||
r0 = _mm512_shuffle_i32x4(d0, d1, 0x88);
|
||||
r1 = _mm512_shuffle_i32x4(d0, d1, 0xdd);
|
||||
d0 = _mm512_shuffle_i32x4(r0, r1, 0x88);
|
||||
d1 = _mm512_shuffle_i32x4(r0, r1, 0xdd);
|
||||
return std::make_tuple(d0, d1);
|
||||
}
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
#endif
|
||||
|
||||
// TODO: debug print, remove me later
|
||||
template<typename scalar_t>
|
||||
void print_array(scalar_t* ptr, int size) {
|
||||
for (int d = 0; d < size; ++d) {
|
||||
if (d % 16 == 0) { std::cout << std::endl; }
|
||||
std::cout << ptr[d] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
201
csrc/cpu/shm.cpp
201
csrc/cpu/shm.cpp
@ -7,9 +7,10 @@
|
||||
|
||||
namespace {
|
||||
#define MAX_SHM_RANK_NUM 8
|
||||
#define MAX_THREAD_NUM 12
|
||||
#define PER_THREAD_SHM_BUFFER_BYTES (4 * 1024 * 1024)
|
||||
#define MIN_THREAD_PROCESS_SIZE (8 * 1024)
|
||||
static_assert(PER_THREAD_SHM_BUFFER_BYTES % 2 == 0);
|
||||
#define PER_THREAD_SHM_BUFFER_OFFSET (PER_THREAD_SHM_BUFFER_BYTES >> 1)
|
||||
#define MIN_THREAD_PROCESS_SIZE (256)
|
||||
#define MAX_P2P_SEND_TENSOR_NUM 8
|
||||
|
||||
template <typename scalar_t>
|
||||
@ -32,10 +33,11 @@ struct KernelVecType<c10::Half> {
|
||||
using scalar_vec_t = vec_op::FP16Vec16;
|
||||
};
|
||||
|
||||
enum class ThreadSHMStat : char { THREAD_READY = 0, SHM_DATA_READY, DONE };
|
||||
|
||||
struct ThreadSHMContext {
|
||||
volatile ThreadSHMStat thread_stats[MAX_SHM_RANK_NUM];
|
||||
volatile char _curr_thread_stamp[2];
|
||||
volatile char _ready_thread_stamp[2];
|
||||
int local_stamp_buffer_idx;
|
||||
int remote_stamp_buffer_idx;
|
||||
int thread_id;
|
||||
int thread_num;
|
||||
int rank;
|
||||
@ -44,10 +46,14 @@ struct ThreadSHMContext {
|
||||
int swizzled_ranks[MAX_SHM_RANK_NUM];
|
||||
void* thread_shm_ptrs[MAX_SHM_RANK_NUM];
|
||||
ThreadSHMContext* shm_contexts[MAX_SHM_RANK_NUM];
|
||||
size_t _thread_buffer_mask[2];
|
||||
char _padding2[40];
|
||||
|
||||
ThreadSHMContext(const int thread_id, const int thread_num, const int rank,
|
||||
const int group_size, void* thread_shm_ptr)
|
||||
: thread_id(thread_id),
|
||||
: local_stamp_buffer_idx(0),
|
||||
remote_stamp_buffer_idx(0),
|
||||
thread_id(thread_id),
|
||||
thread_num(thread_num),
|
||||
rank(rank),
|
||||
group_size(group_size),
|
||||
@ -56,15 +62,25 @@ struct ThreadSHMContext {
|
||||
TORCH_CHECK(group_size <= MAX_SHM_RANK_NUM);
|
||||
TORCH_CHECK((size_t)this % 64 == 0);
|
||||
TORCH_CHECK((size_t)thread_shm_ptr % 64 == 0);
|
||||
_curr_thread_stamp[0] = 1;
|
||||
_curr_thread_stamp[1] = 1;
|
||||
_ready_thread_stamp[0] = 0;
|
||||
_ready_thread_stamp[1] = 0;
|
||||
_thread_buffer_mask[0] = 0;
|
||||
_thread_buffer_mask[1] = 0;
|
||||
for (int i = 0; i < MAX_SHM_RANK_NUM; ++i) {
|
||||
shm_contexts[i] = nullptr;
|
||||
thread_shm_ptrs[i] = nullptr;
|
||||
swizzled_ranks[i] = (i + rank) % group_size;
|
||||
thread_stats[i] = ThreadSHMStat::DONE;
|
||||
}
|
||||
set_context(rank, this, thread_shm_ptr);
|
||||
}
|
||||
|
||||
void set_stamp_buffer_idx(int local, int remote) {
|
||||
local_stamp_buffer_idx = local;
|
||||
remote_stamp_buffer_idx = remote;
|
||||
}
|
||||
|
||||
void set_context(int rank, ThreadSHMContext* ptr, void* thread_shm_ptr) {
|
||||
TORCH_CHECK(rank < MAX_SHM_RANK_NUM);
|
||||
TORCH_CHECK(ptr);
|
||||
@ -77,59 +93,71 @@ struct ThreadSHMContext {
|
||||
|
||||
template <typename T>
|
||||
T* get_thread_shm_ptr(int rank) {
|
||||
return reinterpret_cast<T*>(thread_shm_ptrs[rank]);
|
||||
return reinterpret_cast<T*>(
|
||||
reinterpret_cast<int8_t*>(thread_shm_ptrs[rank]) +
|
||||
(PER_THREAD_SHM_BUFFER_OFFSET &
|
||||
_thread_buffer_mask[local_stamp_buffer_idx]));
|
||||
}
|
||||
|
||||
void next_buffer() {
|
||||
_thread_buffer_mask[local_stamp_buffer_idx] ^= 0xFFFFFFFFFFFFFFFF;
|
||||
}
|
||||
|
||||
char get_curr_stamp(int idx) const { return _curr_thread_stamp[idx]; }
|
||||
|
||||
char get_ready_stamp(int idx) const { return _ready_thread_stamp[idx]; }
|
||||
|
||||
void next_stamp() {
|
||||
_mm_mfence();
|
||||
_curr_thread_stamp[local_stamp_buffer_idx] += 1;
|
||||
}
|
||||
|
||||
void commit_ready_stamp() {
|
||||
_mm_mfence();
|
||||
_ready_thread_stamp[local_stamp_buffer_idx] =
|
||||
_curr_thread_stamp[local_stamp_buffer_idx];
|
||||
}
|
||||
|
||||
int get_swizzled_rank(int idx) { return swizzled_ranks[idx]; }
|
||||
|
||||
void wait_for_all(ThreadSHMStat prev_stat) {
|
||||
for (int idx = 0; idx < group_size; ++idx) {
|
||||
template <typename Cond>
|
||||
void wait_for_all(Cond&& cond) {
|
||||
for (int idx = 1; idx < group_size; ++idx) {
|
||||
int rank = get_swizzled_rank(idx);
|
||||
while (thread_stats[rank] == prev_stat) {
|
||||
++_spinning_count;
|
||||
_mm_pause();
|
||||
}
|
||||
wait_for_one(rank, std::forward<Cond>(cond));
|
||||
}
|
||||
vec_op::mem_barrier();
|
||||
}
|
||||
|
||||
void wait_for_one(int rank, ThreadSHMStat prev_stat) {
|
||||
while (thread_stats[rank] == prev_stat) {
|
||||
template <typename Cond>
|
||||
void wait_for_one(int rank, Cond&& cond) {
|
||||
ThreadSHMContext* rank_ctx = shm_contexts[rank];
|
||||
for (;;) {
|
||||
char local_curr_stamp = get_curr_stamp(local_stamp_buffer_idx);
|
||||
char local_ready_stamp = get_ready_stamp(local_stamp_buffer_idx);
|
||||
char rank_curr_stamp = rank_ctx->get_curr_stamp(remote_stamp_buffer_idx);
|
||||
char rank_ready_stamp =
|
||||
rank_ctx->get_ready_stamp(remote_stamp_buffer_idx);
|
||||
if (cond(local_curr_stamp, local_ready_stamp, rank_curr_stamp,
|
||||
rank_ready_stamp)) {
|
||||
break;
|
||||
}
|
||||
++_spinning_count;
|
||||
_mm_pause();
|
||||
}
|
||||
vec_op::mem_barrier();
|
||||
}
|
||||
|
||||
void set_thread_stat(ThreadSHMStat stat) {
|
||||
for (int idx = 0; idx < group_size; ++idx) {
|
||||
int rank = get_swizzled_rank(idx);
|
||||
shm_contexts[rank]->thread_stats[this->rank] = stat;
|
||||
}
|
||||
static bool check_no_buffer_conflict(char local_curr_stamp,
|
||||
char local_ready_stamp,
|
||||
char rank_curr_stamp,
|
||||
char rank_ready_stamp) {
|
||||
char temp = rank_curr_stamp + 2;
|
||||
return local_curr_stamp != temp;
|
||||
}
|
||||
|
||||
void set_thread_stat(int target_rank, ThreadSHMStat stat) {
|
||||
for (int idx = 0; idx < group_size; ++idx) {
|
||||
int rank = get_swizzled_rank(idx);
|
||||
shm_contexts[rank]->thread_stats[target_rank] = stat;
|
||||
}
|
||||
}
|
||||
|
||||
// barrier for all ranks in the group, used for all2all ops
|
||||
// DONE -> THREAD_READY -> SHM_DATA_READY -> DONE -> ...
|
||||
void barrier(ThreadSHMStat next_stat) {
|
||||
if (next_stat == ThreadSHMStat::THREAD_READY) {
|
||||
set_thread_stat(ThreadSHMStat::THREAD_READY);
|
||||
wait_for_all(ThreadSHMStat::DONE);
|
||||
} else if (next_stat == ThreadSHMStat::SHM_DATA_READY) {
|
||||
set_thread_stat(ThreadSHMStat::SHM_DATA_READY);
|
||||
wait_for_all(ThreadSHMStat::THREAD_READY);
|
||||
} else if (next_stat == ThreadSHMStat::DONE) {
|
||||
set_thread_stat(ThreadSHMStat::DONE);
|
||||
wait_for_all(ThreadSHMStat::SHM_DATA_READY);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Invalid next_stat to barrier.");
|
||||
}
|
||||
static bool check_stamp_ready(char local_curr_stamp, char local_ready_stamp,
|
||||
char rank_curr_stamp, char rank_ready_stamp) {
|
||||
char temp = local_curr_stamp + 1;
|
||||
return (local_curr_stamp == rank_ready_stamp) || (temp == rank_ready_stamp);
|
||||
}
|
||||
|
||||
std::string to_string() const {
|
||||
@ -164,7 +192,7 @@ class SHMManager {
|
||||
const int group_size)
|
||||
: _rank(rank),
|
||||
_group_size(group_size),
|
||||
_thread_num(std::min(torch::get_num_threads(), MAX_THREAD_NUM)),
|
||||
_thread_num(torch::get_num_threads()),
|
||||
_shm_names({""}),
|
||||
_shared_mem_ptrs({nullptr}),
|
||||
_shm_ctx(nullptr) {
|
||||
@ -326,7 +354,8 @@ void shm_cc_loop(ThreadSHMContext* ctx, int64_t elem_num, F&& inner_func) {
|
||||
(total_units_num + thread_num - 1) / thread_num;
|
||||
int64_t per_unit_elem_num = MIN_THREAD_PROCESS_SIZE / sizeof(scalar_t);
|
||||
int64_t max_per_thread_iteration_elem_num =
|
||||
PER_THREAD_SHM_BUFFER_BYTES / sizeof(scalar_t);
|
||||
(PER_THREAD_SHM_BUFFER_BYTES >> 1) /
|
||||
sizeof(scalar_t); // Note: double buffer
|
||||
int64_t per_thread_elem_num = per_unit_elem_num * per_thread_units_num;
|
||||
|
||||
#pragma omp parallel for schedule(static, 1)
|
||||
@ -336,15 +365,27 @@ void shm_cc_loop(ThreadSHMContext* ctx, int64_t elem_num, F&& inner_func) {
|
||||
int64_t curr_elem_num =
|
||||
std::min(max_per_thread_iteration_elem_num, end - offset);
|
||||
ThreadSHMContext* thread_ctx = ctx + i;
|
||||
bool fast_mode = ((end - offset) <= max_per_thread_iteration_elem_num);
|
||||
|
||||
while (curr_elem_num > 0) {
|
||||
inner_func(thread_ctx, offset, curr_elem_num);
|
||||
inner_func(thread_ctx, offset, curr_elem_num, fast_mode);
|
||||
|
||||
thread_ctx->next_stamp();
|
||||
thread_ctx->next_buffer();
|
||||
offset += max_per_thread_iteration_elem_num;
|
||||
curr_elem_num = std::min(max_per_thread_iteration_elem_num, end - offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void reset_threads_stamp_buffer_idx(ThreadSHMContext* ctx, int local,
|
||||
int remote) {
|
||||
int thread_num = ctx->thread_num;
|
||||
for (int i = 0; i < thread_num; ++i) {
|
||||
ThreadSHMContext* thread_ctx = ctx + i;
|
||||
thread_ctx->set_stamp_buffer_idx(local, remote);
|
||||
}
|
||||
}
|
||||
}; // namespace shm_cc_ops
|
||||
|
||||
namespace shm_cc_ops {
|
||||
@ -397,7 +438,7 @@ void all_reduce_sum_impl(ThreadSHMContext* ctx, scalar_t* data,
|
||||
shm_cc_ops::shm_cc_loop<scalar_t>(
|
||||
ctx, elem_num,
|
||||
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
|
||||
int64_t data_elem_num) {
|
||||
int64_t data_elem_num, bool fast_mode) {
|
||||
int rank = thread_ctx->rank;
|
||||
scalar_t* thread_shm_ptr =
|
||||
thread_ctx->get_thread_shm_ptr<scalar_t>(rank);
|
||||
@ -410,16 +451,17 @@ void all_reduce_sum_impl(ThreadSHMContext* ctx, scalar_t* data,
|
||||
thread_ctx->get_swizzled_rank(idx + 1));
|
||||
});
|
||||
|
||||
thread_ctx->barrier(ThreadSHMStat::THREAD_READY);
|
||||
if (!fast_mode) {
|
||||
thread_ctx->wait_for_all(ThreadSHMContext::check_no_buffer_conflict);
|
||||
}
|
||||
|
||||
shm_cc_ops::memcpy_to_shm(thread_shm_ptr, thread_data_ptr,
|
||||
thread_data_elem_num);
|
||||
|
||||
thread_ctx->barrier(ThreadSHMStat::SHM_DATA_READY);
|
||||
|
||||
thread_ctx->commit_ready_stamp();
|
||||
int64_t aligned_data_elem_num =
|
||||
(data_elem_num / vec_elem_num) * vec_elem_num;
|
||||
int64_t i = 0;
|
||||
thread_ctx->wait_for_all(ThreadSHMContext::check_stamp_ready);
|
||||
#pragma GCC unroll 4
|
||||
for (; i < aligned_data_elem_num; i += vec_elem_num) {
|
||||
vec_t local_data(thread_data_ptr + i); // load from cache
|
||||
@ -447,8 +489,6 @@ void all_reduce_sum_impl(ThreadSHMContext* ctx, scalar_t* data,
|
||||
reduced_data.save(thread_data_ptr + i,
|
||||
data_elem_num - aligned_data_elem_num);
|
||||
}
|
||||
|
||||
thread_ctx->barrier(ThreadSHMStat::DONE);
|
||||
});
|
||||
|
||||
return;
|
||||
@ -488,18 +528,18 @@ void shm_gather_impl(ThreadSHMContext* ctx, scalar_t* data, size_t elem_num,
|
||||
shm_cc_ops::shm_cc_loop<scalar_t>(
|
||||
ctx, elem_num,
|
||||
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
|
||||
int64_t data_elem_num) {
|
||||
int64_t data_elem_num, bool fast_mode) {
|
||||
int rank = thread_ctx->rank;
|
||||
scalar_t* thread_shm_ptr =
|
||||
thread_ctx->get_thread_shm_ptr<scalar_t>(rank);
|
||||
|
||||
thread_ctx->barrier(ThreadSHMStat::THREAD_READY);
|
||||
|
||||
shm_cc_ops::memcpy_to_shm(thread_shm_ptr, data + data_offset,
|
||||
data_elem_num * sizeof(scalar_t));
|
||||
|
||||
thread_ctx->barrier(ThreadSHMStat::SHM_DATA_READY);
|
||||
if (!fast_mode) {
|
||||
thread_ctx->wait_for_all(ThreadSHMContext::check_no_buffer_conflict);
|
||||
}
|
||||
|
||||
shm_cc_ops::memcpy(thread_shm_ptr, data + data_offset,
|
||||
data_elem_num * sizeof(scalar_t));
|
||||
thread_ctx->commit_ready_stamp();
|
||||
if (rank == dst) {
|
||||
shm_cc_ops::memcpy(outputs[rank] + data_offset, data + data_offset,
|
||||
data_elem_num * sizeof(scalar_t));
|
||||
@ -508,12 +548,12 @@ void shm_gather_impl(ThreadSHMContext* ctx, scalar_t* data, size_t elem_num,
|
||||
scalar_t* src_ptr =
|
||||
thread_ctx->get_thread_shm_ptr<scalar_t>(src_rank); // shm
|
||||
scalar_t* dst_ptr = outputs[src_rank] + data_offset;
|
||||
shm_cc_ops::memcpy_from_shm(dst_ptr, src_ptr,
|
||||
data_elem_num * sizeof(scalar_t));
|
||||
thread_ctx->wait_for_one(src_rank,
|
||||
ThreadSHMContext::check_stamp_ready);
|
||||
shm_cc_ops::memcpy(dst_ptr, src_ptr,
|
||||
data_elem_num * sizeof(scalar_t));
|
||||
}
|
||||
}
|
||||
|
||||
thread_ctx->barrier(ThreadSHMStat::DONE);
|
||||
});
|
||||
|
||||
return;
|
||||
@ -599,7 +639,7 @@ struct TensorListMeta {
|
||||
int8_t _padding[40];
|
||||
};
|
||||
|
||||
void shm_send_tensor_list_impl(ThreadSHMContext* ctx,
|
||||
void shm_send_tensor_list_impl(ThreadSHMContext* ctx, int64_t dst,
|
||||
const std::vector<torch::Tensor>& tensor_list) {
|
||||
CPU_KERNEL_GUARD_IN(shm_send_tensor_list_impl)
|
||||
std::vector<torch::Tensor> tensor_list_with_metadata;
|
||||
@ -617,15 +657,15 @@ void shm_send_tensor_list_impl(ThreadSHMContext* ctx,
|
||||
TensorListMeta* metadata = new (metadata_tensor.data_ptr()) TensorListMeta();
|
||||
metadata->bind_tensor_list(tensor_list_with_metadata);
|
||||
|
||||
shm_cc_ops::reset_threads_stamp_buffer_idx(ctx, 0, 1);
|
||||
shm_cc_ops::shm_cc_loop<int8_t>(
|
||||
ctx, metadata->total_bytes,
|
||||
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
|
||||
int64_t data_elem_num) {
|
||||
int64_t data_elem_num, bool fast_mode) {
|
||||
int rank = thread_ctx->rank;
|
||||
// Wait until the receiver set the stat to DONE
|
||||
thread_ctx->wait_for_one(rank, ThreadSHMStat::SHM_DATA_READY);
|
||||
|
||||
int64_t curr_shm_offset = 0;
|
||||
thread_ctx->wait_for_one(dst,
|
||||
ThreadSHMContext::check_no_buffer_conflict);
|
||||
while (curr_shm_offset < data_elem_num) {
|
||||
MemPiece frag = metadata->get_data(data_offset + curr_shm_offset);
|
||||
frag.size = std::min(frag.size, data_elem_num - curr_shm_offset);
|
||||
@ -634,8 +674,7 @@ void shm_send_tensor_list_impl(ThreadSHMContext* ctx,
|
||||
frag.ptr, frag.size);
|
||||
curr_shm_offset += frag.size;
|
||||
}
|
||||
|
||||
thread_ctx->set_thread_stat(rank, ThreadSHMStat::SHM_DATA_READY);
|
||||
thread_ctx->commit_ready_stamp();
|
||||
});
|
||||
}
|
||||
|
||||
@ -646,8 +685,8 @@ std::vector<torch::Tensor> shm_recv_tensor_list_impl(ThreadSHMContext* ctx,
|
||||
torch::Tensor metadata_tensor =
|
||||
torch::empty({sizeof(TensorListMeta)}, options);
|
||||
|
||||
// Wait until the sender set the stat of the thread 0 to SHM_DATA_READY
|
||||
ctx->wait_for_one(src, ThreadSHMStat::DONE);
|
||||
shm_cc_ops::reset_threads_stamp_buffer_idx(ctx, 1, 0);
|
||||
ctx->wait_for_one(src, ThreadSHMContext::check_stamp_ready);
|
||||
shm_cc_ops::memcpy(metadata_tensor.data_ptr(),
|
||||
ctx->get_thread_shm_ptr<void>(src),
|
||||
sizeof(TensorListMeta));
|
||||
@ -664,9 +703,8 @@ std::vector<torch::Tensor> shm_recv_tensor_list_impl(ThreadSHMContext* ctx,
|
||||
shm_cc_ops::shm_cc_loop<int8_t>(
|
||||
ctx, metadata.total_bytes,
|
||||
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
|
||||
int64_t data_elem_num) {
|
||||
// Wait until the sender set the stat to SHM_DATA_READY
|
||||
thread_ctx->wait_for_one(src, ThreadSHMStat::DONE);
|
||||
int64_t data_elem_num, bool fast_mode) {
|
||||
thread_ctx->wait_for_one(src, ThreadSHMContext::check_stamp_ready);
|
||||
int64_t curr_shm_offset = 0;
|
||||
while (curr_shm_offset < data_elem_num) {
|
||||
MemPiece frag = metadata.get_data(data_offset + curr_shm_offset);
|
||||
@ -677,8 +715,6 @@ std::vector<torch::Tensor> shm_recv_tensor_list_impl(ThreadSHMContext* ctx,
|
||||
frag.size);
|
||||
curr_shm_offset += frag.size;
|
||||
}
|
||||
|
||||
thread_ctx->set_thread_stat(src, ThreadSHMStat::DONE);
|
||||
});
|
||||
|
||||
std::vector<torch::Tensor> tensor_list;
|
||||
@ -756,7 +792,8 @@ void shm_send_tensor_list(int64_t handle,
|
||||
int64_t dst) {
|
||||
CPU_KERNEL_GUARD_IN(shm_send_tensor_list)
|
||||
shm_send_tensor_list_impl(
|
||||
SHMManager::get_singleton_instance(handle)->get_shm_ctx(), tensor_list);
|
||||
SHMManager::get_singleton_instance(handle)->get_shm_ctx(), dst,
|
||||
tensor_list);
|
||||
CPU_KERNEL_GUARD_OUT(shm_send_tensor_list)
|
||||
}
|
||||
|
||||
@ -778,4 +815,4 @@ std::string join_shm_manager(int64_t handle, const std::string& name) {
|
||||
TORCH_CHECK(shm_manager);
|
||||
shm_manager->join(name);
|
||||
return shm_manager->get_shm_ctx()->to_string();
|
||||
}
|
||||
}
|
||||
|
||||
@ -50,6 +50,27 @@ void shm_send_tensor_list(int64_t handle,
|
||||
|
||||
std::vector<torch::Tensor> shm_recv_tensor_list(int64_t handle, int64_t src);
|
||||
|
||||
at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
bool is_vnni);
|
||||
|
||||
at::Tensor convert_weight_packed(at::Tensor& weight);
|
||||
|
||||
at::Tensor fused_experts_cpu(
|
||||
at::Tensor& hidden_states, at::Tensor& w1, at::Tensor& w2,
|
||||
at::Tensor& topk_weights, at::Tensor& topk_ids, bool inplace,
|
||||
bool use_int8_w8a8, bool use_fp8_w8a16,
|
||||
const std::optional<at::Tensor>& w1_scale,
|
||||
const std::optional<at::Tensor>& w2_scale,
|
||||
const std::optional<std::vector<int64_t>> block_size,
|
||||
const std::optional<at::Tensor>& a1_scale,
|
||||
const std::optional<at::Tensor>& a2_scale, bool is_vnni);
|
||||
|
||||
at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2,
|
||||
at::Tensor& scales2,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
at::ScalarType out_dtype, bool is_vnni);
|
||||
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// vLLM custom ops
|
||||
|
||||
@ -130,8 +151,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding);
|
||||
|
||||
// Quantization
|
||||
#ifdef __AVX512F__
|
||||
#if defined(__AVX512F__) || (defined(__aarch64__) && !defined(__APPLE__))
|
||||
at::Tag stride_tag = at::Tag::needs_fixed_stride_order;
|
||||
|
||||
// Compute int8 quantized tensor for given scaling factor.
|
||||
ops.def(
|
||||
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
|
||||
@ -214,6 +236,28 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def("shm_recv_tensor_list(int handle, int src) -> Tensor[](a)",
|
||||
&shm_recv_tensor_list);
|
||||
#endif
|
||||
|
||||
// sgl-kernels
|
||||
#if defined(__AVX512BF16__) && defined(__AVX512F__) && defined(__AVX512VNNI__)
|
||||
ops.def(
|
||||
"weight_packed_linear(Tensor(a0!) mat1, Tensor(a1!) mat2, Tensor(a2!)? "
|
||||
"bias, bool is_vnni) -> Tensor");
|
||||
ops.impl("weight_packed_linear", torch::kCPU, &weight_packed_linear);
|
||||
ops.def("convert_weight_packed(Tensor! weight) -> Tensor");
|
||||
ops.impl("convert_weight_packed", torch::kCPU, &convert_weight_packed);
|
||||
ops.def(
|
||||
"fused_experts_cpu(Tensor! hidden_states, Tensor w1, Tensor w2, Tensor "
|
||||
"topk_weights, Tensor topk_ids, bool inplace, bool use_int8_w8a8, bool "
|
||||
"use_fp8_w8a16, Tensor? w1_scale, Tensor? w2_scale, SymInt[]? "
|
||||
"block_size, Tensor? a1_scale, Tensor? a2_scale, bool is_vnni) -> "
|
||||
"Tensor");
|
||||
ops.impl("fused_experts_cpu", torch::kCPU, &fused_experts_cpu);
|
||||
ops.def(
|
||||
"int8_scaled_mm_with_quant(Tensor mat1, Tensor mat2, Tensor scales2, "
|
||||
"Tensor? bias, ScalarType out_dtype, bool is_vnni) -> Tensor");
|
||||
ops.impl("int8_scaled_mm_with_quant", torch::kCPU,
|
||||
&int8_scaled_mm_with_quant);
|
||||
#endif
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
|
||||
@ -4,10 +4,37 @@
|
||||
#include <hip/hip_runtime.h>
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define WARP_SIZE 32
|
||||
#ifdef USE_ROCM
|
||||
struct Utils {
|
||||
static __host__ int get_warp_size() {
|
||||
static bool is_cached = false;
|
||||
static int result;
|
||||
|
||||
if (!is_cached) {
|
||||
int device_id;
|
||||
cudaDeviceProp deviceProp;
|
||||
cudaGetDevice(&device_id);
|
||||
cudaGetDeviceProperties(&deviceProp, device_id);
|
||||
|
||||
result = deviceProp.warpSize;
|
||||
is_cached = true;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static __device__ constexpr int get_warp_size() {
|
||||
#ifdef __GFX9__
|
||||
return 64;
|
||||
#else
|
||||
return 32;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
#define WARP_SIZE Utils::get_warp_size()
|
||||
#else
|
||||
#define WARP_SIZE warpSize
|
||||
#define WARP_SIZE 32
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
|
||||
@ -153,7 +153,7 @@ struct ScaledEpilogueBias
|
||||
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::homogeneous_multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
@ -210,7 +210,7 @@ struct ScaledEpilogueBiasAzp
|
||||
EVTComputeAzp>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::homogeneous_multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
@ -288,7 +288,7 @@ struct ScaledEpilogueBiasAzpToken
|
||||
EVTComputeAcc>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::homogeneous_multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
|
||||
@ -195,7 +195,7 @@ struct ScaledEpilogueBias
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::homogeneous_multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
@ -238,7 +238,7 @@ struct ScaledEpilogueColumnBias
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::homogeneous_multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
@ -295,7 +295,7 @@ struct ScaledEpilogueBiasAzp
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAzp>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::homogeneous_multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
@ -371,7 +371,7 @@ struct ScaledEpilogueBiasAzpToken
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAcc>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::homogeneous_multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
|
||||
@ -45,7 +45,6 @@
|
||||
#include "cute/algorithm/functional.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cute/algorithm/gemm.hpp"
|
||||
#include "cute/tensor_predicate.hpp"
|
||||
#include "cute/numeric/arithmetic_tuple.hpp"
|
||||
|
||||
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
|
||||
|
||||
@ -15,15 +15,16 @@ namespace vllm {
|
||||
// TODO(woosuk): Further optimize this kernel.
|
||||
template <typename scalar_t>
|
||||
__global__ void rms_norm_kernel(
|
||||
scalar_t* __restrict__ out, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
scalar_t* __restrict__ out, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
const int64_t input_stride,
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float epsilon, const int num_tokens, const int hidden_size) {
|
||||
__shared__ float s_variance;
|
||||
float variance = 0.0f;
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
const float x = (float)input[blockIdx.x * hidden_size + idx];
|
||||
const float x = (float)input[blockIdx.x * input_stride + idx];
|
||||
variance += x * x;
|
||||
}
|
||||
|
||||
@ -37,7 +38,7 @@ __global__ void rms_norm_kernel(
|
||||
__syncthreads();
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
float x = (float)input[blockIdx.x * hidden_size + idx];
|
||||
float x = (float)input[blockIdx.x * input_stride + idx];
|
||||
out[blockIdx.x * hidden_size + idx] =
|
||||
((scalar_t)(x * s_variance)) * weight[idx];
|
||||
}
|
||||
@ -50,7 +51,8 @@ __global__ void rms_norm_kernel(
|
||||
template <typename scalar_t, int width>
|
||||
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
|
||||
fused_add_rms_norm_kernel(
|
||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
const int64_t input_stride,
|
||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float epsilon, const int num_tokens, const int hidden_size) {
|
||||
@ -59,6 +61,7 @@ fused_add_rms_norm_kernel(
|
||||
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
|
||||
|
||||
const int vec_hidden_size = hidden_size / width;
|
||||
const int64_t vec_input_stride = input_stride / width;
|
||||
__shared__ float s_variance;
|
||||
float variance = 0.0f;
|
||||
/* These and the argument pointers are all declared `restrict` as they are
|
||||
@ -73,7 +76,8 @@ fused_add_rms_norm_kernel(
|
||||
|
||||
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
|
||||
int id = blockIdx.x * vec_hidden_size + idx;
|
||||
_f16Vec<scalar_t, width> temp = input_v[id];
|
||||
int64_t strided_id = blockIdx.x * vec_input_stride + idx;
|
||||
_f16Vec<scalar_t, width> temp = input_v[strided_id];
|
||||
temp += residual_v[id];
|
||||
variance += temp.sum_squares();
|
||||
residual_v[id] = temp;
|
||||
@ -90,10 +94,11 @@ fused_add_rms_norm_kernel(
|
||||
|
||||
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
|
||||
int id = blockIdx.x * vec_hidden_size + idx;
|
||||
int64_t strided_id = blockIdx.x * vec_input_stride + idx;
|
||||
_f16Vec<scalar_t, width> temp = residual_v[id];
|
||||
temp *= s_variance;
|
||||
temp *= weight_v[idx];
|
||||
input_v[id] = temp;
|
||||
input_v[strided_id] = temp;
|
||||
}
|
||||
}
|
||||
|
||||
@ -103,7 +108,8 @@ fused_add_rms_norm_kernel(
|
||||
template <typename scalar_t, int width>
|
||||
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
|
||||
fused_add_rms_norm_kernel(
|
||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
const int64_t input_stride,
|
||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float epsilon, const int num_tokens, const int hidden_size) {
|
||||
@ -111,7 +117,7 @@ fused_add_rms_norm_kernel(
|
||||
float variance = 0.0f;
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
scalar_t z = input[blockIdx.x * hidden_size + idx];
|
||||
scalar_t z = input[blockIdx.x * input_stride + idx];
|
||||
z += residual[blockIdx.x * hidden_size + idx];
|
||||
float x = (float)z;
|
||||
variance += x * x;
|
||||
@ -129,7 +135,7 @@ fused_add_rms_norm_kernel(
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
float x = (float)residual[blockIdx.x * hidden_size + idx];
|
||||
input[blockIdx.x * hidden_size + idx] =
|
||||
input[blockIdx.x * input_stride + idx] =
|
||||
((scalar_t)(x * s_variance)) * weight[idx];
|
||||
}
|
||||
}
|
||||
@ -141,11 +147,12 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
|
||||
torch::Tensor& weight, // [hidden_size]
|
||||
double epsilon) {
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(input.stride(-1) == 1);
|
||||
TORCH_CHECK(weight.is_contiguous());
|
||||
|
||||
int hidden_size = input.size(-1);
|
||||
int num_tokens = input.numel() / hidden_size;
|
||||
int64_t input_stride = input.stride(-2);
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(hidden_size, 1024));
|
||||
@ -153,26 +160,29 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
|
||||
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
|
||||
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), input_stride,
|
||||
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
|
||||
});
|
||||
}
|
||||
|
||||
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
|
||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
|
||||
vllm::fused_add_rms_norm_kernel<scalar_t, width> \
|
||||
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(), \
|
||||
residual.data_ptr<scalar_t>(), \
|
||||
weight.data_ptr<scalar_t>(), epsilon, \
|
||||
num_tokens, hidden_size); \
|
||||
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
|
||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
|
||||
vllm::fused_add_rms_norm_kernel<scalar_t, width> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
input.data_ptr<scalar_t>(), input_stride, \
|
||||
residual.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(), \
|
||||
epsilon, num_tokens, hidden_size); \
|
||||
});
|
||||
|
||||
void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
|
||||
torch::Tensor& residual, // [..., hidden_size]
|
||||
torch::Tensor& weight, // [hidden_size]
|
||||
double epsilon) {
|
||||
TORCH_CHECK(residual.is_contiguous());
|
||||
TORCH_CHECK(weight.is_contiguous());
|
||||
int hidden_size = input.size(-1);
|
||||
int64_t input_stride = input.stride(-2);
|
||||
int num_tokens = input.numel() / hidden_size;
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
@ -194,9 +204,16 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
|
||||
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
|
||||
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
|
||||
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
|
||||
bool ptrs_are_aligned =
|
||||
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
|
||||
if (ptrs_are_aligned && hidden_size % 8 == 0) {
|
||||
constexpr int vector_width = 8;
|
||||
constexpr int req_alignment_bytes =
|
||||
vector_width * 2; // vector_width * sizeof(bfloat16 or float16) (float32
|
||||
// falls back to non-vectorized version anyway)
|
||||
bool ptrs_are_aligned = inp_ptr % req_alignment_bytes == 0 &&
|
||||
res_ptr % req_alignment_bytes == 0 &&
|
||||
wt_ptr % req_alignment_bytes == 0;
|
||||
bool offsets_are_multiple_of_vector_width =
|
||||
hidden_size % vector_width == 0 && input_stride % vector_width == 0;
|
||||
if (ptrs_are_aligned && offsets_are_multiple_of_vector_width) {
|
||||
LAUNCH_FUSED_ADD_RMS_NORM(8);
|
||||
} else {
|
||||
LAUNCH_FUSED_ADD_RMS_NORM(0);
|
||||
|
||||
@ -23,8 +23,9 @@ namespace vllm {
|
||||
// TODO(woosuk): Further optimize this kernel.
|
||||
template <typename scalar_t, typename fp8_type>
|
||||
__global__ void rms_norm_static_fp8_quant_kernel(
|
||||
fp8_type* __restrict__ out, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
fp8_type* __restrict__ out, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
const int input_stride,
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float* __restrict__ scale, // [1]
|
||||
const float epsilon, const int num_tokens, const int hidden_size) {
|
||||
@ -32,7 +33,7 @@ __global__ void rms_norm_static_fp8_quant_kernel(
|
||||
float variance = 0.0f;
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
const float x = (float)input[blockIdx.x * hidden_size + idx];
|
||||
const float x = (float)input[blockIdx.x * input_stride + idx];
|
||||
variance += x * x;
|
||||
}
|
||||
|
||||
@ -49,7 +50,7 @@ __global__ void rms_norm_static_fp8_quant_kernel(
|
||||
float const scale_inv = 1.0f / *scale;
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
float x = (float)input[blockIdx.x * hidden_size + idx];
|
||||
float x = (float)input[blockIdx.x * input_stride + idx];
|
||||
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
|
||||
out[blockIdx.x * hidden_size + idx] =
|
||||
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
|
||||
@ -63,8 +64,9 @@ __global__ void rms_norm_static_fp8_quant_kernel(
|
||||
template <typename scalar_t, int width, typename fp8_type>
|
||||
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
|
||||
fused_add_rms_norm_static_fp8_quant_kernel(
|
||||
fp8_type* __restrict__ out, // [..., hidden_size]
|
||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
fp8_type* __restrict__ out, // [..., hidden_size]
|
||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
const int input_stride,
|
||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float* __restrict__ scale, // [1]
|
||||
@ -74,6 +76,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
|
||||
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
|
||||
|
||||
const int vec_hidden_size = hidden_size / width;
|
||||
const int vec_input_stride = input_stride / width;
|
||||
__shared__ float s_variance;
|
||||
float variance = 0.0f;
|
||||
/* These and the argument pointers are all declared `restrict` as they are
|
||||
@ -87,8 +90,9 @@ fused_add_rms_norm_static_fp8_quant_kernel(
|
||||
reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
|
||||
|
||||
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
|
||||
int stride_id = blockIdx.x * vec_input_stride + idx;
|
||||
int id = blockIdx.x * vec_hidden_size + idx;
|
||||
_f16Vec<scalar_t, width> temp = input_v[id];
|
||||
_f16Vec<scalar_t, width> temp = input_v[stride_id];
|
||||
temp += residual_v[id];
|
||||
variance += temp.sum_squares();
|
||||
residual_v[id] = temp;
|
||||
@ -125,8 +129,9 @@ fused_add_rms_norm_static_fp8_quant_kernel(
|
||||
template <typename scalar_t, int width, typename fp8_type>
|
||||
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
|
||||
fused_add_rms_norm_static_fp8_quant_kernel(
|
||||
fp8_type* __restrict__ out, // [..., hidden_size]
|
||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
fp8_type* __restrict__ out, // [..., hidden_size]
|
||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
const int input_stride,
|
||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float* __restrict__ scale, // [1]
|
||||
@ -135,7 +140,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
|
||||
float variance = 0.0f;
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
scalar_t z = input[blockIdx.x * hidden_size + idx];
|
||||
scalar_t z = input[blockIdx.x * input_stride + idx];
|
||||
z += residual[blockIdx.x * hidden_size + idx];
|
||||
float x = (float)z;
|
||||
variance += x * x;
|
||||
@ -169,7 +174,9 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||
torch::Tensor& weight, // [hidden_size]
|
||||
torch::Tensor& scale, // [1]
|
||||
double epsilon) {
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
int hidden_size = input.size(-1);
|
||||
int input_stride = input.stride(-2);
|
||||
int num_tokens = input.numel() / hidden_size;
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
@ -183,8 +190,9 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||
vllm::rms_norm_static_fp8_quant_kernel<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
|
||||
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(),
|
||||
epsilon, num_tokens, hidden_size);
|
||||
input_stride, weight.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(), epsilon, num_tokens,
|
||||
hidden_size);
|
||||
});
|
||||
});
|
||||
}
|
||||
@ -198,7 +206,7 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||
width, fp8_t> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(), \
|
||||
residual.data_ptr<scalar_t>(), \
|
||||
input_stride, residual.data_ptr<scalar_t>(), \
|
||||
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(), \
|
||||
epsilon, num_tokens, hidden_size); \
|
||||
}); \
|
||||
@ -210,7 +218,10 @@ void fused_add_rms_norm_static_fp8_quant(
|
||||
torch::Tensor& weight, // [hidden_size]
|
||||
torch::Tensor& scale, // [1]
|
||||
double epsilon) {
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
TORCH_CHECK(residual.is_contiguous());
|
||||
int hidden_size = input.size(-1);
|
||||
int input_stride = input.stride(-2);
|
||||
int num_tokens = input.numel() / hidden_size;
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
@ -234,7 +245,7 @@ void fused_add_rms_norm_static_fp8_quant(
|
||||
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
|
||||
bool ptrs_are_aligned =
|
||||
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
|
||||
if (ptrs_are_aligned && hidden_size % 8 == 0) {
|
||||
if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0) {
|
||||
LAUNCH_FUSED_ADD_RMS_NORM(8);
|
||||
} else {
|
||||
LAUNCH_FUSED_ADD_RMS_NORM(0);
|
||||
|
||||
@ -1,656 +0,0 @@
|
||||
// clang-format off
|
||||
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_fwd.cu
|
||||
// and https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_update.cu
|
||||
#include <torch/all.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "causal_conv1d.h"
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
||||
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_store.cuh>
|
||||
|
||||
#ifdef USE_ROCM
|
||||
namespace cub = hipcub;
|
||||
#endif
|
||||
|
||||
#include "static_switch.h"
|
||||
|
||||
|
||||
|
||||
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
||||
|
||||
#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
|
||||
if (ITYPE == at::ScalarType::Half) { \
|
||||
using input_t = at::Half; \
|
||||
using weight_t = at::Half; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (ITYPE == at::ScalarType::BFloat16) { \
|
||||
using input_t = at::BFloat16; \
|
||||
using weight_t = at::BFloat16; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (ITYPE == at::ScalarType::Float) { \
|
||||
using input_t = float; \
|
||||
using weight_t = float; \
|
||||
__VA_ARGS__(); \
|
||||
} else { \
|
||||
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
|
||||
}
|
||||
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
|
||||
void set_conv_params_fwd(ConvParamsBase ¶ms,
|
||||
// sizes
|
||||
const size_t batch,
|
||||
const size_t dim,
|
||||
const size_t seqlen,
|
||||
const size_t width,
|
||||
// device pointers
|
||||
const at::Tensor x,
|
||||
const at::Tensor weight,
|
||||
const at::Tensor out,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
bool silu_activation,
|
||||
int64_t pad_slot_id,
|
||||
const std::optional<at::Tensor>& query_start_loc = std::nullopt,
|
||||
const std::optional<at::Tensor>& cache_indices = std::nullopt,
|
||||
const std::optional<at::Tensor>& has_initial_state = std::nullopt) {
|
||||
|
||||
// Reset the parameters
|
||||
memset(¶ms, 0, sizeof(params));
|
||||
|
||||
params.batch = batch;
|
||||
params.dim = dim;
|
||||
params.seqlen = seqlen;
|
||||
params.width = width;
|
||||
params.pad_slot_id = pad_slot_id;
|
||||
|
||||
params.silu_activation = silu_activation;
|
||||
|
||||
// Set the pointers and strides.
|
||||
params.x_ptr = x.data_ptr();
|
||||
params.weight_ptr = weight.data_ptr();
|
||||
params.bias_ptr = bias.has_value() ? bias.value().data_ptr() : nullptr;
|
||||
params.out_ptr = out.data_ptr();
|
||||
// All stride are in elements, not bytes.
|
||||
params.query_start_loc_ptr = query_start_loc.has_value() ? query_start_loc.value().data_ptr() : nullptr;
|
||||
params.cache_indices_ptr = cache_indices.has_value() ? cache_indices.value().data_ptr() : nullptr;
|
||||
params.has_initial_state_ptr = has_initial_state.has_value() ? has_initial_state.value().data_ptr() : nullptr;
|
||||
const bool varlen = params.query_start_loc_ptr != nullptr;
|
||||
params.x_batch_stride = x.stride(varlen ? 1 : 0);
|
||||
params.x_c_stride = x.stride(varlen ? 0 : 1);
|
||||
params.x_l_stride = x.stride(varlen ? 1 : -1);
|
||||
params.weight_c_stride = weight.stride(0);
|
||||
params.weight_width_stride = weight.stride(1);
|
||||
params.out_batch_stride = out.stride(varlen ? 1 : 0);
|
||||
params.out_c_stride = out.stride(varlen ? 0 : 1);
|
||||
params.out_l_stride = out.stride(varlen ? 1 : -1);
|
||||
}
|
||||
|
||||
|
||||
void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
|
||||
const std::optional<at::Tensor> &bias_,
|
||||
const std::optional<at::Tensor> &conv_states,
|
||||
const std::optional<at::Tensor> &query_start_loc,
|
||||
const std::optional<at::Tensor> &cache_indices,
|
||||
const std::optional<at::Tensor> &has_initial_state,
|
||||
bool silu_activation,
|
||||
// used to identify padding entries if cache_indices provided
|
||||
// in case of padding, the kernel will return early
|
||||
int64_t pad_slot_id) {
|
||||
auto input_type = x.scalar_type();
|
||||
auto weight_type = weight.scalar_type();
|
||||
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
||||
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
|
||||
|
||||
TORCH_CHECK(x.is_cuda());
|
||||
TORCH_CHECK(weight.is_cuda());
|
||||
|
||||
const bool varlen = query_start_loc.has_value() ? true : false;
|
||||
const auto sizes = x.sizes();
|
||||
const int batch_size = varlen ? query_start_loc.value().sizes()[0] - 1 : sizes[0];
|
||||
const int dim = varlen ? sizes[0] : sizes[1];
|
||||
const int seqlen = varlen ? sizes[1] : sizes[2];
|
||||
const int width = weight.size(-1);
|
||||
if (varlen){
|
||||
CHECK_SHAPE(x, dim, seqlen);
|
||||
}
|
||||
else {
|
||||
CHECK_SHAPE(x, batch_size, dim, seqlen);
|
||||
}
|
||||
CHECK_SHAPE(weight, dim, width);
|
||||
|
||||
|
||||
|
||||
if (bias_.has_value()) {
|
||||
auto bias = bias_.value();
|
||||
TORCH_CHECK(bias.scalar_type() == weight_type);
|
||||
TORCH_CHECK(bias.is_cuda());
|
||||
TORCH_CHECK(bias.stride(-1) == 1);
|
||||
CHECK_SHAPE(bias, dim);
|
||||
}
|
||||
|
||||
|
||||
if (has_initial_state.has_value()) {
|
||||
auto has_initial_state_ = has_initial_state.value();
|
||||
TORCH_CHECK(has_initial_state_.scalar_type() == at::ScalarType::Bool);
|
||||
TORCH_CHECK(has_initial_state_.is_cuda());
|
||||
CHECK_SHAPE(has_initial_state_, batch_size);
|
||||
}
|
||||
|
||||
|
||||
if (query_start_loc.has_value()) {
|
||||
auto query_start_loc_ = query_start_loc.value();
|
||||
TORCH_CHECK(query_start_loc_.scalar_type() == at::ScalarType::Int);
|
||||
TORCH_CHECK(query_start_loc_.is_cuda());
|
||||
}
|
||||
|
||||
|
||||
if (cache_indices.has_value()) {
|
||||
auto cache_indices_ = cache_indices.value();
|
||||
TORCH_CHECK(cache_indices_.scalar_type() == at::ScalarType::Int);
|
||||
TORCH_CHECK(cache_indices_.is_cuda());
|
||||
CHECK_SHAPE(cache_indices_, batch_size);
|
||||
}
|
||||
|
||||
at::Tensor out = x;
|
||||
|
||||
ConvParamsBase params;
|
||||
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
|
||||
bias_,
|
||||
silu_activation,
|
||||
pad_slot_id,
|
||||
query_start_loc,
|
||||
cache_indices,
|
||||
has_initial_state
|
||||
);
|
||||
|
||||
if (conv_states.has_value()) {
|
||||
auto conv_states_ = conv_states.value();
|
||||
TORCH_CHECK(conv_states_.scalar_type() == input_type);
|
||||
TORCH_CHECK(conv_states_.is_cuda());
|
||||
params.conv_states_ptr = conv_states_.data_ptr();
|
||||
params.conv_states_batch_stride = conv_states_.stride(0);
|
||||
params.conv_states_c_stride = conv_states_.stride(1);
|
||||
params.conv_states_l_stride = conv_states_.stride(2);
|
||||
} else {
|
||||
params.conv_states_ptr = nullptr;
|
||||
}
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] {
|
||||
causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
void causal_conv1d_update(const at::Tensor &x,
|
||||
const at::Tensor &conv_state,
|
||||
const at::Tensor &weight,
|
||||
const std::optional<at::Tensor> &bias_,
|
||||
bool silu_activation,
|
||||
const std::optional<at::Tensor> &cache_seqlens_,
|
||||
const std::optional<at::Tensor> &conv_state_indices_,
|
||||
// used to identify padding entries if cache_indices provided
|
||||
// in case of padding, the kernel will return early
|
||||
int64_t pad_slot_id) {
|
||||
auto input_type = x.scalar_type();
|
||||
auto weight_type = weight.scalar_type();
|
||||
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
||||
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
|
||||
TORCH_CHECK(weight_type == input_type, "weight type must equal to input type, other variations are disabled due to binary size limitations");
|
||||
TORCH_CHECK(conv_state.scalar_type() == input_type);
|
||||
|
||||
TORCH_CHECK(x.is_cuda());
|
||||
TORCH_CHECK(conv_state.is_cuda());
|
||||
TORCH_CHECK(weight.is_cuda());
|
||||
|
||||
const auto sizes = x.sizes();
|
||||
const int batch_size = sizes[0];
|
||||
const int dim = sizes[1];
|
||||
const int seqlen = sizes[2];
|
||||
const int width = weight.size(-1);
|
||||
const int conv_state_len = conv_state.size(2);
|
||||
TORCH_CHECK(conv_state_len >= width - 1);
|
||||
|
||||
CHECK_SHAPE(x, batch_size, dim, seqlen);
|
||||
CHECK_SHAPE(weight, dim, width);
|
||||
|
||||
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
|
||||
|
||||
if (bias_.has_value()) {
|
||||
auto bias = bias_.value();
|
||||
TORCH_CHECK(bias.scalar_type() == weight_type);
|
||||
TORCH_CHECK(bias.is_cuda());
|
||||
TORCH_CHECK(bias.stride(-1) == 1);
|
||||
CHECK_SHAPE(bias, dim);
|
||||
}
|
||||
|
||||
at::Tensor out = x;
|
||||
|
||||
ConvParamsBase params;
|
||||
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
|
||||
bias_,
|
||||
silu_activation,
|
||||
pad_slot_id);
|
||||
params.conv_state_ptr = conv_state.data_ptr();
|
||||
params.conv_state_len = conv_state_len;
|
||||
// All stride are in elements, not bytes.
|
||||
params.conv_state_batch_stride = conv_state.stride(0);
|
||||
params.conv_state_c_stride = conv_state.stride(1);
|
||||
params.conv_state_l_stride = conv_state.stride(2);
|
||||
|
||||
if (cache_seqlens_.has_value()) {
|
||||
auto cache_seqlens = cache_seqlens_.value();
|
||||
TORCH_CHECK(cache_seqlens.scalar_type() == torch::kInt32);
|
||||
TORCH_CHECK(cache_seqlens.is_cuda());
|
||||
TORCH_CHECK(cache_seqlens.stride(-1) == 1);
|
||||
CHECK_SHAPE(cache_seqlens, batch_size);
|
||||
params.cache_seqlens = cache_seqlens.data_ptr<int32_t>();
|
||||
} else {
|
||||
params.cache_seqlens = nullptr;
|
||||
}
|
||||
|
||||
if (conv_state_indices_.has_value()) {
|
||||
auto conv_state_indices = conv_state_indices_.value();
|
||||
TORCH_CHECK(conv_state_indices.scalar_type() == torch::kInt32)
|
||||
TORCH_CHECK(conv_state_indices.is_cuda());
|
||||
TORCH_CHECK(conv_state_indices.stride(0) == 1)
|
||||
CHECK_SHAPE(conv_state_indices, batch_size);
|
||||
|
||||
int conv_state_entries = conv_state.size(0);
|
||||
CHECK_SHAPE(conv_state, conv_state_entries, dim, conv_state_len);
|
||||
|
||||
params.conv_state_indices_ptr = conv_state_indices.data_ptr<int32_t>();
|
||||
} else {
|
||||
CHECK_SHAPE(conv_state, batch_size, dim, conv_state_len);
|
||||
params.conv_state_indices_ptr = nullptr;
|
||||
}
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] {
|
||||
causal_conv1d_update_cuda<input_t, weight_t>(params, stream);
|
||||
});
|
||||
}
|
||||
|
||||
template<int kNThreads_, int kWidth_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
|
||||
struct Causal_conv1d_fwd_kernel_traits {
|
||||
using input_t = input_t_;
|
||||
using weight_t = weight_t_;
|
||||
static constexpr int kNThreads = kNThreads_;
|
||||
static constexpr int kWidth = kWidth_;
|
||||
static constexpr int kNBytes = sizeof(input_t);
|
||||
static_assert(kNBytes == 2 || kNBytes == 4);
|
||||
static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
|
||||
static_assert(kWidth <= kNElts);
|
||||
static constexpr bool kIsVecLoad = kIsVecLoad_;
|
||||
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
||||
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
||||
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>;
|
||||
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
||||
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, 1, cub::BLOCK_STORE_DIRECT>;
|
||||
static constexpr int kSmemIOSize = kIsVecLoad
|
||||
? 0
|
||||
: custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)});
|
||||
static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts;
|
||||
static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize;
|
||||
};
|
||||
|
||||
template<typename Ktraits>
|
||||
__global__ __launch_bounds__(Ktraits::kNThreads)
|
||||
void causal_conv1d_fwd_kernel(ConvParamsBase params) {
|
||||
constexpr int kWidth = Ktraits::kWidth;
|
||||
constexpr int kNThreads = Ktraits::kNThreads;
|
||||
constexpr int kNElts = Ktraits::kNElts;
|
||||
constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
|
||||
using input_t = typename Ktraits::input_t;
|
||||
using vec_t = typename Ktraits::vec_t;
|
||||
using weight_t = typename Ktraits::weight_t;
|
||||
|
||||
// Shared memory.
|
||||
extern __shared__ char smem_[];
|
||||
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
|
||||
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_);
|
||||
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
|
||||
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
|
||||
vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize);
|
||||
|
||||
const bool kVarlen = params.query_start_loc_ptr != nullptr;
|
||||
const int tidx = threadIdx.x;
|
||||
const int batch_id = blockIdx.x;
|
||||
const int channel_id = blockIdx.y;
|
||||
const int *query_start_loc = kVarlen ? reinterpret_cast<int *>(params.query_start_loc_ptr) : nullptr;
|
||||
const int sequence_start_index = kVarlen ? query_start_loc[batch_id] : batch_id;
|
||||
const int seqlen = kVarlen ? query_start_loc[batch_id + 1] - sequence_start_index : params.seqlen;
|
||||
|
||||
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + sequence_start_index * params.x_batch_stride
|
||||
+ channel_id * params.x_c_stride;
|
||||
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
|
||||
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + sequence_start_index * params.out_batch_stride
|
||||
+ channel_id * params.out_c_stride;
|
||||
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
|
||||
|
||||
bool has_initial_state = params.has_initial_state_ptr == nullptr ? false
|
||||
: reinterpret_cast<bool *>(params.has_initial_state_ptr)[batch_id];
|
||||
|
||||
int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
|
||||
: reinterpret_cast<int *>(params.cache_indices_ptr);
|
||||
int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
|
||||
// cache_index == params.pad_slot_id is defined as padding, so we exit early
|
||||
if (cache_index == params.pad_slot_id){
|
||||
return;
|
||||
}
|
||||
input_t *conv_states = params.conv_states_ptr == nullptr ? nullptr
|
||||
: reinterpret_cast<input_t *>(params.conv_states_ptr) + cache_index * params.conv_states_batch_stride + channel_id * params.conv_states_c_stride;
|
||||
|
||||
// Thread 0 will load the last elements of the previous chunk, so we initialize those to 0.
|
||||
if (tidx == 0) {
|
||||
input_t initial_state[kNElts] = {0};
|
||||
if (has_initial_state) {
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth - 1; ++w){ initial_state[kNElts - 1 - (kWidth - 2) + w ] = conv_states[w]; }
|
||||
}
|
||||
smem_exchange[kNThreads - 1] = reinterpret_cast<vec_t *>(initial_state)[0];
|
||||
}
|
||||
|
||||
float weight_vals[kWidth];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
|
||||
|
||||
constexpr int kChunkSize = kNThreads * kNElts;
|
||||
const int n_chunks = (seqlen + kChunkSize - 1) / kChunkSize;
|
||||
for (int chunk = 0; chunk < n_chunks; ++chunk) {
|
||||
input_t x_vals_load[2 * kNElts] = {0};
|
||||
if constexpr(kIsVecLoad) {
|
||||
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (seqlen - chunk * kChunkSize) / kNElts);
|
||||
} else {
|
||||
__syncthreads();
|
||||
typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), seqlen - chunk * kChunkSize);
|
||||
}
|
||||
x += kChunkSize;
|
||||
__syncthreads();
|
||||
// Thread kNThreads - 1 don't write yet, so that thread 0 can read
|
||||
// the last elements of the previous chunk.
|
||||
if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
|
||||
__syncthreads();
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1];
|
||||
__syncthreads();
|
||||
// Now thread kNThreads - 1 can write the last elements of the current chunk.
|
||||
if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
|
||||
|
||||
float x_vals[2 * kNElts];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); }
|
||||
|
||||
float out_vals[kNElts];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNElts; ++i) {
|
||||
out_vals[i] = bias_val;
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth; ++w) {
|
||||
out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
|
||||
}
|
||||
}
|
||||
|
||||
if (params.silu_activation) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNElts; ++i) {
|
||||
out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i]));
|
||||
}
|
||||
}
|
||||
|
||||
input_t out_vals_store[kNElts];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; }
|
||||
if constexpr(kIsVecLoad) {
|
||||
typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(out), reinterpret_cast<vec_t (&)[1]>(out_vals_store), (seqlen - chunk * kChunkSize) / kNElts);
|
||||
} else {
|
||||
typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, seqlen - chunk * kChunkSize);
|
||||
}
|
||||
out += kChunkSize;
|
||||
|
||||
int final_state_position = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize);
|
||||
// in case the final state is separated between the last "smem_exchange" and
|
||||
// and the one before it (chunk = n_chunks - 1 and chunk = n_chunks - 2),
|
||||
// (which occurs when `final_state_position` is a non-positive index)
|
||||
// we load the correct data from smem_exchange from both chunks, the last chunk iteration and the one before it
|
||||
if (conv_states != nullptr && final_state_position < 0 && seqlen > kWidth){
|
||||
input_t vals_load[kNElts] = {0};
|
||||
if ((chunk == n_chunks - 2) && (tidx == kNThreads - 1)){
|
||||
// chunk = n_chunks - 2, a segment of the final state sits in the last index
|
||||
reinterpret_cast<vec_t *>(vals_load)[0] = smem_exchange[kNThreads - 1];
|
||||
#pragma unroll
|
||||
for (int w = 0; w < -final_state_position; ++w){
|
||||
conv_states[w] = vals_load[kNElts + final_state_position + w];
|
||||
}
|
||||
}
|
||||
if ((chunk == n_chunks - 1) && tidx == 0){
|
||||
// chunk = n_chunks - 1, the second segment of the final state first positions
|
||||
reinterpret_cast<vec_t *>(vals_load)[0] = smem_exchange[0];
|
||||
for (int w = -final_state_position; w < kWidth - 1; ++w){
|
||||
conv_states[w] = vals_load[w + final_state_position];
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Final state is stored in the smem_exchange last token slot,
|
||||
// in case seqlen < kWidth, we would need to take the final state from the
|
||||
// initial state which is stored in conv_states
|
||||
// in case seqlen > kWidth, we would need to load the last kWidth - 1 data
|
||||
// and load it into conv_state accordingly
|
||||
int last_thread = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize) / kNElts;
|
||||
if (conv_states != nullptr && tidx == last_thread) {
|
||||
input_t x_vals_load[kNElts * 2] = {0};
|
||||
// in case we are on the first kWidth tokens
|
||||
if (last_thread == 0 && seqlen < kWidth){
|
||||
// Need to take the initial state
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[0];
|
||||
const int offset = seqlen - (kWidth - 1);
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth - 1; ++w){
|
||||
// pad the existing state
|
||||
if ((w - seqlen) >= 0 && has_initial_state) { conv_states[w - seqlen] = conv_states[w]; }
|
||||
else if ((w - seqlen) >= 0 && !has_initial_state) { conv_states[w - seqlen] = input_t(0.0f); }
|
||||
}
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth - 1; ++w){
|
||||
if (offset + w >= 0)
|
||||
conv_states[w] = x_vals_load[offset + w ];
|
||||
}
|
||||
}
|
||||
else {
|
||||
// in case the final state is in between the threads data
|
||||
const int offset = ((seqlen - (kWidth - 1)) % (kNElts));
|
||||
if ((offset + kWidth - 2) >= kNElts && (last_thread + 1 < kNThreads)){
|
||||
// In case last_thread == kNThreads - 1, accessing last_thread + 1 will result in a
|
||||
// illegal access error on H100.
|
||||
// Therefore, we access last_thread + 1, only if the final state data sits there
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[1] = smem_exchange[last_thread + 1];
|
||||
}
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[last_thread];
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth - 1; ++w){
|
||||
conv_states[w] = x_vals_load[offset + w ];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
||||
void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
|
||||
const bool kVarlen = params.query_start_loc_ptr != nullptr;
|
||||
BOOL_SWITCH(params.seqlen % kNElts == 0 && !kVarlen, kIsVecLoad, [&] {
|
||||
using Ktraits = Causal_conv1d_fwd_kernel_traits<kNThreads, kWidth, kIsVecLoad, input_t, weight_t>;
|
||||
constexpr int kSmemSize = Ktraits::kSmemSize;
|
||||
dim3 grid(params.batch, params.dim);
|
||||
|
||||
auto kernel = &causal_conv1d_fwd_kernel<Ktraits>;
|
||||
|
||||
if (kSmemSize >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
|
||||
}
|
||||
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
||||
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
}
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
if (params.width == 2) {
|
||||
causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 3) {
|
||||
causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 4) {
|
||||
causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template void causal_conv1d_fwd_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_fwd_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
|
||||
|
||||
|
||||
|
||||
template<int kNThreads_, int kWidth_, typename input_t_, typename weight_t_>
|
||||
struct Causal_conv1d_update_kernel_traits {
|
||||
using input_t = input_t_;
|
||||
using weight_t = weight_t_;
|
||||
static constexpr int kNThreads = kNThreads_;
|
||||
static constexpr int kWidth = kWidth_;
|
||||
static constexpr int kNBytes = sizeof(input_t);
|
||||
static_assert(kNBytes == 2 || kNBytes == 4);
|
||||
};
|
||||
|
||||
template<typename Ktraits, bool kIsCircularBuffer>
|
||||
__global__ __launch_bounds__(Ktraits::kNThreads)
|
||||
void causal_conv1d_update_kernel(ConvParamsBase params) {
|
||||
constexpr int kWidth = Ktraits::kWidth;
|
||||
constexpr int kNThreads = Ktraits::kNThreads;
|
||||
using input_t = typename Ktraits::input_t;
|
||||
using weight_t = typename Ktraits::weight_t;
|
||||
|
||||
const int tidx = threadIdx.x;
|
||||
const int batch_id = blockIdx.x;
|
||||
const int channel_id = blockIdx.y * kNThreads + tidx;
|
||||
if (channel_id >= params.dim) return;
|
||||
|
||||
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
||||
+ channel_id * params.x_c_stride;
|
||||
|
||||
// If params.conv_state_batch_indices is set, then the conv state is gathered from the conv state tensor
|
||||
// along the batch axis. Otherwise, the conv state coordinate is the same as the batch id.
|
||||
const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr
|
||||
? batch_id
|
||||
: params.conv_state_indices_ptr[batch_id];
|
||||
// conv_state_batch_coord == params.pad_slot_id is defined as padding so we exit early
|
||||
if (conv_state_batch_coord == params.pad_slot_id){
|
||||
return;
|
||||
}
|
||||
input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr)
|
||||
+ conv_state_batch_coord * params.conv_state_batch_stride
|
||||
+ channel_id * params.conv_state_c_stride;
|
||||
|
||||
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
|
||||
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
||||
+ channel_id * params.out_c_stride;
|
||||
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
|
||||
|
||||
int state_len = params.conv_state_len;
|
||||
int advance_len = params.seqlen;
|
||||
int cache_seqlen = kIsCircularBuffer ? params.cache_seqlens[batch_id] % state_len : 0;
|
||||
int update_idx = cache_seqlen - (kWidth - 1);
|
||||
update_idx = update_idx < 0 ? update_idx + state_len : update_idx;
|
||||
|
||||
float weight_vals[kWidth] = {0};
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
|
||||
|
||||
float x_vals[kWidth] = {0};
|
||||
if constexpr (!kIsCircularBuffer) {
|
||||
#pragma unroll 2
|
||||
for (int i = 0; i < state_len - advance_len - (kWidth - 1); ++i) {
|
||||
conv_state[i * params.conv_state_l_stride] = conv_state[(i + advance_len) * params.conv_state_l_stride];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth - 1; ++i) {
|
||||
input_t state_val = conv_state[(state_len - (kWidth - 1) + i) * params.conv_state_l_stride];
|
||||
if (i < advance_len + (kWidth - 1) && state_len - advance_len - (kWidth - 1) + i >= 0) {
|
||||
conv_state[(state_len - advance_len - (kWidth - 1) + i) * params.conv_state_l_stride] = state_val;
|
||||
}
|
||||
x_vals[i] = float(state_val);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth - 1; ++i, update_idx = update_idx + 1 >= state_len ? update_idx + 1 - state_len : update_idx + 1) {
|
||||
input_t state_val = conv_state[update_idx * params.conv_state_l_stride];
|
||||
x_vals[i] = float(state_val);
|
||||
}
|
||||
}
|
||||
#pragma unroll 2
|
||||
for (int i = 0; i < params.seqlen; ++i) {
|
||||
input_t x_val = x[i * params.x_l_stride];
|
||||
if constexpr (!kIsCircularBuffer) {
|
||||
if (i < advance_len && state_len - advance_len + i >= 0) {
|
||||
conv_state[(state_len - advance_len + i) * params.conv_state_l_stride] = x_val;
|
||||
}
|
||||
} else {
|
||||
conv_state[update_idx * params.conv_state_l_stride] = x_val;
|
||||
++update_idx;
|
||||
update_idx = update_idx >= state_len ? update_idx - state_len : update_idx;
|
||||
}
|
||||
x_vals[kWidth - 1] = float(x_val);
|
||||
float out_val = bias_val;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < kWidth; ++j) { out_val += weight_vals[j] * x_vals[j]; }
|
||||
if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); }
|
||||
out[i * params.out_l_stride] = input_t(out_val);
|
||||
// Shift the input buffer by 1
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = x_vals[i + 1]; }
|
||||
}
|
||||
}
|
||||
|
||||
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
||||
void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
using Ktraits = Causal_conv1d_update_kernel_traits<kNThreads, kWidth, input_t, weight_t>;
|
||||
dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads);
|
||||
auto kernel = params.cache_seqlens == nullptr
|
||||
? &causal_conv1d_update_kernel<Ktraits, false>
|
||||
: &causal_conv1d_update_kernel<Ktraits, true>;
|
||||
kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
if (params.width == 2) {
|
||||
causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 3) {
|
||||
causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 4) {
|
||||
causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream);
|
||||
}
|
||||
}
|
||||
|
||||
template void causal_conv1d_update_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_update_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_update_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
@ -1,159 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
// clang-format off
|
||||
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d.h
|
||||
#pragma once
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct ConvParamsBase {
|
||||
using index_t = uint32_t;
|
||||
|
||||
int batch, dim, seqlen, width;
|
||||
int64_t pad_slot_id;
|
||||
bool silu_activation;
|
||||
|
||||
index_t x_batch_stride;
|
||||
index_t x_c_stride;
|
||||
index_t x_l_stride;
|
||||
index_t weight_c_stride;
|
||||
index_t weight_width_stride;
|
||||
index_t out_batch_stride;
|
||||
index_t out_c_stride;
|
||||
index_t out_l_stride;
|
||||
|
||||
int conv_state_len;
|
||||
index_t conv_state_batch_stride;
|
||||
index_t conv_state_c_stride;
|
||||
index_t conv_state_l_stride;
|
||||
|
||||
// Common data pointers.
|
||||
void *__restrict__ x_ptr;
|
||||
void *__restrict__ weight_ptr;
|
||||
void *__restrict__ bias_ptr;
|
||||
void *__restrict__ out_ptr;
|
||||
|
||||
void *__restrict__ conv_state_ptr;
|
||||
void *__restrict__ query_start_loc_ptr;
|
||||
void *__restrict__ has_initial_state_ptr;
|
||||
void *__restrict__ cache_indices_ptr;
|
||||
int32_t *__restrict__ cache_seqlens;
|
||||
|
||||
// For the continuous batching case. Makes it so that the mamba state for
|
||||
// the current batch doesn't need to be a contiguous tensor.
|
||||
int32_t *__restrict__ conv_state_indices_ptr;
|
||||
|
||||
void *__restrict__ seq_idx_ptr;
|
||||
|
||||
// No __restrict__ since initial_states could be the same as final_states.
|
||||
void * initial_states_ptr;
|
||||
index_t initial_states_batch_stride;
|
||||
index_t initial_states_l_stride;
|
||||
index_t initial_states_c_stride;
|
||||
|
||||
void * final_states_ptr;
|
||||
index_t final_states_batch_stride;
|
||||
index_t final_states_l_stride;
|
||||
index_t final_states_c_stride;
|
||||
|
||||
void * conv_states_ptr;
|
||||
index_t conv_states_batch_stride;
|
||||
index_t conv_states_l_stride;
|
||||
index_t conv_states_c_stride;
|
||||
};
|
||||
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
template<typename T>
|
||||
__device__ inline T shuffle_xor(T val, int offset) {
|
||||
return __shfl_xor_sync(uint32_t(-1), val, offset);
|
||||
}
|
||||
|
||||
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
||||
{
|
||||
return std::max(ilist);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr T constexpr_min(T a, T b) {
|
||||
return std::min(a, b);
|
||||
}
|
||||
|
||||
#else
|
||||
#include <hip/hip_bf16.h>
|
||||
|
||||
template<typename T>
|
||||
__device__ inline T shuffle_xor(T val, int offset) {
|
||||
return __shfl_xor(val, offset);
|
||||
}
|
||||
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
||||
{
|
||||
return *std::max_element(ilist.begin(), ilist.end());
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr T constexpr_min(T a, T b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<int BYTES> struct BytesToType {};
|
||||
|
||||
template<> struct BytesToType<16> {
|
||||
using Type = uint4;
|
||||
static_assert(sizeof(Type) == 16);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<8> {
|
||||
using Type = uint64_t;
|
||||
static_assert(sizeof(Type) == 8);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<4> {
|
||||
using Type = uint32_t;
|
||||
static_assert(sizeof(Type) == 4);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<2> {
|
||||
using Type = uint16_t;
|
||||
static_assert(sizeof(Type) == 2);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<1> {
|
||||
using Type = uint8_t;
|
||||
static_assert(sizeof(Type) == 1);
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
struct SumOp {
|
||||
__device__ inline T operator()(T const & x, T const & y) { return x + y; }
|
||||
};
|
||||
|
||||
template<int THREADS>
|
||||
struct Allreduce {
|
||||
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
|
||||
template<typename T, typename Operator>
|
||||
static __device__ inline T run(T x, Operator &op) {
|
||||
constexpr int OFFSET = THREADS / 2;
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
|
||||
return Allreduce<OFFSET>::run(x, op);
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct Allreduce<2> {
|
||||
template<typename T, typename Operator>
|
||||
static __device__ inline T run(T x, Operator &op) {
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
|
||||
return x;
|
||||
}
|
||||
};
|
||||
@ -1,28 +0,0 @@
|
||||
// Inspired by
|
||||
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
|
||||
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
|
||||
// clang-format off
|
||||
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/static_switch.h
|
||||
|
||||
#pragma once
|
||||
|
||||
/// @param COND - a boolean expression to switch by
|
||||
/// @param CONST_NAME - a name given for the constexpr bool variable.
|
||||
/// @param ... - code to execute for true and false
|
||||
///
|
||||
/// Usage:
|
||||
/// ```
|
||||
/// BOOL_SWITCH(flag, BoolConst, [&] {
|
||||
/// some_function<BoolConst>(...);
|
||||
/// });
|
||||
/// ```
|
||||
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
|
||||
[&] { \
|
||||
if (COND) { \
|
||||
static constexpr bool CONST_NAME = true; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
static constexpr bool CONST_NAME = false; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
||||
@ -7,7 +7,11 @@
|
||||
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
||||
#ifdef USE_ROCM
|
||||
#include <c10/hip/HIPException.h> // For C10_HIP_CHECK and C10_HIP_KERNEL_LAUNCH_CHECK
|
||||
#else
|
||||
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cub/block/block_load.cuh>
|
||||
@ -312,19 +316,25 @@ void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) {
|
||||
// kIsVariableB, kIsVariableC and kHasZ are all set to True to reduce binary size
|
||||
constexpr bool kIsVariableB = true;
|
||||
constexpr bool kIsVariableC = true;
|
||||
constexpr bool kHasZ = true;
|
||||
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
|
||||
BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] {
|
||||
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kVarlen, input_t, weight_t>;
|
||||
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
|
||||
dim3 grid(params.batch, params.dim / kNRows);
|
||||
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
|
||||
if (kSmemSize >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
}
|
||||
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
|
||||
BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] {
|
||||
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kVarlen, input_t, weight_t>;
|
||||
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
|
||||
dim3 grid(params.batch, params.dim / kNRows);
|
||||
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
|
||||
if (kSmemSize >= 48 * 1024) {
|
||||
#ifdef USE_ROCM
|
||||
C10_HIP_CHECK(hipFuncSetAttribute(
|
||||
reinterpret_cast<const void*>(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
#else
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
#endif
|
||||
}
|
||||
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
@ -612,19 +622,20 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||
|
||||
at::Tensor z, out_z;
|
||||
const bool has_z = z_.has_value();
|
||||
TORCH_CHECK(has_z, "has_z = False is disabled in favor of reduced binary size")
|
||||
z = z_.value();
|
||||
TORCH_CHECK(z.scalar_type() == input_type);
|
||||
TORCH_CHECK(z.is_cuda());
|
||||
TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
|
||||
if (varlen){
|
||||
CHECK_SHAPE(z, dim, seqlen);
|
||||
} else {
|
||||
CHECK_SHAPE(z, batch_size, dim, seqlen);
|
||||
if (has_z) {
|
||||
z = z_.value();
|
||||
TORCH_CHECK(z.scalar_type() == input_type);
|
||||
TORCH_CHECK(z.is_cuda());
|
||||
TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
|
||||
if (varlen){
|
||||
CHECK_SHAPE(z, dim, seqlen);
|
||||
} else {
|
||||
CHECK_SHAPE(z, batch_size, dim, seqlen);
|
||||
}
|
||||
|
||||
out_z = z;
|
||||
}
|
||||
|
||||
out_z = z;
|
||||
|
||||
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
|
||||
at::Tensor out = delta;
|
||||
TORCH_CHECK(ssm_states.scalar_type() == input_type);
|
||||
@ -653,4 +664,3 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||
selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@ -1255,8 +1255,6 @@ __global__ void Marlin(
|
||||
if constexpr (has_zp && !is_zp_float) {
|
||||
if (is_new_zp) {
|
||||
if constexpr (group_blocks == -1) is_first_matmul_in_slice = false;
|
||||
FragB frag_zp_0;
|
||||
FragB frag_zp_1;
|
||||
int zp_quant_0, zp_quant_1;
|
||||
|
||||
if constexpr (w_type.size_bits() == 4) {
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
#include <torch/all.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cub/cub.cuh>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/Atomic.cuh>
|
||||
@ -19,9 +20,14 @@ __global__ void moe_align_block_size_kernel(
|
||||
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
|
||||
int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts,
|
||||
int32_t padded_num_experts, int32_t experts_per_warp, int32_t block_size,
|
||||
size_t numel, int32_t* __restrict__ cumsum) {
|
||||
size_t numel, int32_t* __restrict__ cumsum, int32_t max_num_tokens_padded) {
|
||||
extern __shared__ int32_t shared_counts[];
|
||||
|
||||
// Initialize sorted_token_ids with numel
|
||||
for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += blockDim.x) {
|
||||
sorted_token_ids[it] = numel;
|
||||
}
|
||||
|
||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||
const int my_expert_start = warp_id * experts_per_warp;
|
||||
|
||||
@ -45,18 +51,27 @@ __global__ void moe_align_block_size_kernel(
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
cumsum[0] = 0;
|
||||
for (int i = 1; i <= num_experts; ++i) {
|
||||
int expert_count = 0;
|
||||
int warp_idx = (i - 1) / experts_per_warp;
|
||||
int expert_offset = (i - 1) % experts_per_warp;
|
||||
expert_count = shared_counts[warp_idx * experts_per_warp + expert_offset];
|
||||
// Compute prefix sum over token counts per expert
|
||||
using BlockScan = cub::BlockScan<int32_t, 1024>;
|
||||
__shared__ typename BlockScan::TempStorage temp_storage;
|
||||
|
||||
cumsum[i] =
|
||||
cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size;
|
||||
}
|
||||
*total_tokens_post_pad = cumsum[num_experts];
|
||||
int expert_count = 0;
|
||||
int expert_id = threadIdx.x;
|
||||
if (expert_id < num_experts) {
|
||||
int warp_idx = expert_id / experts_per_warp;
|
||||
int expert_offset = expert_id % experts_per_warp;
|
||||
expert_count = shared_counts[warp_idx * experts_per_warp + expert_offset];
|
||||
expert_count = CEILDIV(expert_count, block_size) * block_size;
|
||||
}
|
||||
|
||||
int cumsum_val;
|
||||
BlockScan(temp_storage).ExclusiveSum(expert_count, cumsum_val);
|
||||
if (expert_id <= num_experts) {
|
||||
cumsum[expert_id] = cumsum_val;
|
||||
}
|
||||
|
||||
if (expert_id == num_experts) {
|
||||
*total_tokens_post_pad = cumsum_val;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
@ -67,6 +82,13 @@ __global__ void moe_align_block_size_kernel(
|
||||
expert_ids[i / block_size] = threadIdx.x;
|
||||
}
|
||||
}
|
||||
|
||||
// Fill remaining expert_ids with 0
|
||||
const size_t fill_start_idx = cumsum[num_experts] / block_size + threadIdx.x;
|
||||
const size_t expert_ids_size = CEILDIV(max_num_tokens_padded, block_size);
|
||||
for (size_t i = fill_start_idx; i < expert_ids_size; i += blockDim.x) {
|
||||
expert_ids[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
@ -105,7 +127,12 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
|
||||
const scalar_t* __restrict__ topk_ids,
|
||||
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
|
||||
int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts,
|
||||
int32_t block_size, size_t numel) {
|
||||
int32_t block_size, size_t numel, int32_t max_num_tokens_padded) {
|
||||
// Initialize sorted_token_ids with numel
|
||||
for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += blockDim.x) {
|
||||
sorted_token_ids[it] = numel;
|
||||
}
|
||||
|
||||
const size_t tid = threadIdx.x;
|
||||
const size_t stride = blockDim.x;
|
||||
|
||||
@ -153,6 +180,13 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
// Fill remaining expert_ids with 0
|
||||
const size_t fill_start_idx = cumsum[num_experts] / block_size + threadIdx.x;
|
||||
const size_t expert_ids_size = CEILDIV(max_num_tokens_padded, block_size);
|
||||
for (size_t i = fill_start_idx; i < expert_ids_size; i += blockDim.x) {
|
||||
expert_ids[i] = 0;
|
||||
}
|
||||
|
||||
for (size_t i = tid; i < numel; i += stride) {
|
||||
int32_t expert_id = topk_ids[i];
|
||||
int32_t rank_post_pad =
|
||||
@ -179,13 +213,17 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
int threads = 1024;
|
||||
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||
|
||||
// BlockScan uses 1024 threads and assigns one thread per expert.
|
||||
TORCH_CHECK(padded_num_experts < 1024,
|
||||
"padded_num_experts must be less than 1024");
|
||||
|
||||
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
|
||||
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||
// calc needed amount of shared mem for `cumsum` tensors
|
||||
auto options_int =
|
||||
torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device());
|
||||
torch::Tensor cumsum_buffer =
|
||||
torch::zeros({num_experts + 1}, options_int);
|
||||
torch::empty({num_experts + 1}, options_int);
|
||||
bool small_batch_expert_mode =
|
||||
(topk_ids.numel() < 1024) && (num_experts <= 64);
|
||||
|
||||
@ -203,7 +241,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
experts_ids.data_ptr<int32_t>(),
|
||||
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
|
||||
topk_ids.numel());
|
||||
topk_ids.numel(), sorted_token_ids.size(0));
|
||||
} else {
|
||||
auto align_kernel = vllm::moe::moe_align_block_size_kernel<scalar_t>;
|
||||
|
||||
@ -217,7 +255,8 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
experts_ids.data_ptr<int32_t>(),
|
||||
num_tokens_post_pad.data_ptr<int32_t>(), num_experts,
|
||||
padded_num_experts, experts_per_warp, block_size,
|
||||
topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>());
|
||||
topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>(),
|
||||
sorted_token_ids.size(0));
|
||||
|
||||
const int block_threads = std::min(256, (int)threads);
|
||||
const int num_blocks =
|
||||
@ -239,7 +278,7 @@ void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
|
||||
torch::Tensor& output) // [num_tokens, hidden_size]
|
||||
{
|
||||
const int hidden_size = input.size(-1);
|
||||
const int num_tokens = output.numel() / hidden_size;
|
||||
const auto num_tokens = output.numel() / hidden_size;
|
||||
const int topk = input.size(1);
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
|
||||
@ -10,32 +10,28 @@
|
||||
|
||||
void moe_permute(
|
||||
const torch::Tensor& input, // [n_token, hidden]
|
||||
const torch::Tensor& topk_weights, //[n_token, topk]
|
||||
torch::Tensor& topk_ids, // [n_token, topk]
|
||||
const torch::Tensor& topk_ids, // [n_token, topk]
|
||||
const torch::Tensor& token_expert_indices, // [n_token, topk]
|
||||
const std::optional<torch::Tensor>& expert_map, // [n_expert]
|
||||
int64_t n_expert, int64_t n_local_expert, int64_t topk,
|
||||
const std::optional<int64_t>& align_block_size,
|
||||
torch::Tensor&
|
||||
permuted_input, // [topk * n_token/align_block_size_m, hidden]
|
||||
torch::Tensor& permuted_input, // [permuted_size, hidden]
|
||||
torch::Tensor& expert_first_token_offset, // [n_local_expert + 1]
|
||||
torch::Tensor& src_row_id2dst_row_id_map, // [n_token, topk]
|
||||
torch::Tensor& inv_permuted_idx, // [n_token, topk]
|
||||
torch::Tensor& permuted_idx, // [permute_size]
|
||||
torch::Tensor& m_indices) { // [align_expand_m]
|
||||
TORCH_CHECK(topk_weights.scalar_type() == at::ScalarType::Float,
|
||||
"topk_weights must be float32");
|
||||
TORCH_CHECK(expert_first_token_offset.scalar_type() == at::ScalarType::Long,
|
||||
"expert_first_token_offset must be int64");
|
||||
TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int,
|
||||
"topk_ids must be int32");
|
||||
TORCH_CHECK(token_expert_indices.scalar_type() == at::ScalarType::Int,
|
||||
"token_expert_indices must be int32");
|
||||
TORCH_CHECK(src_row_id2dst_row_id_map.scalar_type() == at::ScalarType::Int,
|
||||
"src_row_id2dst_row_id_map must be int32");
|
||||
TORCH_CHECK(inv_permuted_idx.scalar_type() == at::ScalarType::Int,
|
||||
"inv_permuted_idx must be int32");
|
||||
TORCH_CHECK(expert_first_token_offset.size(0) == n_local_expert + 1,
|
||||
"expert_first_token_offset shape != n_local_expert+1")
|
||||
TORCH_CHECK(
|
||||
src_row_id2dst_row_id_map.sizes() == token_expert_indices.sizes(),
|
||||
"token_expert_indices shape must be same as src_row_id2dst_row_id_map");
|
||||
TORCH_CHECK(inv_permuted_idx.sizes() == token_expert_indices.sizes(),
|
||||
"token_expert_indices shape must be same as inv_permuted_idx");
|
||||
auto n_token = input.sizes()[0];
|
||||
auto n_hidden = input.sizes()[1];
|
||||
auto align_block_size_value =
|
||||
@ -46,8 +42,9 @@ void moe_permute(
|
||||
auto sort_workspace = torch::empty(
|
||||
{sorter_size},
|
||||
torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false));
|
||||
auto copy_topk_ids = topk_ids.clone(); // copy topk_ids for preprocess
|
||||
auto permuted_experts_id = torch::empty_like(topk_ids);
|
||||
auto dst_row_id2src_row_id_map = torch::empty_like(src_row_id2dst_row_id_map);
|
||||
auto sorted_row_idx = torch::empty_like(inv_permuted_idx);
|
||||
auto align_expert_first_token_offset =
|
||||
torch::zeros_like(expert_first_token_offset);
|
||||
|
||||
@ -67,24 +64,22 @@ void moe_permute(
|
||||
const int* expert_map_ptr = get_ptr<int>(expert_map.value());
|
||||
valid_num_ptr =
|
||||
get_ptr<int64_t>(expert_first_token_offset) + n_local_expert;
|
||||
preprocessTopkIdLauncher(get_ptr<int>(topk_ids), n_token * topk,
|
||||
preprocessTopkIdLauncher(get_ptr<int>(copy_topk_ids), n_token * topk,
|
||||
expert_map_ptr, n_expert, stream);
|
||||
}
|
||||
// expert sort topk expert id and scan expert id get expert_first_token_offset
|
||||
sortAndScanExpert(get_ptr<int>(topk_ids), get_ptr<int>(token_expert_indices),
|
||||
get_ptr<int>(permuted_experts_id),
|
||||
get_ptr<int>(dst_row_id2src_row_id_map),
|
||||
get_ptr<int64_t>(expert_first_token_offset), n_token,
|
||||
n_expert, n_local_expert, topk, sorter,
|
||||
get_ptr<int>(sort_workspace), stream);
|
||||
sortAndScanExpert(
|
||||
get_ptr<int>(copy_topk_ids), get_ptr<int>(token_expert_indices),
|
||||
get_ptr<int>(permuted_experts_id), get_ptr<int>(sorted_row_idx),
|
||||
get_ptr<int64_t>(expert_first_token_offset), n_token, n_expert,
|
||||
n_local_expert, topk, sorter, get_ptr<int>(sort_workspace), stream);
|
||||
|
||||
// dispatch expandInputRowsKernelLauncher
|
||||
MOE_DISPATCH(input.scalar_type(), [&] {
|
||||
expandInputRowsKernelLauncher<scalar_t>(
|
||||
get_ptr<scalar_t>(input), get_ptr<scalar_t>(permuted_input),
|
||||
get_ptr<float>(topk_weights), get_ptr<int>(permuted_experts_id),
|
||||
get_ptr<int>(dst_row_id2src_row_id_map),
|
||||
get_ptr<int>(src_row_id2dst_row_id_map),
|
||||
get_ptr<int>(permuted_experts_id), get_ptr<int>(sorted_row_idx),
|
||||
get_ptr<int>(inv_permuted_idx), get_ptr<int>(permuted_idx),
|
||||
get_ptr<int64_t>(expert_first_token_offset), n_token, valid_num_ptr,
|
||||
n_hidden, topk, n_local_expert, align_block_size_value, stream);
|
||||
});
|
||||
@ -101,32 +96,34 @@ void moe_permute(
|
||||
}
|
||||
|
||||
void moe_unpermute(
|
||||
const torch::Tensor& permuted_hidden_states, // [n_token * topk, hidden]
|
||||
const torch::Tensor& topk_weights, //[n_token, topk]
|
||||
const torch::Tensor& topk_ids, // [n_token, topk]
|
||||
const torch::Tensor& src_row_id2dst_row_id_map, // [n_token, topk]
|
||||
const torch::Tensor& expert_first_token_offset, // [n_local_expert+1]
|
||||
int64_t n_expert, int64_t n_local_expert, int64_t topk,
|
||||
const torch::Tensor& permuted_hidden_states, // [n_token * topk, hidden]
|
||||
const torch::Tensor& topk_weights, // [n_token, topk]
|
||||
const torch::Tensor& inv_permuted_idx, // [n_token, topk]
|
||||
const std::optional<torch::Tensor>&
|
||||
expert_first_token_offset, // [n_local_expert+1]
|
||||
int64_t topk,
|
||||
torch::Tensor& hidden_states // [n_token, hidden]
|
||||
) {
|
||||
TORCH_CHECK(src_row_id2dst_row_id_map.sizes() == topk_ids.sizes(),
|
||||
"topk_ids shape must be same as src_row_id2dst_row_id_map");
|
||||
TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int,
|
||||
"topk_ids must be int32");
|
||||
TORCH_CHECK(
|
||||
permuted_hidden_states.scalar_type() == hidden_states.scalar_type(),
|
||||
"topk_ids dtype must be same as src_row_id2dst_row_id_map");
|
||||
"permuted_hidden_states dtype must be same as hidden_states");
|
||||
auto n_token = hidden_states.size(0);
|
||||
auto n_hidden = hidden_states.size(1);
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
const int64_t* valid_ptr =
|
||||
get_ptr<int64_t>(expert_first_token_offset) + n_local_expert;
|
||||
|
||||
int64_t const* valid_ptr = nullptr;
|
||||
if (expert_first_token_offset.has_value()) {
|
||||
int n_local_expert = expert_first_token_offset.value().size(0) - 1;
|
||||
valid_ptr =
|
||||
get_ptr<int64_t>(expert_first_token_offset.value()) + n_local_expert;
|
||||
}
|
||||
|
||||
MOE_DISPATCH(hidden_states.scalar_type(), [&] {
|
||||
finalizeMoeRoutingKernelLauncher<scalar_t, scalar_t>(
|
||||
get_ptr<scalar_t>(permuted_hidden_states),
|
||||
get_ptr<scalar_t>(hidden_states), get_ptr<float>(topk_weights),
|
||||
get_ptr<int>(src_row_id2dst_row_id_map), get_ptr<int>(topk_ids),
|
||||
n_token, n_hidden, topk, valid_ptr, stream);
|
||||
get_ptr<int>(inv_permuted_idx), n_token, n_hidden, topk, valid_ptr,
|
||||
stream);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@ -177,7 +177,7 @@ __global__ void getMIndicesKernel(int64_t* expert_first_token_offset,
|
||||
int tidx = threadIdx.x;
|
||||
extern __shared__ int64_t smem_expert_first_token_offset[];
|
||||
for (int i = tidx; i <= num_local_expert; i += blockDim.x) {
|
||||
smem_expert_first_token_offset[tidx] = __ldg(expert_first_token_offset + i);
|
||||
smem_expert_first_token_offset[i] = __ldg(expert_first_token_offset + i);
|
||||
}
|
||||
__syncthreads();
|
||||
auto last_token_offset = smem_expert_first_token_offset[eidx + 1];
|
||||
|
||||
@ -57,31 +57,19 @@ void sortAndScanExpert(int* expert_for_source_row, const int* source_rows,
|
||||
|
||||
template <typename T>
|
||||
void expandInputRowsKernelLauncher(
|
||||
T const* unpermuted_input, T* permuted_output,
|
||||
const float* unpermuted_scales, int* sorted_experts,
|
||||
T const* unpermuted_input, T* permuted_output, int* sorted_experts,
|
||||
int const* expanded_dest_row_to_expanded_source_row,
|
||||
int* expanded_source_row_to_expanded_dest_row,
|
||||
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
|
||||
int64_t* expert_first_token_offset, int64_t const num_rows,
|
||||
int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
|
||||
int num_local_experts, const int& align_block_size, cudaStream_t stream);
|
||||
|
||||
// Final kernel to unpermute and scale
|
||||
// This kernel unpermutes the original data, does the k-way reduction and
|
||||
// performs the final skip connection.
|
||||
template <typename T, typename OutputType, bool CHECK_SKIPPED>
|
||||
__global__ void finalizeMoeRoutingKernel(
|
||||
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
|
||||
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
|
||||
int const* expert_for_source_row, int64_t const orig_cols, int64_t const k,
|
||||
int64_t const* num_valid_ptr);
|
||||
|
||||
template <class T, class OutputType>
|
||||
void finalizeMoeRoutingKernelLauncher(
|
||||
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
|
||||
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
|
||||
int const* expert_for_source_row, int64_t const num_rows,
|
||||
int64_t const cols, int64_t const k, int64_t const* num_valid_ptr,
|
||||
cudaStream_t stream);
|
||||
int64_t const num_rows, int64_t const cols, int64_t const k,
|
||||
int64_t const* num_valid_ptr, cudaStream_t stream);
|
||||
|
||||
void preprocessTopkIdLauncher(int* topk_id_ptr, int size,
|
||||
const int* expert_map_ptr, int num_experts,
|
||||
|
||||
@ -2,10 +2,9 @@
|
||||
|
||||
template <typename T, bool CHECK_SKIPPED, bool ALIGN_BLOCK_SIZE>
|
||||
__global__ void expandInputRowsKernel(
|
||||
T const* unpermuted_input, T* permuted_output,
|
||||
const float* unpermuted_scales, int* sorted_experts,
|
||||
T const* unpermuted_input, T* permuted_output, int* sorted_experts,
|
||||
int const* expanded_dest_row_to_expanded_source_row,
|
||||
int* expanded_source_row_to_expanded_dest_row,
|
||||
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
|
||||
int64_t* expert_first_token_offset, int64_t const num_rows,
|
||||
int64_t const* num_dest_rows, int64_t const cols, int64_t k,
|
||||
int num_local_experts, int align_block_size) {
|
||||
@ -54,6 +53,10 @@ __global__ void expandInputRowsKernel(
|
||||
assert(expanded_dest_row <= INT32_MAX);
|
||||
expanded_source_row_to_expanded_dest_row[expanded_source_row] =
|
||||
static_cast<int>(expanded_dest_row);
|
||||
// skip non local expert token
|
||||
if (!CHECK_SKIPPED || blockIdx.x < *num_dest_rows) {
|
||||
permuted_idx[expanded_dest_row] = expanded_source_row;
|
||||
}
|
||||
}
|
||||
|
||||
if (!CHECK_SKIPPED || blockIdx.x < *num_dest_rows) {
|
||||
@ -62,7 +65,7 @@ __global__ void expandInputRowsKernel(
|
||||
using DataElem = cutlass::Array<T, ELEM_PER_THREAD>;
|
||||
|
||||
// Duplicate and permute rows
|
||||
int64_t const source_row = expanded_source_row % num_rows;
|
||||
int64_t const source_row = expanded_source_row / k;
|
||||
|
||||
auto const* source_row_ptr =
|
||||
reinterpret_cast<DataElem const*>(unpermuted_input + source_row * cols);
|
||||
@ -82,10 +85,9 @@ __global__ void expandInputRowsKernel(
|
||||
|
||||
template <typename T>
|
||||
void expandInputRowsKernelLauncher(
|
||||
T const* unpermuted_input, T* permuted_output,
|
||||
const float* unpermuted_scales, int* sorted_experts,
|
||||
T const* unpermuted_input, T* permuted_output, int* sorted_experts,
|
||||
int const* expanded_dest_row_to_expanded_source_row,
|
||||
int* expanded_source_row_to_expanded_dest_row,
|
||||
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
|
||||
int64_t* expert_first_token_offset, int64_t const num_rows,
|
||||
int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
|
||||
int num_local_experts, const int& align_block_size, cudaStream_t stream) {
|
||||
@ -105,11 +107,11 @@ void expandInputRowsKernelLauncher(
|
||||
int64_t smem_size = sizeof(int64_t) * (num_local_experts + 1);
|
||||
|
||||
func<<<blocks, threads, smem_size, stream>>>(
|
||||
unpermuted_input, permuted_output, unpermuted_scales, sorted_experts,
|
||||
unpermuted_input, permuted_output, sorted_experts,
|
||||
expanded_dest_row_to_expanded_source_row,
|
||||
expanded_source_row_to_expanded_dest_row, expert_first_token_offset,
|
||||
num_rows, num_valid_tokens_ptr, cols, k, num_local_experts,
|
||||
align_block_size);
|
||||
expanded_source_row_to_expanded_dest_row, permuted_idx,
|
||||
expert_first_token_offset, num_rows, num_valid_tokens_ptr, cols, k,
|
||||
num_local_experts, align_block_size);
|
||||
}
|
||||
|
||||
template <class T, class U>
|
||||
@ -128,11 +130,9 @@ template <typename T, typename OutputType, bool CHECK_SKIPPED>
|
||||
__global__ void finalizeMoeRoutingKernel(
|
||||
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
|
||||
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
|
||||
int const* expert_for_source_row, int64_t const orig_cols, int64_t const k,
|
||||
int64_t const* num_valid_ptr) {
|
||||
int64_t const orig_cols, int64_t const k, int64_t const* num_valid_ptr) {
|
||||
assert(orig_cols % 4 == 0);
|
||||
int64_t const original_row = blockIdx.x;
|
||||
int64_t const num_rows = gridDim.x;
|
||||
auto const offset = original_row * orig_cols;
|
||||
OutputType* reduced_row_ptr = reduced_unpermuted_output + offset;
|
||||
int64_t const num_valid = *num_valid_ptr;
|
||||
@ -159,14 +159,13 @@ __global__ void finalizeMoeRoutingKernel(
|
||||
ComputeElem thread_output;
|
||||
thread_output.fill(0);
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||
int64_t const expanded_original_row = original_row + k_idx * num_rows;
|
||||
int64_t const expanded_original_row = original_row * k + k_idx;
|
||||
int64_t const expanded_permuted_row =
|
||||
expanded_source_row_to_expanded_dest_row[expanded_original_row];
|
||||
|
||||
int64_t const k_offset = original_row * k + k_idx;
|
||||
float const row_scale = scales[k_offset];
|
||||
|
||||
// Check after row_rescale has accumulated
|
||||
if (CHECK_SKIPPED && expanded_permuted_row >= num_valid) {
|
||||
continue;
|
||||
}
|
||||
@ -189,9 +188,8 @@ template <class T, class OutputType>
|
||||
void finalizeMoeRoutingKernelLauncher(
|
||||
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
|
||||
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
|
||||
int const* expert_for_source_row, int64_t const num_rows,
|
||||
int64_t const cols, int64_t const k, int64_t const* num_valid_ptr,
|
||||
cudaStream_t stream) {
|
||||
int64_t const num_rows, int64_t const cols, int64_t const k,
|
||||
int64_t const* num_valid_ptr, cudaStream_t stream) {
|
||||
int64_t const blocks = num_rows;
|
||||
int64_t const threads = 256;
|
||||
bool const check_finished = num_valid_ptr != nullptr;
|
||||
@ -201,6 +199,5 @@ void finalizeMoeRoutingKernelLauncher(
|
||||
auto* const kernel = func_map[check_finished];
|
||||
kernel<<<blocks, threads, 0, stream>>>(
|
||||
expanded_permuted_rows, reduced_unpermuted_output, scales,
|
||||
expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k,
|
||||
num_valid_ptr);
|
||||
expanded_source_row_to_expanded_dest_row, cols, k, num_valid_ptr);
|
||||
}
|
||||
|
||||
@ -190,8 +190,8 @@ __launch_bounds__(TPB) __global__ void moeTopK(
|
||||
2) This implementation assumes k is small, but will work for any k.
|
||||
*/
|
||||
|
||||
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG, typename IndType>
|
||||
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
|
||||
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG, int WARP_SIZE_PARAM, typename IndType>
|
||||
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
|
||||
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, IndType* indices,
|
||||
int* source_rows, const int k, const int start_expert, const int end_expert)
|
||||
{
|
||||
@ -209,12 +209,12 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
|
||||
|
||||
// Restrictions based on previous section.
|
||||
static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg");
|
||||
static_assert(WARP_SIZE % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp");
|
||||
static_assert(WARP_SIZE_PARAM % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp");
|
||||
static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2");
|
||||
static_assert(THREADS_PER_ROW <= WARP_SIZE, "THREADS_PER_ROW can be at most warp size");
|
||||
static_assert(THREADS_PER_ROW <= WARP_SIZE_PARAM, "THREADS_PER_ROW can be at most warp size");
|
||||
|
||||
// We have NUM_EXPERTS elements per row. We specialize for small #experts
|
||||
static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT;
|
||||
static constexpr int ELTS_PER_WARP = WARP_SIZE_PARAM * VPT;
|
||||
static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW;
|
||||
static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP;
|
||||
|
||||
@ -393,41 +393,51 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
|
||||
namespace detail
|
||||
{
|
||||
// Constructs some constants needed to partition the work across threads at compile time.
|
||||
template <int EXPERTS, int BYTES_PER_LDG>
|
||||
template <int EXPERTS, int BYTES_PER_LDG, int WARP_SIZE_PARAM>
|
||||
struct TopkConstants
|
||||
{
|
||||
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
|
||||
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, "");
|
||||
static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
|
||||
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE_PARAM) == 0, "");
|
||||
static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM));
|
||||
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
|
||||
static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
|
||||
static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
|
||||
static const int ROWS_PER_WARP = WARP_SIZE_PARAM / THREADS_PER_ROW;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <int EXPERTS, int WARPS_PER_TB, typename IndType>
|
||||
template <int EXPERTS, int WARPS_PER_TB, int WARP_SIZE_PARAM, typename IndType>
|
||||
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, IndType* indices,
|
||||
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
|
||||
{
|
||||
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
|
||||
|
||||
static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
|
||||
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG>;
|
||||
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG, WARP_SIZE_PARAM>;
|
||||
static constexpr int VPT = Constants::VPT;
|
||||
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
|
||||
const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
|
||||
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
|
||||
|
||||
dim3 block_dim(WARP_SIZE, WARPS_PER_TB);
|
||||
topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG><<<num_blocks, block_dim, 0, stream>>>(
|
||||
dim3 block_dim(WARP_SIZE_PARAM, WARPS_PER_TB);
|
||||
topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG, WARP_SIZE_PARAM><<<num_blocks, block_dim, 0, stream>>>(
|
||||
input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert);
|
||||
}
|
||||
|
||||
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \
|
||||
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB>( \
|
||||
gating_output, nullptr, topk_weights, topk_indices, \
|
||||
token_expert_indices, num_tokens, topk, 0, num_experts, \
|
||||
stream);
|
||||
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \
|
||||
switch (warpSize) { \
|
||||
case 32: \
|
||||
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 32>( \
|
||||
gating_output, nullptr, topk_weights, topk_indices, \
|
||||
token_expert_indices, num_tokens, topk, 0, num_experts, stream); \
|
||||
break; \
|
||||
case 64: \
|
||||
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 64>( \
|
||||
gating_output, nullptr, topk_weights, topk_indices, \
|
||||
token_expert_indices, num_tokens, topk, 0, num_experts, stream); \
|
||||
break; \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported warp size: ", warpSize); \
|
||||
}
|
||||
|
||||
template <typename IndType>
|
||||
void topkGatingSoftmaxKernelLauncher(
|
||||
@ -441,6 +451,7 @@ void topkGatingSoftmaxKernelLauncher(
|
||||
const int topk,
|
||||
cudaStream_t stream) {
|
||||
static constexpr int WARPS_PER_TB = 4;
|
||||
auto warpSize = WARP_SIZE;
|
||||
switch (num_experts) {
|
||||
case 1:
|
||||
LAUNCH_SOFTMAX(1, WARPS_PER_TB);
|
||||
@ -492,7 +503,7 @@ void topk_softmax(
|
||||
torch::Tensor& gating_output) // [num_tokens, num_experts]
|
||||
{
|
||||
const int num_experts = gating_output.size(-1);
|
||||
const int num_tokens = gating_output.numel() / num_experts;
|
||||
const auto num_tokens = gating_output.numel() / num_experts;
|
||||
const int topk = topk_weights.size(-1);
|
||||
|
||||
const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
|
||||
|
||||
@ -56,18 +56,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
" -> Tensor");
|
||||
|
||||
m.def(
|
||||
"moe_permute(Tensor input, Tensor topk_weight, Tensor! topk_ids,"
|
||||
"moe_permute(Tensor input, Tensor topk_ids,"
|
||||
"Tensor token_expert_indices, Tensor? expert_map, int n_expert,"
|
||||
"int n_local_expert,"
|
||||
"int topk, int? align_block_size,Tensor! permuted_input, Tensor! "
|
||||
"expert_first_token_offset, Tensor! src_row_id2dst_row_id_map, Tensor! "
|
||||
"m_indices)->()");
|
||||
"expert_first_token_offset, Tensor! inv_permuted_idx, Tensor! "
|
||||
"permuted_idx, Tensor! m_indices)->()");
|
||||
|
||||
m.def(
|
||||
"moe_unpermute(Tensor permuted_hidden_states, Tensor topk_weights,"
|
||||
"Tensor topk_ids,Tensor src_row_id2dst_row_id_map, Tensor "
|
||||
"expert_first_token_offset, int n_expert, int n_local_expert,int "
|
||||
"topk, Tensor! hidden_states)->()");
|
||||
"Tensor inv_permuted_idx, Tensor? expert_first_token_offset, "
|
||||
"int topk, Tensor! hidden_states)->()");
|
||||
|
||||
m.def("moe_permute_unpermute_supported() -> bool");
|
||||
m.impl("moe_permute_unpermute_supported", &moe_permute_unpermute_supported);
|
||||
|
||||
26
csrc/ops.h
26
csrc/ops.h
@ -287,6 +287,16 @@ void scaled_fp4_experts_quant(
|
||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts);
|
||||
|
||||
void per_token_group_quant_fp8(const torch::Tensor& input,
|
||||
torch::Tensor& output_q, torch::Tensor& output_s,
|
||||
int64_t group_size, double eps, double fp8_min,
|
||||
double fp8_max, bool scale_ue8m0);
|
||||
|
||||
void per_token_group_quant_int8(const torch::Tensor& input,
|
||||
torch::Tensor& output_q,
|
||||
torch::Tensor& output_s, int64_t group_size,
|
||||
double eps, double int8_min, double int8_max);
|
||||
#endif
|
||||
|
||||
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
@ -326,22 +336,6 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
|
||||
const std::optional<torch::Tensor>& has_initial_state,
|
||||
const torch::Tensor& ssm_states, int64_t pad_slot_id);
|
||||
|
||||
void causal_conv1d_update(const at::Tensor& x, const at::Tensor& conv_state,
|
||||
const at::Tensor& weight,
|
||||
const std::optional<at::Tensor>& bias_,
|
||||
bool silu_activation,
|
||||
const std::optional<at::Tensor>& cache_seqlens_,
|
||||
const std::optional<at::Tensor>& conv_state_indices_,
|
||||
int64_t pad_slot_id);
|
||||
|
||||
void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
|
||||
const std::optional<at::Tensor>& bias_,
|
||||
const std::optional<at::Tensor>& conv_states,
|
||||
const std::optional<at::Tensor>& query_start_loc,
|
||||
const std::optional<at::Tensor>& cache_indices,
|
||||
const std::optional<at::Tensor>& has_initial_state,
|
||||
bool silu_activation, int64_t pad_slot_id);
|
||||
|
||||
using fptr_t = int64_t;
|
||||
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
|
||||
torch::Tensor& rank_data, int64_t rank,
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
|
||||
#include <cmath>
|
||||
#include "core/math.hpp"
|
||||
#include "cuda_compat.h"
|
||||
#include "../cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include "quantization/fp8/common.cuh"
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "../per_token_group_quant_8bit.h"
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "../../dispatch_utils.h"
|
||||
@ -162,10 +164,11 @@ __global__ void dynamic_scaled_int8_quant_kernel(
|
||||
|
||||
// calculate for absmax
|
||||
float thread_max = 0.f;
|
||||
for (int i = tid; i < hidden_size; i += stride) {
|
||||
const auto v = fabsf(static_cast<float>(row_in[i]));
|
||||
thread_max = fmaxf(thread_max, v);
|
||||
}
|
||||
vectorize_read_with_alignment<16>(
|
||||
row_in, hidden_size, tid, stride, [&] __device__(const scalar_t& src) {
|
||||
const float v = fabsf(static_cast<float>(src));
|
||||
thread_max = fmaxf(thread_max, v);
|
||||
});
|
||||
using BlockReduce = cub::BlockReduce<float, 256>;
|
||||
__shared__ typename BlockReduce::TempStorage tmp;
|
||||
float block_max = BlockReduce(tmp).Reduce(thread_max, cub::Max{}, blockDim.x);
|
||||
@ -232,9 +235,10 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
|
||||
|
||||
// 1. calculate min & max
|
||||
MinMax thread_mm;
|
||||
for (int i = tid; i < hidden_size; i += stride) {
|
||||
thread_mm += static_cast<float>(row_in[i]);
|
||||
}
|
||||
vectorize_read_with_alignment<16>(row_in, hidden_size, tid, stride,
|
||||
[&] __device__(const scalar_t& src) {
|
||||
thread_mm += static_cast<float>(src);
|
||||
});
|
||||
|
||||
using BlockReduce = cub::BlockReduce<MinMax, 256>;
|
||||
__shared__ typename BlockReduce::TempStorage tmp;
|
||||
@ -334,3 +338,11 @@ void dynamic_scaled_int8_quant(
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void per_token_group_quant_int8(const torch::Tensor& input,
|
||||
torch::Tensor& output_q,
|
||||
torch::Tensor& output_s, int64_t group_size,
|
||||
double eps, double int8_min, double int8_max) {
|
||||
per_token_group_quant_8bit(input, output_q, output_s, group_size, eps,
|
||||
int8_min, int8_max);
|
||||
}
|
||||
@ -51,7 +51,8 @@ struct cutlass_3x_gemm {
|
||||
// These are the minimum alignments needed for the kernels to compile
|
||||
static constexpr int AlignmentAB =
|
||||
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||
static constexpr int AlignmentCD = 4;
|
||||
static constexpr int AlignmentCD =
|
||||
128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
@ -144,4 +145,65 @@ struct cutlass_3x_gemm_sm100 {
|
||||
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
|
||||
};
|
||||
|
||||
template <typename ElementAB_, typename ElementD_,
|
||||
template <typename, typename, typename> typename Epilogue_,
|
||||
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
||||
typename EpilogueSchedule>
|
||||
struct cutlass_3x_gemm_sm120 {
|
||||
using ElementAB = ElementAB_;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentA =
|
||||
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
static constexpr int AlignmentB =
|
||||
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||
|
||||
using ElementC = void;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentC =
|
||||
128 / cutlass::sizeof_bits<ElementD_>::value;
|
||||
|
||||
using ElementD = ElementD_;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentD = AlignmentC;
|
||||
|
||||
using ElementAcc =
|
||||
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
||||
float>::type;
|
||||
using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>;
|
||||
|
||||
// MMA type
|
||||
using ElementAccumulator = float;
|
||||
|
||||
// Epilogue types
|
||||
using ElementBias = cutlass::half_t;
|
||||
using ElementCompute = float;
|
||||
using ElementAux = ElementD;
|
||||
using LayoutAux = LayoutD;
|
||||
using ElementAmax = float;
|
||||
|
||||
using EVTCompute = typename Epilogue::EVTCompute;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, TileShape,
|
||||
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC,
|
||||
ElementD, LayoutD, AlignmentD, EpilogueSchedule,
|
||||
EVTCompute>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, ElementAB,
|
||||
LayoutA, AlignmentA, ElementAB, LayoutB, AlignmentB,
|
||||
ElementAccumulator, TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
|
||||
};
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
@ -36,6 +36,12 @@ void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_sm120_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user