Compare commits
596 Commits
woosuk/fa3
...
copilot/fi
| Author | SHA1 | Date | |
|---|---|---|---|
| fb0089c536 | |||
| b8d7f55dbd | |||
| 2a167b2eeb | |||
| 0ff902f3b4 | |||
| a9082a4d14 | |||
| e0329ed4b4 | |||
| 6879cd80ae | |||
| e269be2ba2 | |||
| 5c4b6e66fe | |||
| d0a4a3f645 | |||
| ebafb0936d | |||
| 0cb7b065c3 | |||
| 2da02dd0d8 | |||
| d765cf01fe | |||
| 712d0f88d8 | |||
| 49ab23b3cc | |||
| c9abb10489 | |||
| 787cdb3829 | |||
| a5203d04df | |||
| 99f8094400 | |||
| 170e8ea9ea | |||
| a71e4765cc | |||
| 39971db3aa | |||
| 504d914314 | |||
| 47455c424f | |||
| c7fc6b1354 | |||
| ad78868450 | |||
| e2db1164a1 | |||
| 416f05929a | |||
| 5e021b4981 | |||
| 1b9b16649c | |||
| e76e233540 | |||
| a75277285b | |||
| 9dc30b7068 | |||
| 053278a5dc | |||
| c55c028998 | |||
| 65197a5fb3 | |||
| b8f17f5d98 | |||
| d9a55204ba | |||
| b4e9fd811f | |||
| 308fa287a8 | |||
| fa78de9dc3 | |||
| f6818a92cb | |||
| 23c939fd30 | |||
| add1adfec7 | |||
| c80c53a30f | |||
| 24d0c9e6ed | |||
| cc7ae5e7ca | |||
| 0313cf854d | |||
| 0483fabc74 | |||
| da65bec309 | |||
| 4645024d3a | |||
| cd7a3df26f | |||
| 32d2b4064f | |||
| 22cf679aad | |||
| b6d7d34fc6 | |||
| 341923b982 | |||
| 424fb7a5d2 | |||
| 88491c1b6b | |||
| 613a23b57f | |||
| 51a215300b | |||
| ebe14621e3 | |||
| 325aa3dee9 | |||
| a073be6d87 | |||
| 695e7adcd2 | |||
| 281710ef9a | |||
| 808d2e9aa0 | |||
| 285178b3b8 | |||
| 88016c372a | |||
| 998720859c | |||
| 0ba1b54ac6 | |||
| 53415653ff | |||
| 17373dcd93 | |||
| 5964069367 | |||
| de9c085e17 | |||
| 111692bb8c | |||
| 394591e343 | |||
| 3ac849665d | |||
| 0b9cc56fac | |||
| 8896eb72eb | |||
| 19fe1a0510 | |||
| 480bdf5a7b | |||
| 5368f76855 | |||
| 8ef6b8a38c | |||
| 3bbe11cc13 | |||
| c5041f899f | |||
| 8b5fe6eb51 | |||
| 800349c2a5 | |||
| 044931f97b | |||
| 1d353b6352 | |||
| 3496274663 | |||
| 8a19303173 | |||
| 603fbbbce0 | |||
| 10f535c086 | |||
| 48bfb0c9b7 | |||
| f8ce022948 | |||
| 0278f1ac3a | |||
| a482e4e769 | |||
| e0b056e443 | |||
| 79f05e4436 | |||
| f8daddcc4c | |||
| c8e33c72c6 | |||
| d70a16625d | |||
| 5cc54f7c5b | |||
| 0c6e40bbaa | |||
| 2e2000f352 | |||
| 31282401b6 | |||
| 0c31e28e95 | |||
| f571ff8eb6 | |||
| f64ee61d9e | |||
| 8993073dc1 | |||
| 655a09f653 | |||
| f94bf9b924 | |||
| 3663870c72 | |||
| 2461d9e562 | |||
| 7be5d113d8 | |||
| b029de9902 | |||
| bbea1cefdd | |||
| f5aa307d77 | |||
| 4b795020ed | |||
| c86af22f31 | |||
| 10cc12ba66 | |||
| a4fbb32fab | |||
| 1b125004be | |||
| 4fbda0b20c | |||
| 4e51fa8cba | |||
| bf7c99dfc4 | |||
| b95697d731 | |||
| 582bbe6bd7 | |||
| 0cdbf5e61c | |||
| ebe56a0064 | |||
| f77a0802b7 | |||
| c4477f55e5 | |||
| dfd2382039 | |||
| 3b11b26b50 | |||
| d6d13bd49e | |||
| 5efd6905bc | |||
| b17109beea | |||
| 4449235843 | |||
| 38217877aa | |||
| c6d80a7a96 | |||
| 7cd17e22d7 | |||
| 50df09fe13 | |||
| 68fcd3fa73 | |||
| 83e69a09d6 | |||
| 3aa8c10038 | |||
| 103f1ec8d3 | |||
| d983769c41 | |||
| 8fd920924c | |||
| de7b67a023 | |||
| f729023272 | |||
| 1a3079a15e | |||
| 941f56858a | |||
| a634733f67 | |||
| 64ab3c7253 | |||
| e58c5a9768 | |||
| d46d417b58 | |||
| 0167efe20d | |||
| c32e6ad1f6 | |||
| 1630cc8d0f | |||
| 14e2b0730b | |||
| 0f4f0191d8 | |||
| a38b8af4c3 | |||
| 21dce80ea9 | |||
| e61bac87ee | |||
| 80141bbf2f | |||
| b94faf9d50 | |||
| 5b5f350d67 | |||
| f7cf5b512e | |||
| 03d4235fd2 | |||
| d6a1a20973 | |||
| a70d0bd0a3 | |||
| 24f4d1a224 | |||
| 4f510bc2a1 | |||
| 1298c67795 | |||
| 4d9c61993a | |||
| b87cb97a53 | |||
| f856c33ce9 | |||
| 03752dba8f | |||
| 40f26734b9 | |||
| 2c3f557f08 | |||
| 21bcc8263f | |||
| 5bfe0dea7a | |||
| 31fd3265c8 | |||
| 31436e8b4f | |||
| 4efd43e9b4 | |||
| 3c8a787247 | |||
| 01a08739e0 | |||
| fda9537c5e | |||
| 90bbe0a5ad | |||
| e75f342261 | |||
| 78dba404ad | |||
| e9d6a3db69 | |||
| a4454e9401 | |||
| 14006840ea | |||
| 6603288736 | |||
| 95e3095136 | |||
| c9b38be8aa | |||
| 0dd3f4f5ab | |||
| 498259ccce | |||
| 6d25e3fd6e | |||
| ac6eb49de3 | |||
| bf756321c7 | |||
| 0e3bb543f0 | |||
| 569aefd134 | |||
| d3f71f1224 | |||
| 5a30bd10d8 | |||
| 27e8d1ea3e | |||
| 5c79b0d648 | |||
| 5f5664b3e4 | |||
| 89657a557c | |||
| 08d5f7113a | |||
| b2fd0b81e0 | |||
| 9f1c642254 | |||
| 7be3a59d8e | |||
| 8ea0c2753a | |||
| 0fc8fa751a | |||
| 21e39436c8 | |||
| 6d243efeda | |||
| c55bc1db26 | |||
| 292084e72a | |||
| 16bff144be | |||
| fe0411fc6f | |||
| 4d4061b6e7 | |||
| 87f48623a5 | |||
| 5c32143b9d | |||
| 94096a47c9 | |||
| a258ad8bcc | |||
| bf7f470b22 | |||
| 4fc722eca4 | |||
| 3253ae765e | |||
| 000cceca8c | |||
| 68373d3126 | |||
| 52ce1420e9 | |||
| 829bbd7882 | |||
| 4dff91c93d | |||
| de9cb61763 | |||
| 2dbccce8a6 | |||
| 933f45334a | |||
| cc826a202b | |||
| 6d3da472bc | |||
| 78863f8c5c | |||
| 5157827cfc | |||
| 7caec10e7b | |||
| 1f83e7d849 | |||
| e4e37ded56 | |||
| f6b5040590 | |||
| fbd88728b3 | |||
| 070da660c1 | |||
| ad0297d113 | |||
| 236b864e4f | |||
| 3e2f7985a2 | |||
| c280066f9d | |||
| b9dc9d2607 | |||
| 1fc375dc05 | |||
| 76144adf76 | |||
| f5d412bafb | |||
| 177e55e3bd | |||
| 1723ef1aae | |||
| 00d6cba0cf | |||
| 7f89ed248f | |||
| 8a87cd27d9 | |||
| a344a1a7da | |||
| 79899b63f6 | |||
| 6e670778cd | |||
| df5afa82e5 | |||
| 6cd69f51bf | |||
| 8ad7285ea2 | |||
| 48b01fd4d4 | |||
| 993d3d122b | |||
| 68af77e51c | |||
| 6b04039a72 | |||
| 1c859a1387 | |||
| 74f441f4b5 | |||
| a0632a3e03 | |||
| e8b40c7fa2 | |||
| 48f4636927 | |||
| 75531a6c13 | |||
| 22341b996e | |||
| 49252cf59e | |||
| 3e6dd40016 | |||
| aa300c438d | |||
| fe91ce9591 | |||
| 5406ebf5c9 | |||
| b2c06509e5 | |||
| b2f6c247a9 | |||
| 3d232dbd19 | |||
| 5c3fbfe46b | |||
| b4cef5e6c7 | |||
| 0fe85087a9 | |||
| d2b0e97ea6 | |||
| 590bddbfc5 | |||
| ae05a6d83d | |||
| 0933f9d518 | |||
| f1f0d2fab8 | |||
| 81f4b96481 | |||
| 39cd09dc86 | |||
| 919234fe17 | |||
| ebcce2cd36 | |||
| 4121de512e | |||
| 279a5f31b3 | |||
| b8ff05361a | |||
| 637093ae26 | |||
| 33c63e9547 | |||
| ab9f2cfd19 | |||
| dbe298046c | |||
| 625ccd1c4d | |||
| 92ff41abea | |||
| 829b9a62d0 | |||
| 540d54ca8d | |||
| 0783f13960 | |||
| 7655dc3e45 | |||
| f4efda821d | |||
| eb08487b18 | |||
| 7c3a0741c6 | |||
| 00e3f9da46 | |||
| a353bd083d | |||
| 1d20c34717 | |||
| b6af24fba7 | |||
| 0ca2393b47 | |||
| 31a500c86f | |||
| 4e8614e88b | |||
| c6cd5ca3d3 | |||
| df0e0f023e | |||
| b4b78d6317 | |||
| 12817a8ac7 | |||
| c9232d41f4 | |||
| 9bd9294f0e | |||
| da2705198f | |||
| 19b927e52d | |||
| 20d65aa755 | |||
| b159c0a67a | |||
| 6772bb0f7d | |||
| fceafaf582 | |||
| 6b794c756c | |||
| 98deac3879 | |||
| 653124bd46 | |||
| 0b1bdac6af | |||
| d94e3026de | |||
| 3f52738dce | |||
| a01e0018b5 | |||
| 9e7e5baaa8 | |||
| d16aa3dae4 | |||
| 6807af8f46 | |||
| 4c558cf62e | |||
| 77a6bf07ae | |||
| 4082338a25 | |||
| c6b928798e | |||
| b1361c7273 | |||
| 4f0f844b16 | |||
| c5830381af | |||
| d31f97cf57 | |||
| 71683ca6f6 | |||
| e18859298d | |||
| fde0b611a3 | |||
| d0a6301588 | |||
| 45c3936e94 | |||
| ba81acbdc1 | |||
| 53c730286c | |||
| 6534d2fc97 | |||
| 422f22e012 | |||
| 6bd8ebf026 | |||
| dab4f9f764 | |||
| c42fe0b63a | |||
| 5a4b4b3729 | |||
| e5d3d63c42 | |||
| 3d9d40efde | |||
| 67c153b88a | |||
| f7ad6a1eb3 | |||
| 80bb1e8afe | |||
| d030b01548 | |||
| 767e63b860 | |||
| 007dd90859 | |||
| b8a9d0e429 | |||
| 50f2aae1b4 | |||
| 46ae7f6666 | |||
| 1ece7f30ba | |||
| bc8372efc3 | |||
| 8d17fa633e | |||
| 9f909b8996 | |||
| 59f3b93636 | |||
| 78077d5417 | |||
| 6d729c43fb | |||
| 2f4657952b | |||
| 3a7e3bbdd2 | |||
| 4fbd8bb597 | |||
| ad344ef552 | |||
| bbaf9e9cb1 | |||
| 4678503476 | |||
| 93d0652433 | |||
| ea1292ad3e | |||
| dc5e4a653c | |||
| 839ab00349 | |||
| 9b94d6ec8f | |||
| 1891a265d3 | |||
| 95a935fc48 | |||
| 458e74eb90 | |||
| 65abe111a3 | |||
| 807d21b80d | |||
| c90fb03df5 | |||
| 84cf78acee | |||
| 16fb668b61 | |||
| f7dcce7a4a | |||
| 8e13d9fe6d | |||
| 3fa5b25845 | |||
| 14a5d903ab | |||
| 951b038298 | |||
| ebf7605b0d | |||
| bc1d02ac85 | |||
| 1e55dfa7e5 | |||
| 384a052971 | |||
| 39052dbca8 | |||
| 9c97a1c349 | |||
| f919d4cb8f | |||
| afa5b7ca0b | |||
| 1b99028069 | |||
| 5898b135ab | |||
| b799f4b9ea | |||
| 06da44f0cb | |||
| a554991748 | |||
| d1af8b7be9 | |||
| 68b254d673 | |||
| 8c50d62f5a | |||
| b4e2916721 | |||
| 65a7917be4 | |||
| b76753f0b5 | |||
| b81fe83b2c | |||
| 0757551c96 | |||
| 8290d15d2c | |||
| 049c245143 | |||
| 00976db0c3 | |||
| d411df0296 | |||
| 010e0e39ea | |||
| 326976291b | |||
| 7e8d685775 | |||
| c49848396d | |||
| 2a84fb422f | |||
| 534c45b962 | |||
| 3d7363e61c | |||
| 0c5254b82a | |||
| 61f67d8acd | |||
| 42172ad18f | |||
| fbd8595c5c | |||
| 5a16fa614c | |||
| 2d18256e47 | |||
| 56186474f6 | |||
| 1bf5e1f25b | |||
| a6022e6fbc | |||
| 2be07a0db1 | |||
| 0edc0cd52b | |||
| 7920e9b1c5 | |||
| b7c0942b65 | |||
| 9a0c5ded5a | |||
| 10a02535d4 | |||
| 65552b476b | |||
| 7ad7adb67f | |||
| 6ade99eafa | |||
| 3157aebb63 | |||
| 8a0ffd6285 | |||
| 23472ff51c | |||
| 08b751ba74 | |||
| 429e4e2d42 | |||
| 35afe1b30b | |||
| 81c57f60a2 | |||
| 311d875614 | |||
| e3edc0a7a8 | |||
| baece8c3d2 | |||
| 2fcf6b27b6 | |||
| 41b9655751 | |||
| bd875d2eb7 | |||
| f703b923f3 | |||
| cd9b9de1fb | |||
| fe6d8257a1 | |||
| e290594072 | |||
| f756a682d9 | |||
| f0964e29cb | |||
| e789cad6b8 | |||
| e5ebeeba53 | |||
| 7be7f3824a | |||
| ccdae737a0 | |||
| 904063907c | |||
| 43c4f3d77c | |||
| 1712543df6 | |||
| 808a7b69df | |||
| 099c046463 | |||
| af473f0a85 | |||
| 157f9c1368 | |||
| 6f287915d8 | |||
| c152e2a8a0 | |||
| 17eaaef595 | |||
| 3303f134e0 | |||
| b2c8ce57c6 | |||
| a3b9c17b56 | |||
| d57dc2364e | |||
| e2c8f1edec | |||
| 1ee5ead5f8 | |||
| acf8aeb79e | |||
| 7e3a8dc906 | |||
| 139d155781 | |||
| 8c9da6be22 | |||
| 399d2a10e2 | |||
| 4815b00f54 | |||
| 4da8bf20d0 | |||
| 7e0b121812 | |||
| 766bc8162c | |||
| 289b18e670 | |||
| 35171b1172 | |||
| a2c6696bfe | |||
| 5e8398805e | |||
| 136825de75 | |||
| c2dba2dba8 | |||
| 434d2f3f7a | |||
| 8e8e0b6af1 | |||
| 82216dc21f | |||
| 370661856b | |||
| cbc8457b26 | |||
| 4d4297e8fe | |||
| 2a4c825523 | |||
| 4be02a3776 | |||
| f6278b6243 | |||
| ad6c655dde | |||
| 14bcf93a6a | |||
| ecbea55ca2 | |||
| 609b533cb6 | |||
| 5e9455ae8f | |||
| a00d8b236f | |||
| 04cf435d95 | |||
| 7377131a2c | |||
| 6b47ef24de | |||
| 1dc8a70b6d | |||
| f825c6bd22 | |||
| 41b67f4263 | |||
| e8961e963a | |||
| 9a3835aaa9 | |||
| 5c7cc33f4d | |||
| 19c9365aa4 | |||
| eec890c1c1 | |||
| 46a13949d5 | |||
| 31f09c615f | |||
| 31f5dc5b2a | |||
| ec7cb19224 | |||
| 2435ea7ed5 | |||
| 4a6b72c2ab | |||
| b4b9813b5e | |||
| 2cb6ef8996 | |||
| 9edd1db02b | |||
| f263a4b53f | |||
| 54991c548a | |||
| 178d03fbd6 | |||
| fa00c5d75b | |||
| 134a8ee8fd | |||
| 90ec006937 | |||
| a47e6ffe93 | |||
| 98a3a81024 | |||
| de98252f49 | |||
| 796bae07c5 | |||
| 6e20924350 | |||
| dd16bdc798 | |||
| e3c876dca3 | |||
| 5d5d419ca6 | |||
| 302962e806 | |||
| 7e6544c797 | |||
| 8e6c7e873f | |||
| 6a51530437 | |||
| 35509fc5be | |||
| 4b29d2784b | |||
| 59a0b8554b | |||
| 469b3ffaaa | |||
| ae87ddd040 | |||
| a7cb6101ca | |||
| c494f96fbc | |||
| 0c275ad5ad | |||
| 74333ae2f6 | |||
| 83156c7b89 | |||
| 4771df7b2b | |||
| 05fae02175 | |||
| d1bf1b9711 | |||
| 586f286789 | |||
| 811ac13d03 | |||
| e79a12fc3a | |||
| cdfd6871a5 | |||
| 4b3e4474d7 | |||
| bd3db7f469 | |||
| 29b97c0995 | |||
| 7b455cf1c0 | |||
| 8a6e108e76 | |||
| d7b28f3415 | |||
| 6fa41e0c32 | |||
| 031ca762d7 | |||
| 6ad6b8e115 | |||
| f4f4e7ef27 | |||
| 5ea71ff46f | |||
| 7175817637 | |||
| 2dffac464c | |||
| bdcb42e45d | |||
| c09efff976 |
@ -8,7 +8,8 @@ template = """<!DOCTYPE html>
|
||||
<html>
|
||||
<body>
|
||||
<h1>Links for vLLM</h1/>
|
||||
<a href="../{wheel_html_escaped}">{wheel}</a><br/>
|
||||
<a href="../{x86_wheel_html_escaped}">{x86_wheel}</a><br/>
|
||||
<a href="../{arm_wheel_html_escaped}">{arm_wheel}</a><br/>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
@ -21,7 +22,25 @@ filename = os.path.basename(args.wheel)
|
||||
|
||||
with open("index.html", "w") as f:
|
||||
print(f"Generated index.html for {args.wheel}")
|
||||
# sync the abi tag with .buildkite/scripts/upload-wheels.sh
|
||||
if "x86_64" in filename:
|
||||
x86_wheel = filename
|
||||
arm_wheel = filename.replace("x86_64", "aarch64").replace(
|
||||
"manylinux1", "manylinux2014"
|
||||
)
|
||||
elif "aarch64" in filename:
|
||||
x86_wheel = filename.replace("aarch64", "x86_64").replace(
|
||||
"manylinux2014", "manylinux1"
|
||||
)
|
||||
arm_wheel = filename
|
||||
else:
|
||||
raise ValueError(f"Unsupported wheel: {filename}")
|
||||
# cloudfront requires escaping the '+' character
|
||||
f.write(
|
||||
template.format(wheel=filename, wheel_html_escaped=filename.replace("+", "%2B"))
|
||||
template.format(
|
||||
x86_wheel=x86_wheel,
|
||||
x86_wheel_html_escaped=x86_wheel.replace("+", "%2B"),
|
||||
arm_wheel=arm_wheel,
|
||||
arm_wheel_html_escaped=arm_wheel.replace("+", "%2B"),
|
||||
)
|
||||
)
|
||||
|
||||
@ -1,12 +0,0 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m HandH1998/QQQ-Llama-3-8b-g128 -b 32 -l 1000 -f 5 -t 1
|
||||
model_name: "HandH1998/QQQ-Llama-3-8b-g128"
|
||||
tasks:
|
||||
- name: "gsm8k"
|
||||
metrics:
|
||||
- name: "exact_match,strict-match"
|
||||
value: 0.419
|
||||
- name: "exact_match,flexible-extract"
|
||||
value: 0.416
|
||||
limit: 1000
|
||||
num_fewshot: 5
|
||||
@ -3,4 +3,3 @@ Meta-Llama-3-70B-Instruct.yaml
|
||||
Mixtral-8x7B-Instruct-v0.1.yaml
|
||||
Qwen2-57B-A14-Instruct.yaml
|
||||
DeepSeek-V2-Lite-Chat.yaml
|
||||
Meta-Llama-3-8B-QQQ.yaml
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# We can use this script to compute baseline accuracy on GSM for transformers.
|
||||
#
|
||||
# Make sure you have lm-eval-harness installed:
|
||||
# pip install lm-eval==0.4.4
|
||||
# pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api]
|
||||
|
||||
usage() {
|
||||
echo``
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
# We use this for fp8, which HF does not support.
|
||||
#
|
||||
# Make sure you have lm-eval-harness installed:
|
||||
# pip install lm-eval==0.4.4
|
||||
# pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api]
|
||||
|
||||
usage() {
|
||||
echo``
|
||||
|
||||
@ -7,7 +7,7 @@ This directory contains two sets of benchmark for vllm.
|
||||
- Performance benchmark: benchmark vllm's performance under various workload, for **developers** to gain clarity on whether their PR improves/degrades vllm's performance
|
||||
- Nightly benchmark: compare vllm's performance against alternatives (tgi, trt-llm and lmdeploy), for **the public** to know when to choose vllm.
|
||||
|
||||
See [vLLM performance dashboard](https://perf.vllm.ai) for the latest performance benchmark results and [vLLM GitHub README](https://github.com/vllm-project/vllm/blob/main/README.md) for latest nightly benchmark results.
|
||||
See [vLLM performance dashboard](https://hud.pytorch.org/benchmark/llms?repoName=vllm-project%2Fvllm) for the latest performance benchmark results and [vLLM GitHub README](https://github.com/vllm-project/vllm/blob/main/README.md) for latest nightly benchmark results.
|
||||
|
||||
## Performance benchmark quick overview
|
||||
|
||||
@ -138,28 +138,20 @@ The raw benchmarking results (in the format of json files) are in the `Artifacts
|
||||
|
||||
The `compare-json-results.py` helps to compare benchmark results JSON files converted using `convert-results-json-to-markdown.py`.
|
||||
When run, benchmark script generates results under `benchmark/results` folder, along with the `benchmark_results.md` and `benchmark_results.json`.
|
||||
`compare-json-results.py` compares two `benchmark_results.json` files and provides performance ratio e.g. for Output Tput, Median TTFT and Median TPOT.
|
||||
`compare-json-results.py` compares two `benchmark_results.json` files and provides performance ratio e.g. for Output Tput, Median TTFT and Median TPOT.
|
||||
If only one benchmark_results.json is passed, `compare-json-results.py` compares different TP and PP configurations in the benchmark_results.json instead.
|
||||
|
||||
Here is an example using the script to compare result_a and result_b without detail test name.
|
||||
`python3 compare-json-results.py -f results_a/benchmark_results.json -f results_b/benchmark_results.json --ignore_test_name`
|
||||
|
||||
| | results_a/benchmark_results.json | results_b/benchmark_results.json | perf_ratio |
|
||||
|----|----------------------------------------|----------------------------------------|----------|
|
||||
| 0 | 142.633982 | 156.526018 | 1.097396 |
|
||||
| 1 | 241.620334 | 294.018783 | 1.216863 |
|
||||
| 2 | 218.298905 | 262.664916 | 1.203235 |
|
||||
| 3 | 242.743860 | 299.816190 | 1.235113 |
|
||||
|
||||
Here is an example using the script to compare result_a and result_b with detail test name.
|
||||
Here is an example using the script to compare result_a and result_b with Model, Dataset name, input/output lenght, max concurrency and qps.
|
||||
`python3 compare-json-results.py -f results_a/benchmark_results.json -f results_b/benchmark_results.json`
|
||||
|
||||
| | results_a/benchmark_results.json_name | results_a/benchmark_results.json | results_b/benchmark_results.json_name | results_b/benchmark_results.json | perf_ratio |
|
||||
|---|---------------------------------------------|----------------------------------------|---------------------------------------------|----------------------------------------|----------|
|
||||
| 0 | serving_llama8B_tp1_sharegpt_qps_1 | 142.633982 | serving_llama8B_tp1_sharegpt_qps_1 | 156.526018 | 1.097396 |
|
||||
| 1 | serving_llama8B_tp1_sharegpt_qps_16 | 241.620334 | serving_llama8B_tp1_sharegpt_qps_16 | 294.018783 | 1.216863 |
|
||||
| 2 | serving_llama8B_tp1_sharegpt_qps_4 | 218.298905 | serving_llama8B_tp1_sharegpt_qps_4 | 262.664916 | 1.203235 |
|
||||
| 3 | serving_llama8B_tp1_sharegpt_qps_inf | 242.743860 | serving_llama8B_tp1_sharegpt_qps_inf | 299.816190 | 1.235113 |
|
||||
| 4 | serving_llama8B_tp2_random_1024_128_qps_1 | 96.613390 | serving_llama8B_tp4_random_1024_128_qps_1 | 108.404853 | 1.122048 |
|
||||
| | Model | Dataset Name | Input Len | Output Len | # of max concurrency | qps | results_a/benchmark_results.json | results_b/benchmark_results.json | perf_ratio |
|
||||
|----|---------------------------------------|--------|-----|-----|------|-----|-----------|----------|----------|
|
||||
| 0 | meta-llama/Meta-Llama-3.1-8B-Instruct | random | 128 | 128 | 1000 | 1 | 142.633982 | 156.526018 | 1.097396 |
|
||||
| 1 | meta-llama/Meta-Llama-3.1-8B-Instruct | random | 128 | 128 | 1000 | inf| 241.620334 | 294.018783 | 1.216863 |
|
||||
|
||||
A comparison diagram will be generated below the table.
|
||||
Here is an example to compare between 96c/results_gnr_96c_091_tp2pp3 and 128c/results_gnr_128c_091_tp2pp3
|
||||
<img width="1886" height="828" alt="image" src="https://github.com/user-attachments/assets/c02a43ef-25d0-4fd6-90e5-2169a28682dd" />
|
||||
|
||||
## Nightly test details
|
||||
|
||||
@ -168,9 +160,9 @@ See [nightly-descriptions.md](nightly-descriptions.md) for the detailed descript
|
||||
### Workflow
|
||||
|
||||
- The [nightly-pipeline.yaml](nightly-pipeline.yaml) specifies the docker containers for different LLM serving engines.
|
||||
- Inside each container, we run [run-nightly-suite.sh](run-nightly-suite.sh), which will probe the serving engine of the current container.
|
||||
- The `run-nightly-suite.sh` will redirect the request to `tests/run-[llm serving engine name]-nightly.sh`, which parses the workload described in [nightly-tests.json](tests/nightly-tests.json) and performs the benchmark.
|
||||
- At last, we run [scripts/plot-nightly-results.py](scripts/plot-nightly-results.py) to collect and plot the final benchmarking results, and update the results to buildkite.
|
||||
- Inside each container, we run [scripts/run-nightly-benchmarks.sh](scripts/run-nightly-benchmarks.sh), which will probe the serving engine of the current container.
|
||||
- The `scripts/run-nightly-benchmarks.sh` will parse the workload described in [nightly-tests.json](tests/nightly-tests.json) and launch the right benchmark for the specified serving engine via `scripts/launch-server.sh`.
|
||||
- At last, we run [scripts/summary-nightly-results.py](scripts/summary-nightly-results.py) to collect and plot the final benchmarking results, and update the results to buildkite.
|
||||
|
||||
### Nightly tests
|
||||
|
||||
@ -180,6 +172,6 @@ In [nightly-tests.json](tests/nightly-tests.json), we include the command line a
|
||||
|
||||
The docker containers for benchmarking are specified in `nightly-pipeline.yaml`.
|
||||
|
||||
WARNING: the docker versions are HARD-CODED and SHOULD BE ALIGNED WITH `nightly-descriptions.md`. The docker versions need to be hard-coded as there are several version-specific bug fixes inside `tests/run-[llm serving engine name]-nightly.sh`.
|
||||
WARNING: the docker versions are HARD-CODED and SHOULD BE ALIGNED WITH `nightly-descriptions.md`. The docker versions need to be hard-coded as there are several version-specific bug fixes inside `scripts/run-nightly-benchmarks.sh` and `scripts/launch-server.sh`.
|
||||
|
||||
WARNING: populating `trt-llm` to latest version is not easy, as it requires updating several protobuf files in [tensorrt-demo](https://github.com/neuralmagic/tensorrt-demo.git).
|
||||
|
||||
@ -17,7 +17,7 @@ Latest reproduction guilde: [github issue link](https://github.com/vllm-project/
|
||||
- SGLang: `lmsysorg/sglang:v0.3.2-cu121`
|
||||
- LMDeploy: `openmmlab/lmdeploy:v0.6.1-cu12`
|
||||
- TensorRT-LLM: `nvcr.io/nvidia/tritonserver:24.07-trtllm-python-py3`
|
||||
- *NOTE: we uses r24.07 as the current implementation only works for this version. We are going to bump this up.*
|
||||
- *NOTE: we use r24.07 as the current implementation only works for this version. We are going to bump this up.*
|
||||
- Check [nightly-pipeline.yaml](nightly-pipeline.yaml) for the concrete docker images, specs and commands we use for the benchmark.
|
||||
- Hardware
|
||||
- 8x Nvidia A100 GPUs
|
||||
|
||||
@ -1,33 +1,202 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from importlib import util
|
||||
|
||||
import pandas as pd
|
||||
|
||||
plotly_found = util.find_spec("plotly.express") is not None
|
||||
|
||||
|
||||
def compare_data_columns(
|
||||
files, name_column, data_column, drop_column, ignore_test_name=False
|
||||
files, name_column, data_column, info_cols, drop_column, debug=False
|
||||
):
|
||||
print("\ncompare_data_column: " + data_column)
|
||||
frames = []
|
||||
compare_frames = []
|
||||
for file in files:
|
||||
data_df = pd.read_json(file)
|
||||
serving_df = data_df.dropna(subset=[drop_column], ignore_index=True)
|
||||
if ignore_test_name is False:
|
||||
serving_df = serving_df.rename(columns={name_column: file + "_name"})
|
||||
frames.append(serving_df[file + "_name"])
|
||||
serving_df = serving_df.rename(columns={data_column: file})
|
||||
frames.append(serving_df[file])
|
||||
compare_frames.append(serving_df[file])
|
||||
if len(compare_frames) >= 2:
|
||||
# Compare numbers among two files
|
||||
ratio_df = compare_frames[1] / compare_frames[0]
|
||||
frames.append(ratio_df)
|
||||
compare_frames.pop(1)
|
||||
"""
|
||||
Align concatenation by keys derived from info_cols instead of row order.
|
||||
- Pick one canonical key list: subset of info_cols present in ALL files.
|
||||
- For each file: set index to those keys, aggregate duplicates
|
||||
- (mean for metric, first for names).
|
||||
- Concat along axis=1 (indexes align), then reset_index so callers can
|
||||
- group by columns.
|
||||
- If --debug, add a <file_label>_name column per file.
|
||||
"""
|
||||
print("\ncompare_data_column:", data_column)
|
||||
|
||||
frames = []
|
||||
raw_data_cols = []
|
||||
compare_frames = []
|
||||
|
||||
# 1) choose a canonical key list from info_cols that exists in ALL files
|
||||
cols_per_file = []
|
||||
for f in files:
|
||||
try:
|
||||
df_tmp = pd.read_json(f, orient="records")
|
||||
except Exception as err:
|
||||
raise ValueError(f"Failed to read {f}") from err
|
||||
cols_per_file.append(set(df_tmp.columns))
|
||||
|
||||
key_cols = [c for c in info_cols if all(c in cset for cset in cols_per_file)]
|
||||
if not key_cols:
|
||||
# soft fallback: use any info_cols present in the first file
|
||||
key_cols = [c for c in info_cols if c in list(cols_per_file[0])]
|
||||
if not key_cols:
|
||||
raise ValueError(
|
||||
"No common key columns found from info_cols across the input files."
|
||||
)
|
||||
|
||||
# 2) build a single "meta" block (keys as columns) once, aligned by the key index
|
||||
meta_added = False
|
||||
|
||||
for file in files:
|
||||
df = pd.read_json(file, orient="records")
|
||||
|
||||
# Keep rows that actually have the compared metric (same as original behavior)
|
||||
if drop_column in df.columns:
|
||||
df = df.dropna(subset=[drop_column], ignore_index=True)
|
||||
|
||||
# Stabilize numeric key columns (harmless if missing)
|
||||
for c in (
|
||||
"Input Len",
|
||||
"Output Len",
|
||||
"TP Size",
|
||||
"PP Size",
|
||||
"# of max concurrency.",
|
||||
"qps",
|
||||
):
|
||||
if c in df.columns:
|
||||
df[c] = pd.to_numeric(df[c], errors="coerce")
|
||||
|
||||
# Ensure all key columns exist
|
||||
for c in key_cols:
|
||||
if c not in df.columns:
|
||||
df[c] = pd.NA
|
||||
|
||||
# Set index = key_cols and aggregate duplicates → unique MultiIndex
|
||||
df_idx = df.set_index(key_cols, drop=False)
|
||||
|
||||
# meta (key columns), unique per key
|
||||
meta = df_idx[key_cols]
|
||||
if not meta.index.is_unique:
|
||||
meta = meta.groupby(level=key_cols, dropna=False).first()
|
||||
|
||||
# metric series for this file, aggregated to one row per key
|
||||
file_label = "/".join(file.split("/")[:-1]) or os.path.basename(file)
|
||||
s = df_idx[data_column]
|
||||
if not s.index.is_unique:
|
||||
s = s.groupby(level=key_cols, dropna=False).mean()
|
||||
s.name = file_label # column label like original
|
||||
|
||||
# add meta once (from first file) so keys are the leftmost columns
|
||||
if not meta_added:
|
||||
frames.append(meta)
|
||||
meta_added = True
|
||||
|
||||
# (NEW) debug: aligned test-name column per file
|
||||
if debug and name_column in df_idx.columns:
|
||||
name_s = df_idx[name_column]
|
||||
if not name_s.index.is_unique:
|
||||
name_s = name_s.groupby(level=key_cols, dropna=False).first()
|
||||
name_s.name = f"{file_label}_name"
|
||||
frames.append(name_s)
|
||||
|
||||
frames.append(s)
|
||||
raw_data_cols.append(file_label)
|
||||
compare_frames.append(s)
|
||||
|
||||
# Generalize ratio: for any file N>=2, add ratio (fileN / file1)
|
||||
if len(compare_frames) >= 2:
|
||||
base = compare_frames[0]
|
||||
current = compare_frames[-1]
|
||||
ratio = current / base
|
||||
ratio = ratio.mask(base == 0) # avoid inf when baseline is 0
|
||||
ratio.name = f"Ratio 1 vs {len(compare_frames)}"
|
||||
frames.append(ratio)
|
||||
|
||||
# 4) concat on columns with aligned MultiIndex;
|
||||
# then reset_index to return keys as columns
|
||||
concat_df = pd.concat(frames, axis=1)
|
||||
return concat_df
|
||||
concat_df = concat_df.reset_index(drop=True).reset_index()
|
||||
if "index" in concat_df.columns:
|
||||
concat_df = concat_df.drop(columns=["index"])
|
||||
|
||||
# Ensure key/info columns appear first (in your info_cols order)
|
||||
front = [c for c in info_cols if c in concat_df.columns]
|
||||
rest = [c for c in concat_df.columns if c not in front]
|
||||
concat_df = concat_df[front + rest]
|
||||
|
||||
print(raw_data_cols)
|
||||
return concat_df, raw_data_cols
|
||||
|
||||
|
||||
def split_json_by_tp_pp(
|
||||
input_file: str = "benchmark_results.json", output_root: str = "."
|
||||
) -> list[str]:
|
||||
"""
|
||||
Split a benchmark JSON into separate folders by (TP Size, PP Size).
|
||||
|
||||
Creates: <output_root>/tp{TP}_pp{PP}/benchmark_results.json
|
||||
Returns: list of file paths written.
|
||||
"""
|
||||
# Load JSON data into DataFrame
|
||||
with open(input_file, encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
# If the JSON is a dict with a list under common keys, use that list
|
||||
if isinstance(data, dict):
|
||||
for key in ("results", "serving_results", "benchmarks", "data"):
|
||||
if isinstance(data.get(key), list):
|
||||
data = data[key]
|
||||
break
|
||||
|
||||
df = pd.DataFrame(data)
|
||||
|
||||
# Keep only "serving" tests
|
||||
name_col = next(
|
||||
(c for c in ["Test name", "test_name", "Test Name"] if c in df.columns), None
|
||||
)
|
||||
if name_col:
|
||||
df = df[
|
||||
df[name_col].astype(str).str.contains(r"serving", case=False, na=False)
|
||||
].copy()
|
||||
|
||||
# Handle alias column names
|
||||
rename_map = {
|
||||
"tp_size": "TP Size",
|
||||
"tensor_parallel_size": "TP Size",
|
||||
"pp_size": "PP Size",
|
||||
"pipeline_parallel_size": "PP Size",
|
||||
}
|
||||
df.rename(
|
||||
columns={k: v for k, v in rename_map.items() if k in df.columns}, inplace=True
|
||||
)
|
||||
|
||||
# Ensure TP/PP columns exist (default to 1 if missing)
|
||||
if "TP Size" not in df.columns:
|
||||
df["TP Size"] = 1
|
||||
if "PP Size" not in df.columns:
|
||||
df["PP Size"] = 1
|
||||
|
||||
# make sure TP/PP are numeric ints with no NaN
|
||||
df["TP Size"] = (
|
||||
pd.to_numeric(df.get("TP Size", 1), errors="coerce").fillna(1).astype(int)
|
||||
)
|
||||
df["PP Size"] = (
|
||||
pd.to_numeric(df.get("PP Size", 1), errors="coerce").fillna(1).astype(int)
|
||||
)
|
||||
|
||||
# Split into separate folders
|
||||
saved_paths: list[str] = []
|
||||
for (tp, pp), group_df in df.groupby(["TP Size", "PP Size"], dropna=False):
|
||||
folder_name = os.path.join(output_root, f"tp{int(tp)}_pp{int(pp)}")
|
||||
os.makedirs(folder_name, exist_ok=True)
|
||||
filepath = os.path.join(folder_name, "benchmark_results.json")
|
||||
group_df.to_json(filepath, orient="records", indent=2, force_ascii=False)
|
||||
print(f"Saved: {filepath}")
|
||||
saved_paths.append(filepath)
|
||||
|
||||
return saved_paths
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -36,31 +205,103 @@ if __name__ == "__main__":
|
||||
"-f", "--file", action="append", type=str, help="input file name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ignore_test_name", action="store_true", help="ignore_test_name or not"
|
||||
"--debug", action="store_true", help="show all information for debugging"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--plot",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=True,
|
||||
help="plot perf diagrams or not --no-plot --plot",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-x",
|
||||
"--xaxis",
|
||||
type=str,
|
||||
default="# of max concurrency.",
|
||||
help="column name to use as X Axis in comparision graph",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
files = args.file
|
||||
print("comparing : " + ", ".join(files))
|
||||
|
||||
drop_column = "P99"
|
||||
name_column = "Test name"
|
||||
info_cols = [
|
||||
"Model",
|
||||
"Dataset Name",
|
||||
"Input Len",
|
||||
"Output Len",
|
||||
"TP Size",
|
||||
"PP Size",
|
||||
"# of max concurrency.",
|
||||
"qps",
|
||||
]
|
||||
data_cols_to_compare = ["Output Tput (tok/s)", "Median TTFT (ms)", "Median"]
|
||||
html_msgs_for_data_cols = [
|
||||
"Compare Output Tokens /n",
|
||||
"Median TTFT /n",
|
||||
"Median TPOT /n",
|
||||
]
|
||||
ignore_test_name = args.ignore_test_name
|
||||
|
||||
if len(args.file) == 1:
|
||||
files = split_json_by_tp_pp(args.file[0], output_root="splits")
|
||||
info_cols = [c for c in info_cols if c not in ("TP Size", "PP Size")]
|
||||
else:
|
||||
files = args.file
|
||||
print("comparing : " + ", ".join(files))
|
||||
debug = args.debug
|
||||
plot = args.plot
|
||||
# For Plot feature, assign y axis from one of info_cols
|
||||
y_axis_index = info_cols.index(args.xaxis) if args.xaxis in info_cols else 6
|
||||
with open("perf_comparison.html", "w") as text_file:
|
||||
for i in range(len(data_cols_to_compare)):
|
||||
output_df = compare_data_columns(
|
||||
output_df, raw_data_cols = compare_data_columns(
|
||||
files,
|
||||
name_column,
|
||||
data_cols_to_compare[i],
|
||||
info_cols,
|
||||
drop_column,
|
||||
ignore_test_name=ignore_test_name,
|
||||
debug=debug,
|
||||
)
|
||||
print(output_df)
|
||||
html = output_df.to_html()
|
||||
text_file.write(html_msgs_for_data_cols[i])
|
||||
text_file.write(html)
|
||||
|
||||
# For Plot feature, insert y axis from one of info_cols
|
||||
raw_data_cols.insert(0, info_cols[y_axis_index])
|
||||
|
||||
filtered_info_cols = info_cols[:-2]
|
||||
existing_group_cols = [
|
||||
c for c in filtered_info_cols if c in output_df.columns
|
||||
]
|
||||
if not existing_group_cols:
|
||||
raise ValueError(
|
||||
f"No valid group-by columns "
|
||||
f"Expected subset: {filtered_info_cols}, "
|
||||
f"but DataFrame has: {list(output_df.columns)}"
|
||||
)
|
||||
output_df_sorted = output_df.sort_values(by=existing_group_cols)
|
||||
output_groups = output_df_sorted.groupby(existing_group_cols, dropna=False)
|
||||
for name, group in output_groups:
|
||||
html = group.to_html()
|
||||
text_file.write(html_msgs_for_data_cols[i])
|
||||
text_file.write(html)
|
||||
|
||||
if plot and plotly_found:
|
||||
import plotly.express as px
|
||||
|
||||
df = group[raw_data_cols]
|
||||
df_sorted = df.sort_values(by=info_cols[y_axis_index])
|
||||
# Melt DataFrame for plotting
|
||||
df_melted = df_sorted.melt(
|
||||
id_vars=info_cols[y_axis_index],
|
||||
var_name="Configuration",
|
||||
value_name=data_cols_to_compare[i],
|
||||
)
|
||||
title = data_cols_to_compare[i] + " vs " + info_cols[y_axis_index]
|
||||
# Create Plotly line chart
|
||||
fig = px.line(
|
||||
df_melted,
|
||||
x=info_cols[y_axis_index],
|
||||
y=data_cols_to_compare[i],
|
||||
color="Configuration",
|
||||
title=title,
|
||||
markers=True,
|
||||
)
|
||||
# Export to HTML
|
||||
text_file.write(fig.to_html(full_html=True, include_plotlyjs="cdn"))
|
||||
|
||||
@ -1,17 +1,19 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import shlex
|
||||
from importlib import util
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pandas as pd
|
||||
import psutil
|
||||
import regex as re
|
||||
from tabulate import tabulate
|
||||
|
||||
results_folder = Path("results/")
|
||||
|
||||
# latency results and the keys that will be printed into markdown
|
||||
latency_results = []
|
||||
latency_column_mapping = {
|
||||
@ -42,14 +44,22 @@ throughput_results_column_mapping = {
|
||||
serving_results = []
|
||||
serving_column_mapping = {
|
||||
"test_name": "Test name",
|
||||
"model_id": "Model",
|
||||
"dataset_name": "Dataset Name",
|
||||
"input_len": "Input Len",
|
||||
"output_len": "Output Len",
|
||||
"tp_size": "TP Size",
|
||||
"pp_size": "PP Size",
|
||||
"dtype": "dtype",
|
||||
"gpu_type": "GPU",
|
||||
"completed": "# of req.",
|
||||
"qps": "qps",
|
||||
"max_concurrency": "# of max concurrency.",
|
||||
"request_throughput": "Tput (req/s)",
|
||||
"total_token_throughput": "Total Token Tput (tok/s)",
|
||||
"output_throughput": "Output Tput (tok/s)",
|
||||
"total_input_tokens": "Total input tokens",
|
||||
"total_output_tokens": "Total output tokens",
|
||||
# "total_input_tokens": "Total input tokens",
|
||||
# "total_output_tokens": "Total output tokens",
|
||||
"mean_ttft_ms": "Mean TTFT (ms)",
|
||||
"median_ttft_ms": "Median TTFT (ms)",
|
||||
"p99_ttft_ms": "P99 TTFT (ms)",
|
||||
@ -94,7 +104,104 @@ def get_size_with_unit(bytes, suffix="B"):
|
||||
bytes /= factor
|
||||
|
||||
|
||||
def _coerce(val: str) -> Any:
|
||||
"""Best-effort type coercion from string to Python types."""
|
||||
low = val.lower()
|
||||
if low == "null":
|
||||
return None
|
||||
if low == "true":
|
||||
return True
|
||||
if low == "false":
|
||||
return False
|
||||
# integers
|
||||
if re.fullmatch(r"[+-]?\d+", val):
|
||||
try:
|
||||
return int(val)
|
||||
except ValueError:
|
||||
pass
|
||||
# floats (keep 'inf'/'-inf'/'nan' as strings)
|
||||
if re.fullmatch(r"[+-]?\d*\.\d+", val):
|
||||
try:
|
||||
return float(val)
|
||||
except ValueError:
|
||||
pass
|
||||
return val
|
||||
|
||||
|
||||
def parse_client_command(cmd: str) -> dict[str, Any]:
|
||||
"""Parse the client_command shell string into {executable, script, args}."""
|
||||
toks = shlex.split(cmd)
|
||||
if len(toks) < 2:
|
||||
raise ValueError("client_command must include an executable and a script")
|
||||
executable, script = toks[0], toks[1]
|
||||
args: dict[str, Any] = {}
|
||||
|
||||
i = 2
|
||||
while i < len(toks):
|
||||
t = toks[i]
|
||||
if t.startswith("--"):
|
||||
# --key=value or --key (value) or boolean flag
|
||||
if "=" in t:
|
||||
key, val = t.split("=", 1)
|
||||
if key == "--metadata":
|
||||
md = {}
|
||||
if val:
|
||||
if "=" in val:
|
||||
k, v = val.split("=", 1)
|
||||
md[k] = _coerce(v)
|
||||
else:
|
||||
md[val] = True
|
||||
args[key] = md
|
||||
else:
|
||||
args[key] = _coerce(val)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
key = t
|
||||
|
||||
# Special: consume metadata k=v pairs until next --flag
|
||||
if key == "--metadata":
|
||||
i += 1
|
||||
md = {}
|
||||
while i < len(toks) and not toks[i].startswith("--"):
|
||||
pair = toks[i]
|
||||
if "=" in pair:
|
||||
k, v = pair.split("=", 1)
|
||||
md[k] = _coerce(v)
|
||||
else:
|
||||
md[pair] = True
|
||||
i += 1
|
||||
args[key] = md
|
||||
continue
|
||||
|
||||
# Standard: check if next token is a value (not a flag)
|
||||
if i + 1 < len(toks) and not toks[i + 1].startswith("--"):
|
||||
args[key] = _coerce(toks[i + 1])
|
||||
i += 2
|
||||
else:
|
||||
# lone flag -> True
|
||||
args[key] = True
|
||||
i += 1
|
||||
else:
|
||||
# unexpected positional; skip
|
||||
i += 1
|
||||
|
||||
return {"executable": executable, "script": script, "args": args}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-r",
|
||||
"--result",
|
||||
type=str,
|
||||
default="results",
|
||||
help="Folder name for benchmark output results.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
results_folder = Path(args.result)
|
||||
if not results_folder.exists():
|
||||
raise FileNotFoundError(f"results folder does not exist: {results_folder}")
|
||||
# collect results
|
||||
for test_file in results_folder.glob("*.json"):
|
||||
with open(test_file) as f:
|
||||
@ -102,7 +209,6 @@ if __name__ == "__main__":
|
||||
|
||||
if "serving" in str(test_file):
|
||||
# this result is generated via `vllm bench serve` command
|
||||
|
||||
# attach the benchmarking command to raw_result
|
||||
try:
|
||||
with open(test_file.with_suffix(".commands")) as f:
|
||||
@ -110,12 +216,44 @@ if __name__ == "__main__":
|
||||
except OSError as e:
|
||||
print(e)
|
||||
continue
|
||||
# Parse Server Command Arg
|
||||
out: dict[str, Any] = {
|
||||
"server_command": parse_client_command(command["server_command"])
|
||||
}
|
||||
parse_args = [
|
||||
"--tensor-parallel-size",
|
||||
"--pipeline-parallel-size",
|
||||
"--dtype",
|
||||
]
|
||||
col_mapping = ["tp_size", "pp_size", "dtype"]
|
||||
for index, arg in enumerate(parse_args):
|
||||
if arg in out["server_command"]["args"]:
|
||||
raw_result.update(
|
||||
{col_mapping[index]: out["server_command"]["args"][arg]}
|
||||
)
|
||||
|
||||
# Parse Client Command Arg
|
||||
out: dict[str, Any] = {
|
||||
"client_command": parse_client_command(command["client_command"])
|
||||
}
|
||||
parse_args = [
|
||||
"--dataset-name",
|
||||
"--random-input-len",
|
||||
"--random-output-len",
|
||||
"--request-rate",
|
||||
]
|
||||
col_mapping = ["dataset_name", "input_len", "output_len", "qps"]
|
||||
|
||||
for index, arg in enumerate(parse_args):
|
||||
if arg in out["client_command"]["args"]:
|
||||
raw_result.update(
|
||||
{col_mapping[index]: out["client_command"]["args"][arg]}
|
||||
)
|
||||
# Add Server, Client command
|
||||
raw_result.update(command)
|
||||
|
||||
# update the test name of this result
|
||||
raw_result.update({"test_name": test_file.stem})
|
||||
|
||||
# add the result to raw_result
|
||||
serving_results.append(raw_result)
|
||||
continue
|
||||
@ -205,7 +343,10 @@ if __name__ == "__main__":
|
||||
columns=latency_column_mapping
|
||||
)
|
||||
if not serving_results.empty:
|
||||
serving_results = serving_results[list(serving_column_mapping.keys())].rename(
|
||||
valid_columns = [
|
||||
col for col in serving_column_mapping if col in serving_results.columns
|
||||
]
|
||||
serving_results = serving_results[valid_columns].rename(
|
||||
columns=serving_column_mapping
|
||||
)
|
||||
if not throughput_results.empty:
|
||||
@ -245,7 +386,9 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
# document the result
|
||||
with open(results_folder / "benchmark_results.md", "w") as f:
|
||||
md_file = "benchmark_results.md"
|
||||
json_file = "benchmark_results.json"
|
||||
with open(results_folder / md_file, "w") as f:
|
||||
results = read_markdown(
|
||||
"../.buildkite/nightly-benchmarks/"
|
||||
+ "performance-benchmarks-descriptions.md"
|
||||
@ -260,7 +403,7 @@ if __name__ == "__main__":
|
||||
f.write(results)
|
||||
|
||||
# document benchmarking results in json
|
||||
with open(results_folder / "benchmark_results.json", "w") as f:
|
||||
with open(results_folder / json_file, "w") as f:
|
||||
results = (
|
||||
latency_results.to_dict(orient="records")
|
||||
+ throughput_results.to_dict(orient="records")
|
||||
|
||||
@ -382,7 +382,7 @@ run_genai_perf_tests() {
|
||||
client_command="genai-perf profile \
|
||||
-m $model \
|
||||
--service-kind openai \
|
||||
--backend vllm \
|
||||
--backend "$backend" \
|
||||
--endpoint-type chat \
|
||||
--streaming \
|
||||
--url localhost:$port \
|
||||
|
||||
@ -194,9 +194,11 @@ run_latency_tests() {
|
||||
|
||||
# check if there is enough GPU to run the test
|
||||
tp=$(echo "$latency_params" | jq -r '.tensor_parallel_size')
|
||||
if [ "$ON_CPU" == "1" ];then
|
||||
if [[ $numa_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $numa_count NUMA nodes found. Skip testcase $test_name."
|
||||
if [ "$ON_CPU" == "1" ]; then
|
||||
pp=$(echo "$latency_params" | jq -r '.pipeline_parallel_size')
|
||||
world_size=$(($tp*$pp))
|
||||
if [[ $numa_count -lt $world_size && -z "${REMOTE_HOST}" ]]; then
|
||||
echo "Required world-size $world_size but only $numa_count NUMA nodes found. Skip testcase $test_name."
|
||||
continue
|
||||
fi
|
||||
else
|
||||
@ -261,9 +263,11 @@ run_throughput_tests() {
|
||||
|
||||
# check if there is enough GPU to run the test
|
||||
tp=$(echo "$throughput_params" | jq -r '.tensor_parallel_size')
|
||||
if [ "$ON_CPU" == "1" ];then
|
||||
if [[ $numa_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $numa_count NUMA nodes found. Skip testcase $test_name."
|
||||
if [ "$ON_CPU" == "1" ]; then
|
||||
pp=$(echo "$throughput_params" | jq -r '.pipeline_parallel_size')
|
||||
world_size=$(($tp*$pp))
|
||||
if [[ $numa_count -lt $world_size && -z "${REMOTE_HOST}" ]]; then
|
||||
echo "Required world-size $world_size but only $numa_count NUMA nodes found. Skip testcase $test_name."
|
||||
continue
|
||||
fi
|
||||
else
|
||||
@ -329,12 +333,21 @@ run_serving_tests() {
|
||||
qps_list=$(echo "$params" | jq -r '.qps_list')
|
||||
qps_list=$(echo "$qps_list" | jq -r '.[] | @sh')
|
||||
echo "Running over qps list $qps_list"
|
||||
max_concurrency_list=$(echo "$params" | jq -r '.max_concurrency_list')
|
||||
if [[ -z "$max_concurrency_list" || "$max_concurrency_list" == "null" ]]; then
|
||||
num_prompts=$(echo "$client_params" | jq -r '.num_prompts')
|
||||
max_concurrency_list="[$num_prompts]"
|
||||
fi
|
||||
max_concurrency_list=$(echo "$max_concurrency_list" | jq -r '.[] | @sh')
|
||||
echo "Running over max concurrency list $max_concurrency_list"
|
||||
|
||||
# check if there is enough resources to run the test
|
||||
tp=$(echo "$server_params" | jq -r '.tensor_parallel_size')
|
||||
if [ "$ON_CPU" == "1" ];then
|
||||
if [[ $numa_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $numa_count NUMA nodes found. Skip testcase $test_name."
|
||||
if [ "$ON_CPU" == "1" ]; then
|
||||
pp=$(echo "$server_params" | jq -r '.pipeline_parallel_size')
|
||||
world_size=$(($tp*$pp))
|
||||
if [[ $numa_count -lt $world_size && -z "${REMOTE_HOST}" ]]; then
|
||||
echo "Required world-size $world_size but only $numa_count NUMA nodes found. Skip testcase $test_name."
|
||||
continue
|
||||
fi
|
||||
else
|
||||
@ -390,35 +403,39 @@ run_serving_tests() {
|
||||
echo "now qps is $qps"
|
||||
fi
|
||||
|
||||
new_test_name=$test_name"_qps_"$qps
|
||||
# iterate over different max_concurrency
|
||||
for max_concurrency in $max_concurrency_list; do
|
||||
new_test_name=$test_name"_qps_"$qps"_concurrency_"$max_concurrency
|
||||
echo " new test name $new_test_name"
|
||||
# pass the tensor parallel size to the client so that it can be displayed
|
||||
# on the benchmark dashboard
|
||||
client_command="vllm bench serve \
|
||||
--save-result \
|
||||
--result-dir $RESULTS_FOLDER \
|
||||
--result-filename ${new_test_name}.json \
|
||||
--request-rate $qps \
|
||||
--max-concurrency $max_concurrency \
|
||||
--metadata "tensor_parallel_size=$tp" \
|
||||
$client_args $client_remote_args "
|
||||
|
||||
# pass the tensor parallel size to the client so that it can be displayed
|
||||
# on the benchmark dashboard
|
||||
client_command="vllm bench serve \
|
||||
--save-result \
|
||||
--result-dir $RESULTS_FOLDER \
|
||||
--result-filename ${new_test_name}.json \
|
||||
--request-rate $qps \
|
||||
--metadata "tensor_parallel_size=$tp" \
|
||||
$client_args $client_remote_args "
|
||||
echo "Running test case $test_name with qps $qps"
|
||||
echo "Client command: $client_command"
|
||||
|
||||
echo "Running test case $test_name with qps $qps"
|
||||
echo "Client command: $client_command"
|
||||
bash -c "$client_command"
|
||||
|
||||
bash -c "$client_command"
|
||||
|
||||
# record the benchmarking commands
|
||||
jq_output=$(jq -n \
|
||||
--arg server "$server_command" \
|
||||
--arg client "$client_command" \
|
||||
--arg gpu "$gpu_type" \
|
||||
'{
|
||||
server_command: $server,
|
||||
client_command: $client,
|
||||
gpu_type: $gpu
|
||||
}')
|
||||
echo "$jq_output" >"$RESULTS_FOLDER/${new_test_name}.commands"
|
||||
# record the benchmarking commands
|
||||
jq_output=$(jq -n \
|
||||
--arg server "$server_command" \
|
||||
--arg client "$client_command" \
|
||||
--arg gpu "$gpu_type" \
|
||||
'{
|
||||
server_command: $server,
|
||||
client_command: $client,
|
||||
gpu_type: $gpu
|
||||
}')
|
||||
echo "$jq_output" >"$RESULTS_FOLDER/${new_test_name}.commands"
|
||||
|
||||
done
|
||||
done
|
||||
|
||||
# clean up
|
||||
|
||||
@ -12,7 +12,6 @@
|
||||
"vllm_server_parameters": {
|
||||
"disable_log_stats": "",
|
||||
"gpu_memory_utilization": 0.9,
|
||||
"num_scheduler_steps": 10,
|
||||
"max_num_seqs": 512,
|
||||
"dtype": "bfloat16"
|
||||
},
|
||||
|
||||
@ -6,7 +6,7 @@
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 1,
|
||||
"load_format": "dummy",
|
||||
"num_iters_warmup": 5,
|
||||
@ -20,7 +20,7 @@
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 4,
|
||||
"load_format": "dummy",
|
||||
"num_iters_warmup": 5,
|
||||
|
||||
@ -36,7 +36,6 @@
|
||||
"vllm_server_parameters": {
|
||||
"disable_log_stats": "",
|
||||
"gpu_memory_utilization": 0.9,
|
||||
"num_scheduler_steps": 10,
|
||||
"max_num_seqs": 512,
|
||||
"dtype": "bfloat16"
|
||||
},
|
||||
@ -90,7 +89,6 @@
|
||||
"vllm_server_parameters": {
|
||||
"disable_log_stats": "",
|
||||
"gpu_memory_utilization": 0.9,
|
||||
"num_scheduler_steps": 10,
|
||||
"max_num_seqs": 512,
|
||||
"dtype": "bfloat16"
|
||||
},
|
||||
@ -144,7 +142,6 @@
|
||||
"vllm_server_parameters": {
|
||||
"disable_log_stats": "",
|
||||
"gpu_memory_utilization": 0.9,
|
||||
"num_scheduler_steps": 10,
|
||||
"max_num_seqs": 512,
|
||||
"dtype": "bfloat16"
|
||||
},
|
||||
@ -195,7 +192,6 @@
|
||||
"vllm_server_parameters": {
|
||||
"disable_log_stats": "",
|
||||
"gpu_memory_utilization": 0.9,
|
||||
"num_scheduler_steps": 10,
|
||||
"max_num_seqs": 512,
|
||||
"dtype": "bfloat16"
|
||||
},
|
||||
@ -248,7 +244,6 @@
|
||||
"vllm_server_parameters": {
|
||||
"disable_log_stats": "",
|
||||
"gpu_memory_utilization": 0.9,
|
||||
"num_scheduler_steps": 10,
|
||||
"max_num_seqs": 512,
|
||||
"dtype": "bfloat16"
|
||||
},
|
||||
@ -301,7 +296,6 @@
|
||||
"vllm_server_parameters": {
|
||||
"disable_log_stats": "",
|
||||
"gpu_memory_utilization": 0.9,
|
||||
"num_scheduler_steps": 10,
|
||||
"max_num_seqs": 512,
|
||||
"dtype": "bfloat16"
|
||||
},
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
[
|
||||
{
|
||||
"test_name": "serving_llama8B_tp1_sharegpt",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"qps_list": ["inf"],
|
||||
"max_concurrency_list": [12, 16, 24, 32, 64, 128, 200],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
@ -10,7 +11,7 @@
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 1,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
@ -23,17 +24,17 @@
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"max_concurrency": 60,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_tp2_sharegpt",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"qps_list": ["inf"],
|
||||
"max_concurrency_list": [12, 16, 24, 32, 64, 128, 200],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
@ -42,7 +43,7 @@
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 2,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
@ -55,17 +56,17 @@
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"max_concurrency": 60,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_tp4_sharegpt",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"qps_list": ["inf"],
|
||||
"max_concurrency_list": [12, 16, 24, 32, 64, 128, 200],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
@ -74,7 +75,7 @@
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 4,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
@ -87,17 +88,17 @@
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"max_concurrency": 60,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_tp1_random_128_128",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"qps_list": ["inf"],
|
||||
"max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
@ -106,7 +107,7 @@
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 1,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
@ -120,19 +121,19 @@
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "random",
|
||||
"random-input-len": 128,
|
||||
"random-output-len": 128,
|
||||
"ignore-eos": "",
|
||||
"max_concurrency": 1000,
|
||||
"num_prompts": 1000
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_tp2_random_128_128",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"qps_list": ["inf"],
|
||||
"max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
@ -141,7 +142,7 @@
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 2,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
@ -155,19 +156,19 @@
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "random",
|
||||
"random-input-len": 128,
|
||||
"random-output-len": 128,
|
||||
"ignore-eos": "",
|
||||
"max_concurrency": 1000,
|
||||
"num_prompts": 1000
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_tp4_random_128_128",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"qps_list": ["inf"],
|
||||
"max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
@ -176,7 +177,7 @@
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 4,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
@ -190,13 +191,11 @@
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "random",
|
||||
"random-input-len": 128,
|
||||
"random-output-len": 128,
|
||||
"ignore-eos": "",
|
||||
"max_concurrency": 1000,
|
||||
"num_prompts": 1000
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
[
|
||||
{
|
||||
"test_name": "serving_llama8B_pp1_sharegpt",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"qps_list": ["inf"],
|
||||
"max_concurrency_list": [12, 16, 24, 32, 64, 128, 200],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
@ -10,7 +11,7 @@
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"pipeline_parallel_size": 1,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
@ -23,17 +24,17 @@
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"max_concurrency": 60,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_pp3_sharegpt",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"qps_list": ["inf"],
|
||||
"max_concurrency_list": [12, 16, 24, 32, 64, 128, 200],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
@ -42,7 +43,7 @@
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"pipeline_parallel_size": 3,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
@ -55,17 +56,17 @@
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"max_concurrency": 60,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_tp2pp6_sharegpt",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"test_name": "serving_llama8B_tp2pp3_sharegpt",
|
||||
"qps_list": ["inf"],
|
||||
"max_concurrency_list": [12, 16, 24, 32, 64, 128, 200],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
@ -74,7 +75,7 @@
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 2,
|
||||
"pipeline_parallel_size": 3,
|
||||
"dtype": "bfloat16",
|
||||
@ -88,17 +89,17 @@
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"max_concurrency": 60,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_pp1_random_128_128",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"qps_list": ["inf"],
|
||||
"max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
@ -107,7 +108,7 @@
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"pipeline_parallel_size": 1,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
@ -121,28 +122,28 @@
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "random",
|
||||
"random-input-len": 128,
|
||||
"random-output-len": 128,
|
||||
"ignore-eos": "",
|
||||
"max_concurrency": 1000,
|
||||
"num_prompts": 1000
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_pp3_random_128_128",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"qps_list": ["inf"],
|
||||
"max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_ENGINE_ITERATION_TIMEOUT_S": 120,
|
||||
"VLLM_CPU_SGL_KERNEL:": 1,
|
||||
"VLLM_CPU_SGL_KERNEL": 1,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"pipeline_parallel_size": 3,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
@ -156,19 +157,19 @@
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "random",
|
||||
"random-input-len": 128,
|
||||
"random-output-len": 128,
|
||||
"ignore-eos": "",
|
||||
"max_concurrency": 1000,
|
||||
"num_prompts": 1000
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_tp2pp3_random_128_128",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"qps_list": ["inf"],
|
||||
"max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
@ -177,7 +178,7 @@
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 2,
|
||||
"pipeline_parallel_size": 3,
|
||||
"dtype": "bfloat16",
|
||||
@ -192,13 +193,12 @@
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "random",
|
||||
"random-input-len": 128,
|
||||
"random-output-len": 128,
|
||||
"ignore-eos": "",
|
||||
"max_concurrency": 1000,
|
||||
"num_prompts": 1000
|
||||
}
|
||||
}
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
{
|
||||
"test_name": "serving_llama8B_tp1_sharegpt",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"max_concurrency_list": [12, 16, 24, 32, 64, 128, 200],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
@ -10,7 +11,7 @@
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 1,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
@ -23,17 +24,17 @@
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"max_concurrency": 60,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_tp2_sharegpt",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"max_concurrency_list": [12, 16, 24, 32, 64, 128, 200],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
@ -42,7 +43,7 @@
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 2,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
@ -55,17 +56,17 @@
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"max_concurrency": 60,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_tp4_sharegpt",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"max_concurrency_list": [12, 16, 24, 32, 64, 128, 200],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
@ -74,7 +75,7 @@
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 4,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
@ -87,17 +88,17 @@
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"max_concurrency": 60,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_tp4_random_1024_128",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"max_concurrency_list": [12, 16, 24, 32, 64, 128, 200],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
@ -106,7 +107,7 @@
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 4,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
@ -120,19 +121,19 @@
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "random",
|
||||
"random-input-len": 1024,
|
||||
"random-output-len": 128,
|
||||
"ignore-eos": "",
|
||||
"max_concurrency": 100,
|
||||
"num_prompts": 100
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_pp6_random_1024_128",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"max_concurrency_list": [12, 16, 24, 32, 64, 128, 200],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
@ -141,7 +142,7 @@
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"pipeline_parallel_size": 6,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
@ -155,13 +156,12 @@
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "random",
|
||||
"random-input-len": 1024,
|
||||
"random-output-len": 128,
|
||||
"ignore-eos": "",
|
||||
"max_concurrency": 100,
|
||||
"num_prompts": 100
|
||||
}
|
||||
}
|
||||
|
||||
@ -6,7 +6,7 @@
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 1,
|
||||
"load_format": "dummy",
|
||||
"dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
@ -21,7 +21,7 @@
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 4,
|
||||
"load_format": "dummy",
|
||||
"dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
|
||||
@ -1,4 +1,20 @@
|
||||
steps:
|
||||
# aarch64 + CUDA builds
|
||||
- label: "Build arm64 wheel - CUDA 12.8"
|
||||
id: build-wheel-arm64-cuda-12-8
|
||||
agents:
|
||||
queue: arm64_cpu_queue_postmerge
|
||||
commands:
|
||||
# #NOTE: torch_cuda_arch_list is derived from upstream PyTorch build files here:
|
||||
# https://github.com/pytorch/pytorch/blob/main/.ci/aarch64_linux/aarch64_ci_build.sh#L7
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg torch_cuda_arch_list='8.7 9.0 10.0+PTX' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
|
||||
- "mkdir artifacts"
|
||||
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||
- "bash .buildkite/scripts/upload-wheels.sh"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
# x86 + CUDA builds
|
||||
- label: "Build wheel - CUDA 12.8"
|
||||
id: build-wheel-cuda-12-8
|
||||
agents:
|
||||
@ -11,7 +27,12 @@ steps:
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
- block: "Build CUDA 12.6 wheel"
|
||||
key: block-build-cu126-wheel
|
||||
depends_on: ~
|
||||
|
||||
- label: "Build wheel - CUDA 12.6"
|
||||
depends_on: block-build-cu126-wheel
|
||||
id: build-wheel-cuda-12-6
|
||||
agents:
|
||||
queue: cpu_queue_postmerge
|
||||
@ -52,7 +73,7 @@ steps:
|
||||
queue: cpu_queue_postmerge
|
||||
commands:
|
||||
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ."
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg FLASHINFER_AOT_COMPILE=true --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ."
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT"
|
||||
|
||||
- label: "Annotate release workflow"
|
||||
|
||||
@ -121,7 +121,6 @@ fi
|
||||
if [[ $commands == *" kernels/quantization"* ]]; then
|
||||
commands="${commands} \
|
||||
--ignore=kernels/quantization/test_int8_quant.py \
|
||||
--ignore=kernels/quantization/test_aqlm.py \
|
||||
--ignore=kernels/quantization/test_machete_mm.py \
|
||||
--ignore=kernels/quantization/test_block_fp8.py \
|
||||
--ignore=kernels/quantization/test_block_int8.py \
|
||||
|
||||
@ -46,6 +46,11 @@ function cpu_tests() {
|
||||
set -e
|
||||
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m"
|
||||
|
||||
# Run kernel tests
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
set -e
|
||||
pytest -v -s tests/kernels/test_onednn.py"
|
||||
|
||||
# Run basic model test
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
set -e
|
||||
@ -99,4 +104,4 @@ function cpu_tests() {
|
||||
|
||||
# All of CPU tests are expected to be finished less than 40 mins.
|
||||
export -f cpu_tests
|
||||
timeout 1.5h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"
|
||||
timeout 2h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"
|
||||
|
||||
@ -4,8 +4,7 @@ set -xu
|
||||
|
||||
|
||||
remove_docker_container() {
|
||||
docker rm -f tpu-test || true;
|
||||
docker rm -f vllm-tpu || true;
|
||||
docker rm -f tpu-test || true;
|
||||
}
|
||||
|
||||
trap remove_docker_container EXIT
|
||||
@ -129,7 +128,7 @@ run_and_track_test() {
|
||||
|
||||
# --- Actual Test Execution ---
|
||||
run_and_track_test 1 "test_struct_output_generate.py" \
|
||||
"HF_HUB_DISABLE_XET=1 python3 -m pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py -k \"not test_structured_output_with_reasoning_matrices\""
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py -k \"not test_structured_output_with_reasoning_matrices\""
|
||||
run_and_track_test 2 "test_moe_pallas.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py"
|
||||
run_and_track_test 3 "test_lora.py" \
|
||||
@ -140,6 +139,8 @@ run_and_track_test 5 "test_spmd_model_weight_loading.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_spmd_model_weight_loading.py"
|
||||
run_and_track_test 6 "test_kv_cache_update_kernel.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_kv_cache_update_kernel.py"
|
||||
run_and_track_test 7 "test_tpu_int8.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_tpu_int8.py"
|
||||
|
||||
# After all tests have been attempted, exit with the overall status.
|
||||
if [ "$overall_script_exit_code" -ne 0 ]; then
|
||||
|
||||
@ -5,7 +5,6 @@ set -xu
|
||||
|
||||
remove_docker_container() {
|
||||
docker rm -f tpu-test || true;
|
||||
docker rm -f vllm-tpu || true;
|
||||
}
|
||||
|
||||
trap remove_docker_container EXIT
|
||||
@ -135,7 +134,7 @@ run_and_track_test 1 "test_compilation.py" \
|
||||
run_and_track_test 2 "test_basic.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_basic.py"
|
||||
run_and_track_test 3 "test_accuracy.py::test_lm_eval_accuracy_v1_engine" \
|
||||
"HF_HUB_DISABLE_XET=1 python3 -m pytest -s -v /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine"
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine"
|
||||
run_and_track_test 4 "test_quantization_accuracy.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py"
|
||||
run_and_track_test 5 "examples/offline_inference/tpu.py" \
|
||||
|
||||
@ -23,9 +23,13 @@ docker run \
|
||||
--device /dev/dri \
|
||||
-v /dev/dri/by-path:/dev/dri/by-path \
|
||||
--entrypoint="" \
|
||||
-e "HF_TOKEN=${HF_TOKEN}" \
|
||||
-e "ZE_AFFINITY_MASK=${ZE_AFFINITY_MASK}" \
|
||||
--name "${container_name}" \
|
||||
"${image_name}" \
|
||||
sh -c '
|
||||
bash -c '
|
||||
set -e
|
||||
echo $ZE_AFFINITY_MASK
|
||||
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
|
||||
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
|
||||
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
|
||||
@ -35,8 +39,8 @@ docker run \
|
||||
pytest -v -s v1/sample --ignore=v1/sample/test_logprobs.py --ignore=v1/sample/test_logprobs_e2e.py
|
||||
pytest -v -s v1/worker --ignore=v1/worker/test_gpu_model_runner.py
|
||||
pytest -v -s v1/structured_output
|
||||
pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_eagle.py
|
||||
pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py
|
||||
pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_eagle.py --ignore=v1/spec_decode/test_tree_attention.py
|
||||
pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py --ignore=v1/kv_connector/unit/test_shared_storage_connector.py
|
||||
pytest -v -s v1/test_serial_utils.py
|
||||
pytest -v -s v1/test_utils.py
|
||||
pytest -v -s v1/test_metrics_reader.py
|
||||
|
||||
@ -17,7 +17,7 @@ if [ "$disk_usage" -gt "$threshold" ]; then
|
||||
# Remove dangling images (those that are not tagged and not used by any container)
|
||||
docker image prune -f
|
||||
# Remove unused volumes / force the system prune for old images as well.
|
||||
docker volume prune -f && docker system prune --force --filter "until=72h" --all
|
||||
docker volume prune -f && docker system prune --force --filter "until=24h" --all
|
||||
echo "Docker images and volumes cleanup completed."
|
||||
else
|
||||
echo "Disk usage is below $threshold%. No cleanup needed."
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# Environment config
|
||||
TEST_NAME=llama8b
|
||||
CONTAINER_NAME=vllm-tpu
|
||||
CONTAINER_NAME=tpu-test
|
||||
|
||||
# vllm config
|
||||
MODEL=meta-llama/Llama-3.1-8B-Instruct
|
||||
|
||||
@ -12,8 +12,6 @@ source /etc/environment
|
||||
source $ENV_FILE
|
||||
|
||||
remove_docker_container() {
|
||||
docker rm -f tpu-test || true;
|
||||
docker rm -f vllm-tpu || true;
|
||||
docker rm -f $CONTAINER_NAME || true;
|
||||
}
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# Environment config
|
||||
TEST_NAME=llama8bw8a8
|
||||
CONTAINER_NAME=vllm-tpu
|
||||
CONTAINER_NAME=tpu-test
|
||||
|
||||
# vllm config
|
||||
MODEL=RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8
|
||||
|
||||
@ -14,8 +14,19 @@ fi
|
||||
# Get the single wheel file
|
||||
wheel="${wheel_files[0]}"
|
||||
|
||||
# Rename 'linux' to 'manylinux1' in the wheel filename
|
||||
new_wheel="${wheel/linux/manylinux1}"
|
||||
# Detect architecture and rename 'linux' to appropriate manylinux version
|
||||
arch=$(uname -m)
|
||||
if [[ $arch == "x86_64" ]]; then
|
||||
manylinux_version="manylinux1"
|
||||
elif [[ $arch == "aarch64" ]]; then
|
||||
manylinux_version="manylinux2014"
|
||||
else
|
||||
echo "Warning: Unknown architecture $arch, using manylinux1 as default"
|
||||
manylinux_version="manylinux1"
|
||||
fi
|
||||
|
||||
# Rename 'linux' to the appropriate manylinux version in the wheel filename
|
||||
new_wheel="${wheel/linux/$manylinux_version}"
|
||||
mv -- "$wheel" "$new_wheel"
|
||||
wheel="$new_wheel"
|
||||
|
||||
|
||||
@ -31,16 +31,6 @@
|
||||
steps:
|
||||
##### fast check tests #####
|
||||
|
||||
- label: Documentation Build # 2min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/test_docs"
|
||||
fast_check: true
|
||||
no_gpu: True
|
||||
commands:
|
||||
- pip install -r ../requirements/docs.txt
|
||||
# TODO: add `--strict` once warnings in docstrings are fixed
|
||||
- mkdocs build
|
||||
|
||||
- label: Pytorch Nightly Dependency Override Check # 2min
|
||||
# if this test fails, it means the nightly torch version is not compatible with some
|
||||
# of the dependencies. Please check the error message and add the package to whitelist
|
||||
@ -57,20 +47,20 @@ steps:
|
||||
- vllm/
|
||||
- tests/mq_llm_engine
|
||||
- tests/async_engine
|
||||
- tests/test_inputs
|
||||
- tests/test_inputs.py
|
||||
- tests/test_outputs.py
|
||||
- tests/multimodal
|
||||
- tests/test_utils
|
||||
- tests/utils_
|
||||
- tests/worker
|
||||
- tests/standalone_tests/lazy_imports.py
|
||||
commands:
|
||||
- python3 standalone_tests/lazy_imports.py
|
||||
- pytest -v -s mq_llm_engine # MQLLMEngine
|
||||
- pytest -v -s async_engine # AsyncLLMEngine
|
||||
- NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py
|
||||
- pytest -v -s test_inputs.py
|
||||
- pytest -v -s test_outputs.py
|
||||
- pytest -v -s multimodal
|
||||
- pytest -v -s test_utils.py # Utils
|
||||
- pytest -v -s utils_ # Utils
|
||||
- pytest -v -s worker # Worker
|
||||
|
||||
- label: Python-only Installation Test
|
||||
@ -98,15 +88,6 @@ steps:
|
||||
- pytest -v -s basic_correctness/test_cpu_offload.py
|
||||
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
|
||||
|
||||
- label: Chunked Prefill Test
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/basic_correctness/test_chunked_prefill
|
||||
commands:
|
||||
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
|
||||
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
|
||||
|
||||
- label: Core Test # 10min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
fast_check: true
|
||||
@ -145,7 +126,8 @@ steps:
|
||||
- tests/entrypoints/test_chat_utils
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/
|
||||
- PYTHONPATH=/vllm-workspace pytest -v -s entrypoints/openai/test_collective_rpc.py # PYTHONPATH is needed to import custom Worker extension
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_collective_rpc.py
|
||||
- pytest -v -s entrypoints/test_chat_utils.py
|
||||
|
||||
- label: Distributed Tests (4 GPUs) # 10min
|
||||
@ -262,7 +244,9 @@ steps:
|
||||
- pytest -v -s v1/core
|
||||
- pytest -v -s v1/engine
|
||||
- pytest -v -s v1/entrypoints
|
||||
- pytest -v -s v1/executor
|
||||
- pytest -v -s v1/sample
|
||||
- pytest -v -s v1/logits_processors
|
||||
- pytest -v -s v1/worker
|
||||
- pytest -v -s v1/structured_output
|
||||
- pytest -v -s v1/spec_decode
|
||||
@ -304,15 +288,6 @@ steps:
|
||||
- python3 offline_inference/basic/score.py
|
||||
- VLLM_USE_V1=0 python3 offline_inference/profiling.py --model facebook/opt-125m run_num_steps --num-steps 2
|
||||
|
||||
- label: Prefix Caching Test # 9min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/prefix_caching
|
||||
commands:
|
||||
- pytest -v -s prefix_caching
|
||||
|
||||
|
||||
- label: Platform Tests (CUDA)
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
@ -354,6 +329,7 @@ steps:
|
||||
- pytest -v -s compile/test_sequence_parallelism.py
|
||||
- pytest -v -s compile/test_async_tp.py
|
||||
- pytest -v -s compile/test_fusion_all_reduce.py
|
||||
- pytest -v -s compile/test_decorator.py
|
||||
|
||||
- label: PyTorch Fullgraph Smoke Test # 9min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
@ -367,6 +343,7 @@ steps:
|
||||
- pytest -v -s compile/piecewise/test_simple.py
|
||||
- pytest -v -s compile/piecewise/test_toy_llama.py
|
||||
- pytest -v -s compile/piecewise/test_full_cudagraph.py
|
||||
- pytest -v -s compile/piecewise/test_multiple_graphs.py
|
||||
|
||||
- label: PyTorch Fullgraph Test # 18min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
@ -409,6 +386,7 @@ steps:
|
||||
- label: Kernels MoE Test %N
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- csrc/quantization/cutlass_w8a8/moe/
|
||||
- csrc/moe/
|
||||
- tests/kernels/moe
|
||||
- vllm/model_executor/layers/fused_moe/
|
||||
@ -426,7 +404,6 @@ steps:
|
||||
|
||||
- label: Tensorizer Test # 11min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
soft_fail: true
|
||||
source_file_dependencies:
|
||||
- vllm/model_executor/model_loader
|
||||
- tests/tensorizer_loader
|
||||
@ -477,13 +454,11 @@ steps:
|
||||
|
||||
- label: LM Eval Small Models # 53min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
|
||||
source_file_dependencies:
|
||||
- csrc/
|
||||
- vllm/model_executor/layers/quantization
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-small.txt --tp-size=1
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1
|
||||
|
||||
- label: OpenAI API correctness
|
||||
mirror_hardwares: [amdexperimental]
|
||||
@ -535,8 +510,6 @@ steps:
|
||||
- vllm/
|
||||
- tests/models/language
|
||||
commands:
|
||||
# Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
|
||||
- pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8'
|
||||
- pip freeze | grep -E 'torch'
|
||||
- pytest -v -s models/language -m core_model
|
||||
|
||||
@ -547,8 +520,10 @@ steps:
|
||||
- vllm/
|
||||
- tests/models/language/generation
|
||||
commands:
|
||||
# Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
|
||||
- pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8'
|
||||
# Install fast path packages for testing against transformers
|
||||
# Note: also needed to run plamo2 model in vLLM
|
||||
- uv pip install --system --no-build-isolation 'git+https://github.com/state-spaces/mamba@v2.2.5'
|
||||
- uv pip install --system --no-build-isolation 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.2'
|
||||
- pytest -v -s models/language/generation -m hybrid_model
|
||||
|
||||
- label: Language Models Test (Extended Generation) # 1hr20min
|
||||
@ -571,6 +546,15 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s models/language/pooling -m 'not core_model'
|
||||
|
||||
- label: Multi-Modal Processor Test
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/models/multimodal
|
||||
commands:
|
||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||
- pytest -v -s models/multimodal/processing --ignore models/multimodal/processing/test_tensor_schema.py
|
||||
- pytest -v -s models/multimodal/processing/test_tensor_schema.py
|
||||
|
||||
- label: Multi-Modal Models Test (Standard)
|
||||
mirror_hardwares: [amdexperimental]
|
||||
torch_nightly: true
|
||||
@ -580,9 +564,7 @@ steps:
|
||||
commands:
|
||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||
- pip freeze | grep -E 'torch'
|
||||
- pytest -v -s models/multimodal/processing
|
||||
- pytest -v -s --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/test_tensor_schema.py models/multimodal -m core_model
|
||||
- pytest -v -s models/multimodal/test_tensor_schema.py -m core_model # Needs mp_method="spawn"
|
||||
- pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/processing
|
||||
- cd .. && pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work
|
||||
|
||||
- label: Multi-Modal Models Test (Extended) 1
|
||||
@ -593,7 +575,7 @@ steps:
|
||||
- tests/models/multimodal
|
||||
commands:
|
||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||
- pytest -v -s --ignore models/multimodal/generation/test_common.py --ignore models/multimodal/processing models/multimodal -m 'not core_model'
|
||||
- pytest -v -s models/multimodal -m 'not core_model' --ignore models/multimodal/generation/test_common.py --ignore models/multimodal/processing
|
||||
|
||||
- label: Multi-Modal Models Test (Extended) 2
|
||||
mirror_hardwares: [amdexperimental]
|
||||
@ -656,23 +638,29 @@ steps:
|
||||
- vllm/model_executor/layers/fused_moe/cutlass_moe.py
|
||||
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
|
||||
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
|
||||
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
|
||||
- vllm/v1/attention/backends/flashinfer.py
|
||||
- vllm/compilation/fusion.py
|
||||
- vllm/compilation/fusion_attn.py
|
||||
commands:
|
||||
- nvidia-smi
|
||||
- python3 examples/offline_inference/basic/chat.py
|
||||
# Attention
|
||||
# num_heads2 broken by https://github.com/flashinfer-ai/flashinfer/issues/1353
|
||||
- pytest -v -s tests/kernels/attention/test_flashinfer.py -k 'not num_heads2'
|
||||
- pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_decode_attention.py
|
||||
- pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_attention.py
|
||||
- pytest -v -s tests/kernels/test_cutlass_mla_decode.py
|
||||
# Quantization
|
||||
- pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8'
|
||||
- pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py
|
||||
- pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py
|
||||
- pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py
|
||||
- pytest -v -s tests/kernels/moe/test_nvfp4_moe.py
|
||||
- pytest -v -s tests/kernels/moe/test_mxfp4_moe.py
|
||||
# Fusion
|
||||
- pytest -v -s tests/compile/test_fusion_all_reduce.py
|
||||
- pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern
|
||||
- pytest -v -s tests/kernels/moe/test_flashinfer.py
|
||||
|
||||
##### 1 GPU test #####
|
||||
##### multi gpus test #####
|
||||
@ -749,7 +737,6 @@ steps:
|
||||
# this test fails consistently.
|
||||
# TODO: investigate and fix
|
||||
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
||||
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/test_disagg.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown
|
||||
- pytest -v -s models/multimodal/generation/test_maverick.py
|
||||
|
||||
@ -774,27 +761,6 @@ steps:
|
||||
- pytest -v -s models/test_oot_registration.py # it needs a clean process
|
||||
- pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins
|
||||
|
||||
- label: Multi-step Tests (4 GPUs) # 36min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 4
|
||||
source_file_dependencies:
|
||||
- vllm/model_executor/layers/sampler.py
|
||||
- vllm/sequence.py
|
||||
- vllm/worker/worker_base.py
|
||||
- vllm/worker/worker.py
|
||||
- vllm/worker/multi_step_worker.py
|
||||
- vllm/worker/model_runner_base.py
|
||||
- vllm/worker/model_runner.py
|
||||
- vllm/worker/multi_step_model_runner.py
|
||||
- vllm/engine
|
||||
- tests/multi_step
|
||||
commands:
|
||||
# this test is quite flaky
|
||||
# TODO: investigate and fix.
|
||||
# - pytest -v -s multi_step/test_correctness_async_llm.py
|
||||
- pytest -v -s multi_step/test_correctness_llm.py
|
||||
|
||||
- label: Pipeline Parallelism Test # 45min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
@ -877,3 +843,10 @@ steps:
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4
|
||||
|
||||
- label: Qwen MoE EP Test # optional
|
||||
gpu: h200
|
||||
optional: true
|
||||
num_gpus: 2
|
||||
commands:
|
||||
- CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 /vllm-workspace/examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048
|
||||
|
||||
21
.github/CODEOWNERS
vendored
21
.github/CODEOWNERS
vendored
@ -9,7 +9,8 @@
|
||||
/vllm/worker/worker_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/worker/worker.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth
|
||||
/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth @yewentao256
|
||||
/vllm/model_executor/layers/mamba @tdoublep
|
||||
/vllm/multimodal @DarkLight1337 @ywang96
|
||||
/vllm/vllm_flash_attn @LucasWilkinson
|
||||
/vllm/lora @jeejeelee
|
||||
@ -20,31 +21,31 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
||||
|
||||
# Any change to the VllmConfig changes can have a large user-facing impact,
|
||||
# so spam a lot of people
|
||||
/vllm/config.py @simon-mo @WoosukKwon @youkaichao @robertgshaw2-redhat @mgoin @tlrmchlsmth @houseroad @hmellor
|
||||
/vllm/config @simon-mo @WoosukKwon @youkaichao @robertgshaw2-redhat @mgoin @tlrmchlsmth @houseroad @hmellor @yewentao256 @ProExpertProg
|
||||
|
||||
# vLLM V1
|
||||
/vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat
|
||||
/vllm/v1/structured_output @mgoin @russellb @aarnphm
|
||||
/vllm/v1/attention/backends/triton_attn.py @tdoublep
|
||||
|
||||
# Test ownership
|
||||
/.buildkite/lm-eval-harness @mgoin @simon-mo
|
||||
/tests/async_engine @njhill @robertgshaw2-redhat @simon-mo
|
||||
/tests/basic_correctness/test_chunked_prefill @rkooo567 @comaniac
|
||||
/tests/distributed/test_multi_node_assignment.py @youkaichao
|
||||
/tests/distributed/test_pipeline_parallel.py @youkaichao
|
||||
/tests/distributed/test_same_node.py @youkaichao
|
||||
/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @simon-mo @aarnphm
|
||||
/tests/kernels @tlrmchlsmth @WoosukKwon
|
||||
/tests/kernels @tlrmchlsmth @WoosukKwon @yewentao256
|
||||
/tests/models @DarkLight1337 @ywang96
|
||||
/tests/multi_step @alexm-redhat @comaniac
|
||||
/tests/multimodal @DarkLight1337 @ywang96
|
||||
/tests/prefix_caching @comaniac @KuntaiDu
|
||||
/tests/quantization @mgoin @robertgshaw2-redhat
|
||||
/tests/quantization @mgoin @robertgshaw2-redhat @yewentao256
|
||||
/tests/test_inputs.py @DarkLight1337 @ywang96
|
||||
/tests/v1/entrypoints/llm/test_struct_output_generate.py @mgoin @russellb @aarnphm
|
||||
/tests/v1/structured_output @mgoin @russellb @aarnphm
|
||||
/tests/weight_loading @mgoin @youkaichao
|
||||
/tests/weight_loading @mgoin @youkaichao @yewentao256
|
||||
/tests/lora @jeejeelee
|
||||
/tests/models/language/generation/test_hybrid.py @tdoublep
|
||||
|
||||
# Docs
|
||||
/docs @hmellor
|
||||
@ -73,3 +74,9 @@ mkdocs.yaml @hmellor
|
||||
/vllm/model_executor/models/pixtral*.py @patrickvonplaten
|
||||
/vllm/transformers_utils/configs/mistral.py @patrickvonplaten
|
||||
/vllm/transformers_utils/tokenizers/mistral.py @patrickvonplaten
|
||||
|
||||
# Kernels
|
||||
/vllm/attention/ops/chunked_prefill_paged_decode.py @tdoublep
|
||||
/vllm/attention/ops/triton_unified_attention.py @tdoublep
|
||||
|
||||
|
||||
|
||||
20
.github/PULL_REQUEST_TEMPLATE.md
vendored
20
.github/PULL_REQUEST_TEMPLATE.md
vendored
@ -1,11 +1,5 @@
|
||||
# Essential Elements of an Effective PR Description Checklist
|
||||
|
||||
- [ ] The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
|
||||
- [ ] The test plan, such as providing test command.
|
||||
- [ ] The test results, such as pasting the results comparison before and after, or e2e results
|
||||
- [ ] (Optional) The necessary documentation update, such as updating `supported_models.md` and `examples` for a new model.
|
||||
|
||||
PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS ABOVE HAVE BEEN CONSIDERED.
|
||||
<!-- markdownlint-disable -->
|
||||
PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.
|
||||
|
||||
## Purpose
|
||||
|
||||
@ -15,4 +9,14 @@ PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS ABOVE HAVE B
|
||||
|
||||
## (Optional) Documentation Update
|
||||
|
||||
---
|
||||
<details>
|
||||
<summary> Essential Elements of an Effective PR Description Checklist </summary>
|
||||
|
||||
- [ ] The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
|
||||
- [ ] The test plan, such as providing test command.
|
||||
- [ ] The test results, such as pasting the results comparison before and after, or e2e results
|
||||
- [ ] (Optional) The necessary documentation update, such as updating `supported_models.md` and `examples` for a new model.
|
||||
</details>
|
||||
|
||||
**BEFORE SUBMITTING, PLEASE READ <https://docs.vllm.ai/en/latest/contributing>** (anything written below this line will be removed by GitHub Actions)
|
||||
|
||||
14
.github/mergify.yml
vendored
14
.github/mergify.yml
vendored
@ -118,6 +118,20 @@ pull_request_rules:
|
||||
add:
|
||||
- qwen
|
||||
|
||||
- name: label-gpt-oss
|
||||
description: Automatically apply gpt-oss label
|
||||
conditions:
|
||||
- or:
|
||||
- files~=^examples/.*gpt[-_]?oss.*\.py
|
||||
- files~=^tests/.*gpt[-_]?oss.*\.py
|
||||
- files~=^vllm/model_executor/models/.*gpt[-_]?oss.*\.py
|
||||
- files~=^vllm/model_executor/layers/.*gpt[-_]?oss.*\.py
|
||||
- title~=(?i)gpt[-_]?oss
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
- gpt-oss
|
||||
|
||||
- name: label-rocm
|
||||
description: Automatically apply rocm label
|
||||
conditions:
|
||||
|
||||
8
.github/scripts/cleanup_pr_body.sh
vendored
8
.github/scripts/cleanup_pr_body.sh
vendored
@ -15,11 +15,11 @@ NEW=/tmp/new_pr_body.txt
|
||||
gh pr view --json body --template "{{.body}}" "${PR_NUMBER}" > "${OLD}"
|
||||
cp "${OLD}" "${NEW}"
|
||||
|
||||
# Remove "FIX #xxxx (*link existing issues this PR will resolve*)"
|
||||
sed -i '/FIX #xxxx.*$/d' "${NEW}"
|
||||
# Remove markdown comments (like the <!-- markdownlint-disable --> at the start)
|
||||
sed -i '/<!--.*-->$/d' "${NEW}"
|
||||
|
||||
# Remove "FILL IN THE PR DESCRIPTION HERE"
|
||||
sed -i '/FILL IN THE PR DESCRIPTION HERE/d' "${NEW}"
|
||||
# Remove "PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED."
|
||||
sed -i '/PLEASE FILL IN THE PR DESCRIPTION HERE.*$/d' "${NEW}"
|
||||
|
||||
# Remove all lines after and including "**BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE**"
|
||||
sed -i '/\*\*BEFORE SUBMITTING, PLEASE READ.*\*\*/,$d' "${NEW}"
|
||||
|
||||
89
.github/workflows/lint-and-deploy.yaml
vendored
89
.github/workflows/lint-and-deploy.yaml
vendored
@ -1,89 +0,0 @@
|
||||
name: Lint and Deploy Charts
|
||||
|
||||
on: pull_request
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
lint-and-deploy:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@b9e51907a09c216f16ebe8536097933489208112 # v4.3.0
|
||||
with:
|
||||
version: v3.14.4
|
||||
|
||||
#Python is required because ct lint runs Yamale and yamllint which require Python.
|
||||
- uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
|
||||
with:
|
||||
python-version: '3.13'
|
||||
|
||||
- name: Set up chart-testing
|
||||
uses: helm/chart-testing-action@0d28d3144d3a25ea2cc349d6e59901c4ff469b3b # v2.7.0
|
||||
with:
|
||||
version: v3.10.1
|
||||
|
||||
- name: Run chart-testing (lint)
|
||||
run: ct lint --target-branch ${{ github.event.repository.default_branch }} --chart-dirs examples/online_serving/chart-helm --charts examples/online_serving/chart-helm
|
||||
|
||||
- name: Setup minio
|
||||
run: |
|
||||
docker network create vllm-net
|
||||
docker run -d -p 9000:9000 --name minio --net vllm-net \
|
||||
-e "MINIO_ACCESS_KEY=minioadmin" \
|
||||
-e "MINIO_SECRET_KEY=minioadmin" \
|
||||
-v /tmp/data:/data \
|
||||
-v /tmp/config:/root/.minio \
|
||||
minio/minio server /data
|
||||
export AWS_ACCESS_KEY_ID=minioadmin
|
||||
export AWS_SECRET_ACCESS_KEY=minioadmin
|
||||
export AWS_EC2_METADATA_DISABLED=true
|
||||
mkdir opt-125m
|
||||
cd opt-125m && curl -O -Ls "https://huggingface.co/facebook/opt-125m/resolve/main/{pytorch_model.bin,config.json,generation_config.json,merges.txt,special_tokens_map.json,tokenizer_config.json,vocab.json}" && cd ..
|
||||
aws --endpoint-url http://127.0.0.1:9000/ s3 mb s3://testbucket
|
||||
aws --endpoint-url http://127.0.0.1:9000/ s3 cp opt-125m/ s3://testbucket/opt-125m --recursive
|
||||
|
||||
- name: Create kind cluster
|
||||
uses: helm/kind-action@a1b0e391336a6ee6713a0583f8c6240d70863de3 # v1.12.0
|
||||
|
||||
- name: Build the Docker image vllm cpu
|
||||
run: docker buildx build -f docker/Dockerfile.cpu -t vllm-cpu-env .
|
||||
|
||||
- name: Configuration of docker images, network and namespace for the kind cluster
|
||||
run: |
|
||||
docker pull amazon/aws-cli:2.6.4
|
||||
kind load docker-image amazon/aws-cli:2.6.4 --name chart-testing
|
||||
kind load docker-image vllm-cpu-env:latest --name chart-testing
|
||||
docker network connect vllm-net "$(docker ps -aqf "name=chart-testing-control-plane")"
|
||||
kubectl create ns ns-vllm
|
||||
|
||||
- name: Run chart-testing (install)
|
||||
run: |
|
||||
export AWS_ACCESS_KEY_ID=minioadmin
|
||||
export AWS_SECRET_ACCESS_KEY=minioadmin
|
||||
sleep 30 && kubectl -n ns-vllm logs -f "$(kubectl -n ns-vllm get pods | awk '/deployment/ {print $1;exit}')" &
|
||||
helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set image.env[2].name=VLLM_CPU_CI_ENV --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string image.env[2].value="1" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env"
|
||||
|
||||
- name: curl test
|
||||
run: |
|
||||
kubectl -n ns-vllm port-forward service/test-vllm-service 8001:80 &
|
||||
sleep 10
|
||||
CODE="$(curl -v -f --location http://localhost:8001/v1/completions \
|
||||
--header "Content-Type: application/json" \
|
||||
--data '{
|
||||
"model": "opt-125m",
|
||||
"prompt": "San Francisco is a",
|
||||
"max_tokens": 7,
|
||||
"temperature": 0
|
||||
}'):$CODE"
|
||||
echo "$CODE"
|
||||
111
.github/workflows/publish.yml
vendored
111
.github/workflows/publish.yml
vendored
@ -1,111 +0,0 @@
|
||||
# This workflow will upload a Python Package to Release asset
|
||||
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions
|
||||
|
||||
name: Create Release
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- v*
|
||||
|
||||
# Needed to create release and upload assets
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
release:
|
||||
# Retrieve tag and create release
|
||||
name: Create Release
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
upload_url: ${{ steps.create_release.outputs.upload_url }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
||||
- name: Extract branch info
|
||||
shell: bash
|
||||
run: |
|
||||
echo "release_tag=${GITHUB_REF#refs/*/}" >> "$GITHUB_ENV"
|
||||
|
||||
- name: Create Release
|
||||
id: create_release
|
||||
uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1
|
||||
env:
|
||||
RELEASE_TAG: ${{ env.release_tag }}
|
||||
with:
|
||||
github-token: "${{ secrets.GITHUB_TOKEN }}"
|
||||
script: |
|
||||
const script = require('.github/workflows/scripts/create_release.js')
|
||||
await script(github, context, core)
|
||||
|
||||
# NOTE(simon): No longer build wheel using GitHub Actions. See buildkite's release workflow.
|
||||
# wheel:
|
||||
# name: Build Wheel
|
||||
# runs-on: ${{ matrix.os }}
|
||||
# needs: release
|
||||
|
||||
# strategy:
|
||||
# fail-fast: false
|
||||
# matrix:
|
||||
# os: ['ubuntu-20.04']
|
||||
# python-version: ['3.9', '3.10', '3.11', '3.12']
|
||||
# pytorch-version: ['2.4.0'] # Must be the most recent version that meets requirements/cuda.txt.
|
||||
# cuda-version: ['11.8', '12.1']
|
||||
|
||||
# steps:
|
||||
# - name: Checkout
|
||||
# uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
||||
# - name: Setup ccache
|
||||
# uses: hendrikmuhs/ccache-action@ed74d11c0b343532753ecead8a951bb09bb34bc9 # v1.2.14
|
||||
# with:
|
||||
# create-symlink: true
|
||||
# key: ${{ github.job }}-${{ matrix.python-version }}-${{ matrix.cuda-version }}
|
||||
|
||||
# - name: Set up Linux Env
|
||||
# if: ${{ runner.os == 'Linux' }}
|
||||
# run: |
|
||||
# bash -x .github/workflows/scripts/env.sh
|
||||
|
||||
# - name: Set up Python
|
||||
# uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
||||
# with:
|
||||
# python-version: ${{ matrix.python-version }}
|
||||
|
||||
# - name: Install CUDA ${{ matrix.cuda-version }}
|
||||
# run: |
|
||||
# bash -x .github/workflows/scripts/cuda-install.sh ${{ matrix.cuda-version }} ${{ matrix.os }}
|
||||
|
||||
# - name: Install PyTorch ${{ matrix.pytorch-version }} with CUDA ${{ matrix.cuda-version }}
|
||||
# run: |
|
||||
# bash -x .github/workflows/scripts/pytorch-install.sh ${{ matrix.python-version }} ${{ matrix.pytorch-version }} ${{ matrix.cuda-version }}
|
||||
|
||||
# - name: Build wheel
|
||||
# shell: bash
|
||||
# env:
|
||||
# CMAKE_BUILD_TYPE: Release # do not compile with debug symbol to reduce wheel size
|
||||
# run: |
|
||||
# bash -x .github/workflows/scripts/build.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }}
|
||||
# wheel_name=$(find dist -name "*whl" -print0 | xargs -0 -n 1 basename)
|
||||
# asset_name=${wheel_name//"linux"/"manylinux1"}
|
||||
# echo "wheel_name=${wheel_name}" >> "$GITHUB_ENV"
|
||||
# echo "asset_name=${asset_name}" >> "$GITHUB_ENV"
|
||||
|
||||
# - name: Upload Release Asset
|
||||
# uses: actions/upload-release-asset@e8f9f06c4b078e705bd2ea027f0926603fc9b4d5 # v1.0.2
|
||||
# env:
|
||||
# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
# with:
|
||||
# upload_url: ${{ needs.release.outputs.upload_url }}
|
||||
# asset_path: ./dist/${{ env.wheel_name }}
|
||||
# asset_name: ${{ env.asset_name }}
|
||||
# asset_content_type: application/*
|
||||
|
||||
# (Danielkinz): This last step will publish the .whl to pypi. Warning: untested
|
||||
# - name: Publish package
|
||||
# uses: pypa/gh-action-pypi-publish@release/v1.8
|
||||
# with:
|
||||
# repository-url: https://test.pypi.org/legacy/
|
||||
# password: ${{ secrets.PYPI_API_TOKEN }}
|
||||
# skip-existing: true
|
||||
49
.github/workflows/reminder_comment.yml
vendored
49
.github/workflows/reminder_comment.yml
vendored
@ -12,16 +12,43 @@ jobs:
|
||||
uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1
|
||||
with:
|
||||
script: |
|
||||
github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number,
|
||||
body: '👋 Hi! Thank you for contributing to the vLLM project.\n\n' +
|
||||
'💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.\n\n' +
|
||||
'Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your `fastcheck` build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping `simon-mo` or `khluu` to add you in our Buildkite org.\n\n' +
|
||||
'Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.\n\n' +
|
||||
'To run CI, PR reviewers can either: Add `ready` label to the PR or enable auto-merge.\n\n' +
|
||||
'🚀'
|
||||
})
|
||||
try {
|
||||
// Get the PR author
|
||||
const prAuthor = context.payload.pull_request.user.login;
|
||||
|
||||
// Check if this is the author's first PR in this repository
|
||||
// Use GitHub's search API to find all PRs by this author
|
||||
const { data: searchResults } = await github.rest.search.issuesAndPullRequests({
|
||||
q: `repo:${context.repo.owner}/${context.repo.repo} type:pr author:${prAuthor}`,
|
||||
per_page: 100
|
||||
});
|
||||
|
||||
const authorPRCount = searchResults.total_count;
|
||||
|
||||
console.log(`Found ${authorPRCount} PRs by ${prAuthor}`);
|
||||
|
||||
// Only post comment if this is the first PR (only one PR by this author)
|
||||
if (authorPRCount === 1) {
|
||||
console.log(`Posting welcome comment for first-time contributor: ${prAuthor}`);
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number,
|
||||
body: '👋 Hi! Thank you for contributing to the vLLM project.\n\n' +
|
||||
'💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.\n\n' +
|
||||
'Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which starts running only a small and essential subset of CI tests to quickly catch errors. \n\n' +
|
||||
'You ask your reviewers to trigger select CI tests on top of `fastcheck` CI. \n\n' +
|
||||
'Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.\n\n' +
|
||||
'To run CI, PR reviewers can either: Add `ready` label to the PR or enable auto-merge.\n\n' +
|
||||
'If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.\n\n' +
|
||||
'🚀'
|
||||
});
|
||||
} else {
|
||||
console.log(`Skipping comment for ${prAuthor} - not their first PR (${authorPRCount} PRs found)`);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error checking PR history or posting comment:', error);
|
||||
// Don't fail the workflow, just log the error
|
||||
}
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
9
.gitignore
vendored
9
.gitignore
vendored
@ -4,6 +4,9 @@
|
||||
# vllm-flash-attn built from source
|
||||
vllm/vllm_flash_attn/*
|
||||
|
||||
# triton jit
|
||||
.triton
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
@ -147,7 +150,8 @@ venv.bak/
|
||||
# mkdocs documentation
|
||||
/site
|
||||
docs/argparse
|
||||
docs/examples
|
||||
docs/examples/*
|
||||
!docs/examples/README.md
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
@ -203,3 +207,6 @@ shellcheck*/
|
||||
|
||||
# Ignore moe/marlin_moe gen code
|
||||
csrc/moe/marlin_moe_wna16/kernel_*
|
||||
|
||||
# Ignore ep_kernels_workspace folder
|
||||
ep_kernels_workspace/
|
||||
@ -30,7 +30,7 @@ install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
|
||||
# Supported python versions. These versions will be searched in order, the
|
||||
# first match will be selected. These should be kept in sync with setup.py.
|
||||
#
|
||||
set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12")
|
||||
set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12", "3.13")
|
||||
|
||||
# Supported AMD GPU architectures.
|
||||
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201")
|
||||
@ -249,7 +249,6 @@ set(VLLM_EXT_SRC
|
||||
"csrc/quantization/gguf/gguf_kernel.cu"
|
||||
"csrc/quantization/activation_kernels.cu"
|
||||
"csrc/cuda_utils_kernels.cu"
|
||||
"csrc/prepare_inputs/advance_step.cu"
|
||||
"csrc/custom_all_reduce.cu"
|
||||
"csrc/torch_bindings.cpp")
|
||||
|
||||
@ -287,7 +286,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
FetchContent_MakeAvailable(cutlass)
|
||||
|
||||
list(APPEND VLLM_EXT_SRC
|
||||
"csrc/quantization/aqlm/gemm_kernels.cu"
|
||||
"csrc/quantization/awq/gemm_kernels.cu"
|
||||
"csrc/permute_cols.cu"
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
|
||||
@ -351,20 +349,27 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
|
||||
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})
|
||||
|
||||
set(MARLIN_SRCS
|
||||
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
|
||||
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
|
||||
"csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu"
|
||||
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
||||
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
|
||||
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_SRCS}"
|
||||
CUDA_ARCHS "${MARLIN_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties("csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_EXT_SRC "${MARLIN_SRCS}")
|
||||
|
||||
message(STATUS "Building Marlin kernels for archs: ${MARLIN_ARCHS}")
|
||||
else()
|
||||
message(STATUS "Not building Marlin kernels as no compatible archs found"
|
||||
@ -427,6 +432,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
set(SRCS
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8.cu"
|
||||
)
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
@ -744,6 +750,33 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"found in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Only build W4A8 kernels if we are building for something compatible with sm90a
|
||||
cuda_archs_loose_intersection(W4A8_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${W4A8_ARCHS}")
|
||||
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
|
||||
message(STATUS "Building W4A8 kernels for archs: ${W4A8_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0
|
||||
AND W4A8_ARCHS)
|
||||
message(STATUS "Not building W4A8 kernels as CUDA Compiler version is "
|
||||
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
|
||||
"later if you intend on running w4a16 quantized models on "
|
||||
"Hopper.")
|
||||
else()
|
||||
message(STATUS "Not building W4A8 kernels as no compatible archs "
|
||||
"found in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# if CUDA endif
|
||||
endif()
|
||||
|
||||
@ -853,6 +886,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MOE_WNAA16_MARLIN_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_MOE_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MOE_WNAA16_MARLIN_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
|
||||
list(APPEND VLLM_MOE_EXT_SRC ${MOE_WNAA16_MARLIN_SRC})
|
||||
|
||||
|
||||
@ -18,14 +18,16 @@ Easy, fast, and cheap LLM serving for everyone
|
||||
|
||||
*Latest News* 🔥
|
||||
|
||||
- [2025/05] We hosted [NYC vLLM Meetup](https://lu.ma/c1rqyf1f)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing).
|
||||
- [2025/08] We hosted [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/pDmAXHcN7Iqc8sUKgJgGtg) focusing on building, developing, and integrating with vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1OvLx39wnCGy_WKq8SiVKf7YcxxYI3WCH).
|
||||
- [2025/08] We hosted [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/dgkWg1WFpWGO2jCdTqQHxA) focusing on large-scale LLM deployment! Please find the meetup slides [here](https://drive.google.com/drive/folders/1Pid6NSFLU43DZRi0EaTcPgXsAzDvbBqF) and the recording [here](https://www.chaspark.com/#/live/1166916873711665152).
|
||||
- [2025/05] vLLM is now a hosted project under PyTorch Foundation! Please find the announcement [here](https://pytorch.org/blog/pytorch-foundation-welcomes-vllm/).
|
||||
- [2025/04] We hosted [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing).
|
||||
- [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html).
|
||||
|
||||
<details>
|
||||
<summary>Previous News</summary>
|
||||
|
||||
- [2025/05] We hosted [NYC vLLM Meetup](https://lu.ma/c1rqyf1f)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing).
|
||||
- [2025/04] We hosted [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing).
|
||||
- [2025/03] We hosted [vLLM x Ollama Inference Night](https://lu.ma/vllm-ollama)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/16T2PDD1YwRnZ4Tu8Q5r6n53c5Lr5c73UV9Vd2_eBo4U/edit?usp=sharing).
|
||||
- [2025/03] We hosted [the first vLLM China Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg)! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1REHvfQMKGnvz6p3Fd23HhSO4c8j5WPGZV0bKYLwnHyQ/edit?usp=sharing).
|
||||
- [2025/03] We hosted [the East Coast vLLM Meetup](https://lu.ma/7mu4k4xx)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1NHiv8EUFF1NLd3fEYODm56nDmL26lEeXCaDgyDlTsRs/edit#slide=id.g31441846c39_0_0).
|
||||
@ -121,6 +123,7 @@ Cash Donations:
|
||||
|
||||
Compute Resources:
|
||||
|
||||
- Alibaba Cloud
|
||||
- AMD
|
||||
- Anyscale
|
||||
- AWS
|
||||
@ -160,7 +163,7 @@ If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs
|
||||
## Contact Us
|
||||
|
||||
<!-- --8<-- [start:contact-us] -->
|
||||
- For technical questions and feature requests, please use GitHub [Issues](https://github.com/vllm-project/vllm/issues) or [Discussions](https://github.com/vllm-project/vllm/discussions)
|
||||
- For technical questions and feature requests, please use GitHub [Issues](https://github.com/vllm-project/vllm/issues)
|
||||
- For discussing with fellow users, please use the [vLLM Forum](https://discuss.vllm.ai)
|
||||
- For coordinating contributions and development, please use [Slack](https://slack.vllm.ai)
|
||||
- For security disclosures, please use GitHub's [Security Advisories](https://github.com/vllm-project/vllm/security/advisories) feature
|
||||
|
||||
@ -22,6 +22,25 @@ become available.
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td><code>wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>ShareGPT4V (Image)</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td>
|
||||
<code>wget https://huggingface.co/datasets/Lin-Chen/ShareGPT4V/blob/main/sharegpt4v_instruct_gpt4-vision_cap100k.json</code>
|
||||
<br>
|
||||
<div>Note that the images need to be downloaded separately. For example, to download COCO's 2017 Train images:</div>
|
||||
<code>wget http://images.cocodataset.org/zips/train2017.zip</code>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>ShareGPT4Video (Video)</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td>
|
||||
<code>git clone https://huggingface.co/datasets/ShareGPT4Video/ShareGPT4Video</code>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>BurstGPT</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
@ -29,7 +48,7 @@ become available.
|
||||
<td><code>wget https://github.com/HPMLL/BurstGPT/releases/download/v1.1/BurstGPT_without_fails_2.csv</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>Sonnet</strong></td>
|
||||
<td><strong>Sonnet (deprecated)</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td>Local file: <code>benchmarks/sonnet.txt</code></td>
|
||||
@ -40,6 +59,18 @@ become available.
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td><code>synthetic</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>RandomMultiModal (Image/Video)</strong></td>
|
||||
<td style="text-align: center;">🟡</td>
|
||||
<td style="text-align: center;">🚧</td>
|
||||
<td><code>synthetic</code> </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>Prefix Repetition</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td><code>synthetic</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>HuggingFace-VisionArena</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
@ -177,6 +208,7 @@ vllm serve Qwen/Qwen2-VL-7B-Instruct
|
||||
```bash
|
||||
vllm bench serve \
|
||||
--backend openai-chat \
|
||||
--endpoint-type openai-chat \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--endpoint /v1/chat/completions \
|
||||
--dataset-name hf \
|
||||
@ -213,6 +245,7 @@ vllm serve Qwen/Qwen2-VL-7B-Instruct
|
||||
```bash
|
||||
vllm bench serve \
|
||||
--backend openai-chat \
|
||||
--endpoint-type openai-chat \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--endpoint /v1/chat/completions \
|
||||
--dataset-name hf \
|
||||
@ -227,6 +260,7 @@ vllm bench serve \
|
||||
```bash
|
||||
vllm bench serve \
|
||||
--backend openai-chat \
|
||||
--endpoint-type openai-chat \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--endpoint /v1/chat/completions \
|
||||
--dataset-name hf \
|
||||
@ -581,6 +615,20 @@ python3 benchmarks/benchmark_prefix_caching.py \
|
||||
--input-length-range 128:256
|
||||
```
|
||||
|
||||
### Prefix Repetition Dataset
|
||||
|
||||
```bash
|
||||
vllm bench serve \
|
||||
--backend openai \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--dataset-name prefix_repetition \
|
||||
--num-prompts 100 \
|
||||
--prefix-repetition-prefix-len 512 \
|
||||
--prefix-repetition-suffix-len 128 \
|
||||
--prefix-repetition-num-prefixes 5 \
|
||||
--prefix-repetition-output-len 128
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## ⚡ Example - Request Prioritization Benchmark
|
||||
@ -616,3 +664,139 @@ python3 benchmarks/benchmark_prioritization.py \
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## 👁️ Example - Multi-Modal Benchmark
|
||||
|
||||
<details>
|
||||
<summary>Show more</summary>
|
||||
|
||||
<br/>
|
||||
|
||||
Benchmark the performance of multi-modal requests in vLLM.
|
||||
|
||||
### Images (ShareGPT4V)
|
||||
|
||||
Start vLLM:
|
||||
|
||||
```bash
|
||||
python -m vllm.entrypoints.openai.api_server \
|
||||
--model Qwen/Qwen2.5-VL-7B-Instruct \
|
||||
--dtype bfloat16 \
|
||||
--limit-mm-per-prompt '{"image": 1}' \
|
||||
--allowed-local-media-path /path/to/sharegpt4v/images
|
||||
```
|
||||
|
||||
Send requests with images:
|
||||
|
||||
```bash
|
||||
python benchmarks/benchmark_serving.py \
|
||||
--backend openai-chat \
|
||||
--model Qwen/Qwen2.5-VL-7B-Instruct \
|
||||
--dataset-name sharegpt \
|
||||
--dataset-path /path/to/ShareGPT4V/sharegpt4v_instruct_gpt4-vision_cap100k.json \
|
||||
--num-prompts 100 \
|
||||
--save-result \
|
||||
--result-dir ~/vllm_benchmark_results \
|
||||
--save-detailed \
|
||||
--endpoint /v1/chat/completion
|
||||
```
|
||||
|
||||
### Videos (ShareGPT4Video)
|
||||
|
||||
Start vLLM:
|
||||
|
||||
```bash
|
||||
python -m vllm.entrypoints.openai.api_server \
|
||||
--model Qwen/Qwen2.5-VL-7B-Instruct \
|
||||
--dtype bfloat16 \
|
||||
--limit-mm-per-prompt '{"video": 1}' \
|
||||
--allowed-local-media-path /path/to/sharegpt4video/videos
|
||||
```
|
||||
|
||||
Send requests with videos:
|
||||
|
||||
```bash
|
||||
python benchmarks/benchmark_serving.py \
|
||||
--backend openai-chat \
|
||||
--model Qwen/Qwen2.5-VL-7B-Instruct \
|
||||
--dataset-name sharegpt \
|
||||
--dataset-path /path/to/ShareGPT4Video/llava_v1_5_mix665k_with_video_chatgpt72k_share4video28k.json \
|
||||
--num-prompts 100 \
|
||||
--save-result \
|
||||
--result-dir ~/vllm_benchmark_results \
|
||||
--save-detailed \
|
||||
--endpoint /v1/chat/completion
|
||||
```
|
||||
|
||||
### Synthetic Random Images (random-mm)
|
||||
|
||||
Generate synthetic image inputs alongside random text prompts to stress-test vision models without external datasets.
|
||||
|
||||
Notes:
|
||||
|
||||
- Works only with online benchmark via the OpenAI backend (`--backend openai-chat`) and endpoint `/v1/chat/completions`.
|
||||
- Video sampling is not yet implemented.
|
||||
|
||||
Start the server (example):
|
||||
|
||||
```bash
|
||||
vllm serve Qwen/Qwen2.5-VL-3B-Instruct \
|
||||
--dtype bfloat16 \
|
||||
--max-model-len 16384 \
|
||||
--limit-mm-per-prompt '{"image": 3, "video": 0}' \
|
||||
--mm-processor-kwargs max_pixels=1003520
|
||||
```
|
||||
|
||||
Benchmark. It is recommended to use the flag `--ignore-eos` to simulate real responses. You can set the size of the output via the arg `random-output-len`.
|
||||
|
||||
Ex.1: Fixed number of items and a single image resolutionm, enforcing generation of approx 40 tokens:
|
||||
|
||||
```bash
|
||||
vllm bench serve \
|
||||
--backend openai-chat \
|
||||
--model Qwen/Qwen2.5-VL-3B-Instruct \
|
||||
--endpoint /v1/chat/completions \
|
||||
--dataset-name random-mm \
|
||||
--num-prompts 100 \
|
||||
--max-concurrency 10 \
|
||||
--random-prefix-len 25 \
|
||||
--random-input-len 300 \
|
||||
--random-output-len 40 \
|
||||
--random-range-ratio 0.2 \
|
||||
--random-mm-base-items-per-request 2 \
|
||||
--random-mm-limit-mm-per-prompt '{"image": 3, "video": 0}' \
|
||||
--random-mm-bucket-config '{(224, 224, 1): 1.0}' \
|
||||
--request-rate inf \
|
||||
--ignore-eos \
|
||||
--seed 42
|
||||
```
|
||||
|
||||
The number of items per request can be controlled by passing multiple image buckets:
|
||||
|
||||
```bash
|
||||
--random-mm-base-items-per-request 2 \
|
||||
--random-mm-num-mm-items-range-ratio 0.5 \
|
||||
--random-mm-limit-mm-per-prompt '{"image": 4, "video": 0}' \
|
||||
--random-mm-bucket-config '{(256, 256, 1): 0.7, (720, 1280, 1): 0.3}' \
|
||||
```
|
||||
|
||||
Flags specific to `random-mm`:
|
||||
|
||||
- `--random-mm-base-items-per-request`: base number of multimodal items per request.
|
||||
- `--random-mm-num-mm-items-range-ratio`: vary item count uniformly in the closed integer range [floor(n·(1−r)), ceil(n·(1+r))]. Set r=0 to keep it fixed; r=1 allows 0 items.
|
||||
- `--random-mm-limit-mm-per-prompt`: per-modality hard caps, e.g. '{"image": 3, "video": 0}'.
|
||||
- `--random-mm-bucket-config`: dict mapping (H, W, T) → probability. Entries with probability 0 are removed; remaining probabilities are renormalized to sum to 1. Use T=1 for images. Set any T>1 for videos (video sampling not yet supported).
|
||||
|
||||
Behavioral notes:
|
||||
|
||||
- If the requested base item count cannot be satisfied under the provided per-prompt limits, the tool raises an error rather than silently clamping.
|
||||
|
||||
How sampling works:
|
||||
|
||||
- Determine per-request item count k by sampling uniformly from the integer range defined by `--random-mm-base-items-per-request` and `--random-mm-num-mm-items-range-ratio`, then clamp k to at most the sum of per-modality limits.
|
||||
- For each of the k items, sample a bucket (H, W, T) according to the normalized probabilities in `--random-mm-bucket-config`, while tracking how many items of each modality have been added.
|
||||
- If a modality (e.g., image) reaches its limit from `--random-mm-limit-mm-per-prompt`, all buckets of that modality are excluded and the remaining bucket probabilities are renormalized before continuing.
|
||||
This should be seen as an edge case, and if this behavior can be avoided by setting `--random-mm-limit-mm-per-prompt` to a large number. Note that this might result in errors due to engine config `--limit-mm-per-prompt`.
|
||||
- The resulting request contains synthetic image data in `multi_modal_data` (OpenAI Chat format). When `random-mm` is used with the OpenAI Chat backend, prompts remain text and MM content is attached via `multi_modal_data`.
|
||||
|
||||
</details>
|
||||
|
||||
@ -31,9 +31,10 @@ class RequestFuncInput:
|
||||
model_name: Optional[str] = None
|
||||
logprobs: Optional[int] = None
|
||||
extra_body: Optional[dict] = None
|
||||
multi_modal_content: Optional[dict] = None
|
||||
multi_modal_content: Optional[dict | list[dict]] = None
|
||||
ignore_eos: bool = False
|
||||
language: Optional[str] = None
|
||||
request_id: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -71,6 +72,9 @@ async def async_request_tgi(
|
||||
"inputs": request_func_input.prompt,
|
||||
"parameters": params,
|
||||
}
|
||||
headers = None
|
||||
if request_func_input.request_id:
|
||||
headers = {"x-request-id": request_func_input.request_id}
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
if request_func_input.ignore_eos:
|
||||
@ -82,7 +86,9 @@ async def async_request_tgi(
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload) as response:
|
||||
async with session.post(
|
||||
url=api_url, json=payload, headers=headers
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
async for chunk_bytes in response.content:
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
@ -145,6 +151,9 @@ async def async_request_trt_llm(
|
||||
}
|
||||
if request_func_input.ignore_eos:
|
||||
payload["min_length"] = request_func_input.output_len
|
||||
headers = None
|
||||
if request_func_input.request_id:
|
||||
headers = {"x-request-id": request_func_input.request_id}
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
@ -152,7 +161,9 @@ async def async_request_trt_llm(
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload) as response:
|
||||
async with session.post(
|
||||
url=api_url, json=payload, headers=headers
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
async for chunk_bytes in response.content:
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
@ -211,6 +222,8 @@ async def async_request_deepspeed_mii(
|
||||
"top_p": 1.0,
|
||||
}
|
||||
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
||||
if request_func_input.request_id:
|
||||
headers["x-request-id"] = request_func_input.request_id
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
@ -283,6 +296,8 @@ async def async_request_openai_completions(
|
||||
if request_func_input.extra_body:
|
||||
payload.update(request_func_input.extra_body)
|
||||
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
||||
if request_func_input.request_id:
|
||||
headers["x-request-id"] = request_func_input.request_id
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
@ -364,7 +379,15 @@ async def async_request_openai_chat_completions(
|
||||
) as session:
|
||||
content = [{"type": "text", "text": request_func_input.prompt}]
|
||||
if request_func_input.multi_modal_content:
|
||||
content.append(request_func_input.multi_modal_content)
|
||||
mm_content = request_func_input.multi_modal_content
|
||||
if isinstance(mm_content, list):
|
||||
content.extend(mm_content)
|
||||
elif isinstance(mm_content, dict):
|
||||
content.append(mm_content)
|
||||
else:
|
||||
raise TypeError(
|
||||
"multi_modal_content must be a dict or list[dict] for openai-chat"
|
||||
)
|
||||
payload = {
|
||||
"model": request_func_input.model_name
|
||||
if request_func_input.model_name
|
||||
@ -387,6 +410,8 @@ async def async_request_openai_chat_completions(
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
}
|
||||
if request_func_input.request_id:
|
||||
headers["x-request-id"] = request_func_input.request_id
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
@ -483,6 +508,8 @@ async def async_request_openai_audio(
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
}
|
||||
if request_func_input.request_id:
|
||||
headers["x-request-id"] = request_func_input.request_id
|
||||
|
||||
# Send audio file
|
||||
def to_bytes(y, sr):
|
||||
@ -491,7 +518,10 @@ async def async_request_openai_audio(
|
||||
buffer.seek(0)
|
||||
return buffer
|
||||
|
||||
with to_bytes(*request_func_input.multi_modal_content["audio"]) as f:
|
||||
mm_audio = request_func_input.multi_modal_content
|
||||
if not isinstance(mm_audio, dict) or "audio" not in mm_audio:
|
||||
raise TypeError("multi_modal_content must be a dict containing 'audio'")
|
||||
with to_bytes(*mm_audio["audio"]) as f:
|
||||
form = aiohttp.FormData()
|
||||
form.add_field("file", f, content_type="audio/wav")
|
||||
for key, value in payload.items():
|
||||
|
||||
74
benchmarks/benchmark_block_pool.py
Normal file
74
benchmarks/benchmark_block_pool.py
Normal file
@ -0,0 +1,74 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import gc
|
||||
|
||||
from tabulate import tabulate
|
||||
|
||||
from benchmark_utils import TimeCollector
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
|
||||
|
||||
def main(args):
|
||||
rows = []
|
||||
for allocate_block in args.allocate_blocks:
|
||||
# Enforce a GC collect ahead to minimize the impact among runs
|
||||
gc.collect()
|
||||
block_pool = BlockPool(num_gpu_blocks=args.num_gpu_blocks, enable_caching=True)
|
||||
|
||||
get_blocks_times = TimeCollector(TimeCollector.US)
|
||||
free_blocks_times = TimeCollector(TimeCollector.US)
|
||||
for _ in range(args.num_iteration):
|
||||
with get_blocks_times:
|
||||
blocks = block_pool.get_new_blocks(allocate_block)
|
||||
with free_blocks_times:
|
||||
block_pool.free_blocks(blocks)
|
||||
|
||||
rows.append(
|
||||
[get_blocks_times.cnt, args.num_gpu_blocks, allocate_block]
|
||||
+ get_blocks_times.dump_avg_max()
|
||||
+ free_blocks_times.dump_avg_max()
|
||||
)
|
||||
|
||||
print(
|
||||
tabulate(
|
||||
rows,
|
||||
headers=[
|
||||
"Iterations",
|
||||
"Total\nBlocks",
|
||||
"Allocated\nBlocks",
|
||||
"Get Blocks\nAvg (us)",
|
||||
"Get Blocks\nMax (us)",
|
||||
"Free Blocks\nAvg (us)",
|
||||
"Free Blocks\nMax (us)",
|
||||
],
|
||||
tablefmt="grid",
|
||||
floatfmt=".3f",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def invoke_main() -> None:
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark the performance of BlockPool for KV Cache."
|
||||
)
|
||||
parser.add_argument("--num-gpu-blocks", type=int, default=100000)
|
||||
parser.add_argument(
|
||||
"--num-iteration",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of iterations to run to stablize final data readings",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--allocate-blocks",
|
||||
type=int,
|
||||
nargs="*",
|
||||
default=[10, 50, 100, 500, 1000],
|
||||
help="Number of blocks to allocate",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
invoke_main() # pragma: no cover
|
||||
@ -19,6 +19,7 @@ import logging
|
||||
import random
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from functools import cache
|
||||
from io import BytesIO
|
||||
@ -52,8 +53,9 @@ class SampleRequest:
|
||||
prompt: Union[str, Any]
|
||||
prompt_len: int
|
||||
expected_output_len: int
|
||||
multi_modal_data: Optional[Union[MultiModalDataDict, dict]] = None
|
||||
multi_modal_data: Optional[Union[MultiModalDataDict, dict, list[dict]]] = None
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
request_id: Optional[str] = None
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@ -155,7 +157,10 @@ class BenchmarkDataset(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def sample(
|
||||
self, tokenizer: PreTrainedTokenizerBase, num_requests: int
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
request_id_prefix: str = "",
|
||||
) -> list[SampleRequest]:
|
||||
"""
|
||||
Abstract method to generate sample requests from the dataset.
|
||||
@ -167,6 +172,7 @@ class BenchmarkDataset(ABC):
|
||||
tokenizer (PreTrainedTokenizerBase): The tokenizer to be used
|
||||
for processing the dataset's text.
|
||||
num_requests (int): The number of sample requests to generate.
|
||||
request_id_prefix (str) The prefix of request_id.
|
||||
|
||||
Returns:
|
||||
list[SampleRequest]: A list of sample requests generated from the
|
||||
@ -175,7 +181,10 @@ class BenchmarkDataset(ABC):
|
||||
raise NotImplementedError("sample must be implemented in subclasses.")
|
||||
|
||||
def maybe_oversample_requests(
|
||||
self, requests: list[SampleRequest], num_requests: int
|
||||
self,
|
||||
requests: list[SampleRequest],
|
||||
num_requests: int,
|
||||
request_id_prefix: str = "",
|
||||
) -> None:
|
||||
"""
|
||||
Oversamples the list of requests if its size is less than the desired
|
||||
@ -183,11 +192,18 @@ class BenchmarkDataset(ABC):
|
||||
|
||||
Args:
|
||||
requests (List[SampleRequest]): The current list of sampled
|
||||
requests. num_requests (int): The target number of requests.
|
||||
requests.
|
||||
num_requests (int): The target number of requests.
|
||||
request_id_prefix (str) The prefix of the request ids.
|
||||
"""
|
||||
if len(requests) < num_requests:
|
||||
random.seed(self.random_seed)
|
||||
additional = random.choices(requests, k=num_requests - len(requests))
|
||||
additional = deepcopy(
|
||||
random.choices(requests, k=num_requests - len(requests))
|
||||
)
|
||||
for i in range(len(additional)):
|
||||
req = additional[i]
|
||||
req.request_id = request_id_prefix + str(len(requests) + i)
|
||||
requests.extend(additional)
|
||||
logger.info("Oversampled requests to reach %d total samples.", num_requests)
|
||||
|
||||
@ -277,6 +293,41 @@ def process_image(image: Any) -> Mapping[str, Any]:
|
||||
)
|
||||
|
||||
|
||||
def process_video(video: Any) -> Mapping[str, Any]:
|
||||
"""
|
||||
Process a single video input and return a multimedia content dictionary.
|
||||
|
||||
Supports the following input types:
|
||||
|
||||
1. Dictionary with raw video bytes: - Expects a dict with a 'bytes' key
|
||||
containing raw video data.
|
||||
|
||||
2. String input: - Treats the string as a URL or local file path. -
|
||||
Prepends "file://" if the string doesn't start with "http://" or
|
||||
"file://". - Returns a dictionary with the image URL.
|
||||
|
||||
Raises:
|
||||
ValueError: If the input is not a supported type.
|
||||
"""
|
||||
if isinstance(video, dict) and "bytes" in video:
|
||||
video_bytes = video["bytes"]
|
||||
video_base64 = base64.b64encode(video_bytes).decode("utf-8")
|
||||
return {
|
||||
"type": "video_url",
|
||||
"video_url": {"url": f"data:video/mp4;base64,{video_base64}"},
|
||||
}
|
||||
|
||||
if isinstance(video, str):
|
||||
video_url = (
|
||||
video if video.startswith(("http://", "file://")) else f"file://{video}"
|
||||
)
|
||||
return {"type": "video_url", "video_url": {"url": video_url}}
|
||||
|
||||
raise ValueError(
|
||||
f"Invalid video input {video}. Must be a string of local path/remote url, or a dictionary with raw video bytes in the form of `{{'bytes': raw_video_bytes}}`." # noqa: E501
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Random Dataset Implementation (Synthetic Data)
|
||||
# -----------------------------------------------------------------------------
|
||||
@ -303,6 +354,7 @@ class RandomDataset(BenchmarkDataset):
|
||||
range_ratio: float = DEFAULT_RANGE_RATIO,
|
||||
input_len: int = DEFAULT_INPUT_LEN,
|
||||
output_len: int = DEFAULT_OUTPUT_LEN,
|
||||
request_id_prefix: str = "",
|
||||
**kwargs,
|
||||
) -> list[SampleRequest]:
|
||||
# Enforce range_ratio < 1
|
||||
@ -363,8 +415,10 @@ class RandomDataset(BenchmarkDataset):
|
||||
prompt=prompt,
|
||||
prompt_len=total_input_len,
|
||||
expected_output_len=int(output_lens[i]),
|
||||
request_id=request_id_prefix + str(i),
|
||||
)
|
||||
)
|
||||
|
||||
return requests
|
||||
|
||||
|
||||
@ -406,9 +460,11 @@ class ShareGPTDataset(BenchmarkDataset):
|
||||
max_loras: Optional[int] = None,
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
request_id_prefix: str = "",
|
||||
**kwargs,
|
||||
) -> list:
|
||||
samples: list = []
|
||||
ind = 0
|
||||
for entry in self.data:
|
||||
if len(samples) >= num_requests:
|
||||
break
|
||||
@ -430,17 +486,26 @@ class ShareGPTDataset(BenchmarkDataset):
|
||||
skip_min_output_len_check=output_len is not None,
|
||||
):
|
||||
continue
|
||||
if image_path := entry.get("image"):
|
||||
mm_content = process_image(image_path)
|
||||
elif video_path := entry.get("video"):
|
||||
mm_content = process_video(video_path)
|
||||
else:
|
||||
mm_content = None
|
||||
if enable_multimodal_chat:
|
||||
prompt = self.apply_multimodal_chat_transformation(prompt, None)
|
||||
prompt = self.apply_multimodal_chat_transformation(prompt, mm_content)
|
||||
samples.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=new_output_len,
|
||||
lora_request=lora_request,
|
||||
multi_modal_data=mm_content,
|
||||
request_id=request_id_prefix + str(ind),
|
||||
)
|
||||
)
|
||||
self.maybe_oversample_requests(samples, num_requests)
|
||||
ind += 1
|
||||
self.maybe_oversample_requests(samples, num_requests, request_id_prefix)
|
||||
return samples
|
||||
|
||||
|
||||
@ -506,10 +571,11 @@ class CustomDataset(BenchmarkDataset):
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
skip_chat_template: bool = False,
|
||||
request_id_prefix: str = "",
|
||||
**kwargs,
|
||||
) -> list:
|
||||
sampled_requests = []
|
||||
for item in self.data:
|
||||
for i, item in enumerate(self.data):
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
prompt = item["prompt"]
|
||||
@ -528,9 +594,12 @@ class CustomDataset(BenchmarkDataset):
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
request_id=request_id_prefix + str(i),
|
||||
)
|
||||
)
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
self.maybe_oversample_requests(
|
||||
sampled_requests, num_requests, request_id_prefix
|
||||
)
|
||||
|
||||
return sampled_requests
|
||||
|
||||
@ -572,6 +641,7 @@ class SonnetDataset(BenchmarkDataset):
|
||||
input_len: int = DEFAULT_INPUT_LEN,
|
||||
output_len: int = DEFAULT_OUTPUT_LEN,
|
||||
return_prompt_formatted: bool = False,
|
||||
request_id_prefix: str = "",
|
||||
**kwargs,
|
||||
) -> list:
|
||||
# Calculate average token length for a poem line.
|
||||
@ -597,6 +667,7 @@ class SonnetDataset(BenchmarkDataset):
|
||||
prefix_lines = self.data[:num_prefix_lines]
|
||||
|
||||
samples = []
|
||||
ind = 0
|
||||
while len(samples) < num_requests:
|
||||
extra_lines = random.choices(
|
||||
self.data, k=num_input_lines - num_prefix_lines
|
||||
@ -607,14 +678,17 @@ class SonnetDataset(BenchmarkDataset):
|
||||
msg, add_generation_prompt=True, tokenize=False
|
||||
)
|
||||
prompt_len = len(tokenizer(prompt_formatted).input_ids)
|
||||
|
||||
if prompt_len <= input_len:
|
||||
samples.append(
|
||||
SampleRequest(
|
||||
prompt=prompt_formatted if return_prompt_formatted else prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
request_id=request_id_prefix + str(ind),
|
||||
)
|
||||
)
|
||||
ind += 1
|
||||
return samples
|
||||
|
||||
|
||||
@ -666,6 +740,7 @@ class BurstGPTDataset(BenchmarkDataset):
|
||||
num_requests: int,
|
||||
max_loras: Optional[int] = None,
|
||||
lora_path: Optional[str] = None,
|
||||
request_id_prefix: str = "",
|
||||
**kwargs,
|
||||
) -> list[SampleRequest]:
|
||||
samples = []
|
||||
@ -687,6 +762,7 @@ class BurstGPTDataset(BenchmarkDataset):
|
||||
prompt_len=input_len,
|
||||
expected_output_len=output_len,
|
||||
lora_request=lora_req,
|
||||
request_id=request_id_prefix + str(i),
|
||||
)
|
||||
)
|
||||
return samples
|
||||
@ -746,12 +822,14 @@ class ConversationDataset(HuggingFaceDataset):
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
request_id_prefix: str = "",
|
||||
**kwargs,
|
||||
) -> list:
|
||||
# Filter examples with at least 2 conversations
|
||||
filtered_data = self.data.filter(lambda x: len(x["conversations"]) >= 2)
|
||||
sampled_requests = []
|
||||
dynamic_output = output_len is None
|
||||
ind = 0
|
||||
|
||||
for item in filtered_data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
@ -779,9 +857,13 @@ class ConversationDataset(HuggingFaceDataset):
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
multi_modal_data=mm_content,
|
||||
request_id=request_id_prefix + str(ind),
|
||||
)
|
||||
)
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
ind += 1
|
||||
self.maybe_oversample_requests(
|
||||
sampled_requests, num_requests, request_id_prefix
|
||||
)
|
||||
return sampled_requests
|
||||
|
||||
|
||||
@ -808,11 +890,12 @@ class VisionArenaDataset(HuggingFaceDataset):
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
request_id_prefix: str = "",
|
||||
**kwargs,
|
||||
) -> list:
|
||||
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
|
||||
sampled_requests = []
|
||||
for item in self.data:
|
||||
for i, item in enumerate(self.data):
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path)
|
||||
@ -832,9 +915,12 @@ class VisionArenaDataset(HuggingFaceDataset):
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
multi_modal_data=mm_content,
|
||||
request_id=request_id_prefix + str(i),
|
||||
)
|
||||
)
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
self.maybe_oversample_requests(
|
||||
sampled_requests, num_requests, request_id_prefix
|
||||
)
|
||||
return sampled_requests
|
||||
|
||||
|
||||
@ -864,15 +950,18 @@ class InstructCoderDataset(HuggingFaceDataset):
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
request_id_prefix: str = "",
|
||||
**kwargs,
|
||||
) -> list:
|
||||
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
|
||||
sampled_requests = []
|
||||
for item in self.data:
|
||||
for i, item in enumerate(self.data):
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
prompt = f"{item['input']}\n\n{item['instruction']} Just output \
|
||||
the code, do not include any explanation."
|
||||
prompt = (
|
||||
f"{item['input']}\n\n{item['instruction']} Just output "
|
||||
"the code, do not include any explanation."
|
||||
)
|
||||
|
||||
# apply template
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
@ -886,9 +975,12 @@ class InstructCoderDataset(HuggingFaceDataset):
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
request_id=request_id_prefix + str(i),
|
||||
)
|
||||
)
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
self.maybe_oversample_requests(
|
||||
sampled_requests, num_requests, request_id_prefix
|
||||
)
|
||||
return sampled_requests
|
||||
|
||||
|
||||
@ -918,12 +1010,13 @@ class MTBenchDataset(HuggingFaceDataset):
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
request_id_prefix: str = "",
|
||||
**kwargs,
|
||||
) -> list:
|
||||
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
|
||||
sampled_requests = []
|
||||
|
||||
for item in self.data:
|
||||
for i, item in enumerate(self.data):
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
prompt = item["turns"][0]
|
||||
@ -941,9 +1034,12 @@ class MTBenchDataset(HuggingFaceDataset):
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
request_id=request_id_prefix + str(i),
|
||||
)
|
||||
)
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
self.maybe_oversample_requests(
|
||||
sampled_requests, num_requests, request_id_prefix
|
||||
)
|
||||
return sampled_requests
|
||||
|
||||
|
||||
@ -968,10 +1064,12 @@ class AIMODataset(HuggingFaceDataset):
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
request_id_prefix: str = "",
|
||||
**kwargs,
|
||||
) -> list:
|
||||
sampled_requests = []
|
||||
dynamic_output = output_len is None
|
||||
ind = 0
|
||||
|
||||
for item in self.data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
@ -994,9 +1092,13 @@ class AIMODataset(HuggingFaceDataset):
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
multi_modal_data=None,
|
||||
request_id=request_id_prefix + str(ind),
|
||||
)
|
||||
)
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
ind += 1
|
||||
self.maybe_oversample_requests(
|
||||
sampled_requests, num_requests, request_id_prefix
|
||||
)
|
||||
return sampled_requests
|
||||
|
||||
|
||||
@ -1066,12 +1168,18 @@ class NextEditPredictionDataset(HuggingFaceDataset):
|
||||
"zed-industries/zeta": _format_zeta_prompt,
|
||||
}
|
||||
|
||||
def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int, **kwargs):
|
||||
def sample(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
request_id_prefix: str = "",
|
||||
**kwargs,
|
||||
):
|
||||
formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(self.dataset_path)
|
||||
if formatting_prompt_func is None:
|
||||
raise ValueError(f"Unsupported dataset path: {self.dataset_path}")
|
||||
samples = []
|
||||
for sample in self.data:
|
||||
for i, sample in enumerate(self.data):
|
||||
sample = formatting_prompt_func(sample)
|
||||
samples.append(
|
||||
SampleRequest(
|
||||
@ -1080,11 +1188,12 @@ class NextEditPredictionDataset(HuggingFaceDataset):
|
||||
expected_output_len=len(
|
||||
tokenizer(sample["expected_output"]).input_ids
|
||||
),
|
||||
request_id=request_id_prefix + str(i),
|
||||
)
|
||||
)
|
||||
if len(samples) >= num_requests:
|
||||
break
|
||||
self.maybe_oversample_requests(samples, num_requests)
|
||||
self.maybe_oversample_requests(samples, num_requests, request_id_prefix)
|
||||
return samples
|
||||
|
||||
|
||||
@ -1133,6 +1242,7 @@ class ASRDataset(HuggingFaceDataset):
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
request_id_prefix: str = "",
|
||||
**kwargs,
|
||||
) -> list:
|
||||
import librosa
|
||||
@ -1142,6 +1252,7 @@ class ASRDataset(HuggingFaceDataset):
|
||||
prompt_len = len(tokenizer(prompt).input_ids)
|
||||
sampled_requests = []
|
||||
skipped = 0
|
||||
ind = 0
|
||||
for item in self.data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
@ -1160,8 +1271,10 @@ class ASRDataset(HuggingFaceDataset):
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
multi_modal_data=mm_content,
|
||||
request_id=request_id_prefix + str(ind),
|
||||
)
|
||||
)
|
||||
ind += 1
|
||||
if skipped:
|
||||
logger.warning(
|
||||
"%d samples discarded from dataset due to"
|
||||
@ -1169,5 +1282,7 @@ class ASRDataset(HuggingFaceDataset):
|
||||
" what Whisper supports.",
|
||||
skipped,
|
||||
)
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
self.maybe_oversample_requests(
|
||||
sampled_requests, num_requests, request_id_prefix
|
||||
)
|
||||
return sampled_requests
|
||||
|
||||
112
benchmarks/benchmark_ngram_proposer.py
Normal file
112
benchmarks/benchmark_ngram_proposer.py
Normal file
@ -0,0 +1,112 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import gc
|
||||
|
||||
import numpy as np
|
||||
from tabulate import tabulate
|
||||
|
||||
from benchmark_utils import TimeCollector
|
||||
from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
|
||||
|
||||
def main(args):
|
||||
rows = []
|
||||
for max_ngram in args.max_ngram:
|
||||
collector = TimeCollector(TimeCollector.US)
|
||||
|
||||
model_config = ModelConfig(
|
||||
model="facebook/opt-125m",
|
||||
task="generate",
|
||||
max_model_len=args.num_token + args.num_spec_token,
|
||||
tokenizer="facebook/opt-125m",
|
||||
tokenizer_mode="auto",
|
||||
dtype="auto",
|
||||
seed=None,
|
||||
trust_remote_code=False,
|
||||
)
|
||||
proposer = NgramProposer(
|
||||
vllm_config=VllmConfig(
|
||||
model_config=model_config,
|
||||
speculative_config=SpeculativeConfig(
|
||||
prompt_lookup_min=args.min_ngram,
|
||||
prompt_lookup_max=max_ngram,
|
||||
num_speculative_tokens=args.num_spec_token,
|
||||
method="ngram",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Warm up
|
||||
proposer.propose(np.random.randint(0, 20, (args.num_token,)))
|
||||
|
||||
gc.collect()
|
||||
for _ in range(args.num_iteration):
|
||||
tokens = np.random.randint(0, 20, (args.num_req, args.num_token))
|
||||
with collector:
|
||||
for i in range(args.num_req):
|
||||
proposer.propose(tokens[i, :])
|
||||
rows.append(
|
||||
[args.num_req, args.num_token, args.min_ngram, max_ngram]
|
||||
+ collector.dump_avg_max()
|
||||
)
|
||||
|
||||
print(
|
||||
tabulate(
|
||||
rows,
|
||||
headers=[
|
||||
"# Request",
|
||||
"# Token",
|
||||
"Min Ngram",
|
||||
"Max Ngram",
|
||||
"Avg (us)",
|
||||
"Max (us)",
|
||||
],
|
||||
tablefmt="grid",
|
||||
floatfmt=".3f",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def invoke_main() -> None:
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark the performance of N-gram speculative decode drafting"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-iteration",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Number of iterations to run to stablize final data readings",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-req", type=int, default=128, help="Number of requests in the batch"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-token", type=int, default=1500, help="Number of tokens for each request"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min-ngram",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Minimum n-gram to match",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-ngram",
|
||||
type=int,
|
||||
nargs="*",
|
||||
default=[5, 7, 10, 15, 20],
|
||||
help="Maximum n-gram to match",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-spec-token",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Number of speculative tokens to generate",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
invoke_main() # pragma: no cover
|
||||
@ -263,7 +263,14 @@ async def benchmark(
|
||||
input_requests[0].multi_modal_data,
|
||||
)
|
||||
|
||||
assert test_mm_content is None or isinstance(test_mm_content, dict)
|
||||
assert (
|
||||
test_mm_content is None
|
||||
or isinstance(test_mm_content, dict)
|
||||
or (
|
||||
isinstance(test_mm_content, list)
|
||||
and all(isinstance(item, dict) for item in test_mm_content)
|
||||
)
|
||||
), "multi_modal_data must be a dict or list[dict]"
|
||||
test_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
model_name=model_name,
|
||||
@ -368,11 +375,12 @@ async def benchmark(
|
||||
rps_change_events.append({"rps": rps_val, "timestamp": timestamp})
|
||||
last_int_rps = current_int_rps
|
||||
|
||||
prompt, prompt_len, output_len, mm_content = (
|
||||
prompt, prompt_len, output_len, mm_content, request_id = (
|
||||
request.prompt,
|
||||
request.prompt_len,
|
||||
request.expected_output_len,
|
||||
request.multi_modal_data,
|
||||
request.request_id,
|
||||
)
|
||||
req_model_id, req_model_name = model_id, model_name
|
||||
if lora_modules:
|
||||
@ -390,6 +398,7 @@ async def benchmark(
|
||||
multi_modal_content=mm_content,
|
||||
ignore_eos=ignore_eos,
|
||||
extra_body=extra_body,
|
||||
request_id=request_id,
|
||||
)
|
||||
task = limited_request_func(request_func_input=request_func_input, pbar=pbar)
|
||||
tasks.append(asyncio.create_task(task))
|
||||
@ -658,6 +667,7 @@ def main(args: argparse.Namespace):
|
||||
tokenizer=tokenizer,
|
||||
output_len=args.custom_output_len,
|
||||
skip_chat_template=args.custom_skip_chat_template,
|
||||
request_id_prefix=args.request_id_prefix,
|
||||
)
|
||||
|
||||
elif args.dataset_name == "sonnet":
|
||||
@ -671,6 +681,7 @@ def main(args: argparse.Namespace):
|
||||
prefix_len=args.sonnet_prefix_len,
|
||||
tokenizer=tokenizer,
|
||||
return_prompt_formatted=False,
|
||||
request_id_prefix=args.request_id_prefix,
|
||||
)
|
||||
else:
|
||||
assert tokenizer.chat_template or tokenizer.default_chat_template, (
|
||||
@ -683,6 +694,7 @@ def main(args: argparse.Namespace):
|
||||
prefix_len=args.sonnet_prefix_len,
|
||||
tokenizer=tokenizer,
|
||||
return_prompt_formatted=True,
|
||||
request_id_prefix=args.request_id_prefix,
|
||||
)
|
||||
|
||||
elif args.dataset_name == "hf":
|
||||
@ -744,6 +756,7 @@ def main(args: argparse.Namespace):
|
||||
num_requests=args.num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
output_len=args.hf_output_len,
|
||||
request_id_prefix=args.request_id_prefix,
|
||||
)
|
||||
|
||||
else:
|
||||
@ -755,10 +768,15 @@ def main(args: argparse.Namespace):
|
||||
tokenizer=tokenizer,
|
||||
num_requests=args.num_prompts,
|
||||
output_len=args.sharegpt_output_len,
|
||||
request_id_prefix=args.request_id_prefix,
|
||||
),
|
||||
"burstgpt": lambda: BurstGPTDataset(
|
||||
random_seed=args.seed, dataset_path=args.dataset_path
|
||||
).sample(tokenizer=tokenizer, num_requests=args.num_prompts),
|
||||
).sample(
|
||||
tokenizer=tokenizer,
|
||||
num_requests=args.num_prompts,
|
||||
request_id_prefix=args.request_id_prefix,
|
||||
),
|
||||
"random": lambda: RandomDataset(dataset_path=args.dataset_path).sample(
|
||||
tokenizer=tokenizer,
|
||||
num_requests=args.num_prompts,
|
||||
@ -766,6 +784,7 @@ def main(args: argparse.Namespace):
|
||||
input_len=args.random_input_len,
|
||||
output_len=args.random_output_len,
|
||||
range_ratio=args.random_range_ratio,
|
||||
request_id_prefix=args.request_id_prefix,
|
||||
),
|
||||
}
|
||||
|
||||
@ -1111,6 +1130,13 @@ def create_argument_parser():
|
||||
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
|
||||
"and the blog: https://hao-ai-lab.github.io/blogs/distserve",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--request-id-prefix",
|
||||
type=str,
|
||||
required=False,
|
||||
default="benchmark-serving",
|
||||
help="Specify the prefix of request id.",
|
||||
)
|
||||
|
||||
# group for dataset specific arguments
|
||||
custom_group = parser.add_argument_group("custom dataset options")
|
||||
|
||||
@ -597,8 +597,8 @@ def validate_args(args):
|
||||
# https://github.com/vllm-project/vllm/issues/16222
|
||||
if args.data_parallel_size > 1:
|
||||
raise ValueError(
|
||||
"Data parallel is not supported in offline benchmark, \
|
||||
please use benchmark serving instead"
|
||||
"Data parallel is not supported in offline benchmark, "
|
||||
"please use benchmark serving instead"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from typing import Any
|
||||
import time
|
||||
from types import TracebackType
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
|
||||
def convert_to_pytorch_benchmark_format(
|
||||
@ -72,3 +73,53 @@ def write_to_json(filename: str, records: list) -> None:
|
||||
cls=InfEncoder,
|
||||
default=lambda o: f"<{type(o).__name__} object is not JSON serializable>",
|
||||
)
|
||||
|
||||
|
||||
# Collect time and generate time metrics
|
||||
#
|
||||
# Example Usage:
|
||||
# collector = TimeCollector(TimeCollector.US)
|
||||
# for _ in range(total_iteration):
|
||||
# with collector:
|
||||
# ...
|
||||
# collector.dump_avg_max()
|
||||
class TimeCollector:
|
||||
NS: int = 1
|
||||
US: int = NS * 1000
|
||||
MS: int = US * 1000
|
||||
S: int = MS * 1000
|
||||
|
||||
def __init__(self, scale: int) -> None:
|
||||
self.cnt: int = 0
|
||||
self._sum: int = 0
|
||||
self._max: Optional[int] = None
|
||||
self.scale = scale
|
||||
self.start_time: int = time.monotonic_ns()
|
||||
|
||||
def collect(self, v: int) -> None:
|
||||
self.cnt += 1
|
||||
self._sum += v
|
||||
if self._max is None:
|
||||
self._max = v
|
||||
else:
|
||||
self._max = max(self._max, v)
|
||||
|
||||
def avg(self) -> Union[float, str]:
|
||||
return self._sum * 1.0 / self.cnt / self.scale if self.cnt > 0 else "N/A"
|
||||
|
||||
def max(self) -> Union[float, str]:
|
||||
return self._max / self.scale if self._max else "N/A"
|
||||
|
||||
def dump_avg_max(self) -> list[Union[float, str]]:
|
||||
return [self.avg(), self.max()]
|
||||
|
||||
def __enter__(self) -> None:
|
||||
self.start_time = time.monotonic_ns()
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[type[BaseException]],
|
||||
exc_value: Optional[BaseException],
|
||||
exc_traceback: Optional[TracebackType],
|
||||
) -> None:
|
||||
self.collect(time.monotonic_ns() - self.start_time)
|
||||
|
||||
@ -1,63 +1,199 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
|
||||
import aiohttp
|
||||
from quart import Quart, make_response, request
|
||||
from quart import Quart, Response, make_response, request
|
||||
from rate_limiter import RateLimiter
|
||||
from request_queue import RequestQueue
|
||||
|
||||
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
||||
|
||||
app = Quart(__name__)
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def forward_request(url, data):
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
def parse_args():
|
||||
"""parse command line arguments"""
|
||||
parser = argparse.ArgumentParser(description="vLLM P/D disaggregation proxy server")
|
||||
|
||||
# Add args
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
type=float,
|
||||
default=300,
|
||||
help="Timeout for backend service requests in seconds (default: 300)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-concurrent",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Maximum concurrent requests to backend services (default: 100)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--queue-size",
|
||||
type=int,
|
||||
default=500,
|
||||
help="Maximum number of requests in the queue (default: 500)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rate-limit",
|
||||
type=int,
|
||||
default=40,
|
||||
help="Maximum requests per second (default: 40)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="Port to run the server on (default: 8000)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefill-url",
|
||||
type=str,
|
||||
default="http://localhost:8100/v1/completions",
|
||||
help="Prefill service endpoint URL",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decode-url",
|
||||
type=str,
|
||||
default="http://localhost:8200/v1/completions",
|
||||
help="Decode service endpoint URL",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""parse command line arguments"""
|
||||
args = parse_args()
|
||||
|
||||
# Initialize configuration using command line parameters
|
||||
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=args.timeout)
|
||||
MAX_CONCURRENT_REQUESTS = args.max_concurrent
|
||||
REQUEST_QUEUE_SIZE = args.queue_size
|
||||
RATE_LIMIT = args.rate_limit
|
||||
PREFILL_SERVICE_URL = args.prefill_url
|
||||
DECODE_SERVICE_URL = args.decode_url
|
||||
PORT = args.port
|
||||
|
||||
app = Quart(__name__)
|
||||
|
||||
# Initialize the rate limiter and request queue
|
||||
rate_limiter = RateLimiter(RATE_LIMIT)
|
||||
request_queue = RequestQueue(MAX_CONCURRENT_REQUESTS, REQUEST_QUEUE_SIZE)
|
||||
|
||||
# Attach the configuration object to the application instance
|
||||
app.config.update(
|
||||
{
|
||||
"AIOHTTP_TIMEOUT": AIOHTTP_TIMEOUT,
|
||||
"rate_limiter": rate_limiter,
|
||||
"request_queue": request_queue,
|
||||
"PREFILL_SERVICE_URL": PREFILL_SERVICE_URL,
|
||||
"DECODE_SERVICE_URL": DECODE_SERVICE_URL,
|
||||
}
|
||||
)
|
||||
|
||||
# Start queue processing on app startup
|
||||
@app.before_serving
|
||||
async def startup():
|
||||
"""Start request processing task when app starts serving"""
|
||||
asyncio.create_task(request_queue.process())
|
||||
|
||||
async def forward_request(url, data):
|
||||
"""Forward request to backend service with rate limiting and error handling"""
|
||||
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
||||
async with session.post(url=url, json=data, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
# if response.headers.get('Transfer-Encoding') == 'chunked':
|
||||
if True:
|
||||
async for chunk_bytes in response.content.iter_chunked(1024):
|
||||
yield chunk_bytes
|
||||
else:
|
||||
content = await response.read()
|
||||
yield content
|
||||
|
||||
|
||||
@app.route("/v1/completions", methods=["POST"])
|
||||
async def handle_request():
|
||||
try:
|
||||
original_request_data = await request.get_json()
|
||||
|
||||
prefill_request = original_request_data.copy()
|
||||
# change max_tokens = 1 to let it only do prefill
|
||||
prefill_request["max_tokens"] = 1
|
||||
|
||||
# finish prefill
|
||||
async for _ in forward_request(
|
||||
"http://localhost:8100/v1/completions", prefill_request
|
||||
# Use rate limiter as context manager
|
||||
async with (
|
||||
rate_limiter,
|
||||
aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session,
|
||||
):
|
||||
continue
|
||||
try:
|
||||
async with session.post(
|
||||
url=url, json=data, headers=headers
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
# Stream response chunks
|
||||
async for chunk_bytes in response.content.iter_chunked(1024):
|
||||
yield chunk_bytes
|
||||
else:
|
||||
# Handle backend service errors
|
||||
error_text = await response.text()
|
||||
logger.error(
|
||||
"Backend service error: %s - %s",
|
||||
response.status,
|
||||
error_text,
|
||||
)
|
||||
yield b'{"error": "Backend service error"}'
|
||||
except aiohttp.ClientError as e:
|
||||
# Handle connection errors
|
||||
logger.error("Connection error to %s: %s", url, str(e))
|
||||
yield b'{"error": "Service unavailable"}'
|
||||
except asyncio.TimeoutError:
|
||||
# Handle timeout errors
|
||||
logger.error("Timeout connecting to %s", url)
|
||||
yield b'{"error": "Service timeout"}'
|
||||
|
||||
# return decode
|
||||
generator = forward_request(
|
||||
"http://localhost:8200/v1/completions", original_request_data
|
||||
)
|
||||
response = await make_response(generator)
|
||||
response.timeout = None
|
||||
async def process_request():
|
||||
"""Process a single request through prefill and decode stages"""
|
||||
try:
|
||||
original_request_data = await request.get_json()
|
||||
|
||||
return response
|
||||
# Create prefill request (max_tokens=1)
|
||||
prefill_request = original_request_data.copy()
|
||||
prefill_request["max_tokens"] = 1
|
||||
|
||||
except Exception as e:
|
||||
import sys
|
||||
import traceback
|
||||
# Execute prefill stage
|
||||
async for _ in forward_request(PREFILL_SERVICE_URL, prefill_request):
|
||||
continue
|
||||
|
||||
exc_info = sys.exc_info()
|
||||
print("Error occurred in disagg prefill proxy server")
|
||||
print(e)
|
||||
print("".join(traceback.format_exception(*exc_info)))
|
||||
# Execute decode stage and stream response
|
||||
generator = forward_request(DECODE_SERVICE_URL, original_request_data)
|
||||
response = await make_response(generator)
|
||||
response.timeout = None # Disable timeout for streaming response
|
||||
return response
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error processing request")
|
||||
return Response(
|
||||
response=b'{"error": "Internal server error"}',
|
||||
status=500,
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
@app.route("/v1/completions", methods=["POST"])
|
||||
async def handle_request():
|
||||
"""Handle incoming API requests with concurrency and rate limiting"""
|
||||
# Create task for request processing
|
||||
task = asyncio.create_task(process_request())
|
||||
|
||||
# Enqueue request or reject if queue is full
|
||||
if not await request_queue.enqueue(task):
|
||||
return Response(
|
||||
response=b'{"error": "Server busy, try again later"}',
|
||||
status=503,
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
try:
|
||||
# Return the response from the processing task
|
||||
return await task
|
||||
except asyncio.CancelledError:
|
||||
# Handle task cancellation (timeout or queue full)
|
||||
logger.warning("Request cancelled due to timeout or queue full")
|
||||
return Response(
|
||||
response=b'{"error": "Request cancelled"}',
|
||||
status=503,
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
# Start the Quart server with host can be set to 0.0.0.0
|
||||
app.run(port=PORT)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(port=8000)
|
||||
main()
|
||||
|
||||
45
benchmarks/disagg_benchmarks/rate_limiter.py
Normal file
45
benchmarks/disagg_benchmarks/rate_limiter.py
Normal file
@ -0,0 +1,45 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Token bucket rate limiter implementation"""
|
||||
|
||||
def __init__(self, rate_limit):
|
||||
self.rate_limit = rate_limit # Requests per second
|
||||
self.num_available_tokens = rate_limit # Available tokens
|
||||
self.last_refill = time.monotonic() # Last token refill time
|
||||
self.lock = asyncio.Lock() # Synchronization lock
|
||||
|
||||
async def acquire(self):
|
||||
"""Acquire a token from the rate limiter"""
|
||||
while True:
|
||||
async with self.lock:
|
||||
current_time = time.monotonic()
|
||||
elapsed = current_time - self.last_refill
|
||||
|
||||
# Refill num_available_tokens if more than 1 second has passed
|
||||
if elapsed > 1.0:
|
||||
self.num_available_tokens = self.rate_limit
|
||||
self.last_refill = current_time
|
||||
|
||||
# Check if num_available_tokens are available
|
||||
if self.num_available_tokens > 0:
|
||||
self.num_available_tokens -= 1
|
||||
return True
|
||||
|
||||
# Calculate wait time if no num_available_tokens available
|
||||
wait_time = 1.0 - elapsed
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Enter async context manager - acquire token"""
|
||||
await self.acquire()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
"""Exit async context manager - no cleanup needed"""
|
||||
pass
|
||||
39
benchmarks/disagg_benchmarks/request_queue.py
Normal file
39
benchmarks/disagg_benchmarks/request_queue.py
Normal file
@ -0,0 +1,39 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
from collections import deque
|
||||
|
||||
|
||||
class RequestQueue:
|
||||
"""Request queue manager with concurrency control"""
|
||||
|
||||
def __init__(self, max_concurrent, max_queue_size):
|
||||
# Maximum concurrent requests
|
||||
self.max_concurrent = max_concurrent
|
||||
self.max_queue_size = max_queue_size # Maximum queue size
|
||||
# Concurrency control
|
||||
self.semaphore = asyncio.Semaphore(max_concurrent)
|
||||
self.queue = deque() # Request queue
|
||||
self.queue_size = 0 # Current queue size
|
||||
self.lock = asyncio.Lock() # Sync queue Lock
|
||||
|
||||
async def enqueue(self, task):
|
||||
"""Add a request task to the queue"""
|
||||
async with self.lock:
|
||||
if self.queue_size >= self.max_queue_size:
|
||||
return False
|
||||
|
||||
self.queue.append(task)
|
||||
self.queue_size += 1
|
||||
return True
|
||||
|
||||
async def process(self):
|
||||
"""Process queued requests using semaphore for concurrency control"""
|
||||
while True:
|
||||
if self.queue:
|
||||
async with self.semaphore, self.lock:
|
||||
task = self.queue.popleft()
|
||||
self.queue_size -= 1
|
||||
await task
|
||||
await asyncio.sleep(0.01) # Yield control to event loop
|
||||
@ -1,345 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.aqlm import (
|
||||
dequantize_weight,
|
||||
generic_dequantize_gemm,
|
||||
get_int_dtype,
|
||||
optimized_dequantize_gemm,
|
||||
)
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
|
||||
|
||||
def torch_mult(
|
||||
# [..., in_features]
|
||||
input: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
# [num_out_groups, 1, 1, 1]
|
||||
scales: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
output = F.linear(input, weights)
|
||||
return output
|
||||
|
||||
|
||||
def dequant_out_scale(
|
||||
# [..., in_features]
|
||||
input: torch.Tensor,
|
||||
# [num_out_groups, num_in_groups, num_codebooks]
|
||||
codes: torch.IntTensor,
|
||||
# [num_codebooks, codebook_size, out_group_size, in_group_size]
|
||||
codebooks: torch.Tensor,
|
||||
# [num_out_groups, 1, 1, 1]
|
||||
scales: torch.Tensor,
|
||||
output_partition_sizes: torch.IntTensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
|
||||
|
||||
if bias is None:
|
||||
output = F.linear(input, weights, bias)
|
||||
orig_shape = output.shape
|
||||
flattened_output = output.view(-1, output.size(-1))
|
||||
f_scales = scales.view(-1, scales.shape[0])
|
||||
b_scales = f_scales.expand(flattened_output.shape[0], -1)
|
||||
flattened_output *= b_scales
|
||||
return flattened_output.view(orig_shape)
|
||||
else:
|
||||
b_scales = scales.view(scales.shape[:-3] + (-1,)).expand(-1, weights.shape[1])
|
||||
weights *= b_scales
|
||||
return F.linear(input, weights, bias)
|
||||
|
||||
|
||||
def dequant_weight_scale(
|
||||
# [..., in_features]
|
||||
input: torch.Tensor,
|
||||
# [num_out_groups, num_in_groups, num_codebooks]
|
||||
codes: torch.IntTensor,
|
||||
# [num_codebooks, codebook_size, out_group_size, in_group_size]
|
||||
codebooks: torch.Tensor,
|
||||
# [num_out_groups, 1, 1, 1]
|
||||
scales: torch.Tensor,
|
||||
output_partition_sizes: torch.IntTensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
|
||||
|
||||
b_scales = scales.view(scales.shape[:-3] + (-1,)).expand(-1, weights.shape[1])
|
||||
weights *= b_scales
|
||||
return F.linear(input, weights, bias)
|
||||
|
||||
|
||||
def dequant_no_scale(
|
||||
# [..., in_features]
|
||||
input: torch.Tensor,
|
||||
# [num_out_groups, num_in_groups, num_codebooks]
|
||||
codes: torch.IntTensor,
|
||||
# [num_codebooks, codebook_size, out_group_size, in_group_size]
|
||||
codebooks: torch.Tensor,
|
||||
# [num_out_groups, 1, 1, 1]
|
||||
scales: torch.Tensor,
|
||||
output_partition_sizes: torch.IntTensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
|
||||
|
||||
return F.linear(input, weights, bias)
|
||||
|
||||
|
||||
# Compare the optimized 1x16 and 2x8 cuda decompression/dequant kernels against
|
||||
# the generic pytorch version.
|
||||
# Just visual comparison.
|
||||
def dequant_test(k: int, parts: torch.Tensor, nbooks: int, bits: int) -> None:
|
||||
n = int(parts.sum().item())
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
code_range = (1 << bits) // 2
|
||||
ingroups = 8
|
||||
|
||||
codes = torch.randint(
|
||||
-code_range,
|
||||
code_range,
|
||||
size=(n, k // ingroups, nbooks),
|
||||
dtype=get_int_dtype(bits),
|
||||
device=device,
|
||||
)
|
||||
|
||||
codebooks = torch.randn(
|
||||
size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
|
||||
dtype=torch.float16,
|
||||
device=device,
|
||||
)
|
||||
|
||||
count = 0
|
||||
for index in range(16):
|
||||
for i in range(8):
|
||||
for book in range(nbooks):
|
||||
codebooks[book, index, 0, i] = count * (10**book)
|
||||
count += 1
|
||||
|
||||
print("codes shape", codes.shape)
|
||||
|
||||
for i in range(16):
|
||||
for book in range(nbooks):
|
||||
codes[0, i, book] = i
|
||||
codes[0, -i, book] = i
|
||||
|
||||
weights = dequantize_weight(codes, codebooks, None)
|
||||
weights2 = ops.aqlm_dequant(codes, codebooks, parts)
|
||||
|
||||
print("weights shape:", weights.shape)
|
||||
print("weights2 shape:", weights2.shape)
|
||||
|
||||
print("weights are:", weights)
|
||||
print("weights2 are:", weights2)
|
||||
|
||||
print("first 128 weights are", weights[0, 0:128].to(torch.int32))
|
||||
print("first 128 weights2 are:", weights2[0, 0:128].to(torch.int32))
|
||||
|
||||
print("last 128 weights are", weights[0, -128:])
|
||||
print("last 128 weights2 are:", weights2[0, -128:])
|
||||
|
||||
|
||||
def main():
|
||||
parser = FlexibleArgumentParser(description="Benchmark aqlm performance.")
|
||||
|
||||
# Add arguments
|
||||
parser.add_argument(
|
||||
"--nbooks", type=int, default=1, help="Number of codebooks (default: 1)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bits",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Number of bits per code element (default: 16)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Run the decompression/dequant tester rather than benchmarking "
|
||||
"(default: False)",
|
||||
)
|
||||
|
||||
# Parse the arguments
|
||||
args = parser.parse_args()
|
||||
|
||||
# Extract values
|
||||
nbooks = args.nbooks
|
||||
bits = args.bits
|
||||
|
||||
if args.test:
|
||||
dequant_test(4096, torch.tensor((4096,)), nbooks, bits)
|
||||
return
|
||||
|
||||
# Otherwise, benchmark.
|
||||
methods = [
|
||||
ops.aqlm_gemm,
|
||||
dequant_out_scale,
|
||||
generic_dequantize_gemm,
|
||||
optimized_dequantize_gemm,
|
||||
dequant_weight_scale,
|
||||
torch_mult,
|
||||
dequant_no_scale,
|
||||
]
|
||||
|
||||
filename = f"./aqlm_benchmark_{nbooks}x{bits}.csv"
|
||||
print(f"writing benchmarks to file {filename}")
|
||||
with open(filename, "w") as f:
|
||||
sys.stdout = f
|
||||
|
||||
print("m | k | n | n parts", end="")
|
||||
for method in methods:
|
||||
print(f" | {method.__name__.replace('_', ' ')} (µs)", end="")
|
||||
print("")
|
||||
|
||||
# These are reasonable prefill sizes.
|
||||
ksandpartions = (
|
||||
(4096, (4096, 4096, 4096)),
|
||||
(4096, (4096,)),
|
||||
(4096, (11008, 11008)),
|
||||
(11008, (4096,)),
|
||||
)
|
||||
|
||||
# reasonable ranges for m.
|
||||
for m in [
|
||||
1,
|
||||
2,
|
||||
4,
|
||||
8,
|
||||
10,
|
||||
12,
|
||||
14,
|
||||
16,
|
||||
24,
|
||||
32,
|
||||
48,
|
||||
52,
|
||||
56,
|
||||
64,
|
||||
96,
|
||||
112,
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
1536,
|
||||
2048,
|
||||
3072,
|
||||
4096,
|
||||
]:
|
||||
print(f"{m}", file=sys.__stdout__)
|
||||
for ksp in ksandpartions:
|
||||
run_grid(m, ksp[0], torch.tensor(ksp[1]), nbooks, bits, methods)
|
||||
|
||||
sys.stdout = sys.__stdout__
|
||||
|
||||
|
||||
def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, methods):
|
||||
# I didn't see visible improvements from increasing these, but feel free :)
|
||||
num_warmup_trials = 1
|
||||
num_trials = 1
|
||||
|
||||
num_calls = 100
|
||||
|
||||
# warmup.
|
||||
for method in methods:
|
||||
for _ in range(num_warmup_trials):
|
||||
run_timing(
|
||||
num_calls=num_calls,
|
||||
m=m,
|
||||
k=k,
|
||||
parts=parts,
|
||||
nbooks=nbooks,
|
||||
bits=bits,
|
||||
method=method,
|
||||
)
|
||||
|
||||
n = parts.sum().item()
|
||||
print(f"{m} | {k} | {n} | {parts.tolist()}", end="")
|
||||
|
||||
for method in methods:
|
||||
best_time_us = 1e20
|
||||
for _ in range(num_trials):
|
||||
kernel_dur_ms = run_timing(
|
||||
num_calls=num_calls,
|
||||
m=m,
|
||||
k=k,
|
||||
parts=parts,
|
||||
nbooks=nbooks,
|
||||
bits=bits,
|
||||
method=method,
|
||||
)
|
||||
|
||||
kernel_dur_us = 1000 * kernel_dur_ms
|
||||
|
||||
if kernel_dur_us < best_time_us:
|
||||
best_time_us = kernel_dur_us
|
||||
|
||||
print(f" | {kernel_dur_us:.0f}", end="")
|
||||
|
||||
print("")
|
||||
|
||||
|
||||
def run_timing(
|
||||
num_calls: int, m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, method
|
||||
) -> float:
|
||||
n = int(parts.sum().item())
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
input = torch.randn((1, m, k), dtype=torch.float16, device=device)
|
||||
|
||||
code_range = (1 << bits) // 2
|
||||
ingroups = 8
|
||||
|
||||
codes = torch.randint(
|
||||
-code_range,
|
||||
code_range,
|
||||
size=(n, k // ingroups, nbooks),
|
||||
dtype=get_int_dtype(bits),
|
||||
device=device,
|
||||
)
|
||||
|
||||
codebooks = torch.randn(
|
||||
size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
|
||||
dtype=torch.float16,
|
||||
device=device,
|
||||
)
|
||||
|
||||
scales = torch.randn(size=(n, 1, 1, 1), dtype=torch.float16, device=device)
|
||||
|
||||
# for comparison to just a pytorch mult.
|
||||
weights = torch.randn((n, k), dtype=torch.float16, device=device)
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
start_event.record()
|
||||
|
||||
if method is torch_mult:
|
||||
for i in range(num_calls):
|
||||
torch_mult(input, weights, scales)
|
||||
else:
|
||||
for i in range(num_calls):
|
||||
method(input, codes, codebooks, scales, parts, None)
|
||||
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
|
||||
dur_ms = start_event.elapsed_time(end_event) / num_calls
|
||||
return dur_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@ -3,6 +3,8 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from packaging import version
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
|
||||
MINIMUM_BITBLAS_VERSION,
|
||||
)
|
||||
@ -10,7 +12,7 @@ from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
|
||||
try:
|
||||
import bitblas
|
||||
|
||||
if bitblas.__version__ < MINIMUM_BITBLAS_VERSION:
|
||||
if version.parse(bitblas.__version__) < version.parse(MINIMUM_BITBLAS_VERSION):
|
||||
raise ImportError(
|
||||
"bitblas version is wrong. Please "
|
||||
f"install bitblas>={MINIMUM_BITBLAS_VERSION}"
|
||||
|
||||
@ -80,6 +80,11 @@ def bench_run(
|
||||
a, score, topk, renormalize=False
|
||||
)
|
||||
|
||||
ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
|
||||
ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
|
||||
c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
|
||||
|
||||
def run_triton_moe(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
@ -111,6 +116,10 @@ def bench_run(
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
per_act_token: bool,
|
||||
@ -125,6 +134,10 @@ def bench_run(
|
||||
topk_ids,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
per_act_token,
|
||||
a1_scale=None,
|
||||
)
|
||||
@ -136,6 +149,10 @@ def bench_run(
|
||||
w2_q: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
):
|
||||
@ -150,6 +167,10 @@ def bench_run(
|
||||
topk_ids,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
per_act_token,
|
||||
a1_scale=None,
|
||||
)
|
||||
@ -194,6 +215,10 @@ def bench_run(
|
||||
w2_q,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
)
|
||||
@ -231,6 +256,10 @@ def bench_run(
|
||||
"w1_scale": w1_scale,
|
||||
"w2_scale": w2_scale,
|
||||
"per_act_token": per_act_token,
|
||||
"ab_strides1": ab_strides1,
|
||||
"ab_strides2": ab_strides2,
|
||||
"c_strides1": c_strides1,
|
||||
"c_strides2": c_strides2,
|
||||
# cuda graph params
|
||||
"cutlass_graph": cutlass_graph,
|
||||
"triton_graph": triton_graph,
|
||||
@ -289,6 +318,10 @@ def bench_run(
|
||||
w2_q,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
per_act_token,
|
||||
@ -297,7 +330,7 @@ def bench_run(
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
|
||||
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
|
||||
@ -236,6 +236,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
|
||||
a=bt.a,
|
||||
c=None,
|
||||
b_q_weight=w_q,
|
||||
b_bias=None,
|
||||
b_scales=w_s,
|
||||
global_scale=None,
|
||||
b_zeros=w_zp,
|
||||
@ -252,28 +253,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
|
||||
else:
|
||||
assert bt.a.dtype == torch.int8
|
||||
assert bt.wtype == scalar_types.uint4b8
|
||||
|
||||
if bt.w_ch_s is not None:
|
||||
s_ch = bt.w_ch_s.to(torch.float32)
|
||||
else:
|
||||
s_ch = torch.ones(bt.w_ref.shape[1], dtype=torch.float32, device=device)
|
||||
|
||||
if bt.w_tok_s is not None:
|
||||
s_tok = bt.w_tok_s.to(torch.float32)
|
||||
else:
|
||||
s_tok = torch.ones(bt.a.shape[0], dtype=torch.float32, device=device)
|
||||
|
||||
fn = lambda: ops.marlin_qqq_gemm(
|
||||
a=bt.a,
|
||||
b_q_weight=w_q,
|
||||
s_group=w_s,
|
||||
s_tok=s_tok,
|
||||
s_ch=s_ch,
|
||||
workspace=workspace.scratch,
|
||||
size_m=bt.a.shape[0],
|
||||
size_n=bt.w_ref.shape[1],
|
||||
size_k=bt.w_ref.shape[0],
|
||||
)
|
||||
raise NotImplementedError("QQQ is not supported anymore")
|
||||
|
||||
return fn
|
||||
|
||||
@ -304,6 +284,25 @@ def machete_create_bench_fn(
|
||||
)
|
||||
|
||||
|
||||
def cutlass_w4a8_create_bench_fn(
|
||||
bt: BenchmarkTensors, out_type=torch.dtype, schedule=None
|
||||
) -> Callable:
|
||||
w_q = bt.w_q.t().contiguous().t() # make col major
|
||||
w_q = ops.cutlass_encode_and_reorder_int4b(w_q)
|
||||
# expects fp8 scales
|
||||
w_s = ops.cutlass_pack_scale_fp8(bt.w_g_s.to(torch.float8_e4m3fn))
|
||||
|
||||
return lambda: ops.cutlass_w4a8_mm(
|
||||
a=bt.a,
|
||||
b_q=w_q,
|
||||
b_group_scales=w_s,
|
||||
b_group_size=bt.group_size,
|
||||
b_channel_scales=bt.w_ch_s,
|
||||
a_token_scales=bt.w_tok_s,
|
||||
maybe_schedule=schedule,
|
||||
)
|
||||
|
||||
|
||||
# impl
|
||||
|
||||
# bench
|
||||
@ -405,6 +404,20 @@ def bench(
|
||||
)
|
||||
)
|
||||
|
||||
# cutlass w4a8
|
||||
if types.act_type == torch.float8_e4m3fn and group_size == 128:
|
||||
timers.append(
|
||||
bench_fns(
|
||||
label,
|
||||
sub_label,
|
||||
f"cutlass w4a8 ({name_type_string})",
|
||||
[
|
||||
cutlass_w4a8_create_bench_fn(bt, out_type=types.output_type)
|
||||
for bt in benchmark_tensors
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
if sweep_schedules:
|
||||
global _SWEEP_SCHEDULES_RESULTS
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from datetime import datetime
|
||||
@ -22,10 +23,10 @@ from vllm.utils import FlexibleArgumentParser
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
def ensure_divisibility(numerator, denominator):
|
||||
def ensure_divisibility(numerator, denominator, text):
|
||||
"""Ensure that numerator is divisible by the denominator."""
|
||||
assert numerator % denominator == 0, (
|
||||
"intermediate_size {} is not divisible by tp {}.".format(numerator, denominator)
|
||||
assert numerator % denominator == 0, "{} {} is not divisible by tp {}.".format(
|
||||
text, numerator, denominator
|
||||
)
|
||||
|
||||
|
||||
@ -429,7 +430,6 @@ class BenchmarkWorker:
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype_str,
|
||||
is_marlin=False,
|
||||
)
|
||||
else:
|
||||
config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))]
|
||||
@ -542,6 +542,7 @@ def save_configs(
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
block_quant_shape: list[int],
|
||||
save_dir: str,
|
||||
) -> None:
|
||||
dtype_str = get_config_dtype_str(
|
||||
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
|
||||
@ -552,7 +553,8 @@ def save_configs(
|
||||
filename = get_config_file_name(
|
||||
num_experts, shard_intermediate_size // 2, dtype_str, block_quant_shape
|
||||
)
|
||||
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
filename = os.path.join(save_dir, filename)
|
||||
print(f"Writing best config to {filename}...")
|
||||
with open(filename, "w") as f:
|
||||
json.dump(configs, f, indent=4)
|
||||
@ -577,12 +579,10 @@ def main(args: argparse.Namespace):
|
||||
E = config.ffn_config.moe_num_experts
|
||||
topk = config.ffn_config.moe_top_k
|
||||
intermediate_size = config.ffn_config.ffn_hidden_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
elif config.architectures[0] == "JambaForCausalLM":
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
elif config.architectures[0] in (
|
||||
"DeepseekV3ForCausalLM",
|
||||
"DeepseekV2ForCausalLM",
|
||||
@ -591,17 +591,14 @@ def main(args: argparse.Namespace):
|
||||
E = config.n_routed_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
elif config.architectures[0] in ("Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"):
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
elif config.architectures[0] in ("HunYuanMoEV1ForCausalLM"):
|
||||
E = config.num_experts
|
||||
topk = config.moe_topk[0]
|
||||
intermediate_size = config.moe_intermediate_size[0]
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
else:
|
||||
# Support for llama4
|
||||
config = config.get_text_config()
|
||||
@ -609,8 +606,14 @@ def main(args: argparse.Namespace):
|
||||
E = config.num_local_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.intermediate_size
|
||||
enable_ep = bool(args.enable_expert_parallel)
|
||||
if enable_ep:
|
||||
ensure_divisibility(E, args.tp_size, "Number of experts")
|
||||
E = E // args.tp_size
|
||||
shard_intermediate_size = 2 * intermediate_size
|
||||
else:
|
||||
ensure_divisibility(intermediate_size, args.tp_size, "intermediate_size")
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
ensure_divisibility(intermediate_size, args.tp_size)
|
||||
hidden_size = config.hidden_size
|
||||
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
|
||||
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
||||
@ -706,6 +709,7 @@ def main(args: argparse.Namespace):
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
block_quant_shape,
|
||||
args.save_dir,
|
||||
)
|
||||
end = time.time()
|
||||
print(f"Tuning took {end - start:.2f} seconds")
|
||||
@ -742,10 +746,14 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--tp-size", "-tp", "--tensor-parallel-size", type=int, default=2
|
||||
)
|
||||
parser.add_argument("--enable-expert-parallel", "-enable-ep", action="store_true")
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto"
|
||||
)
|
||||
parser.add_argument("--use-deep-gemm", action="store_true")
|
||||
parser.add_argument(
|
||||
"--save-dir", type=str, default="./", help="Directory to save tuned results"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--batch-size", type=int, nargs="+", required=False)
|
||||
parser.add_argument("--tune", action="store_true")
|
||||
|
||||
328
benchmarks/kernels/benchmark_mrope.py
Normal file
328
benchmarks/kernels/benchmark_mrope.py
Normal file
@ -0,0 +1,328 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# This script benchmarks the mrope kernel (mainly for Qwen2VL and Qwen2.5VL models).
|
||||
# It generates test data, runs benchmarks, and saves results to a CSV file.
|
||||
#
|
||||
# The CSV file (named with current date/time) contains these columns:
|
||||
# model_name, tp_size, num_tokens, num_heads, num_kv_heads, head_dim, max_position,
|
||||
# rope_theta, is_neox_style, rope_scaling, dtype, torch_mean, torch_median, torch_p99,
|
||||
# torch_min, torch_max, triton_mean, triton_median, triton_p99, triton_min, triton_max,
|
||||
# speedup
|
||||
#
|
||||
# == Usage Examples ==
|
||||
#
|
||||
# Single model benchmark:
|
||||
# python3 benchmark_mrope.py --model-name Qwen/Qwen2-VL-7B-Instruct --tp-size 1 \
|
||||
# --warmup-iter 10 --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024
|
||||
#
|
||||
# All models benchmark:
|
||||
# python3 benchmark_mrope.py --model-name "" --tp-size 1 --warmup-iter 10 \
|
||||
# --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024
|
||||
#
|
||||
# All models with different TP sizes:
|
||||
# python3 benchmark_mrope.py --model-name "" --tp-size 1 2 4 8 --warmup-iter 10 \
|
||||
# --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024
|
||||
#
|
||||
# All models with different token counts:
|
||||
# python3 benchmark_mrope.py --model-name "" --tp-size 1 --warmup-iter 10 \
|
||||
# --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024 4096 16384
|
||||
import csv
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.config import get_config
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
def generate_test_data(
|
||||
num_tokens: int,
|
||||
num_q_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
max_position_embeddings: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
):
|
||||
"""Generate test data for given configuration."""
|
||||
# Create 2D positions (3, num_tokens) for multimodal case
|
||||
positions = torch.randint(
|
||||
0, max_position_embeddings // 4, (3, num_tokens), device=device
|
||||
)
|
||||
|
||||
# Create query and key tensors
|
||||
query = torch.randn(num_tokens, num_q_heads * head_size, dtype=dtype, device=device)
|
||||
key = torch.randn(num_tokens, num_kv_heads * head_size, dtype=dtype, device=device)
|
||||
|
||||
return positions, query, key
|
||||
|
||||
|
||||
def calculate_stats(times: list[float]) -> dict[str, float]:
|
||||
"""Calculate statistics from a list of times."""
|
||||
times_array = np.array(times)
|
||||
return {
|
||||
"mean": np.mean(times_array),
|
||||
"median": np.median(times_array),
|
||||
"p99": np.percentile(times_array, 99),
|
||||
"min": np.min(times_array),
|
||||
"max": np.max(times_array),
|
||||
}
|
||||
|
||||
|
||||
def benchmark_mrope(
|
||||
model_name: str,
|
||||
num_tokens: int,
|
||||
head_dim: int,
|
||||
tp_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
max_position: int = 8192,
|
||||
rope_theta: float = 10000,
|
||||
is_neox_style: bool = True,
|
||||
rope_scaling: dict[str, Any] = None,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
seed: int = 0,
|
||||
warmup_iter: int = 10,
|
||||
benchmark_iter: int = 100,
|
||||
csv_writer=None,
|
||||
):
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
# the parameters to compute the q k v size based on tp_size
|
||||
mrope_helper_class = get_rope(
|
||||
head_size=head_dim,
|
||||
rotary_dim=head_dim,
|
||||
max_position=max_position,
|
||||
base=rope_theta,
|
||||
is_neox_style=is_neox_style,
|
||||
rope_scaling=rope_scaling,
|
||||
dtype=dtype,
|
||||
).to(device=device)
|
||||
|
||||
print(80 * "=")
|
||||
print(
|
||||
f"Evaluating model: {model_name} "
|
||||
f"with tp_size: {tp_size} "
|
||||
f"and num_tokens: {num_tokens}, "
|
||||
f"dtype: {dtype}"
|
||||
)
|
||||
|
||||
# create q k v input tensors
|
||||
# create rotary pos emb input tensors
|
||||
positions, query, key = generate_test_data(
|
||||
num_tokens, num_heads, num_kv_heads, head_dim, max_position, dtype, device
|
||||
)
|
||||
|
||||
# Warm up
|
||||
for _ in range(warmup_iter):
|
||||
mrope_helper_class.forward_native(
|
||||
positions,
|
||||
query.clone(),
|
||||
key.clone(),
|
||||
)
|
||||
|
||||
mrope_helper_class.forward_cuda(
|
||||
positions,
|
||||
query.clone(),
|
||||
key.clone(),
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Time reference implementation
|
||||
torch_times = []
|
||||
for _ in range(benchmark_iter):
|
||||
query_clone = query.clone()
|
||||
key_clone = key.clone()
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
|
||||
mrope_helper_class.forward_native(
|
||||
positions,
|
||||
query_clone,
|
||||
key_clone,
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
torch_times.append(time.time() - start_time)
|
||||
|
||||
# Time triton kernel implementation
|
||||
triton_times = []
|
||||
for _ in range(benchmark_iter):
|
||||
query_clone = query.clone()
|
||||
key_clone = key.clone()
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
mrope_helper_class.forward_cuda(
|
||||
positions,
|
||||
query_clone,
|
||||
key_clone,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
triton_times.append(time.time() - start_time)
|
||||
|
||||
# Calculate statistics
|
||||
torch_stats = calculate_stats(torch_times)
|
||||
triton_stats = calculate_stats(triton_times)
|
||||
print(f"\nPerformance for config ({num_tokens}, {num_heads}, {num_kv_heads}):")
|
||||
|
||||
print(
|
||||
f"Torch implementation: "
|
||||
f"mean={torch_stats['mean']:.8f}s, "
|
||||
f"median={torch_stats['median']:.8f}s, "
|
||||
f"p99={torch_stats['p99']:.8f}s"
|
||||
)
|
||||
|
||||
print(
|
||||
f"Triton implementation: "
|
||||
f"mean={triton_stats['mean']:.8f}s, "
|
||||
f"median={triton_stats['median']:.8f}s, "
|
||||
f"p99={triton_stats['p99']:.8f}s"
|
||||
)
|
||||
|
||||
print(
|
||||
f"Triton Speedup over Torch: {torch_stats['mean'] / triton_stats['mean']:.8f}x"
|
||||
)
|
||||
|
||||
# Write to CSV
|
||||
if csv_writer:
|
||||
row = [
|
||||
model_name,
|
||||
tp_size,
|
||||
num_tokens,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
max_position,
|
||||
rope_theta,
|
||||
is_neox_style,
|
||||
str(rope_scaling),
|
||||
str(dtype).split(".")[-1],
|
||||
torch_stats["mean"],
|
||||
torch_stats["median"],
|
||||
torch_stats["p99"],
|
||||
torch_stats["min"],
|
||||
torch_stats["max"],
|
||||
triton_stats["mean"],
|
||||
triton_stats["median"],
|
||||
triton_stats["p99"],
|
||||
triton_stats["min"],
|
||||
triton_stats["max"],
|
||||
torch_stats["mean"] / triton_stats["mean"], # speedup
|
||||
]
|
||||
csv_writer.writerow(row)
|
||||
|
||||
return torch_stats, triton_stats
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark the rotary embedding kernels."
|
||||
)
|
||||
parser.add_argument("--model-name", type=str, default="")
|
||||
parser.add_argument("--tp-size", type=int, default=1)
|
||||
parser.add_argument("--warmup-iter", type=int, default=10)
|
||||
parser.add_argument("--benchmark-iter", type=int, default=100)
|
||||
parser.add_argument("--dtype", type=str, choices=["bfloat16"], default="bfloat16")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--num-tokens", type=int, nargs="+", required=False)
|
||||
parser.add_argument("--trust-remote-code", action="store_true")
|
||||
parser.add_argument("--output-csv", type=str, default="mrope_benchmark_results.csv")
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
# Create CSV file for results
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
csv_filename = f"{os.path.splitext(args.output_csv)[0]}_{timestamp}.csv"
|
||||
|
||||
with open(csv_filename, "w", newline="") as csvfile:
|
||||
csv_writer = csv.writer(csvfile)
|
||||
# Write header
|
||||
header = [
|
||||
"model_name",
|
||||
"tp_size",
|
||||
"num_tokens",
|
||||
"num_heads",
|
||||
"num_kv_heads",
|
||||
"head_dim",
|
||||
"max_position",
|
||||
"rope_theta",
|
||||
"is_neox_style",
|
||||
"rope_scaling",
|
||||
"dtype",
|
||||
"torch_mean",
|
||||
"torch_median",
|
||||
"torch_p99",
|
||||
"torch_min",
|
||||
"torch_max",
|
||||
"triton_mean",
|
||||
"triton_median",
|
||||
"triton_p99",
|
||||
"triton_min",
|
||||
"triton_max",
|
||||
"speedup",
|
||||
]
|
||||
csv_writer.writerow(header)
|
||||
|
||||
model_tp_dict = {}
|
||||
if args.model_name == "":
|
||||
model_tp_dict = {
|
||||
"Qwen/Qwen2-VL-2B-Instruct": [1],
|
||||
"Qwen/Qwen2-VL-7B-Instruct": [1],
|
||||
"Qwen/Qwen2-VL-72B-Instruct": [2, 4, 8],
|
||||
"Qwen/Qwen2.5-VL-3B-Instruct": [1, 2, 4, 8],
|
||||
"Qwen/Qwen2.5-VL-7B-Instruct": [1, 2, 4, 8],
|
||||
"Qwen/Qwen2.5-VL-72B-Instruct": [2, 4, 8],
|
||||
}
|
||||
else:
|
||||
model_tp_dict[args.model_name] = [args.tp_size]
|
||||
|
||||
if args.num_tokens is None:
|
||||
num_tokens_list = [2**i for i in range(0, 18)]
|
||||
else:
|
||||
num_tokens_list = args.num_tokens
|
||||
|
||||
for model_name, tp_list in model_tp_dict.items():
|
||||
config = get_config(model_name, trust_remote_code=args.trust_remote_code)
|
||||
for tp_size in tp_list:
|
||||
# get the model config
|
||||
total_num_kv_heads = config.num_key_value_heads
|
||||
total_num_heads = config.num_attention_heads
|
||||
num_heads = total_num_heads // tp_size
|
||||
num_kv_heads = max(1, total_num_kv_heads // tp_size)
|
||||
head_dim = config.hidden_size // total_num_heads
|
||||
q_size = num_heads * head_dim
|
||||
kv_size = num_kv_heads * head_dim
|
||||
is_neox_style = True
|
||||
rope_theta = config.rope_theta
|
||||
max_position = config.max_position_embeddings
|
||||
|
||||
for num_tokens in num_tokens_list:
|
||||
benchmark_mrope(
|
||||
model_name=model_name,
|
||||
num_tokens=num_tokens,
|
||||
head_dim=head_dim,
|
||||
tp_size=tp_size,
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
max_position=max_position,
|
||||
rope_theta=rope_theta,
|
||||
is_neox_style=is_neox_style,
|
||||
rope_scaling=config.rope_scaling,
|
||||
dtype=getattr(torch, args.dtype),
|
||||
seed=args.seed,
|
||||
warmup_iter=args.warmup_iter,
|
||||
benchmark_iter=args.benchmark_iter,
|
||||
csv_writer=csv_writer,
|
||||
)
|
||||
|
||||
print(f"Benchmark results saved to {csv_filename}")
|
||||
77
benchmarks/kernels/benchmark_silu_mul_fp8_quant.py
Normal file
77
benchmarks/kernels/benchmark_silu_mul_fp8_quant.py
Normal file
@ -0,0 +1,77 @@
|
||||
#!/usr/bin/env python3
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
silu_mul_fp8_quant_deep_gemm,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def benchmark(E, T, H, G=128, runs=50):
|
||||
current_platform.seed_everything(42)
|
||||
y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda")
|
||||
tokens_per_expert = torch.randint(
|
||||
T // 2, T, size=(E,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
# Warmup
|
||||
for _ in range(10):
|
||||
silu_mul_fp8_quant_deep_gemm(y, tokens_per_expert, group_size=G)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Benchmark
|
||||
torch.cuda.synchronize()
|
||||
start = time.perf_counter()
|
||||
for _ in range(runs):
|
||||
silu_mul_fp8_quant_deep_gemm(y, tokens_per_expert, group_size=G)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
avg_time = (time.perf_counter() - start) / runs * 1000
|
||||
|
||||
# Calculate actual work done (only count valid tokens)
|
||||
actual_tokens = tokens_per_expert.sum().item()
|
||||
actual_elements = actual_tokens * H
|
||||
|
||||
# GFLOPS: operations per element = exp + 3 muls + 1 div + quantization ops ≈ 8 ops
|
||||
ops_per_element = 8
|
||||
total_ops = actual_elements * ops_per_element
|
||||
gflops = total_ops / (avg_time / 1000) / 1e9
|
||||
|
||||
# Memory bandwidth: bfloat16 inputs (2 bytes), fp8 output (1 byte), scales (4 bytes)
|
||||
input_bytes = actual_tokens * 2 * H * 2 # 2*H bfloat16 inputs
|
||||
output_bytes = actual_tokens * H * 1 # H fp8 outputs
|
||||
scale_bytes = actual_tokens * (H // G) * 4 # scales in float32
|
||||
total_bytes = input_bytes + output_bytes + scale_bytes
|
||||
memory_bw = total_bytes / (avg_time / 1000) / 1e9
|
||||
|
||||
return avg_time, gflops, memory_bw
|
||||
|
||||
|
||||
configs = [
|
||||
(8, 32, 1024),
|
||||
(16, 64, 2048),
|
||||
(32, 128, 4096),
|
||||
# DeepSeekV3 Configs
|
||||
(256, 16, 7168),
|
||||
(256, 32, 7168),
|
||||
(256, 64, 7168),
|
||||
(256, 128, 7168),
|
||||
(256, 256, 7168),
|
||||
(256, 512, 7168),
|
||||
(256, 1024, 7168),
|
||||
]
|
||||
|
||||
print(f"GPU: {torch.cuda.get_device_name()}")
|
||||
print(f"{'Config':<20} {'Time(ms)':<10} {'GFLOPS':<10} {'GB/s':<10}")
|
||||
print("-" * 50)
|
||||
|
||||
for E, T, H in configs:
|
||||
try:
|
||||
time_ms, gflops, gbps = benchmark(E, T, H)
|
||||
print(f"E={E:3d},T={T:4d},H={H:4d} {time_ms:8.3f} {gflops:8.1f} {gbps:8.1f}")
|
||||
except Exception:
|
||||
print(f"E={E:3d},T={T:4d},H={H:4d} FAILED")
|
||||
@ -1,254 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import csv
|
||||
import os
|
||||
import random
|
||||
from datetime import datetime
|
||||
|
||||
import flashinfer
|
||||
import torch
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
|
||||
# KV Cache Layout for TRT-LLM
|
||||
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
|
||||
|
||||
|
||||
def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||
finfo = torch.finfo(dtype)
|
||||
min_val, max_val = x.aminmax()
|
||||
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
||||
scale = finfo.max / amax * 0.1
|
||||
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
|
||||
return x_scl_sat.to(dtype), scale.float().reciprocal()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def benchmark_decode(
|
||||
num_seqs,
|
||||
max_seq_len,
|
||||
page_size=16,
|
||||
dtype=torch.bfloat16,
|
||||
kv_layout="HND",
|
||||
num_kv_heads=8,
|
||||
kv_cache_dtype="auto",
|
||||
head_dim=128,
|
||||
warmup=10,
|
||||
trials=20,
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
device = "cuda"
|
||||
torch.manual_seed(0)
|
||||
|
||||
# Currently only HEAD_GRP_SIZE == 8 is supported
|
||||
HEAD_GRP_SIZE = 8
|
||||
MAX_SEQ_LEN = max_seq_len
|
||||
|
||||
# large number to reduce kv_cache reuse
|
||||
NUM_BLOCKS = int(256000 / page_size)
|
||||
|
||||
workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8, device=device)
|
||||
|
||||
# For decode, batch_size is num_decode_token
|
||||
num_qo_heads = num_kv_heads * HEAD_GRP_SIZE
|
||||
sm_scale = float(1.0 / (head_dim**0.5))
|
||||
q = torch.randn(num_seqs, num_qo_heads, head_dim, device=device, dtype=dtype)
|
||||
kv_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
||||
|
||||
max_kv_len = max(kv_lens)
|
||||
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int, device=device)
|
||||
max_num_blocks_per_seq = (max_kv_len + page_size - 1) // page_size
|
||||
|
||||
block_tables = torch.randint(
|
||||
0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
|
||||
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim)
|
||||
kv_cache = torch.randn(size=kv_cache_shape, device=device, dtype=dtype)
|
||||
k_scale = v_scale = 1.0
|
||||
|
||||
if kv_cache_dtype.startswith("fp8"):
|
||||
kv_cache, _ = to_float8(kv_cache)
|
||||
|
||||
output_trtllm = torch.empty(q.shape, dtype=dtype)
|
||||
|
||||
# Benchmark TRT decode
|
||||
def trt_decode():
|
||||
return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
|
||||
q,
|
||||
kv_cache,
|
||||
workspace_buffer,
|
||||
block_tables,
|
||||
kv_lens_tensor,
|
||||
max_kv_len,
|
||||
bmm1_scale=k_scale * sm_scale,
|
||||
bmm2_scale=v_scale,
|
||||
out=output_trtllm,
|
||||
)
|
||||
|
||||
def time_fn(fn, warmup=10, trials=20):
|
||||
torch.cuda.synchronize()
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
times = []
|
||||
for i in range(warmup):
|
||||
fn()
|
||||
for i in range(trials):
|
||||
start.record()
|
||||
fn()
|
||||
end.record()
|
||||
torch.cuda.synchronize()
|
||||
times.append(start.elapsed_time(end)) # ms
|
||||
return sum(times) / len(times), torch.std(torch.tensor(times))
|
||||
|
||||
# TRT Decode
|
||||
trt_mean, trt_std = time_fn(trt_decode)
|
||||
|
||||
kv_indptr = [0]
|
||||
kv_indices = []
|
||||
kv_last_page_lens = []
|
||||
for i in range(num_seqs):
|
||||
seq_len = kv_lens[i]
|
||||
assert seq_len > 0
|
||||
num_blocks = (seq_len + page_size - 1) // page_size
|
||||
kv_indices.extend(block_tables[i, :num_blocks])
|
||||
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
||||
kv_last_page_len = seq_len % page_size
|
||||
if kv_last_page_len == 0:
|
||||
kv_last_page_len = page_size
|
||||
kv_last_page_lens.append(kv_last_page_len)
|
||||
|
||||
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
||||
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||
|
||||
output_baseline = torch.empty(q.shape, dtype=dtype)
|
||||
|
||||
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer,
|
||||
kv_layout,
|
||||
use_tensor_cores=((num_qo_heads // num_kv_heads) > 4),
|
||||
)
|
||||
|
||||
wrapper.plan(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
page_size,
|
||||
"NONE",
|
||||
q_data_type=dtype,
|
||||
kv_data_type=torch.float8_e4m3fn if kv_cache_dtype.startswith("fp8") else dtype,
|
||||
)
|
||||
|
||||
def baseline_decode():
|
||||
return wrapper.run(q, kv_cache, sm_scale, k_scale, v_scale, output_baseline)
|
||||
|
||||
baseline_mean, baseline_std = time_fn(baseline_decode)
|
||||
|
||||
# Calculate percentage speedup (positive means TRT is faster)
|
||||
speedup_percent = (baseline_mean - trt_mean) / baseline_mean
|
||||
|
||||
print(
|
||||
f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.3f}\t{trt_std.item():.3f}"
|
||||
f"\t{baseline_mean:.3f}\t{baseline_std.item():.3f}\t{speedup_percent:.3f}"
|
||||
)
|
||||
|
||||
# Return results for CSV writing
|
||||
return {
|
||||
"num_seqs": num_seqs,
|
||||
"trt_mean": trt_mean,
|
||||
"trt_std": trt_std.item(),
|
||||
"baseline_mean": baseline_mean,
|
||||
"baseline_std": baseline_std.item(),
|
||||
"speedup_percent": speedup_percent,
|
||||
"q_dtype": str(dtype),
|
||||
"kv_cache_dtype": kv_cache_dtype,
|
||||
"page_size": page_size,
|
||||
"num_kv_heads": num_kv_heads,
|
||||
"head_dim": head_dim,
|
||||
"max_seq_len": max_seq_len,
|
||||
}
|
||||
|
||||
|
||||
def write_results_to_csv(results, filename=None):
|
||||
"""Write benchmark results to CSV file."""
|
||||
if filename is None:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv"
|
||||
|
||||
fieldnames = [
|
||||
"num_seqs",
|
||||
"trt_mean",
|
||||
"trt_std",
|
||||
"baseline_mean",
|
||||
"baseline_std",
|
||||
"speedup_percent",
|
||||
"q_dtype",
|
||||
"kv_cache_dtype",
|
||||
"page_size",
|
||||
"num_kv_heads",
|
||||
"head_dim",
|
||||
"max_seq_len",
|
||||
]
|
||||
|
||||
file_exists = os.path.exists(filename)
|
||||
|
||||
with open(filename, "a", newline="") as csvfile:
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||
|
||||
if not file_exists:
|
||||
writer.writeheader()
|
||||
|
||||
for result in results:
|
||||
writer.writerow(result)
|
||||
|
||||
print(f"Results written to {filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
num_seqs = [1, 4, 8, 16, 32, 64, 128, 256]
|
||||
max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
|
||||
all_results = []
|
||||
|
||||
print(
|
||||
"Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, "
|
||||
"output_dtype: bfloat16"
|
||||
)
|
||||
print(
|
||||
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t"
|
||||
"baseline_std\tspeedup_percent"
|
||||
)
|
||||
for max_seq_len in max_seq_lens:
|
||||
for bs in num_seqs:
|
||||
result = benchmark_decode(
|
||||
bs,
|
||||
max_seq_len,
|
||||
dtype=torch.bfloat16,
|
||||
kv_cache_dtype="auto",
|
||||
)
|
||||
all_results.append(result)
|
||||
|
||||
print(
|
||||
"Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8, "
|
||||
"output_dtype: bfloat16"
|
||||
)
|
||||
print(
|
||||
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t"
|
||||
"baseline_std\tspeedup_percent"
|
||||
)
|
||||
for max_seq_len in max_seq_lens:
|
||||
for bs in num_seqs:
|
||||
result = benchmark_decode(
|
||||
bs,
|
||||
max_seq_len,
|
||||
dtype=torch.bfloat16,
|
||||
kv_cache_dtype="fp8",
|
||||
)
|
||||
all_results.append(result)
|
||||
|
||||
# Write all results to CSV
|
||||
write_results_to_csv(all_results)
|
||||
292
benchmarks/kernels/benchmark_trtllm_decode_attention.py
Normal file
292
benchmarks/kernels/benchmark_trtllm_decode_attention.py
Normal file
@ -0,0 +1,292 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import csv
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import flashinfer
|
||||
import torch
|
||||
|
||||
from vllm.utils import round_up
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
FP8_DTYPE = torch.float8_e4m3fn
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
|
||||
def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||
finfo = torch.finfo(dtype)
|
||||
min_val, max_val = x.aminmax()
|
||||
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
||||
scale = finfo.max / amax * 0.1
|
||||
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
|
||||
return x_scl_sat.to(dtype), scale.float().reciprocal()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def benchmark_decode(
|
||||
dtype: torch.dtype,
|
||||
quant_dtypes: tuple[
|
||||
Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype]
|
||||
],
|
||||
batch_size: int,
|
||||
max_seq_len: int,
|
||||
num_heads: tuple[int, int] = (64, 8),
|
||||
head_size: int = 128,
|
||||
kv_layout: str = "HND",
|
||||
block_size: int = 16,
|
||||
warmup: int = 10,
|
||||
trials: int = 20,
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
torch.manual_seed(0)
|
||||
|
||||
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
|
||||
q_quant_dtype = q_quant_dtype or dtype
|
||||
kv_quant_dtype = kv_quant_dtype or dtype
|
||||
o_quant_dtype = o_quant_dtype or dtype
|
||||
|
||||
num_qo_heads, num_kv_heads = num_heads
|
||||
assert num_qo_heads % num_kv_heads == 0
|
||||
|
||||
sm_scale = float(1.0 / (head_size**0.5))
|
||||
|
||||
# large number to reduce kv_cache reuse
|
||||
NUM_BLOCKS = int(256000 / block_size)
|
||||
|
||||
kv_cache_shape = None
|
||||
if kv_layout == "NHD":
|
||||
kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
|
||||
elif kv_layout == "HND":
|
||||
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
|
||||
else:
|
||||
raise ValueError(f"Invalid kv_layout: {kv_layout}")
|
||||
|
||||
# Always using 1.0 scale to reflect the real perf in benchmarking
|
||||
q_scale = 1.0
|
||||
ref_query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype)
|
||||
if q_quant_dtype == FP8_DTYPE:
|
||||
query, _ = to_float8(ref_query)
|
||||
else:
|
||||
query = ref_query
|
||||
|
||||
kv_lens = torch.randint(1, max_seq_len, (batch_size,), dtype=torch.int32)
|
||||
kv_lens[-1] = max_seq_len
|
||||
|
||||
seq_lens = kv_lens
|
||||
max_seq_len = torch.max(seq_lens).item()
|
||||
|
||||
# Always using 1.0 scale to reflect the real perf in benchmarking
|
||||
k_scale = v_scale = 1.0
|
||||
ref_kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
||||
if kv_quant_dtype == FP8_DTYPE:
|
||||
kv_cache, _ = to_float8(ref_kv_cache)
|
||||
else:
|
||||
kv_cache = ref_kv_cache
|
||||
|
||||
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(
|
||||
0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
kv_indptr = [0]
|
||||
kv_indices = []
|
||||
kv_last_page_lens = []
|
||||
for i in range(batch_size):
|
||||
seq_len = seq_lens[i]
|
||||
assert seq_len > 0
|
||||
num_blocks = (seq_len + block_size - 1) // block_size
|
||||
kv_indices.extend(block_tables[i, :num_blocks])
|
||||
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
||||
kv_last_page_len = seq_len % block_size
|
||||
if kv_last_page_len == 0:
|
||||
kv_last_page_len = block_size
|
||||
kv_last_page_lens.append(kv_last_page_len)
|
||||
|
||||
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
||||
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||
workspace_buffer = torch.zeros(1024 * 1024 * 1024, dtype=torch.int8)
|
||||
|
||||
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer,
|
||||
kv_layout,
|
||||
use_tensor_cores=True,
|
||||
)
|
||||
wrapper.plan(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
"NONE",
|
||||
sm_scale=sm_scale,
|
||||
q_data_type=dtype,
|
||||
kv_data_type=dtype,
|
||||
)
|
||||
|
||||
def time_fn(fn, warmup=10, trials=20):
|
||||
torch.cuda.synchronize()
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
times = []
|
||||
for i in range(warmup):
|
||||
fn()
|
||||
for i in range(trials):
|
||||
start.record()
|
||||
fn()
|
||||
end.record()
|
||||
torch.cuda.synchronize()
|
||||
times.append(start.elapsed_time(end)) # ms
|
||||
return sum(times) / len(times), torch.std(torch.tensor(times))
|
||||
|
||||
o_scale = 1.0
|
||||
o_sf_scale = None
|
||||
output_baseline = torch.empty(ref_query.shape, dtype=dtype)
|
||||
if o_quant_dtype == FP4_DTYPE:
|
||||
o_sf_scale = 500.0
|
||||
output_trtllm = flashinfer.utils.FP4Tensor(
|
||||
torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8),
|
||||
torch.empty(
|
||||
(
|
||||
round_up(query.shape[0], 128),
|
||||
round_up(query.shape[1] * query.shape[2] // 16, 4),
|
||||
),
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
)
|
||||
else:
|
||||
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||
|
||||
def baseline_decode():
|
||||
return wrapper.run(
|
||||
ref_query,
|
||||
ref_kv_cache,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
out=output_baseline,
|
||||
)
|
||||
|
||||
def trtllm_decode():
|
||||
return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
|
||||
query=query,
|
||||
kv_cache=kv_cache,
|
||||
workspace_buffer=workspace_buffer,
|
||||
block_tables=block_tables,
|
||||
seq_lens=seq_lens,
|
||||
max_seq_len=max_seq_len,
|
||||
bmm1_scale=q_scale * k_scale * sm_scale,
|
||||
bmm2_scale=v_scale / o_scale,
|
||||
o_sf_scale=o_sf_scale,
|
||||
out=output_trtllm,
|
||||
)
|
||||
|
||||
baseline_mean, baseline_std = time_fn(baseline_decode)
|
||||
trtllm_mean, trtllm_std = time_fn(trtllm_decode)
|
||||
|
||||
# Calculate percentage speedup (positive means TRT is faster)
|
||||
speedup_percent = (baseline_mean - trtllm_mean) / baseline_mean
|
||||
|
||||
print(
|
||||
f"\t{batch_size}\t{max_seq_len}\t{trtllm_mean:.3f}\t{trtllm_std.item():.3f}"
|
||||
f"\t{baseline_mean:.3f}\t{baseline_std.item():.3f}\t{speedup_percent:.3f}"
|
||||
)
|
||||
|
||||
# Return results for CSV writing
|
||||
return {
|
||||
"batch_size": batch_size,
|
||||
"trtllm_mean": trtllm_mean,
|
||||
"trtllm_std": trtllm_std.item(),
|
||||
"baseline_mean": baseline_mean,
|
||||
"baseline_std": baseline_std.item(),
|
||||
"speedup_percent": speedup_percent,
|
||||
"q_dtype": str(q_quant_dtype),
|
||||
"kv_cache_dtype": str(kv_quant_dtype),
|
||||
"output_dtype": str(o_quant_dtype),
|
||||
"block_size": block_size,
|
||||
"num_kv_heads": num_kv_heads,
|
||||
"head_size": head_size,
|
||||
"max_seq_len": max_seq_len,
|
||||
}
|
||||
|
||||
|
||||
def write_results_to_csv(results, filename=None):
|
||||
"""Write benchmark results to CSV file."""
|
||||
if filename is None:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv"
|
||||
|
||||
fieldnames = [
|
||||
"batch_size",
|
||||
"trtllm_mean",
|
||||
"trtllm_std",
|
||||
"baseline_mean",
|
||||
"baseline_std",
|
||||
"speedup_percent",
|
||||
"q_dtype",
|
||||
"kv_cache_dtype",
|
||||
"output_dtype",
|
||||
"block_size",
|
||||
"num_kv_heads",
|
||||
"head_size",
|
||||
"max_seq_len",
|
||||
]
|
||||
|
||||
file_exists = os.path.exists(filename)
|
||||
|
||||
with open(filename, "a", newline="") as csvfile:
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||
|
||||
if not file_exists:
|
||||
writer.writeheader()
|
||||
|
||||
for result in results:
|
||||
writer.writerow(result)
|
||||
|
||||
print(f"Results written to {filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256]
|
||||
max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
|
||||
all_results = []
|
||||
|
||||
dtype = torch.bfloat16
|
||||
quant_dtypes = [
|
||||
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
|
||||
(None, None, None),
|
||||
(None, FP8_DTYPE, None),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
|
||||
]
|
||||
|
||||
for quant_dtype in quant_dtypes:
|
||||
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtype
|
||||
q_quant_dtype = q_quant_dtype or dtype
|
||||
kv_quant_dtype = kv_quant_dtype or dtype
|
||||
o_quant_dtype = o_quant_dtype or dtype
|
||||
|
||||
print(
|
||||
f"Running benchmark for q_dtype = {q_quant_dtype}, "
|
||||
f"kv_cache_dtype: {kv_quant_dtype}, "
|
||||
f"output_dtype: {o_quant_dtype}"
|
||||
)
|
||||
print(
|
||||
"\tbatch_size\tmax_seq_len\ttrtllm_mean\ttrtllm_std\tbaseline_mean\t"
|
||||
"baseline_std\tspeedup_percent"
|
||||
)
|
||||
for max_seq_len in max_seq_lens:
|
||||
for bs in batch_sizes:
|
||||
result = benchmark_decode(
|
||||
dtype=dtype,
|
||||
quant_dtypes=quant_dtype,
|
||||
batch_size=bs,
|
||||
max_seq_len=max_seq_len,
|
||||
)
|
||||
all_results.append(result)
|
||||
|
||||
# Write all results to CSV
|
||||
write_results_to_csv(all_results)
|
||||
307
benchmarks/kernels/benchmark_trtllm_prefill_attention.py
Normal file
307
benchmarks/kernels/benchmark_trtllm_prefill_attention.py
Normal file
@ -0,0 +1,307 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import csv
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import flashinfer
|
||||
import torch
|
||||
|
||||
from vllm.utils import round_up
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
FP8_DTYPE = torch.float8_e4m3fn
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
|
||||
def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||
finfo = torch.finfo(dtype)
|
||||
min_val, max_val = x.aminmax()
|
||||
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
||||
scale = finfo.max / amax * 0.1
|
||||
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
|
||||
return x_scl_sat.to(dtype), scale.float().reciprocal()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def benchmark_prefill(
|
||||
dtype: torch.dtype,
|
||||
quant_dtypes: tuple[
|
||||
Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype]
|
||||
],
|
||||
batch_size: int,
|
||||
max_seq_len: int,
|
||||
num_heads: tuple[int, int] = (64, 8),
|
||||
head_size: int = 128,
|
||||
kv_layout: str = "HND",
|
||||
block_size: int = 16,
|
||||
warmup: int = 10,
|
||||
trials: int = 20,
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
torch.manual_seed(0)
|
||||
|
||||
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
|
||||
q_quant_dtype = q_quant_dtype or dtype
|
||||
kv_quant_dtype = kv_quant_dtype or dtype
|
||||
o_quant_dtype = o_quant_dtype or dtype
|
||||
|
||||
max_q_len = max_kv_len = max_seq_len
|
||||
|
||||
num_qo_heads, num_kv_heads = num_heads
|
||||
assert num_qo_heads % num_kv_heads == 0
|
||||
|
||||
sm_scale = float(1.0 / (head_size**0.5))
|
||||
|
||||
# large number to reduce kv_cache reuse
|
||||
NUM_BLOCKS = int(256000 / block_size)
|
||||
|
||||
kv_cache_shape = None
|
||||
if kv_layout == "NHD":
|
||||
kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
|
||||
elif kv_layout == "HND":
|
||||
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
|
||||
else:
|
||||
raise ValueError(f"Invalid kv_layout: {kv_layout}")
|
||||
|
||||
q_lens = torch.randint(1, max_q_len, (batch_size,), dtype=torch.int32)
|
||||
q_lens[-1] = max_q_len
|
||||
q_indptr = torch.cat(
|
||||
[
|
||||
torch.tensor([0], dtype=torch.int32),
|
||||
torch.cumsum(q_lens, dim=0, dtype=torch.int32),
|
||||
]
|
||||
)
|
||||
|
||||
# Always using 1.0 scale to reflect the real perf in benchmarking
|
||||
q_scale = 1.0
|
||||
ref_query = torch.randn(
|
||||
torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype
|
||||
)
|
||||
if q_quant_dtype == FP8_DTYPE:
|
||||
query, _ = to_float8(ref_query)
|
||||
else:
|
||||
query = ref_query
|
||||
|
||||
kv_lens = torch.randint(0, max_kv_len, (batch_size,), dtype=torch.int32)
|
||||
kv_lens[-1] = max_kv_len
|
||||
|
||||
seq_lens = kv_lens + q_lens
|
||||
max_seq_len = torch.max(seq_lens).item()
|
||||
|
||||
# Always using 1.0 scale to reflect the real perf in benchmarking
|
||||
k_scale = v_scale = 1.0
|
||||
ref_kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
||||
if kv_quant_dtype == FP8_DTYPE:
|
||||
kv_cache, _ = to_float8(ref_kv_cache)
|
||||
else:
|
||||
kv_cache = ref_kv_cache
|
||||
|
||||
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(
|
||||
0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
kv_indptr = [0]
|
||||
kv_indices = []
|
||||
kv_last_page_lens = []
|
||||
for i in range(batch_size):
|
||||
seq_len = seq_lens[i]
|
||||
assert seq_len > 0
|
||||
num_blocks = (seq_len + block_size - 1) // block_size
|
||||
kv_indices.extend(block_tables[i, :num_blocks])
|
||||
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
||||
kv_last_page_len = seq_len % block_size
|
||||
if kv_last_page_len == 0:
|
||||
kv_last_page_len = block_size
|
||||
kv_last_page_lens.append(kv_last_page_len)
|
||||
|
||||
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
||||
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||
workspace_buffer = torch.zeros(1024 * 1024 * 1024, dtype=torch.int8)
|
||||
|
||||
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
||||
workspace_buffer, kv_layout
|
||||
)
|
||||
wrapper.plan(
|
||||
q_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
causal=True,
|
||||
sm_scale=sm_scale,
|
||||
q_data_type=dtype,
|
||||
kv_data_type=dtype,
|
||||
)
|
||||
|
||||
def time_fn(fn, warmup=10, trials=20):
|
||||
torch.cuda.synchronize()
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
times = []
|
||||
for i in range(warmup):
|
||||
fn()
|
||||
for i in range(trials):
|
||||
start.record()
|
||||
fn()
|
||||
end.record()
|
||||
torch.cuda.synchronize()
|
||||
times.append(start.elapsed_time(end)) # ms
|
||||
return sum(times) / len(times), torch.std(torch.tensor(times))
|
||||
|
||||
o_scale = 1.0
|
||||
o_sf_scale = None
|
||||
output_baseline = torch.empty(ref_query.shape, dtype=dtype)
|
||||
if o_quant_dtype == FP4_DTYPE:
|
||||
o_sf_scale = 500.0
|
||||
output_trtllm = flashinfer.utils.FP4Tensor(
|
||||
torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8),
|
||||
torch.empty(
|
||||
(
|
||||
round_up(query.shape[0], 128),
|
||||
round_up(query.shape[1] * query.shape[2] // 16, 4),
|
||||
),
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
)
|
||||
else:
|
||||
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||
|
||||
def baseline_prefill():
|
||||
return wrapper.run(
|
||||
ref_query,
|
||||
ref_kv_cache,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
out=output_baseline,
|
||||
)
|
||||
|
||||
def trtllm_prefill():
|
||||
return flashinfer.prefill.trtllm_batch_context_with_kv_cache(
|
||||
query=query,
|
||||
kv_cache=kv_cache,
|
||||
workspace_buffer=workspace_buffer,
|
||||
block_tables=block_tables,
|
||||
seq_lens=seq_lens,
|
||||
max_q_len=max_q_len,
|
||||
max_kv_len=max_seq_len,
|
||||
bmm1_scale=q_scale * k_scale * sm_scale,
|
||||
bmm2_scale=v_scale / o_scale,
|
||||
batch_size=batch_size,
|
||||
cum_seq_lens_q=q_indptr,
|
||||
cum_seq_lens_kv=kv_indptr,
|
||||
o_sf_scale=o_sf_scale,
|
||||
out=output_trtllm,
|
||||
)
|
||||
|
||||
baseline_mean, baseline_std = time_fn(baseline_prefill)
|
||||
trtllm_mean, trtllm_std = time_fn(trtllm_prefill)
|
||||
|
||||
# Calculate percentage speedup (positive means TRT is faster)
|
||||
speedup_percent = (baseline_mean - trtllm_mean) / baseline_mean
|
||||
|
||||
print(
|
||||
f"\t{batch_size}\t{max_seq_len}\t{trtllm_mean:8.3f}\t{trtllm_std.item():8.3f}"
|
||||
f"\t{baseline_mean:8.3f}\t{baseline_std.item():8.3f}\t{speedup_percent:8.3f}"
|
||||
)
|
||||
|
||||
# Return results for CSV writing
|
||||
return {
|
||||
"batch_size": batch_size,
|
||||
"trtllm_mean": trtllm_mean,
|
||||
"trtllm_std": trtllm_std.item(),
|
||||
"baseline_mean": baseline_mean,
|
||||
"baseline_std": baseline_std.item(),
|
||||
"speedup_percent": speedup_percent,
|
||||
"q_dtype": str(q_quant_dtype),
|
||||
"kv_cache_dtype": str(kv_quant_dtype),
|
||||
"output_dtype": str(o_quant_dtype),
|
||||
"block_size": block_size,
|
||||
"num_kv_heads": num_kv_heads,
|
||||
"head_size": head_size,
|
||||
"max_seq_len": max_seq_len,
|
||||
}
|
||||
|
||||
|
||||
def write_results_to_csv(results, filename=None):
|
||||
"""Write benchmark results to CSV file."""
|
||||
if filename is None:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv"
|
||||
|
||||
fieldnames = [
|
||||
"batch_size",
|
||||
"trtllm_mean",
|
||||
"trtllm_std",
|
||||
"baseline_mean",
|
||||
"baseline_std",
|
||||
"speedup_percent",
|
||||
"q_dtype",
|
||||
"kv_cache_dtype",
|
||||
"output_dtype",
|
||||
"block_size",
|
||||
"num_kv_heads",
|
||||
"head_size",
|
||||
"max_seq_len",
|
||||
]
|
||||
|
||||
file_exists = os.path.exists(filename)
|
||||
|
||||
with open(filename, "a", newline="") as csvfile:
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||
|
||||
if not file_exists:
|
||||
writer.writeheader()
|
||||
|
||||
for result in results:
|
||||
writer.writerow(result)
|
||||
|
||||
print(f"Results written to {filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256]
|
||||
max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
|
||||
all_results = []
|
||||
|
||||
dtype = torch.bfloat16
|
||||
quant_dtypes = [
|
||||
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
|
||||
(None, None, None),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
|
||||
]
|
||||
|
||||
for quant_dtype in quant_dtypes:
|
||||
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtype
|
||||
q_quant_dtype = q_quant_dtype or dtype
|
||||
kv_quant_dtype = kv_quant_dtype or dtype
|
||||
o_quant_dtype = o_quant_dtype or dtype
|
||||
|
||||
print(
|
||||
f"Running benchmark for q_dtype = {q_quant_dtype}, "
|
||||
f"kv_cache_dtype: {kv_quant_dtype}, "
|
||||
f"output_dtype: {o_quant_dtype}"
|
||||
)
|
||||
print(
|
||||
"\tbatch_size\tmax_seq_len\ttrtllm_mean\ttrtllm_std\tbaseline_mean\t"
|
||||
"baseline_std\tspeedup_percent"
|
||||
)
|
||||
for max_seq_len in max_seq_lens:
|
||||
for bs in batch_sizes:
|
||||
result = benchmark_prefill(
|
||||
dtype=dtype,
|
||||
quant_dtypes=quant_dtype,
|
||||
batch_size=bs,
|
||||
max_seq_len=max_seq_len,
|
||||
)
|
||||
all_results.append(result)
|
||||
|
||||
# Write all results to CSV
|
||||
write_results_to_csv(all_results)
|
||||
@ -11,8 +11,8 @@ from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
import triton
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
_w8a8_block_fp8_matmul,
|
||||
|
||||
@ -95,4 +95,10 @@ WEIGHT_SHAPES = {
|
||||
([2048, 2816], 1),
|
||||
([1408, 2048], 0),
|
||||
],
|
||||
"CohereLabs/c4ai-command-a-03-2025": [
|
||||
([12288, 14336], 1),
|
||||
([12288, 12288], 0),
|
||||
([12288, 73728], 1),
|
||||
([36864, 12288], 0),
|
||||
],
|
||||
}
|
||||
|
||||
@ -1,108 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import gc
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from tabulate import tabulate
|
||||
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
|
||||
|
||||
class Metric:
|
||||
def __init__(self) -> None:
|
||||
self.cnt: int = 0
|
||||
self.sum_v: int = 0
|
||||
self.max_v: Optional[int] = None
|
||||
|
||||
def update(self, v: int) -> None:
|
||||
self.cnt += 1
|
||||
self.sum_v += v
|
||||
if self.max_v is None:
|
||||
self.max_v = v
|
||||
else:
|
||||
self.max_v = max(self.max_v, v)
|
||||
|
||||
def avg_v(self) -> float:
|
||||
return self.sum_v * 1.0 / self.cnt
|
||||
|
||||
|
||||
def main(args):
|
||||
rows = []
|
||||
for allocate_block in args.allocate_blocks:
|
||||
# Enforce a GC collect ahead to minimize the impact among runs
|
||||
gc.collect()
|
||||
block_pool = BlockPool(num_gpu_blocks=args.num_gpu_blocks, enable_caching=True)
|
||||
|
||||
get_blocks_metric: Metric = Metric()
|
||||
free_blocks_metric: Metric = Metric()
|
||||
for _ in range(args.num_iteration):
|
||||
t1 = time.monotonic_ns()
|
||||
blocks = block_pool.get_new_blocks(allocate_block)
|
||||
t2 = time.monotonic_ns()
|
||||
block_pool.free_blocks(blocks)
|
||||
t3 = time.monotonic_ns()
|
||||
get_blocks_metric.update(t2 - t1)
|
||||
free_blocks_metric.update(t3 - t2)
|
||||
|
||||
if get_blocks_metric.max_v is not None and free_blocks_metric.max_v is not None:
|
||||
rows.append(
|
||||
[
|
||||
get_blocks_metric.cnt,
|
||||
args.num_gpu_blocks,
|
||||
allocate_block,
|
||||
get_blocks_metric.avg_v() / 1000000,
|
||||
get_blocks_metric.max_v / 1000000.0,
|
||||
free_blocks_metric.avg_v() / 1000000,
|
||||
free_blocks_metric.max_v / 1000000.0,
|
||||
]
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"No valid metrics found."
|
||||
f" {get_blocks_metric.max_v=} {free_blocks_metric.max_v=}"
|
||||
)
|
||||
|
||||
print(
|
||||
tabulate(
|
||||
rows,
|
||||
headers=[
|
||||
"Iterations",
|
||||
"Total\nBlocks",
|
||||
"Allocated\nBlocks",
|
||||
"Get Blocks\nAvg (ms)",
|
||||
"Get Blocks\nMax (ms)",
|
||||
"Free Blocks\nAvg (ms)",
|
||||
"Free Blocks\nMax (ms)",
|
||||
],
|
||||
tablefmt="grid",
|
||||
floatfmt=".6f",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def invoke_main() -> None:
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark the performance of BlockPool for KV Cache."
|
||||
)
|
||||
parser.add_argument("--num-gpu-blocks", type=int, default=100000)
|
||||
parser.add_argument(
|
||||
"--num-iteration",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of iterations to run to stablize final data readings",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--allocate-blocks",
|
||||
type=int,
|
||||
nargs="*",
|
||||
default=[10, 50, 100, 500, 1000],
|
||||
help="Number of blocks to allocate",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
invoke_main() # pragma: no cover
|
||||
73
benchmarks/multi_turn/README.md
Normal file
73
benchmarks/multi_turn/README.md
Normal file
@ -0,0 +1,73 @@
|
||||
# Benchmark KV Cache Offloading with Multi-Turn Conversations
|
||||
|
||||
The requirements (pip) for `benchmark_serving_multi_turn.py` can be found in `requirements.txt`
|
||||
|
||||
First start serving your model
|
||||
|
||||
```bash
|
||||
export MODEL_PATH=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/
|
||||
|
||||
vllm serve $MODEL_PATH --served-model-name Llama --disable-log-requests
|
||||
```
|
||||
|
||||
The variable `MODEL_PATH` should be a path to the model files (e.g. downloaded from huggingface).
|
||||
|
||||
## Synthetic Multi-Turn Conversations
|
||||
|
||||
Download the following text file (used for generation of synthetic conversations)
|
||||
|
||||
```bash
|
||||
wget https://www.gutenberg.org/ebooks/1184.txt.utf-8
|
||||
mv 1184.txt.utf-8 pg1184.txt
|
||||
```
|
||||
|
||||
The filename `pg1184.txt` is used in `generate_multi_turn.json` (see `"text_files"`).
|
||||
|
||||
But you may use other text files if you prefer (using this specific file is not required).
|
||||
|
||||
Then run the benchmarking script
|
||||
|
||||
```bash
|
||||
export MODEL_PATH=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/
|
||||
|
||||
python benchmark_serving_multi_turn.py --model $MODEL_PATH --served-model-name Llama \
|
||||
--input-file generate_multi_turn.json --num-clients 2 --max-active-conversations 6
|
||||
```
|
||||
|
||||
You can edit the file `generate_multi_turn.json` to change the conversation parameters (number of turns, etc.).
|
||||
|
||||
If successful, you will see the following output
|
||||
|
||||
```bash
|
||||
----------------------------------------------------------------------------------------------------
|
||||
Statistics summary:
|
||||
runtime_sec = 215.810
|
||||
requests_per_sec = 0.769
|
||||
----------------------------------------------------------------------------------------------------
|
||||
count mean std min 25% 50% 75% 90% 99% max
|
||||
ttft_ms 166.0 78.22 67.63 45.91 59.94 62.26 64.43 69.66 353.18 567.54
|
||||
tpot_ms 166.0 25.37 0.57 24.40 25.07 25.31 25.50 25.84 27.50 28.05
|
||||
latency_ms 166.0 2591.07 326.90 1998.53 2341.62 2573.01 2860.10 3003.50 3268.46 3862.94
|
||||
input_num_turns 166.0 7.43 4.57 1.00 3.00 7.00 11.00 13.00 17.00 17.00
|
||||
input_num_tokens 166.0 2006.20 893.56 522.00 1247.75 2019.00 2718.00 3233.00 3736.45 3899.00
|
||||
output_num_tokens 166.0 100.01 11.80 80.00 91.00 99.00 109.75 116.00 120.00 120.00
|
||||
output_num_chunks 166.0 99.01 11.80 79.00 90.00 98.00 108.75 115.00 119.00 119.00
|
||||
----------------------------------------------------------------------------------------------------
|
||||
```
|
||||
|
||||
## ShareGPT Conversations
|
||||
|
||||
To run with the ShareGPT data, download the following ShareGPT dataset:
|
||||
`https://huggingface.co/datasets/philschmid/sharegpt-raw/blob/main/sharegpt_20230401_clean_lang_split.json`
|
||||
|
||||
Use the `convert_sharegpt_to_openai.py` script to convert the dataset to a format supported by `benchmark_serving_multi_turn.py`
|
||||
|
||||
```bash
|
||||
python convert_sharegpt_to_openai.py sharegpt_20230401_clean_lang_split.json sharegpt_conv_128.json --seed=99 --max-items=128
|
||||
```
|
||||
|
||||
The script will convert the ShareGPT dataset to a dataset with the standard user/assistant roles.
|
||||
|
||||
The flag `--max-items=128` is used to sample 128 conversations from the original dataset (change as needed).
|
||||
|
||||
Use the output JSON file `sharegpt_conv_128.json` as the `--input-file` for `benchmark_serving_multi_turn.py`.
|
||||
493
benchmarks/multi_turn/bench_dataset.py
Normal file
493
benchmarks/multi_turn/bench_dataset.py
Normal file
@ -0,0 +1,493 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from statistics import mean
|
||||
from typing import Any, NamedTuple, Optional, Union
|
||||
|
||||
import numpy as np # type: ignore
|
||||
import pandas as pd # type: ignore
|
||||
from bench_utils import (
|
||||
TEXT_SEPARATOR,
|
||||
Color,
|
||||
logger,
|
||||
)
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
|
||||
# Conversation ID is a string (e.g: "UzTK34D")
|
||||
ConvId = str
|
||||
|
||||
# A list of dicts (dicts with keys "id" and "messages")
|
||||
ShareGptConversations = list[dict[str, Any]]
|
||||
|
||||
# A list of dicts (dicts with keys "role" and "content")
|
||||
MessagesList = list[dict[str, str]]
|
||||
|
||||
# Map conversation ID to conversation messages
|
||||
ConversationsMap = list[ConvId, MessagesList]
|
||||
|
||||
|
||||
class Distribution(ABC):
|
||||
@abstractmethod
|
||||
def sample(self, size: int = 1) -> np.ndarray:
|
||||
pass
|
||||
|
||||
|
||||
class UniformDistribution(Distribution):
|
||||
def __init__(
|
||||
self,
|
||||
min_val: Union[int, float],
|
||||
max_val: Union[int, float],
|
||||
is_integer: bool = True,
|
||||
) -> None:
|
||||
self.min_val = min_val
|
||||
self.max_val = max_val
|
||||
self.is_integer = is_integer
|
||||
|
||||
def sample(self, size: int = 1) -> np.ndarray:
|
||||
if self.is_integer:
|
||||
return np.random.randint(
|
||||
int(self.min_val), int(self.max_val + 1), size=size
|
||||
)
|
||||
else:
|
||||
return np.random.uniform(self.min_val, self.max_val, size=size)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"UniformDistribution[{self.min_val}, {self.max_val}]"
|
||||
|
||||
|
||||
class ConstantDistribution(Distribution):
|
||||
def __init__(self, value: Union[int, float]) -> None:
|
||||
self.value = value
|
||||
self.max_val = value
|
||||
|
||||
def sample(self, size: int = 1) -> np.ndarray:
|
||||
return np.full(shape=size, fill_value=self.value)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Constant[{self.value}]"
|
||||
|
||||
|
||||
class ZipfDistribution(Distribution):
|
||||
def __init__(self, alpha: float, max_val: Optional[int] = None) -> None:
|
||||
self.alpha = alpha
|
||||
self.max_val = max_val
|
||||
|
||||
def sample(self, size: int = 1) -> np.ndarray:
|
||||
samples = np.random.zipf(self.alpha, size=size)
|
||||
if self.max_val:
|
||||
samples = np.minimum(samples, self.max_val)
|
||||
return samples
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"ZipfDistribution[{self.alpha}]"
|
||||
|
||||
|
||||
class PoissonDistribution(Distribution):
|
||||
def __init__(self, alpha: float, max_val: Optional[int] = None) -> None:
|
||||
self.alpha = alpha
|
||||
self.max_val = max_val
|
||||
|
||||
def sample(self, size: int = 1) -> np.ndarray:
|
||||
samples = np.random.poisson(self.alpha, size=size)
|
||||
if self.max_val:
|
||||
samples = np.minimum(samples, self.max_val)
|
||||
return samples
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"PoissonDistribution[{self.alpha}]"
|
||||
|
||||
|
||||
class LognormalDistribution(Distribution):
|
||||
def __init__(
|
||||
self, mean: float, sigma: float, max_val: Optional[int] = None
|
||||
) -> None:
|
||||
self.mean = mean
|
||||
self.sigma = sigma
|
||||
self.max_val = max_val
|
||||
|
||||
def sample(self, size: int = 1) -> np.ndarray:
|
||||
samples = np.random.lognormal(mean=self.mean, sigma=self.sigma, size=size)
|
||||
if self.max_val:
|
||||
samples = np.minimum(samples, self.max_val)
|
||||
|
||||
return np.round(samples).astype(int)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"LognormalDistribution[{self.mean}, {self.sigma}]"
|
||||
|
||||
|
||||
class GenConvArgs(NamedTuple):
|
||||
num_conversations: int
|
||||
text_files: list[str]
|
||||
input_num_turns: Distribution
|
||||
input_common_prefix_num_tokens: Distribution
|
||||
input_prefix_num_tokens: Distribution
|
||||
input_num_tokens: Distribution
|
||||
output_num_tokens: Distribution
|
||||
print_stats: bool
|
||||
|
||||
|
||||
def verify_field_exists(
|
||||
conf: dict, field_name: str, section: str, subsection: str
|
||||
) -> None:
|
||||
if field_name not in conf:
|
||||
raise ValueError(
|
||||
f"Missing field '{field_name}' in {section=} and {subsection=}"
|
||||
)
|
||||
|
||||
|
||||
def get_random_distribution(
|
||||
conf: dict, section: str, subsection: str, optional: bool = False
|
||||
) -> Distribution:
|
||||
# section can be "prompt_input" or "prompt_output" (both required)
|
||||
conf = conf[section]
|
||||
|
||||
if optional and subsection not in conf:
|
||||
# Optional subsection, if not found assume the value is always 0
|
||||
return ConstantDistribution(0)
|
||||
|
||||
# subsection can be "num_turns", "num_tokens" or "prefix_num_tokens"
|
||||
if subsection not in conf:
|
||||
raise ValueError(f"Missing subsection {subsection} in section {section}")
|
||||
|
||||
conf = conf[subsection]
|
||||
|
||||
distribution = conf.get("distribution")
|
||||
if distribution is None:
|
||||
raise ValueError(
|
||||
f"Missing field 'distribution' in {section=} and {subsection=}"
|
||||
)
|
||||
|
||||
if distribution == "constant":
|
||||
verify_field_exists(conf, "value", section, subsection)
|
||||
return ConstantDistribution(conf["value"])
|
||||
|
||||
elif distribution == "zipf":
|
||||
verify_field_exists(conf, "alpha", section, subsection)
|
||||
max_val = conf.get("max", None)
|
||||
return ZipfDistribution(conf["alpha"], max_val=max_val)
|
||||
|
||||
elif distribution == "poisson":
|
||||
verify_field_exists(conf, "alpha", section, subsection)
|
||||
max_val = conf.get("max", None)
|
||||
return PoissonDistribution(conf["alpha"], max_val=max_val)
|
||||
|
||||
elif distribution == "lognormal":
|
||||
verify_field_exists(conf, "mean", section, subsection)
|
||||
verify_field_exists(conf, "sigma", section, subsection)
|
||||
max_val = conf.get("max", None)
|
||||
return LognormalDistribution(conf["mean"], conf["sigma"], max_val=max_val)
|
||||
|
||||
elif distribution == "uniform":
|
||||
verify_field_exists(conf, "min", section, subsection)
|
||||
verify_field_exists(conf, "max", section, subsection)
|
||||
|
||||
min_value = conf["min"]
|
||||
max_value = conf["max"]
|
||||
|
||||
assert min_value > 0
|
||||
assert min_value <= max_value
|
||||
|
||||
is_integer = isinstance(min_value, int) and isinstance(max_value, int)
|
||||
return UniformDistribution(min_value, max_value, is_integer)
|
||||
else:
|
||||
raise ValueError(f"Unknown distribution: {distribution}")
|
||||
|
||||
|
||||
def parse_input_json_file(conf: dict) -> GenConvArgs:
|
||||
# Validate the input file
|
||||
assert isinstance(conf, dict)
|
||||
required_fields = [
|
||||
"filetype",
|
||||
"num_conversations",
|
||||
"text_files",
|
||||
"prompt_input",
|
||||
"prompt_output",
|
||||
]
|
||||
for field in required_fields:
|
||||
assert field in conf, f"Missing field {field} in input {conf}"
|
||||
|
||||
assert conf["filetype"] == "generate_conversations"
|
||||
|
||||
assert conf["num_conversations"] > 0, "num_conversations should be larger than zero"
|
||||
|
||||
text_files = conf["text_files"]
|
||||
|
||||
assert isinstance(text_files, list), "Field 'text_files' should be a list"
|
||||
assert len(text_files) > 0, (
|
||||
"Field 'text_files' should be a list with at least one file"
|
||||
)
|
||||
|
||||
# Parse the parameters for the prompt input/output workload
|
||||
input_num_turns = get_random_distribution(conf, "prompt_input", "num_turns")
|
||||
input_num_tokens = get_random_distribution(conf, "prompt_input", "num_tokens")
|
||||
input_common_prefix_num_tokens = get_random_distribution(
|
||||
conf, "prompt_input", "common_prefix_num_tokens", optional=True
|
||||
)
|
||||
input_prefix_num_tokens = get_random_distribution(
|
||||
conf, "prompt_input", "prefix_num_tokens"
|
||||
)
|
||||
output_num_tokens = get_random_distribution(conf, "prompt_output", "num_tokens")
|
||||
|
||||
print_stats: bool = conf.get("print_stats", False)
|
||||
assert isinstance(print_stats, bool), (
|
||||
"Field 'print_stats' should be either 'true' or 'false'"
|
||||
)
|
||||
|
||||
args = GenConvArgs(
|
||||
num_conversations=conf["num_conversations"],
|
||||
text_files=text_files,
|
||||
input_num_turns=input_num_turns,
|
||||
input_common_prefix_num_tokens=input_common_prefix_num_tokens,
|
||||
input_prefix_num_tokens=input_prefix_num_tokens,
|
||||
input_num_tokens=input_num_tokens,
|
||||
output_num_tokens=output_num_tokens,
|
||||
print_stats=print_stats,
|
||||
)
|
||||
return args
|
||||
|
||||
|
||||
def print_conv_stats(conversations: ConversationsMap, tokenizer: AutoTokenizer) -> None:
|
||||
# Collect statistics
|
||||
conv_stats: list[dict[Any, Any]] = []
|
||||
req_stats: list[int] = []
|
||||
|
||||
print("\nCollecting statistics...")
|
||||
for messages in conversations.values():
|
||||
# messages is a list of dicts
|
||||
user_tokens: list[int] = []
|
||||
assistant_tokens: list[int] = []
|
||||
request_tokens: list[int] = []
|
||||
|
||||
req_tokens = 0
|
||||
for m in messages:
|
||||
content = m["content"]
|
||||
num_tokens = len(tokenizer(content).input_ids)
|
||||
|
||||
if m["role"] == "user":
|
||||
user_tokens.append(num_tokens)
|
||||
# New user prompt including all chat history
|
||||
req_tokens += num_tokens
|
||||
request_tokens.append(req_tokens)
|
||||
|
||||
elif m["role"] == "assistant":
|
||||
assistant_tokens.append(num_tokens)
|
||||
# Update assistant answer
|
||||
# (will be part of chat history for the next user prompt)
|
||||
req_tokens += num_tokens
|
||||
|
||||
item_stats = {
|
||||
"conversation_turns": len(messages),
|
||||
"user_tokens": mean(user_tokens),
|
||||
"assistant_tokens": mean(assistant_tokens),
|
||||
}
|
||||
|
||||
conv_stats.append(item_stats)
|
||||
req_stats.extend(request_tokens)
|
||||
|
||||
# Print statistics
|
||||
percentiles = [0.25, 0.5, 0.75, 0.9, 0.99]
|
||||
|
||||
print(TEXT_SEPARATOR)
|
||||
print(f"{Color.YELLOW}Conversations statistics:{Color.RESET}")
|
||||
print(TEXT_SEPARATOR)
|
||||
df = pd.DataFrame(conv_stats)
|
||||
print(df.describe(percentiles=percentiles).transpose())
|
||||
print(TEXT_SEPARATOR)
|
||||
print(f"{Color.YELLOW}Request statistics:{Color.RESET}")
|
||||
print(TEXT_SEPARATOR)
|
||||
df = pd.DataFrame(req_stats, columns=["request_tokens"])
|
||||
print(df.describe(percentiles=percentiles).transpose())
|
||||
print(TEXT_SEPARATOR)
|
||||
|
||||
|
||||
def generate_conversations(
|
||||
args: GenConvArgs, tokenizer: AutoTokenizer
|
||||
) -> ConversationsMap:
|
||||
# Text for all user prompts
|
||||
# (text from the input text files will be appended to this line)
|
||||
base_prompt_text = "Please rewrite the following text and add more content: "
|
||||
base_prompt_token_count = len(
|
||||
tokenizer.encode(base_prompt_text, add_special_tokens=False)
|
||||
)
|
||||
|
||||
logger.info(f"{Color.PURPLE}Generating conversations...{Color.RESET}")
|
||||
logger.info(args)
|
||||
|
||||
list_of_tokens = []
|
||||
|
||||
for filename in args.text_files:
|
||||
# Load text file that will be used to generate prompts
|
||||
with open(filename) as file:
|
||||
data = file.read()
|
||||
tokens_in_file = tokenizer.encode(data, add_special_tokens=False)
|
||||
list_of_tokens.extend(tokens_in_file)
|
||||
|
||||
conversations: ConversationsMap = {}
|
||||
conv_id = 0
|
||||
|
||||
# Generate number of turns for every conversation
|
||||
turn_count: np.ndarray = args.input_num_turns.sample(args.num_conversations)
|
||||
|
||||
# Turn count should be at least 2 (one user prompt and one assistant answer)
|
||||
turn_count = np.maximum(turn_count, 2)
|
||||
|
||||
# Round up to an even number (every user prompt should have an answer)
|
||||
turn_count = turn_count + (turn_count % 2)
|
||||
|
||||
# Generate number of prefix tokens for every conversation
|
||||
conv_prefix_tokens: np.ndarray = args.input_prefix_num_tokens.sample(
|
||||
args.num_conversations
|
||||
)
|
||||
|
||||
# Used to reduce shared text between conversations
|
||||
# (jump/skip over text sections between conversations)
|
||||
base_offset = 0
|
||||
|
||||
# Common prefix size for all conversations (only 1 sample required)
|
||||
common_prefix_text = ""
|
||||
common_prefix_tokens: int = args.input_common_prefix_num_tokens.sample(1)[0]
|
||||
if common_prefix_tokens > 0:
|
||||
# Using "." at the end to separate sentences
|
||||
common_prefix_text = (
|
||||
tokenizer.decode(list_of_tokens[: common_prefix_tokens - 2]) + "."
|
||||
)
|
||||
base_offset += common_prefix_tokens
|
||||
|
||||
for conv_id in range(args.num_conversations):
|
||||
# Generate a single conversation
|
||||
messages: MessagesList = []
|
||||
|
||||
nturns = turn_count[conv_id]
|
||||
|
||||
# User prompt token count per turn (with lower limit)
|
||||
input_token_count: np.ndarray = args.input_num_tokens.sample(nturns)
|
||||
input_token_count = np.maximum(input_token_count, base_prompt_token_count)
|
||||
|
||||
# Assistant answer token count per turn (with lower limit)
|
||||
output_token_count: np.ndarray = args.output_num_tokens.sample(nturns)
|
||||
output_token_count = np.maximum(output_token_count, 1)
|
||||
|
||||
user_turn = True
|
||||
for turn_id in range(nturns):
|
||||
if user_turn:
|
||||
role = "user"
|
||||
num_tokens = input_token_count[turn_id]
|
||||
|
||||
# Generate the user prompt,
|
||||
# use a unique prefix (the conv_id) for each conversation
|
||||
# (to avoid shared prefix between conversations)
|
||||
content = f"{conv_id} is a nice number... "
|
||||
|
||||
if len(common_prefix_text) > 0 and turn_id == 0:
|
||||
content = common_prefix_text + content
|
||||
|
||||
# Update the number of tokens left for the content
|
||||
num_tokens -= len(tokenizer.encode(content, add_special_tokens=False))
|
||||
|
||||
if turn_id == 0:
|
||||
prefix_num_tokens = conv_prefix_tokens[conv_id]
|
||||
if prefix_num_tokens > 0:
|
||||
# Add prefix text (context) to the first turn
|
||||
start_offset = base_offset
|
||||
end_offset = start_offset + prefix_num_tokens
|
||||
assert len(list_of_tokens) > end_offset, (
|
||||
"Not enough input text to generate "
|
||||
f"{prefix_num_tokens} tokens for the "
|
||||
f"prefix text ({start_offset=}, {end_offset=})"
|
||||
)
|
||||
|
||||
content += f"{conv_id}, " + tokenizer.decode(
|
||||
list_of_tokens[start_offset:end_offset]
|
||||
)
|
||||
base_offset += prefix_num_tokens
|
||||
|
||||
# Add the actual user prompt/question after the prefix text
|
||||
content += base_prompt_text
|
||||
num_tokens -= base_prompt_token_count
|
||||
|
||||
if num_tokens > 0:
|
||||
# Add text from the input file (to reach the desired token count)
|
||||
start_offset = base_offset + turn_id * input_token_count.max()
|
||||
end_offset = start_offset + num_tokens
|
||||
assert len(list_of_tokens) > end_offset, (
|
||||
f"Not enough input text to generate {num_tokens} tokens "
|
||||
f"for the prompt ({start_offset=}, {end_offset=})"
|
||||
)
|
||||
|
||||
# Convert tokens back to text
|
||||
content += tokenizer.decode(list_of_tokens[start_offset:end_offset])
|
||||
else:
|
||||
role = "assistant"
|
||||
# This content will not be used as input to the LLM server
|
||||
# (actual answers will be used instead).
|
||||
# Content is only required to determine the min_tokens/max_tokens
|
||||
# (inputs to the LLM server).
|
||||
num_tokens = output_token_count[turn_id]
|
||||
assert len(list_of_tokens) > num_tokens, (
|
||||
f"Not enough input text to generate {num_tokens} "
|
||||
"tokens for assistant content"
|
||||
)
|
||||
content = tokenizer.decode(list_of_tokens[:num_tokens])
|
||||
|
||||
# Append the user/assistant message to the list of messages
|
||||
messages.append({"role": role, "content": content})
|
||||
user_turn = not user_turn
|
||||
|
||||
# Add the new conversation
|
||||
conversations[f"CONV_ID_{conv_id}"] = messages
|
||||
|
||||
# Increase base offset for the next conversation
|
||||
base_offset += nturns
|
||||
|
||||
if args.print_stats:
|
||||
print_conv_stats(conversations, tokenizer)
|
||||
|
||||
return conversations
|
||||
|
||||
|
||||
def conversations_list_to_dict(input_list: ShareGptConversations) -> ConversationsMap:
|
||||
conversations: ConversationsMap = {}
|
||||
|
||||
for item in input_list:
|
||||
conv_id: str = item["id"]
|
||||
assert isinstance(conv_id, str)
|
||||
|
||||
assert conv_id not in conversations, (
|
||||
f"Conversation ID {conv_id} found more than once in the input"
|
||||
)
|
||||
|
||||
messages: MessagesList = item["messages"]
|
||||
assert isinstance(messages, list), (
|
||||
f"Conversation messages should be a list (ID: {conv_id})"
|
||||
)
|
||||
assert len(messages) > 0, f"Conversation with no messages (ID: {conv_id})"
|
||||
|
||||
conversations[conv_id] = messages
|
||||
|
||||
logger.info(f"Using {len(conversations)} unique conversations (IDs)")
|
||||
assert len(conversations) == len(input_list)
|
||||
|
||||
# Print statistics about the selected conversations
|
||||
stats: list[dict[str, Any]] = []
|
||||
for conv_data in conversations.values():
|
||||
stats.append({"num_turns": len(conv_data)})
|
||||
|
||||
print(TEXT_SEPARATOR)
|
||||
print(f"{Color.YELLOW}Conversations statistics:{Color.RESET}")
|
||||
print(TEXT_SEPARATOR)
|
||||
percentiles = [0.25, 0.5, 0.75, 0.9, 0.99, 0.999, 0.9999]
|
||||
conv_stats = pd.DataFrame(stats).describe(percentiles=percentiles)
|
||||
print(conv_stats.transpose())
|
||||
print(TEXT_SEPARATOR)
|
||||
|
||||
return conversations
|
||||
|
||||
|
||||
def conversations_dict_to_list(input_dict: ConversationsMap) -> ShareGptConversations:
|
||||
output: ShareGptConversations = []
|
||||
for conv_id, conv_data in input_dict.items():
|
||||
new_item = {"id": conv_id, "messages": conv_data}
|
||||
output.append(new_item)
|
||||
|
||||
return output
|
||||
28
benchmarks/multi_turn/bench_utils.py
Normal file
28
benchmarks/multi_turn/bench_utils.py
Normal file
@ -0,0 +1,28 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import logging
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class Color(Enum):
|
||||
RED = "\033[91m"
|
||||
GREEN = "\033[92m"
|
||||
BLUE = "\033[94m"
|
||||
PURPLE = "\033[95m"
|
||||
CYAN = "\033[96m"
|
||||
YELLOW = "\033[93m"
|
||||
RESET = "\033[0m"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
TEXT_SEPARATOR = "-" * 100
|
||||
|
||||
# Configure the logger
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] - %(message)s",
|
||||
datefmt="%d-%m-%Y %H:%M:%S",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
1569
benchmarks/multi_turn/benchmark_serving_multi_turn.py
Normal file
1569
benchmarks/multi_turn/benchmark_serving_multi_turn.py
Normal file
File diff suppressed because it is too large
Load Diff
354
benchmarks/multi_turn/convert_sharegpt_to_openai.py
Normal file
354
benchmarks/multi_turn/convert_sharegpt_to_openai.py
Normal file
@ -0,0 +1,354 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Download dataset from:
|
||||
https://huggingface.co/datasets/philschmid/sharegpt-raw/blob/main/sharegpt_20230401_clean_lang_split.json
|
||||
|
||||
Convert to OpenAI API:
|
||||
export INPUT_FILE=sharegpt_20230401_clean_lang_split.json
|
||||
python convert_sharegpt_to_openai.py $INPUT_FILE sharegpt_conv_128.json --max-items=128
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import random
|
||||
from statistics import mean
|
||||
from typing import Any, Optional
|
||||
|
||||
import pandas as pd # type: ignore
|
||||
import tqdm # type: ignore
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
|
||||
|
||||
def has_non_english_chars(text: str) -> bool:
|
||||
return not text.isascii()
|
||||
|
||||
|
||||
def content_is_valid(
|
||||
content: str, min_content_len: Optional[int], max_content_len: Optional[int]
|
||||
) -> bool:
|
||||
if min_content_len and len(content) < min_content_len:
|
||||
return False
|
||||
|
||||
if max_content_len and len(content) > max_content_len:
|
||||
return False
|
||||
|
||||
return has_non_english_chars(content)
|
||||
|
||||
|
||||
def print_stats(
|
||||
conversations: "list[dict[Any, Any]]", tokenizer: Optional[AutoTokenizer] = None
|
||||
) -> None:
|
||||
# Collect statistics
|
||||
stats = []
|
||||
|
||||
print("\nCollecting statistics...")
|
||||
for item in tqdm.tqdm(conversations):
|
||||
# item has "id" and "messages"
|
||||
messages = item["messages"]
|
||||
|
||||
user_turns = 0
|
||||
assistant_turns = 0
|
||||
user_words = 0
|
||||
assistant_words = 0
|
||||
conv_chars = 0
|
||||
|
||||
user_tokens: list[int] = []
|
||||
assistant_tokens: list[int] = []
|
||||
|
||||
for m in messages:
|
||||
content = m["content"]
|
||||
conv_chars += len(content)
|
||||
content_num_words = content.count(" ") + 1
|
||||
|
||||
num_tokens = 0
|
||||
if tokenizer:
|
||||
num_tokens = len(tokenizer(m["content"]).input_ids)
|
||||
|
||||
if m["role"] == "user":
|
||||
user_turns += 1
|
||||
user_words += content_num_words
|
||||
if tokenizer:
|
||||
user_tokens.append(num_tokens)
|
||||
|
||||
elif m["role"] == "assistant":
|
||||
assistant_turns += 1
|
||||
assistant_words += content_num_words
|
||||
if tokenizer:
|
||||
assistant_tokens.append(num_tokens)
|
||||
|
||||
# assert user_turns == assistant_turns, \
|
||||
# f"Invalid conversation ID {item['id']}"
|
||||
|
||||
conv_words = user_words + assistant_words
|
||||
item_stats = {
|
||||
"user_turns": user_turns,
|
||||
"assistant_turns": assistant_turns,
|
||||
"user_words": user_words,
|
||||
"assistant_words": assistant_words,
|
||||
"conv_turns": len(messages),
|
||||
"conv_words": conv_words,
|
||||
"conv_characters": conv_chars,
|
||||
}
|
||||
|
||||
if len(user_tokens) > 0:
|
||||
item_stats["user_tokens"] = int(mean(user_tokens))
|
||||
|
||||
if len(assistant_tokens) > 0:
|
||||
item_stats["assistant_tokens"] = int(mean(assistant_tokens))
|
||||
|
||||
stats.append(item_stats)
|
||||
|
||||
print("\nStatistics:")
|
||||
percentiles = [0.25, 0.5, 0.75, 0.9, 0.99, 0.999, 0.9999]
|
||||
df = pd.DataFrame(stats)
|
||||
print(df.describe(percentiles=percentiles).transpose())
|
||||
|
||||
|
||||
def convert_sharegpt_to_openai(
|
||||
seed: int,
|
||||
input_file: str,
|
||||
output_file: str,
|
||||
max_items: Optional[int],
|
||||
min_content_len: Optional[int] = None,
|
||||
max_content_len: Optional[int] = None,
|
||||
min_turns: Optional[int] = None,
|
||||
max_turns: Optional[int] = None,
|
||||
model: Optional[str] = None,
|
||||
) -> None:
|
||||
if min_turns and max_turns:
|
||||
assert min_turns <= max_turns
|
||||
|
||||
if min_content_len and max_content_len:
|
||||
# Verify that min is not larger than max if both were given
|
||||
assert min_content_len <= max_content_len
|
||||
|
||||
print(
|
||||
f"Input parameters:\n{seed=}, {max_items=}, {min_content_len=},"
|
||||
f" {max_content_len=}, {min_turns=}, {max_turns=}\n"
|
||||
)
|
||||
|
||||
random.seed(seed)
|
||||
|
||||
tokenizer = None
|
||||
if model is not None:
|
||||
print(f"Loading tokenizer from: {model}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||
|
||||
# Read the ShareGPT JSON file
|
||||
print(f"Reading file: {input_file}")
|
||||
with open(input_file, encoding="utf-8") as f:
|
||||
# Should be a list of dicts
|
||||
# Each dict should have "id" (string) and "conversations" (list of dicts)
|
||||
sharegpt_data = json.load(f)
|
||||
|
||||
assert isinstance(sharegpt_data, list), "Input file should contain a list of dicts"
|
||||
|
||||
print(f"Total items in input file: {len(sharegpt_data):,}")
|
||||
|
||||
print(f"Shuffling dataset with seed {seed}")
|
||||
random.shuffle(sharegpt_data)
|
||||
|
||||
# Map conversation ID to the all the messages
|
||||
conversation_parts: dict[str, list[Any]] = {}
|
||||
|
||||
for item in tqdm.tqdm(sharegpt_data):
|
||||
assert "id" in item, "Missing key 'id'"
|
||||
assert "conversations" in item, "Missing key 'conversations'"
|
||||
|
||||
# Conversation ID (e.g: "hiWPlMD") and part/session (0, 1, 2, etc.)
|
||||
conv_id, _ = item["id"].split("_")
|
||||
new_turns = item["conversations"]
|
||||
|
||||
if conv_id not in conversation_parts:
|
||||
# Start new conversation
|
||||
conversation_parts[conv_id] = []
|
||||
elif len(conversation_parts[conv_id]) > 0 and len(new_turns) > 0:
|
||||
prev_turns = conversation_parts[conv_id][-1]
|
||||
if prev_turns[-1]["from"] == new_turns[0]["from"]:
|
||||
new_turns = new_turns[1:]
|
||||
|
||||
if len(new_turns) > 0:
|
||||
# We assume that parts are in order in the ShareGPT dataset
|
||||
conversation_parts[conv_id].append(new_turns)
|
||||
|
||||
dataset: list[dict[str, Any]] = []
|
||||
for conv_id, conv_parts in conversation_parts.items():
|
||||
new_item = {"id": conv_id}
|
||||
|
||||
conversations: list[dict[str, str]] = []
|
||||
|
||||
# Merge all parts
|
||||
for conv_part in conv_parts:
|
||||
conversations.extend(conv_part)
|
||||
|
||||
if len(conversations) > 0:
|
||||
new_item["conversations"] = conversations
|
||||
dataset.append(new_item)
|
||||
|
||||
print(f"Total unique conversations (IDs) in input file: {len(dataset):,}")
|
||||
|
||||
# Final output data
|
||||
final_openai_dataset: list[dict] = []
|
||||
|
||||
# Filter conversations from the ShareGPT dataset and convert to OpenAI format
|
||||
for item in tqdm.tqdm(dataset):
|
||||
messages: list[dict] = []
|
||||
|
||||
assert "id" in item, "Missing key 'id'"
|
||||
assert "conversations" in item, "Missing key 'conversations'"
|
||||
|
||||
conv_id = item["id"]
|
||||
conversations = item["conversations"]
|
||||
|
||||
if min_turns is not None and len(conversations) < min_turns:
|
||||
# Skip short conversations
|
||||
continue
|
||||
|
||||
# Convert each message in the conversation, up to max_turns if specified
|
||||
for i, turn in enumerate(conversations):
|
||||
assert "from" in turn and "value" in turn, (
|
||||
f"Invalid conversation ID {conv_id} - missing 'from' or 'value'"
|
||||
)
|
||||
|
||||
role = None
|
||||
turn_from = turn["from"]
|
||||
|
||||
if turn_from in {"human", "user"}:
|
||||
role = "user"
|
||||
elif turn_from in {"gpt", "bing", "chatgpt", "bard"}:
|
||||
role = "assistant"
|
||||
elif turn_from == "system":
|
||||
role = "system"
|
||||
|
||||
assert role is not None, (
|
||||
f"Invalid conversation ID {conv_id} - 'from'='{turn_from}' is invalid"
|
||||
)
|
||||
|
||||
if i == 0 and role != "user":
|
||||
# If the first message is from assistant (gpt), skip it.
|
||||
# this happens when the conversation is a follow-up
|
||||
# to a previous conversation (from the same user).
|
||||
continue
|
||||
|
||||
if max_turns is not None and i >= max_turns:
|
||||
break
|
||||
|
||||
# Convert message to OpenAI format (with "role" and "content")
|
||||
content = turn["value"]
|
||||
messages.append({"role": role, "content": content})
|
||||
|
||||
# Add the converted conversation to the OpenAI format
|
||||
if len(messages) > 0:
|
||||
valid_messages = True
|
||||
|
||||
# First turn should always be from the user
|
||||
user_turn = True
|
||||
|
||||
for m in messages:
|
||||
# Make sure that turns alternate between user and assistant
|
||||
if (user_turn and m["role"] != "user") or (
|
||||
not user_turn and m["role"] != "assistant"
|
||||
):
|
||||
valid_messages = False
|
||||
break
|
||||
|
||||
user_turn = not user_turn
|
||||
|
||||
content = m["content"]
|
||||
valid_messages = content_is_valid(
|
||||
content, min_content_len, max_content_len
|
||||
)
|
||||
if not valid_messages:
|
||||
break
|
||||
|
||||
if valid_messages is True:
|
||||
final_openai_dataset.append({"id": conv_id, "messages": messages})
|
||||
|
||||
assert len(final_openai_dataset) > 0, "Final number of conversations is zero"
|
||||
|
||||
print_stats(final_openai_dataset)
|
||||
|
||||
print_stats_again = False
|
||||
if max_items is not None and len(final_openai_dataset) > max_items:
|
||||
print(f"\n\nSampling {max_items} items from the dataset...")
|
||||
print_stats_again = True
|
||||
final_openai_dataset = random.sample(final_openai_dataset, max_items)
|
||||
|
||||
if print_stats_again:
|
||||
# Print stats after the dataset changed
|
||||
print_stats(final_openai_dataset, tokenizer)
|
||||
|
||||
# Write the converted data to a new JSON file
|
||||
final_size = len(final_openai_dataset)
|
||||
print(f"\nTotal conversations converted (after filtering): {final_size:,}")
|
||||
print(f"\nWriting file: {output_file}")
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump(final_openai_dataset, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert ShareGPT dataset to OpenAI API format"
|
||||
)
|
||||
parser.add_argument("input_file", help="Path to the input ShareGPT JSON file")
|
||||
parser.add_argument(
|
||||
"output_file", help="Path to the output OpenAI format JSON file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed", type=int, default=0, help="Seed for random number generators"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-items",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum number of items in the output file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min-turns",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Minimum number of turns per conversation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-turns",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum number of turns per conversation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min-content-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Min number of characters in the messages' content",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-content-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Max number of characters in the messages' content",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default=None,
|
||||
help="LLM model, only the tokenizer will be used",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_sharegpt_to_openai(
|
||||
args.seed,
|
||||
args.input_file,
|
||||
args.output_file,
|
||||
args.max_items,
|
||||
args.min_content_len,
|
||||
args.max_content_len,
|
||||
args.min_turns,
|
||||
args.max_turns,
|
||||
args.model,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
35
benchmarks/multi_turn/generate_multi_turn.json
Normal file
35
benchmarks/multi_turn/generate_multi_turn.json
Normal file
@ -0,0 +1,35 @@
|
||||
{
|
||||
"filetype": "generate_conversations",
|
||||
"num_conversations": 24,
|
||||
"text_files": ["pg1184.txt"],
|
||||
"print_stats": false,
|
||||
"prompt_input": {
|
||||
"num_turns": {
|
||||
"distribution": "uniform",
|
||||
"min": 12,
|
||||
"max": 18
|
||||
},
|
||||
"common_prefix_num_tokens": {
|
||||
"distribution": "constant",
|
||||
"value": 500
|
||||
},
|
||||
"prefix_num_tokens": {
|
||||
"distribution": "lognormal",
|
||||
"mean": 6,
|
||||
"sigma": 4,
|
||||
"max": 1500
|
||||
},
|
||||
"num_tokens": {
|
||||
"distribution": "uniform",
|
||||
"min": 120,
|
||||
"max": 160
|
||||
}
|
||||
},
|
||||
"prompt_output": {
|
||||
"num_tokens": {
|
||||
"distribution": "uniform",
|
||||
"min": 80,
|
||||
"max": 120
|
||||
}
|
||||
}
|
||||
}
|
||||
5
benchmarks/multi_turn/requirements.txt
Normal file
5
benchmarks/multi_turn/requirements.txt
Normal file
@ -0,0 +1,5 @@
|
||||
numpy>=1.24
|
||||
pandas>=2.0.0
|
||||
aiohttp>=3.10
|
||||
transformers>=4.46
|
||||
xlsxwriter>=3.2.1
|
||||
@ -182,17 +182,17 @@ endif()
|
||||
#
|
||||
# Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 /ARM platforms)
|
||||
# Flag to enable ACL kernels for AARCH64 platforms
|
||||
if ( VLLM_BUILD_ACL STREQUAL "ON")
|
||||
if (VLLM_BUILD_ACL STREQUAL "ON")
|
||||
set(USE_ACL ON)
|
||||
else()
|
||||
set(USE_ACL OFF)
|
||||
endif()
|
||||
|
||||
if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND)
|
||||
if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND OR POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND)
|
||||
FetchContent_Declare(
|
||||
oneDNN
|
||||
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
|
||||
GIT_TAG v3.8.1
|
||||
GIT_TAG v3.9
|
||||
GIT_PROGRESS TRUE
|
||||
GIT_SHALLOW TRUE
|
||||
)
|
||||
@ -204,7 +204,7 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND)
|
||||
endif()
|
||||
set(ONEDNN_AARCH64_USE_ACL "ON")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ENV{ACL_ROOT_DIR}/build/")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set(ONEDNN_LIBRARY_TYPE "STATIC")
|
||||
set(ONEDNN_BUILD_DOC "OFF")
|
||||
@ -217,38 +217,23 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND)
|
||||
set(ONEDNN_ENABLE_ITT_TASKS "OFF")
|
||||
set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF")
|
||||
set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF")
|
||||
set(ONEDNN_VERBOSE "OFF")
|
||||
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
|
||||
|
||||
FetchContent_MakeAvailable(oneDNN)
|
||||
|
||||
list(APPEND LIBS dnnl)
|
||||
elseif(POWER10_FOUND)
|
||||
FetchContent_Declare(
|
||||
oneDNN
|
||||
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
|
||||
GIT_TAG v3.7.2
|
||||
GIT_PROGRESS TRUE
|
||||
GIT_SHALLOW TRUE
|
||||
add_library(dnnl_ext OBJECT "csrc/cpu/dnnl_helper.cpp")
|
||||
target_include_directories(
|
||||
dnnl_ext
|
||||
PUBLIC ${oneDNN_SOURCE_DIR}/include
|
||||
PUBLIC ${oneDNN_BINARY_DIR}/include
|
||||
PRIVATE ${oneDNN_SOURCE_DIR}/src
|
||||
)
|
||||
|
||||
set(ONEDNN_LIBRARY_TYPE "STATIC")
|
||||
set(ONEDNN_BUILD_DOC "OFF")
|
||||
set(ONEDNN_BUILD_EXAMPLES "OFF")
|
||||
set(ONEDNN_BUILD_TESTS "OFF")
|
||||
set(ONEDNN_ENABLE_WORKLOAD "INFERENCE")
|
||||
set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER")
|
||||
set(ONEDNN_BUILD_GRAPH "OFF")
|
||||
set(ONEDNN_ENABLE_JIT_PROFILING "OFF")
|
||||
set(ONEDNN_ENABLE_ITT_TASKS "OFF")
|
||||
set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF")
|
||||
set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF")
|
||||
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
|
||||
|
||||
set(DNNL_CPU_RUNTIME "OMP")
|
||||
|
||||
FetchContent_MakeAvailable(oneDNN)
|
||||
|
||||
list(APPEND LIBS dnnl)
|
||||
target_link_libraries(dnnl_ext dnnl)
|
||||
target_compile_options(dnnl_ext PRIVATE ${CXX_COMPILE_FLAGS} -fPIC)
|
||||
list(APPEND LIBS dnnl_ext)
|
||||
set(USE_ONEDNN ON)
|
||||
else()
|
||||
set(USE_ONEDNN OFF)
|
||||
endif()
|
||||
|
||||
message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")
|
||||
@ -275,7 +260,6 @@ set(VLLM_EXT_SRC
|
||||
|
||||
if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cpu/quant.cpp"
|
||||
"csrc/cpu/shm.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI)
|
||||
@ -289,14 +273,11 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||
${VLLM_EXT_SRC})
|
||||
add_compile_definitions(-DCPU_CAPABILITY_AVX512)
|
||||
endif()
|
||||
elseif(POWER10_FOUND)
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cpu/quant.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
endif()
|
||||
if (ASIMD_FOUND)
|
||||
|
||||
if(USE_ONEDNN)
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cpu/quant.cpp"
|
||||
"csrc/cpu/dnnl_kernels.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
endif()
|
||||
|
||||
|
||||
@ -19,7 +19,7 @@ else()
|
||||
FetchContent_Declare(
|
||||
flashmla
|
||||
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git
|
||||
GIT_TAG 575f7724b9762f265bbee5889df9c7d630801845
|
||||
GIT_TAG a757314c04eedd166e329e846c820eb1bdd702de
|
||||
GIT_PROGRESS TRUE
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
@ -37,13 +37,14 @@ cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
|
||||
set(FlashMLA_SOURCES
|
||||
${flashmla_SOURCE_DIR}/csrc/flash_api.cpp
|
||||
${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_bf16_sm90.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_fp16_sm90.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_metadata.cu)
|
||||
${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/kernels/mla_combine.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu)
|
||||
|
||||
set(FlashMLA_INCLUDES
|
||||
${flashmla_SOURCE_DIR}/csrc/cutlass/include
|
||||
${flashmla_SOURCE_DIR}/csrc/include)
|
||||
${flashmla_SOURCE_DIR}/csrc)
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${FlashMLA_SOURCES}"
|
||||
|
||||
@ -38,7 +38,7 @@ else()
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn
|
||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||
GIT_TAG 1c2624e53c078854e0637ee566c72fe2107e75f4
|
||||
GIT_TAG 57b4e68b9f9d94750b46de8f8dbd2bfcc86edd4f
|
||||
GIT_PROGRESS TRUE
|
||||
# Don't share the vllm-flash-attn build between build types
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
|
||||
@ -467,6 +467,12 @@ function (define_gpu_extension_target GPU_MOD_NAME)
|
||||
if (GPU_LANGUAGE STREQUAL "HIP")
|
||||
# Make this target dependent on the hipify preprocessor step.
|
||||
add_dependencies(${GPU_MOD_NAME} hipify${GPU_MOD_NAME})
|
||||
# Make sure we include the hipified versions of the headers, and avoid conflicts with the ones in the original source folder
|
||||
target_include_directories(${GPU_MOD_NAME} PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/csrc
|
||||
${GPU_INCLUDE_DIRECTORIES})
|
||||
else()
|
||||
target_include_directories(${GPU_MOD_NAME} PRIVATE csrc
|
||||
${GPU_INCLUDE_DIRECTORIES})
|
||||
endif()
|
||||
|
||||
if (GPU_ARCHITECTURES)
|
||||
@ -482,8 +488,6 @@ function (define_gpu_extension_target GPU_MOD_NAME)
|
||||
target_compile_definitions(${GPU_MOD_NAME} PRIVATE
|
||||
"-DTORCH_EXTENSION_NAME=${GPU_MOD_NAME}")
|
||||
|
||||
target_include_directories(${GPU_MOD_NAME} PRIVATE csrc
|
||||
${GPU_INCLUDE_DIRECTORIES})
|
||||
|
||||
target_link_libraries(${GPU_MOD_NAME} PRIVATE torch ${GPU_LIBRARIES})
|
||||
|
||||
|
||||
@ -128,6 +128,45 @@ __global__ void act_and_mul_kernel_with_param(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T swigluoai_and_mul(const T& gate, const T& up,
|
||||
float alpha, float limit) {
|
||||
// clamp gate: min=None, max=limit
|
||||
const float gate_f = (float)gate;
|
||||
const float clamped_gate = gate_f > limit ? limit : gate_f;
|
||||
|
||||
// clamp up: min=-limit, max=limit
|
||||
const float up_f = (float)up;
|
||||
const float clamped_up =
|
||||
up_f > limit ? limit : (up_f < -limit ? -limit : up_f);
|
||||
|
||||
// glu = gate * sigmoid(gate * alpha)
|
||||
const float sigmoid_val = 1.0f / (1.0f + expf(-clamped_gate * alpha));
|
||||
const float glu = clamped_gate * sigmoid_val;
|
||||
|
||||
// (up + 1) * glu
|
||||
return (T)((clamped_up + 1.0f) * glu);
|
||||
}
|
||||
|
||||
template <typename scalar_t,
|
||||
scalar_t (*ACT_FN)(const scalar_t&, const scalar_t&, const float,
|
||||
const float)>
|
||||
__global__ void swigluoai_and_mul_kernel(
|
||||
scalar_t* __restrict__ out, // [..., d]
|
||||
const scalar_t* __restrict__ input, // [..., 2, d]
|
||||
const int d, const float alpha, const float limit) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
// TODO: Vectorize loads and stores.
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
// gate = x[..., ::2] (even indices)
|
||||
const scalar_t gate = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx]);
|
||||
// up = x[..., 1::2] (odd indices)
|
||||
const scalar_t up = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx + 1]);
|
||||
|
||||
out[token_idx * d + idx] = ACT_FN(gate, up, alpha, limit);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
#define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM) \
|
||||
@ -145,11 +184,31 @@ __global__ void act_and_mul_kernel_with_param(
|
||||
PARAM); \
|
||||
});
|
||||
|
||||
#define LAUNCH_SIGLUOAI_AND_MUL(KERNEL, ALPHA, LIMIT) \
|
||||
int d = input.size(-1) / 2; \
|
||||
int64_t num_tokens = input.numel() / input.size(-1); \
|
||||
dim3 grid(num_tokens); \
|
||||
dim3 block(std::min(d, 1024)); \
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||
input.scalar_type(), "clamp_swiglu_kernel_with_params", [&] { \
|
||||
vllm::swigluoai_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
|
||||
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
|
||||
input.data_ptr<scalar_t>(), d, ALPHA, \
|
||||
LIMIT); \
|
||||
});
|
||||
|
||||
void fatrelu_and_mul(torch::Tensor& out, // [..., d],
|
||||
torch::Tensor& input, // [..., 2 * d]
|
||||
double threshold) {
|
||||
LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(vllm::fatrelu_kernel, threshold);
|
||||
}
|
||||
void swigluoai_and_mul(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input, // [..., 2 * d]
|
||||
double alpha, double limit) {
|
||||
LAUNCH_SIGLUOAI_AND_MUL(vllm::swigluoai_and_mul, alpha, limit);
|
||||
}
|
||||
namespace vllm {
|
||||
|
||||
// Element-wise activation kernel template.
|
||||
|
||||
@ -167,7 +167,7 @@ typename T::Fmha::Arguments args_from_options(
|
||||
// TODO(trevor-m): Change split_kv back to -1 when
|
||||
// https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will
|
||||
// perform worse with larger context length and smaller batch sizes.
|
||||
num_kv_splits, // split_kv
|
||||
static_cast<int>(num_kv_splits), // split_kv
|
||||
nullptr, // is_var_split_kv
|
||||
};
|
||||
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
|
||||
@ -264,7 +264,7 @@ int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_ba
|
||||
// Assumes device 0 when getting sm_count.
|
||||
arguments.hw_info.sm_count =
|
||||
sm_count <= 0 ? cutlass::KernelHardwareInfo::query_device_multiprocessor_count(/*device_id=*/0) : sm_count;
|
||||
arguments.split_kv = num_kv_splits;
|
||||
arguments.split_kv = static_cast<int>(num_kv_splits);
|
||||
MlaSm100Type::Fmha::set_split_kv(arguments);
|
||||
|
||||
return MlaSm100Type::Fmha::get_workspace_size(arguments);
|
||||
|
||||
@ -40,9 +40,11 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
|
||||
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||
const double scale, const std::string& kv_cache_dtype);
|
||||
|
||||
void gather_cache(
|
||||
void gather_and_maybe_dequant_cache(
|
||||
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
|
||||
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
|
||||
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
|
||||
torch::Tensor const& cu_seq_lens, // [BATCH+1]
|
||||
int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);
|
||||
int64_t batch_size, const std::string& kv_cache_dtype,
|
||||
torch::Tensor const& scale,
|
||||
std::optional<torch::Tensor> seq_starts = std::nullopt);
|
||||
@ -624,9 +624,9 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||
namespace vllm {
|
||||
|
||||
// grid is launched with dimensions (batch, num_splits)
|
||||
template <typename scalar_t>
|
||||
__global__ void gather_cache(
|
||||
const scalar_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
|
||||
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
||||
__global__ void gather_and_maybe_dequant_cache(
|
||||
const cache_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
|
||||
// ENTRIES...]
|
||||
scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...]
|
||||
const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
|
||||
@ -634,6 +634,7 @@ __global__ void gather_cache(
|
||||
const int32_t block_size, const int32_t entry_size,
|
||||
const int64_t block_table_stride, const int64_t cache_block_stride,
|
||||
const int64_t cache_entry_stride, const int64_t dst_entry_stride,
|
||||
const float* __restrict__ scale,
|
||||
const int32_t* __restrict__ seq_starts) { // Optional: starting offsets per
|
||||
// batch
|
||||
|
||||
@ -675,10 +676,16 @@ __global__ void gather_cache(
|
||||
if (partial_block_size) full_blocks_end -= 1;
|
||||
}
|
||||
|
||||
auto copy_entry = [&](const scalar_t* __restrict__ _src,
|
||||
auto copy_entry = [&](const cache_t* __restrict__ _src,
|
||||
scalar_t* __restrict__ _dst) {
|
||||
for (int i = threadIdx.x; i < entry_size; i += blockDim.x)
|
||||
_dst[i] = _src[i];
|
||||
for (int i = threadIdx.x; i < entry_size; i += blockDim.x) {
|
||||
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
|
||||
_dst[i] = static_cast<scalar_t>(_src[i]);
|
||||
} else {
|
||||
_dst[i] =
|
||||
fp8::scaled_convert<scalar_t, cache_t, kv_dt>(_src[i], *scale);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
for (int pid = split_start; pid < full_blocks_end; ++pid) {
|
||||
@ -705,25 +712,31 @@ __global__ void gather_cache(
|
||||
} // namespace vllm
|
||||
|
||||
// Macro to dispatch the kernel based on the data type.
|
||||
#define CALL_GATHER_CACHE(CPY_DTYPE) \
|
||||
vllm::gather_cache<CPY_DTYPE><<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<CPY_DTYPE*>(src_cache.data_ptr()), \
|
||||
reinterpret_cast<CPY_DTYPE*>(dst.data_ptr()), \
|
||||
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
|
||||
block_size, entry_size, block_table_stride, cache_block_stride, \
|
||||
cache_entry_stride, dst_entry_stride, seq_starts_ptr);
|
||||
// SCALAR_T is the data type of the destination tensor.
|
||||
// CACHE_T is the stored data type of kv-cache.
|
||||
// KV_DTYPE is the real data type of kv-cache.
|
||||
#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::gather_and_maybe_dequant_cache<SCALAR_T, CACHE_T, KV_DTYPE> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<CACHE_T*>(src_cache.data_ptr()), \
|
||||
reinterpret_cast<SCALAR_T*>(dst.data_ptr()), \
|
||||
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
|
||||
block_size, entry_size, block_table_stride, cache_block_stride, \
|
||||
cache_entry_stride, dst_entry_stride, \
|
||||
reinterpret_cast<const float*>(scale.data_ptr()), seq_starts_ptr);
|
||||
|
||||
// Gather sequences from the cache into the destination tensor.
|
||||
// - cu_seq_lens contains the cumulative sequence lengths for each batch
|
||||
// - block_table contains the cache block indices for each sequence
|
||||
// - Optionally, seq_starts (if provided) offsets the starting block index by
|
||||
// (seq_starts[bid] / page_size)
|
||||
void gather_cache(
|
||||
void gather_and_maybe_dequant_cache(
|
||||
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
|
||||
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
|
||||
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
|
||||
torch::Tensor const& cu_seq_lens, // [BATCH+1]
|
||||
int64_t batch_size,
|
||||
int64_t batch_size, const std::string& kv_cache_dtype,
|
||||
torch::Tensor const& scale,
|
||||
std::optional<torch::Tensor> seq_starts = std::nullopt) {
|
||||
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
@ -761,20 +774,8 @@ void gather_cache(
|
||||
dim3 grid(batch_size, num_splits);
|
||||
dim3 block(1024);
|
||||
|
||||
TORCH_CHECK(src_cache.dtype() == dst.dtype(),
|
||||
"src_cache and dst must have the same dtype");
|
||||
|
||||
const int dtype_bits = src_cache.element_size() * 8;
|
||||
const int32_t* seq_starts_ptr =
|
||||
seq_starts.has_value() ? seq_starts.value().data_ptr<int32_t>() : nullptr;
|
||||
|
||||
if (dtype_bits == 32) {
|
||||
CALL_GATHER_CACHE(uint32_t);
|
||||
} else if (dtype_bits == 16) {
|
||||
CALL_GATHER_CACHE(uint16_t);
|
||||
} else if (dtype_bits == 8) {
|
||||
CALL_GATHER_CACHE(uint8_t);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits);
|
||||
}
|
||||
DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype, CALL_GATHER_CACHE);
|
||||
}
|
||||
|
||||
@ -321,6 +321,8 @@ static inline constexpr auto kFE3M2f =
|
||||
ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);
|
||||
static inline constexpr auto kFE4M3fn =
|
||||
ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN);
|
||||
static inline constexpr auto kFE8M0fnu =
|
||||
ScalarType(8, 0, false, 0, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN);
|
||||
static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2);
|
||||
static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7);
|
||||
static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10);
|
||||
|
||||
@ -89,7 +89,7 @@ struct FP16Vec16 : public Vec<FP16Vec16> {
|
||||
|
||||
explicit FP16Vec16(const FP32Vec16&);
|
||||
|
||||
void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; }
|
||||
void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); }
|
||||
|
||||
void save(void* ptr, const int elem_num) const {
|
||||
constexpr uint32_t M = 0xFFFFFFFF;
|
||||
@ -126,7 +126,7 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||
|
||||
explicit BF16Vec16(const FP32Vec16&);
|
||||
|
||||
void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; }
|
||||
void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); }
|
||||
|
||||
void save(void* ptr, const int elem_num) const {
|
||||
constexpr uint32_t M = 0xFFFFFFFF;
|
||||
@ -180,8 +180,8 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
(__m128i)vec8_data.reg, 1)) {}
|
||||
|
||||
void save(void* ptr) const {
|
||||
*reinterpret_cast<__m256i*>(ptr) = reg_low;
|
||||
*reinterpret_cast<__m256i*>((__m256i*)ptr + 1) = reg_high;
|
||||
_mm256_storeu_si256((__m256i*)ptr, reg_low);
|
||||
_mm256_storeu_si256((__m256i*)ptr + 1, reg_high);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
346
csrc/cpu/dnnl_helper.cpp
Normal file
346
csrc/cpu/dnnl_helper.cpp
Normal file
@ -0,0 +1,346 @@
|
||||
#include <list>
|
||||
#include <optional>
|
||||
|
||||
#include "common/memory_desc.hpp"
|
||||
#include "common/memory.hpp"
|
||||
|
||||
#include "dnnl_helper.h"
|
||||
|
||||
static dnnl::engine& default_engine() {
|
||||
static dnnl::engine engine(dnnl::engine::kind::cpu, 0);
|
||||
return engine;
|
||||
}
|
||||
|
||||
static dnnl::stream& default_stream() {
|
||||
static dnnl::stream stream(default_engine());
|
||||
return stream;
|
||||
}
|
||||
|
||||
void release_dnnl_matmul_handler(int64_t handler) {
|
||||
DNNLMatMulPrimitiveHandler* ptr =
|
||||
reinterpret_cast<DNNLMatMulPrimitiveHandler*>(handler);
|
||||
delete ptr;
|
||||
}
|
||||
|
||||
template <typename KT, typename VT>
|
||||
class DNNLPrimitiveCache {
|
||||
public:
|
||||
using cache_value_t = std::pair<KT, VT>;
|
||||
using result_value_t = VT;
|
||||
using container_t = std::list<cache_value_t>;
|
||||
using value_iterator_t = typename container_t::iterator;
|
||||
using map_t = std::unordered_map<KT, value_iterator_t>;
|
||||
using creator_t = VT (*)();
|
||||
|
||||
public:
|
||||
DNNLPrimitiveCache(size_t capacity)
|
||||
: capacity_(capacity),
|
||||
values_(),
|
||||
key_to_value_(std::min(256lu, capacity)) {
|
||||
assert(capacity > 0);
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
result_value_t get_or_create(const KT& key, F&& creator) {
|
||||
std::optional<value_iterator_t> value = get_value(key);
|
||||
if (value.has_value()) {
|
||||
return value.value()->second;
|
||||
} else {
|
||||
return add_value({key, creator()})->second;
|
||||
}
|
||||
}
|
||||
|
||||
size_t size() const { return values_.size(); }
|
||||
|
||||
private:
|
||||
void dump_data() {
|
||||
std::stringstream ss;
|
||||
ss << "table_id: " << std::hex << reinterpret_cast<size_t>(this) << std::dec
|
||||
<< "\n";
|
||||
ss << "container: [";
|
||||
for (auto&& iter : values_) {
|
||||
ss << "(" << iter.first << ", " << std::hex
|
||||
<< reinterpret_cast<size_t>(iter.second.get()) << "), " << std::dec;
|
||||
}
|
||||
ss << "]\n";
|
||||
|
||||
ss << "map: [";
|
||||
for (auto&& iter : key_to_value_) {
|
||||
ss << "(" << iter.first << ", " << iter.second->first << ", " << std::hex
|
||||
<< reinterpret_cast<size_t>(iter.second->second.get()) << std::dec
|
||||
<< "), ";
|
||||
}
|
||||
ss << "]\n";
|
||||
std::printf("%s\n", ss.str().c_str());
|
||||
}
|
||||
|
||||
value_iterator_t add_value(cache_value_t&& new_value) {
|
||||
if (size() == capacity_) {
|
||||
cache_value_t& last_item = values_.back();
|
||||
key_to_value_.erase(last_item.first);
|
||||
values_.pop_back();
|
||||
}
|
||||
|
||||
auto& added_value_ = values_.emplace_front(std::move(new_value));
|
||||
key_to_value_.emplace(added_value_.first, values_.begin());
|
||||
return values_.begin();
|
||||
}
|
||||
|
||||
std::optional<value_iterator_t> get_value(const KT& key) {
|
||||
if (key_to_value_.size() > 0 && key == values_.begin()->first) {
|
||||
return values_.begin();
|
||||
}
|
||||
|
||||
auto value_map_iterator = key_to_value_.find(key);
|
||||
if (value_map_iterator != key_to_value_.end()) {
|
||||
values_.splice(values_.begin(), values_, value_map_iterator->second);
|
||||
return value_map_iterator->second;
|
||||
} else {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
const size_t capacity_;
|
||||
container_t values_;
|
||||
map_t key_to_value_;
|
||||
};
|
||||
|
||||
DNNLMatMulPrimitiveHandler::DNNLMatMulPrimitiveHandler(
|
||||
const Args& args, dnnl::memory::data_type b_type)
|
||||
: b_n_size_(args.b_n_size),
|
||||
b_n_stride_(args.b_n_stride),
|
||||
b_k_size_(args.b_k_size),
|
||||
b_k_stride_(args.b_k_stride),
|
||||
b_type_(b_type),
|
||||
c_type_(args.c_type),
|
||||
runtime_memory_ptrs_(8),
|
||||
primitive_cache_size_(args.primitive_cache_size) {
|
||||
assert(primitive_cache_size_ > 0);
|
||||
}
|
||||
|
||||
void DNNLMatMulPrimitiveHandler::prepack_weight(
|
||||
void* original_b_ptr, dnnl::memory::desc b_target_mem_desc) {
|
||||
dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_,
|
||||
{b_k_stride_, b_n_stride_});
|
||||
dnnl::memory original_weight(original_b_md, default_engine(), original_b_ptr);
|
||||
dnnl::memory packed_weight(b_target_mem_desc, default_engine());
|
||||
{
|
||||
dnnl::reorder(original_weight, packed_weight)
|
||||
.execute(default_stream(), original_weight, packed_weight);
|
||||
default_stream().wait();
|
||||
}
|
||||
memory_cache_[DNNL_ARG_WEIGHTS] = packed_weight;
|
||||
b_target_mem_desc_ = b_target_mem_desc;
|
||||
}
|
||||
|
||||
void DNNLMatMulPrimitiveHandler::set_runtime_memory_ptr(
|
||||
size_t index, dnnl_memory* memory_ptr) {
|
||||
dnnl::impl::memory_storage_t* mem_storage_ptr = memory_ptr->memory_storage();
|
||||
dnnl_memory_desc* mem_desc = const_cast<dnnl_memory_desc*>(memory_ptr->md());
|
||||
runtime_memory_ptrs_[index] = {mem_storage_ptr, mem_desc};
|
||||
}
|
||||
|
||||
std::pair<dnnl::impl::memory_storage_t*, dnnl_memory_desc*>
|
||||
DNNLMatMulPrimitiveHandler::get_runtime_memory_ptr(size_t index) {
|
||||
return runtime_memory_ptrs_[index];
|
||||
}
|
||||
|
||||
namespace std {
|
||||
template <>
|
||||
struct hash<W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey> {
|
||||
size_t operator()(
|
||||
const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& val) const {
|
||||
return hash<dnnl_dim_t>()(val.b_n_size) ^ hash<dnnl_dim_t>()(val.b_k_size) ^
|
||||
hash<int>()(static_cast<int>(val.a_qs)) ^
|
||||
hash<int>()(static_cast<int>(val.b_qs)) ^ hash<bool>()(val.use_azp) ^
|
||||
hash<int>()(static_cast<int>(val.c_type));
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct hash<W8A8MatMulPrimitiveHandler::MSizeCacheKey> {
|
||||
size_t operator()(
|
||||
const W8A8MatMulPrimitiveHandler::MSizeCacheKey& val) const {
|
||||
return hash<dnnl_dim_t>()(val.a_m_size) ^ hash<bool>()(val.use_bias) ^
|
||||
hash<int>()(static_cast<int>(val.bias_type));
|
||||
}
|
||||
};
|
||||
} // namespace std
|
||||
|
||||
bool operator==(const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& l,
|
||||
const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& r) {
|
||||
return l.b_n_size == r.b_n_size && l.b_k_size == r.b_k_size &&
|
||||
l.a_qs == r.a_qs && l.b_qs == r.b_qs && l.use_azp == r.use_azp &&
|
||||
l.c_type == r.c_type;
|
||||
}
|
||||
|
||||
bool operator==(const W8A8MatMulPrimitiveHandler::MSizeCacheKey& l,
|
||||
const W8A8MatMulPrimitiveHandler::MSizeCacheKey& r) {
|
||||
return l.use_bias == r.use_bias && l.a_m_size == r.a_m_size &&
|
||||
l.bias_type == r.bias_type;
|
||||
}
|
||||
|
||||
static std::shared_ptr<W8A8MatMulPrimitiveHandler::MSizeCache>
|
||||
get_w8a8_class_primitive_cache(
|
||||
const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& key,
|
||||
int64_t cache_size) {
|
||||
static W8A8MatMulPrimitiveHandler::ClassMatmulCache cache(128);
|
||||
assert(cache_size > 0);
|
||||
return cache.get_or_create(key, [&]() {
|
||||
return std::make_shared<W8A8MatMulPrimitiveHandler::MSizeCache>(cache_size);
|
||||
});
|
||||
}
|
||||
|
||||
W8A8MatMulPrimitiveHandler::W8A8MatMulPrimitiveHandler(const Args& args)
|
||||
: DNNLMatMulPrimitiveHandler(
|
||||
static_cast<const DNNLMatMulPrimitiveHandler::Args&>(args),
|
||||
dnnl::memory::data_type::s8),
|
||||
use_azp_(args.use_a_zero_point),
|
||||
a_qs_(args.a_quantization_strategy),
|
||||
b_qs_(args.b_quantization_strategy),
|
||||
m_size_cache_(nullptr) {
|
||||
assert(a_qs_ != QuantizationStrategy::PER_OUTPUT_CHANNEL);
|
||||
assert(b_qs_ != QuantizationStrategy::PER_TOKEN);
|
||||
if (a_qs_ == QuantizationStrategy::PER_TOKEN) {
|
||||
assert(!use_azp_);
|
||||
};
|
||||
prepack_weight(args.b_ptr,
|
||||
create_primitive_desc(
|
||||
MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL,
|
||||
.use_bias = false,
|
||||
.bias_type = dnnl::memory::data_type::undef},
|
||||
true)
|
||||
.weights_desc());
|
||||
init_runtime_memory_cache(args);
|
||||
}
|
||||
|
||||
void W8A8MatMulPrimitiveHandler::execute(ExecArgs& args) {
|
||||
auto&& [a_storage, a_mem_desc] = get_runtime_memory_ptr(0);
|
||||
auto&& [c_storage, c_mem_desc] = get_runtime_memory_ptr(1);
|
||||
a_storage->set_data_handle((void*)args.a_ptr);
|
||||
a_mem_desc->dims[0] = args.a_m_size;
|
||||
c_storage->set_data_handle((void*)args.c_ptr);
|
||||
c_mem_desc->dims[0] = args.a_m_size;
|
||||
|
||||
if (a_qs_ == QuantizationStrategy::PER_TENSOR) {
|
||||
auto&& [a_scale_storage, a_scale_mem_desc] = get_runtime_memory_ptr(2);
|
||||
a_scale_storage->set_data_handle((void*)args.a_scales_ptr);
|
||||
}
|
||||
if (use_azp_) {
|
||||
auto&& [a_zero_point_storage, a_zero_point_mem_desc] =
|
||||
get_runtime_memory_ptr(3);
|
||||
a_zero_point_storage->set_data_handle((void*)args.a_zero_points_ptr);
|
||||
}
|
||||
|
||||
if (args.use_bias) {
|
||||
auto&& [bias_storage, bias_mem_desc] = get_runtime_memory_ptr(4);
|
||||
bias_storage->set_data_handle((void*)args.bias_ptr);
|
||||
}
|
||||
|
||||
dnnl::matmul matmul = get_matmul_cache(args);
|
||||
matmul.execute(default_stream(), memory_cache_);
|
||||
default_stream().wait();
|
||||
}
|
||||
|
||||
dnnl::matmul W8A8MatMulPrimitiveHandler::get_matmul_cache(
|
||||
const MSizeCacheKey& key) {
|
||||
if (m_size_cache_.get() == nullptr) {
|
||||
ClassMatmulCacheKey key = {.b_n_size = b_n_size_,
|
||||
.b_k_size = b_k_size_,
|
||||
.a_qs = a_qs_,
|
||||
.b_qs = b_qs_,
|
||||
.use_azp = use_azp_,
|
||||
.c_type = c_type_};
|
||||
m_size_cache_ = get_w8a8_class_primitive_cache(key, primitive_cache_size_);
|
||||
}
|
||||
|
||||
return m_size_cache_->get_or_create(key, [&]() {
|
||||
dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false);
|
||||
return dnnl::matmul(desc);
|
||||
});
|
||||
}
|
||||
|
||||
void W8A8MatMulPrimitiveHandler::init_runtime_memory_cache(const Args& args) {
|
||||
memory_cache_[DNNL_ARG_SRC] = dnnl::memory({{1, b_k_size_},
|
||||
dnnl::memory::data_type::s8,
|
||||
dnnl::memory::format_tag::ab},
|
||||
default_engine(), nullptr);
|
||||
set_runtime_memory_ptr(0, memory_cache_[DNNL_ARG_SRC].get());
|
||||
memory_cache_[DNNL_ARG_DST] =
|
||||
dnnl::memory({{1, b_n_size_}, c_type_, dnnl::memory::format_tag::ab},
|
||||
default_engine(), nullptr);
|
||||
set_runtime_memory_ptr(1, memory_cache_[DNNL_ARG_DST].get());
|
||||
|
||||
// For PER_TOKEN, scales will be applied in outside epilogue
|
||||
if (a_qs_ == QuantizationStrategy::PER_TENSOR) {
|
||||
memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC] = dnnl::memory(
|
||||
{{1}, dnnl::memory::data_type::f32, {1}}, default_engine(), nullptr);
|
||||
set_runtime_memory_ptr(
|
||||
2, memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC].get());
|
||||
if (use_azp_) {
|
||||
memory_cache_[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC] = dnnl::memory(
|
||||
{{1}, dnnl::memory::data_type::s32, {1}}, default_engine(), nullptr);
|
||||
set_runtime_memory_ptr(
|
||||
3, memory_cache_[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC].get());
|
||||
}
|
||||
}
|
||||
|
||||
if (b_qs_ == QuantizationStrategy::PER_TENSOR) {
|
||||
memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] =
|
||||
dnnl::memory({{1}, dnnl::memory::data_type::f32, {1}}, default_engine(),
|
||||
(void*)args.b_scales_ptr);
|
||||
} else if (b_qs_ == QuantizationStrategy::PER_OUTPUT_CHANNEL) {
|
||||
memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] =
|
||||
dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
|
||||
default_engine(), (void*)args.b_scales_ptr);
|
||||
}
|
||||
|
||||
memory_cache_[DNNL_ARG_BIAS] =
|
||||
dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
|
||||
default_engine(), nullptr);
|
||||
set_runtime_memory_ptr(4, memory_cache_[DNNL_ARG_BIAS].get());
|
||||
}
|
||||
|
||||
dnnl::matmul::primitive_desc W8A8MatMulPrimitiveHandler::create_primitive_desc(
|
||||
const MSizeCacheKey& key, bool first_time) {
|
||||
dnnl::memory::desc a_md({key.a_m_size, b_k_size_},
|
||||
dnnl::memory::data_type::s8,
|
||||
dnnl::memory::format_tag::ab);
|
||||
dnnl::memory::desc b_md;
|
||||
if (first_time) {
|
||||
b_md =
|
||||
dnnl::memory::desc({b_k_size_, b_n_size_}, dnnl::memory::data_type::s8,
|
||||
dnnl::memory::format_tag::any);
|
||||
} else {
|
||||
b_md = b_target_mem_desc_;
|
||||
}
|
||||
dnnl::memory::desc c_md({key.a_m_size, b_n_size_}, c_type_,
|
||||
dnnl::memory::format_tag::ab);
|
||||
|
||||
dnnl::primitive_attr attr;
|
||||
// For PER_TOKEN, scales will be applied in outside epilogue
|
||||
if (a_qs_ == QuantizationStrategy::PER_TENSOR) {
|
||||
attr.set_scales_mask(DNNL_ARG_SRC, 0);
|
||||
if (use_azp_) {
|
||||
attr.set_zero_points_mask(DNNL_ARG_SRC, 0);
|
||||
}
|
||||
}
|
||||
|
||||
if (b_qs_ == QuantizationStrategy::PER_TENSOR) {
|
||||
attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0);
|
||||
} else if (b_qs_ == QuantizationStrategy::PER_OUTPUT_CHANNEL) {
|
||||
attr.set_scales_mask(DNNL_ARG_WEIGHTS, 2);
|
||||
}
|
||||
|
||||
if (key.use_bias) {
|
||||
// For PER_TOKEN, bias will be applied in epilogue
|
||||
assert(a_qs_ == QuantizationStrategy::PER_TENSOR);
|
||||
dnnl::memory::desc bias_md({1, b_n_size_}, key.bias_type, {b_n_size_, 1});
|
||||
return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, bias_md,
|
||||
c_md, attr);
|
||||
} else {
|
||||
return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md,
|
||||
attr);
|
||||
}
|
||||
}
|
||||
169
csrc/cpu/dnnl_helper.h
Normal file
169
csrc/cpu/dnnl_helper.h
Normal file
@ -0,0 +1,169 @@
|
||||
#ifndef DNNL_HELPER_H
|
||||
#define DNNL_HELPER_H
|
||||
|
||||
#include <optional>
|
||||
#include <cassert>
|
||||
|
||||
#include "oneapi/dnnl/dnnl.hpp"
|
||||
|
||||
namespace c10 {
|
||||
struct BFloat16;
|
||||
struct Half;
|
||||
} // namespace c10
|
||||
|
||||
namespace dnnl {
|
||||
namespace impl {
|
||||
struct memory_storage_t;
|
||||
struct matmul_pd_t;
|
||||
struct matmul_desc_t;
|
||||
} // namespace impl
|
||||
} // namespace dnnl
|
||||
struct dnnl_memory_desc;
|
||||
|
||||
template <typename KT, typename VT>
|
||||
class DNNLPrimitiveCache;
|
||||
|
||||
template <typename T>
|
||||
struct DNNLType {
|
||||
static constexpr dnnl::memory::data_type type =
|
||||
dnnl::memory::data_type::undef;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<int8_t> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s8;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<int32_t> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s32;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<float> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f32;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<c10::BFloat16> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<c10::Half> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f16;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
constexpr inline dnnl::memory::data_type get_dnnl_type() {
|
||||
return DNNLType<std::decay_t<T>>::type;
|
||||
}
|
||||
|
||||
class DNNLMatMulPrimitiveHandler {
|
||||
public:
|
||||
virtual ~DNNLMatMulPrimitiveHandler() = default;
|
||||
|
||||
protected:
|
||||
struct Args {
|
||||
dnnl_dim_t b_n_size;
|
||||
dnnl_dim_t b_n_stride;
|
||||
dnnl_dim_t b_k_size;
|
||||
dnnl_dim_t b_k_stride;
|
||||
void* b_ptr;
|
||||
dnnl::memory::data_type c_type;
|
||||
size_t primitive_cache_size;
|
||||
};
|
||||
|
||||
protected:
|
||||
DNNLMatMulPrimitiveHandler(const Args& args, dnnl::memory::data_type b_type);
|
||||
|
||||
void prepack_weight(void* original_b_ptr,
|
||||
dnnl::memory::desc b_target_mem_desc);
|
||||
|
||||
void set_runtime_memory_ptr(size_t index, dnnl_memory* memory_ptr);
|
||||
|
||||
std::pair<dnnl::impl::memory_storage_t*, dnnl_memory_desc*>
|
||||
get_runtime_memory_ptr(size_t index);
|
||||
|
||||
protected:
|
||||
const dnnl_dim_t b_n_size_;
|
||||
const dnnl_dim_t b_n_stride_;
|
||||
const dnnl_dim_t b_k_size_;
|
||||
const dnnl_dim_t b_k_stride_;
|
||||
dnnl::memory::data_type b_type_;
|
||||
dnnl::memory::data_type c_type_;
|
||||
std::unordered_map<int, dnnl::memory> memory_cache_;
|
||||
std::vector<std::pair<dnnl::impl::memory_storage_t*, dnnl_memory_desc*>>
|
||||
runtime_memory_ptrs_;
|
||||
dnnl::memory::desc b_target_mem_desc_;
|
||||
int64_t primitive_cache_size_;
|
||||
};
|
||||
|
||||
class W8A8MatMulPrimitiveHandler : public DNNLMatMulPrimitiveHandler {
|
||||
public:
|
||||
enum class QuantizationStrategy { PER_TOKEN, PER_TENSOR, PER_OUTPUT_CHANNEL };
|
||||
|
||||
struct Args : public DNNLMatMulPrimitiveHandler::Args {
|
||||
bool use_a_zero_point;
|
||||
QuantizationStrategy a_quantization_strategy;
|
||||
QuantizationStrategy b_quantization_strategy;
|
||||
float* b_scales_ptr;
|
||||
};
|
||||
|
||||
struct ClassMatmulCacheKey {
|
||||
dnnl_dim_t b_n_size;
|
||||
dnnl_dim_t b_k_size;
|
||||
QuantizationStrategy a_qs;
|
||||
QuantizationStrategy b_qs;
|
||||
bool use_azp;
|
||||
dnnl::memory::data_type c_type;
|
||||
|
||||
friend bool operator==(const ClassMatmulCacheKey& l,
|
||||
const ClassMatmulCacheKey& r);
|
||||
};
|
||||
|
||||
struct MSizeCacheKey {
|
||||
dnnl_dim_t a_m_size;
|
||||
bool use_bias;
|
||||
dnnl::memory::data_type bias_type;
|
||||
|
||||
friend bool operator==(const MSizeCacheKey& l, const MSizeCacheKey& r);
|
||||
};
|
||||
|
||||
using MSizeCache = DNNLPrimitiveCache<MSizeCacheKey, dnnl::matmul>;
|
||||
using ClassMatmulCache =
|
||||
DNNLPrimitiveCache<ClassMatmulCacheKey, std::shared_ptr<MSizeCache>>;
|
||||
|
||||
struct ExecArgs : public MSizeCacheKey {
|
||||
const int8_t* a_ptr;
|
||||
const float* a_scales_ptr;
|
||||
const int32_t* a_zero_points_ptr;
|
||||
const void* bias_ptr;
|
||||
void* c_ptr;
|
||||
};
|
||||
|
||||
public:
|
||||
W8A8MatMulPrimitiveHandler(const Args& args);
|
||||
|
||||
QuantizationStrategy get_input_scale_strategy() const { return a_qs_; }
|
||||
|
||||
bool get_input_use_zero_point() const { return use_azp_; }
|
||||
|
||||
void execute(ExecArgs& args);
|
||||
|
||||
private:
|
||||
dnnl::matmul::primitive_desc create_primitive_desc(const MSizeCacheKey& key,
|
||||
bool first_time);
|
||||
|
||||
void init_runtime_memory_cache(const Args& args);
|
||||
|
||||
dnnl::matmul get_matmul_cache(const MSizeCacheKey& key);
|
||||
|
||||
private:
|
||||
const bool use_azp_;
|
||||
const QuantizationStrategy a_qs_;
|
||||
const QuantizationStrategy b_qs_;
|
||||
std::shared_ptr<MSizeCache> m_size_cache_;
|
||||
};
|
||||
|
||||
#endif
|
||||
@ -1,206 +0,0 @@
|
||||
#ifndef DNNL_HELPER_HPP
|
||||
#define DNNL_HELPER_HPP
|
||||
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Half.h>
|
||||
|
||||
#include "oneapi/dnnl/dnnl.hpp"
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
struct DNNLType {
|
||||
static constexpr dnnl::memory::data_type type =
|
||||
dnnl::memory::data_type::undef;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<int8_t> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s8;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<int32_t> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s32;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<float> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f32;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<c10::BFloat16> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<c10::Half> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f16;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
constexpr inline dnnl::memory::data_type get_dnnl_type() {
|
||||
return DNNLType<std::decay_t<T>>::type;
|
||||
}
|
||||
}; // namespace
|
||||
|
||||
template <bool InputNoScale>
|
||||
class DNNLPrimitiveHelper {
|
||||
public:
|
||||
// I8 input GEMM kernel (C = a_scales * A @ (b_scales * B^T) + bias)
|
||||
// A: [M, K], row-major
|
||||
// B: [K, N], column-major
|
||||
// C: [M, N], row-major
|
||||
// bias: [N], row-major, optional
|
||||
// a_scales: [MS]
|
||||
// b_scales: [NS]
|
||||
// Note: Due to the limitation of oneDNN
|
||||
// (https://github.com/oneapi-src/oneDNN/issues/1636), the quantized bias is
|
||||
// not supported.
|
||||
|
||||
template <typename OutputT, typename BiasT>
|
||||
static void gemm_s8s8_jit(const int8_t* a, const int8_t* b, OutputT* c,
|
||||
const BiasT* bias, dnnl_dim_t M, dnnl_dim_t N,
|
||||
dnnl_dim_t K, const float* a_scales,
|
||||
const float* b_scales, dnnl_dim_t MS,
|
||||
dnnl_dim_t NS) {
|
||||
auto&& OutputType = get_dnnl_type<OutputT>();
|
||||
auto&& BiasType = get_dnnl_type<BiasT>();
|
||||
|
||||
dnnl::memory::desc a_md({M, K}, dnnl::memory::data_type::s8, {K, 1});
|
||||
dnnl::memory::desc b_md({K, N}, dnnl::memory::data_type::s8, {1, K});
|
||||
dnnl::memory::desc c_md({M, N}, OutputType, {N, 1});
|
||||
|
||||
dnnl::primitive_attr attr;
|
||||
if constexpr (!InputNoScale) {
|
||||
if (MS == 1) {
|
||||
// per-tensor
|
||||
attr.set_scales_mask(DNNL_ARG_SRC, 0);
|
||||
} else {
|
||||
// per-token
|
||||
TORCH_CHECK(false, "per-token quantization is unsupported.");
|
||||
}
|
||||
}
|
||||
|
||||
if (NS == 1) {
|
||||
// per-tensor
|
||||
attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0);
|
||||
} else {
|
||||
// per-channel
|
||||
attr.set_scales_mask(DNNL_ARG_WEIGHTS, 2);
|
||||
}
|
||||
|
||||
dnnl::matmul::primitive_desc matmul_pd;
|
||||
// Create memory descriptors with format_tag::any for the primitive. This
|
||||
// enables the matmul primitive to choose memory layouts for an
|
||||
// optimized primitive implementation, and these layouts may differ from the
|
||||
// ones provided by the user.
|
||||
#ifdef __aarch64__
|
||||
auto mat_src_md = dnnl::memory::desc({M, K}, dnnl::memory::data_type::s8,
|
||||
dnnl::memory::format_tag::any);
|
||||
auto mat_weights_md = dnnl::memory::desc(
|
||||
{K, N}, dnnl::memory::data_type::s8, dnnl::memory::format_tag::any);
|
||||
auto mat_dst_md =
|
||||
dnnl::memory::desc({M, N}, OutputType, dnnl::memory::format_tag::any);
|
||||
if (bias) {
|
||||
dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1});
|
||||
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), mat_src_md,
|
||||
mat_weights_md, bias_md,
|
||||
mat_dst_md, attr);
|
||||
} else {
|
||||
matmul_pd = dnnl::matmul::primitive_desc(
|
||||
default_engine(), mat_src_md, mat_weights_md, mat_dst_md, attr);
|
||||
}
|
||||
#else
|
||||
if (bias) {
|
||||
dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1});
|
||||
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
|
||||
bias_md, c_md, attr);
|
||||
} else {
|
||||
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
|
||||
c_md, attr);
|
||||
}
|
||||
#endif
|
||||
dnnl::matmul matmul(matmul_pd);
|
||||
|
||||
auto& engine = default_engine();
|
||||
|
||||
dnnl::memory a_m(a_md, engine, (void*)a);
|
||||
dnnl::memory b_m(b_md, engine, (void*)b);
|
||||
dnnl::memory c_m(c_md, engine, (void*)c);
|
||||
dnnl::memory a_scales_m({{MS}, dnnl::memory::data_type::f32, {1}}, engine,
|
||||
(void*)a_scales);
|
||||
dnnl::memory b_scales_m({{NS}, dnnl::memory::data_type::f32, {1}}, engine,
|
||||
(void*)b_scales);
|
||||
|
||||
auto& stream = default_stream();
|
||||
|
||||
auto mat_src_mem = a_m;
|
||||
auto mat_weights_mem = b_m;
|
||||
auto mat_dst_mem = c_m;
|
||||
#ifdef __aarch64__
|
||||
if (matmul_pd.weights_desc() != b_m.get_desc()) {
|
||||
mat_weights_mem = dnnl::memory(matmul_pd.weights_desc(), engine);
|
||||
dnnl::reorder(b_m, mat_weights_mem).execute(stream, b_m, mat_weights_mem);
|
||||
}
|
||||
#endif
|
||||
if constexpr (InputNoScale) {
|
||||
if (bias) {
|
||||
dnnl::memory::desc bias_md({N}, BiasType, {1});
|
||||
dnnl::memory bias_m(bias_md, engine, (void*)bias);
|
||||
matmul.execute(
|
||||
stream, {
|
||||
{DNNL_ARG_SRC, mat_src_mem},
|
||||
{DNNL_ARG_WEIGHTS, mat_weights_mem},
|
||||
{DNNL_ARG_BIAS, bias_m},
|
||||
{DNNL_ARG_DST, mat_dst_mem},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||
});
|
||||
} else {
|
||||
matmul.execute(
|
||||
stream, {
|
||||
{DNNL_ARG_SRC, mat_src_mem},
|
||||
{DNNL_ARG_WEIGHTS, mat_weights_mem},
|
||||
{DNNL_ARG_DST, mat_dst_mem},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||
});
|
||||
}
|
||||
} else {
|
||||
if (bias) {
|
||||
dnnl::memory::desc bias_md({N}, BiasType, {1});
|
||||
dnnl::memory bias_m(bias_md, engine, (void*)bias);
|
||||
matmul.execute(
|
||||
stream, {
|
||||
{DNNL_ARG_SRC, mat_src_mem},
|
||||
{DNNL_ARG_WEIGHTS, mat_weights_mem},
|
||||
{DNNL_ARG_BIAS, bias_m},
|
||||
{DNNL_ARG_DST, mat_dst_mem},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||
});
|
||||
} else {
|
||||
matmul.execute(
|
||||
stream, {
|
||||
{DNNL_ARG_SRC, mat_src_mem},
|
||||
{DNNL_ARG_WEIGHTS, mat_weights_mem},
|
||||
{DNNL_ARG_DST, mat_dst_mem},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||
});
|
||||
}
|
||||
}
|
||||
stream.wait();
|
||||
}
|
||||
|
||||
private:
|
||||
static dnnl::engine& default_engine() {
|
||||
static dnnl::engine engine(dnnl::engine::kind::cpu, 0);
|
||||
return engine;
|
||||
}
|
||||
|
||||
static dnnl::stream& default_stream() {
|
||||
static dnnl::stream stream(default_engine());
|
||||
return stream;
|
||||
}
|
||||
};
|
||||
#endif
|
||||
494
csrc/cpu/dnnl_kernels.cpp
Normal file
494
csrc/cpu/dnnl_kernels.cpp
Normal file
@ -0,0 +1,494 @@
|
||||
#include "cpu_types.hpp"
|
||||
#include "dnnl_helper.h"
|
||||
|
||||
namespace {
|
||||
template <typename scalar_t>
|
||||
struct KernelVecType {
|
||||
using load_vec_type = void;
|
||||
using cvt_vec_type = void;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KernelVecType<float> {
|
||||
using load_vec_type = vec_op::FP32Vec16;
|
||||
using cvt_vec_type = vec_op::FP32Vec16;
|
||||
};
|
||||
|
||||
#if !defined(__aarch64__) || defined(ARM_BF16_SUPPORT)
|
||||
template <>
|
||||
struct KernelVecType<c10::BFloat16> {
|
||||
using load_vec_type = vec_op::BF16Vec16;
|
||||
using cvt_vec_type = vec_op::FP32Vec16;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <>
|
||||
struct KernelVecType<c10::Half> {
|
||||
#if defined(__powerpc64__) || defined(__s390x__)
|
||||
// Power architecture-specific vector type
|
||||
using load_vec_type = vec_op::FP32Vec16;
|
||||
#else
|
||||
// Fallback for other architectures
|
||||
using load_vec_type = vec_op::FP16Vec16;
|
||||
#endif
|
||||
using cvt_vec_type = vec_op::FP32Vec16;
|
||||
};
|
||||
|
||||
template <bool AZP, typename scalar_t>
|
||||
void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
const float* scale, const int32_t* azp,
|
||||
const int64_t num_tokens,
|
||||
const int64_t input_stride,
|
||||
const int64_t hidden_size) {
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int64_t vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
constexpr float i8_min =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::min());
|
||||
constexpr float i8_max =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::max());
|
||||
const cvt_vec_t inv_scale(1.0 / *scale);
|
||||
const cvt_vec_t i8_min_vec(i8_min);
|
||||
const cvt_vec_t i8_max_vec(i8_max);
|
||||
|
||||
cvt_vec_t zp_vec;
|
||||
if constexpr (AZP) {
|
||||
zp_vec = cvt_vec_t(static_cast<float>(*azp));
|
||||
}
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int64_t i = 0; i < num_tokens; ++i) {
|
||||
int64_t j = 0;
|
||||
const scalar_t* input_ptr = input + i * input_stride;
|
||||
int8_t* output_ptr = output + i * hidden_size;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
load_vec_t elems(input_ptr + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = elems_fp32 * inv_scale;
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + zp_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output_ptr + j);
|
||||
}
|
||||
|
||||
load_vec_t elems(input_ptr + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = elems_fp32 * inv_scale;
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + zp_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output_ptr + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool AZP, typename scalar_t>
|
||||
void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
float* scale, int32_t* azp,
|
||||
const int64_t num_tokens,
|
||||
const int64_t input_stride,
|
||||
const int64_t hidden_size) {
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
constexpr float i8_min =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::min());
|
||||
constexpr float i8_max =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::max());
|
||||
const cvt_vec_t i8_min_vec(i8_min);
|
||||
const cvt_vec_t i8_max_vec(i8_max);
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int64_t i = 0; i < num_tokens; ++i) {
|
||||
cvt_vec_t max_value(std::numeric_limits<float>::lowest());
|
||||
cvt_vec_t min_value(std::numeric_limits<float>::max());
|
||||
{
|
||||
int64_t j = 0;
|
||||
const scalar_t* input_ptr = input + i * input_stride;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
load_vec_t elems(input_ptr + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
if constexpr (AZP) {
|
||||
max_value = max_value.max(elems_fp32);
|
||||
min_value = min_value.min(elems_fp32);
|
||||
} else {
|
||||
max_value = max_value.max(elems_fp32.abs());
|
||||
}
|
||||
}
|
||||
|
||||
load_vec_t elems(input_ptr + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
|
||||
if (j + vec_elem_num == hidden_size) {
|
||||
if constexpr (AZP) {
|
||||
max_value = max_value.max(elems_fp32);
|
||||
min_value = min_value.min(elems_fp32);
|
||||
} else {
|
||||
max_value = max_value.max(elems_fp32.abs());
|
||||
}
|
||||
} else {
|
||||
if constexpr (AZP) {
|
||||
max_value = max_value.max(elems_fp32, hidden_size - j);
|
||||
min_value = min_value.min(elems_fp32, hidden_size - j);
|
||||
} else {
|
||||
max_value = max_value.max(elems_fp32.abs(), hidden_size - j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float scale_val, azp_val;
|
||||
if constexpr (AZP) {
|
||||
float max_scalar = max_value.reduce_max();
|
||||
float min_scalar = min_value.reduce_min();
|
||||
scale_val = (max_scalar - min_scalar) / 255.0f;
|
||||
azp_val = std::nearbyint(-128.0f - min_scalar / scale_val);
|
||||
azp[i] = azp_val;
|
||||
scale[i] = scale_val;
|
||||
} else {
|
||||
scale_val = max_value.reduce_max() / 127.0f;
|
||||
scale[i] = scale_val;
|
||||
}
|
||||
|
||||
const cvt_vec_t inv_scale(1.0 / scale_val);
|
||||
const cvt_vec_t azp_vec(azp_val);
|
||||
|
||||
{
|
||||
int64_t j = 0;
|
||||
const scalar_t* input_ptr = input + i * input_stride;
|
||||
int8_t* output_ptr = output + i * hidden_size;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
load_vec_t elems(input_ptr + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = (elems_fp32 * inv_scale);
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + azp_vec;
|
||||
}
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output_ptr + j);
|
||||
}
|
||||
|
||||
load_vec_t elems(input_ptr + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = (elems_fp32 * inv_scale);
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + azp_vec;
|
||||
}
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output_ptr + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool AZP, bool Bias, typename scalar_t>
|
||||
void dynamic_quant_epilogue(const float* input, scalar_t* output,
|
||||
const float* a_scale, const int32_t* azp,
|
||||
const float* azp_adj, const scalar_t* bias,
|
||||
const int64_t num_tokens,
|
||||
const int64_t hidden_size) {
|
||||
CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue)
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
const int64_t thread_num = omp_get_max_threads();
|
||||
if (num_tokens > thread_num) {
|
||||
#pragma omp parallel for
|
||||
for (int64_t i = 0; i < num_tokens; ++i) {
|
||||
const float* input_ptr = input + i * hidden_size;
|
||||
scalar_t* output_ptr = output + i * hidden_size;
|
||||
int64_t j = 0;
|
||||
cvt_vec_t token_scale_vec(a_scale[i]);
|
||||
cvt_vec_t token_zp_scale_vec;
|
||||
if constexpr (AZP) {
|
||||
float zp_scale_val = a_scale[i] * static_cast<float>(azp[i]);
|
||||
token_zp_scale_vec = cvt_vec_t(zp_scale_val);
|
||||
}
|
||||
for (; j < hidden_size - vec_elem_num; ++j) {
|
||||
cvt_vec_t elems_fp32(input_ptr + j);
|
||||
elems_fp32 = elems_fp32 * token_scale_vec;
|
||||
if constexpr (AZP) {
|
||||
cvt_vec_t azp_adj_fp32(azp_adj + j);
|
||||
elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec;
|
||||
}
|
||||
if constexpr (Bias) {
|
||||
load_vec_t bias_vec(bias + j);
|
||||
cvt_vec_t bias_vec_fp32(bias_vec);
|
||||
elems_fp32 = elems_fp32 + bias_vec_fp32;
|
||||
}
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output_ptr + j);
|
||||
}
|
||||
cvt_vec_t elems_fp32(input_ptr + j);
|
||||
elems_fp32 = elems_fp32 * token_scale_vec;
|
||||
if constexpr (AZP) {
|
||||
cvt_vec_t azp_adj_fp32(azp_adj + j);
|
||||
elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec;
|
||||
}
|
||||
if constexpr (Bias) {
|
||||
load_vec_t bias_vec(bias + j);
|
||||
cvt_vec_t bias_vec_fp32(bias_vec);
|
||||
elems_fp32 = elems_fp32 + bias_vec_fp32;
|
||||
}
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output_ptr + j, hidden_size - j);
|
||||
}
|
||||
} else {
|
||||
const int64_t vec_iteration =
|
||||
(hidden_size + vec_elem_num - 1) / vec_elem_num;
|
||||
const int64_t vec_iteration_per_thread =
|
||||
(vec_iteration + thread_num - 1) / thread_num;
|
||||
const int64_t elem_num_per_thread = vec_iteration_per_thread * vec_elem_num;
|
||||
#pragma omp parallel for schedule(static, 1)
|
||||
for (int64_t i = 0; i < thread_num; ++i) {
|
||||
const int64_t start = elem_num_per_thread * i;
|
||||
const int64_t end = std::min(hidden_size, elem_num_per_thread + start);
|
||||
for (int64_t j = 0; j < num_tokens; ++j) {
|
||||
cvt_vec_t token_scale_vec(a_scale[j]);
|
||||
cvt_vec_t token_zp_scale_vec;
|
||||
if constexpr (AZP) {
|
||||
float zp_scale_val = a_scale[j] * static_cast<float>(azp[j]);
|
||||
token_zp_scale_vec = cvt_vec_t(zp_scale_val);
|
||||
}
|
||||
int64_t k = start;
|
||||
const float* input_ptr = input + j * hidden_size;
|
||||
scalar_t* output_ptr = output + j * hidden_size;
|
||||
for (; k < end - vec_elem_num; k += vec_elem_num) {
|
||||
cvt_vec_t elems_fp32(input_ptr + k);
|
||||
elems_fp32 = elems_fp32 * token_scale_vec;
|
||||
if constexpr (AZP) {
|
||||
cvt_vec_t azp_adj_fp32(azp_adj + k);
|
||||
elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec;
|
||||
}
|
||||
if constexpr (Bias) {
|
||||
load_vec_t bias_vec(bias + k);
|
||||
cvt_vec_t bias_vec_fp32(bias_vec);
|
||||
elems_fp32 = elems_fp32 + bias_vec_fp32;
|
||||
}
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output_ptr + k);
|
||||
}
|
||||
if (k < end) {
|
||||
cvt_vec_t elems_fp32(input_ptr + k);
|
||||
elems_fp32 = elems_fp32 * token_scale_vec;
|
||||
if constexpr (AZP) {
|
||||
cvt_vec_t azp_adj_fp32(azp_adj + k);
|
||||
elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec;
|
||||
}
|
||||
if constexpr (Bias) {
|
||||
load_vec_t bias_vec(bias + k);
|
||||
cvt_vec_t bias_vec_fp32(bias_vec);
|
||||
elems_fp32 = elems_fp32 + bias_vec_fp32;
|
||||
}
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output_ptr + k, end - k);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
int64_t create_onednn_scaled_mm_handler(
|
||||
const torch::Tensor& b, // [IC, OC], column-major
|
||||
const torch::Tensor& b_scales, // [1] or [OC]
|
||||
at::ScalarType output_type, bool dynamic_act_quant, bool use_azp,
|
||||
int64_t primitive_cache_size) {
|
||||
TORCH_CHECK(b.dim() == 2);
|
||||
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||
TORCH_CHECK(b_scales.is_contiguous());
|
||||
|
||||
W8A8MatMulPrimitiveHandler::Args args;
|
||||
args.primitive_cache_size = primitive_cache_size;
|
||||
|
||||
if (b_scales.numel() == 1) {
|
||||
args.b_quantization_strategy =
|
||||
W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR;
|
||||
} else {
|
||||
TORCH_CHECK_EQ(b_scales.numel(), b.size(1));
|
||||
args.b_quantization_strategy =
|
||||
W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_OUTPUT_CHANNEL;
|
||||
}
|
||||
args.b_scales_ptr = b_scales.data_ptr<float>();
|
||||
args.b_k_size = b.size(0);
|
||||
args.b_k_stride = b.stride(0);
|
||||
args.b_n_size = b.size(1);
|
||||
args.b_n_stride = b.stride(1);
|
||||
args.b_ptr = b.data_ptr<int8_t>();
|
||||
|
||||
if (dynamic_act_quant) {
|
||||
// dynamic per-token, bias, A scales and A zps will be applied in outside.
|
||||
args.a_quantization_strategy =
|
||||
W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TOKEN;
|
||||
args.use_a_zero_point = false;
|
||||
} else {
|
||||
// static per-tensor
|
||||
args.a_quantization_strategy =
|
||||
W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR;
|
||||
args.use_a_zero_point = use_azp;
|
||||
}
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(output_type, "create_onednn_scaled_mm_handler",
|
||||
[&] {
|
||||
if (dynamic_act_quant) {
|
||||
args.c_type = get_dnnl_type<float>();
|
||||
} else {
|
||||
args.c_type = get_dnnl_type<scalar_t>();
|
||||
}
|
||||
});
|
||||
|
||||
return reinterpret_cast<int64_t>(new W8A8MatMulPrimitiveHandler(args));
|
||||
}
|
||||
|
||||
void onednn_scaled_mm(
|
||||
torch::Tensor& c, // [M, OC], row-major
|
||||
const torch::Tensor& a, // [M, IC], row-major
|
||||
const torch::Tensor& a_scales, // [M] or [1]
|
||||
const std::optional<torch::Tensor>& azp, // [M] or [1]
|
||||
const std::optional<torch::Tensor>& azp_adj, // [M] or [1]
|
||||
const std::optional<torch::Tensor>& bias, // [N]
|
||||
int64_t handler) {
|
||||
CPU_KERNEL_GUARD_IN(onednn_scaled_mm)
|
||||
TORCH_CHECK(a.dim() == 2);
|
||||
TORCH_CHECK(a.is_contiguous());
|
||||
TORCH_CHECK(c.is_contiguous());
|
||||
W8A8MatMulPrimitiveHandler* ptr =
|
||||
reinterpret_cast<W8A8MatMulPrimitiveHandler*>(handler);
|
||||
const int32_t* azp_ptr = nullptr;
|
||||
if (azp.has_value()) {
|
||||
azp_ptr = azp->data_ptr<int32_t>();
|
||||
}
|
||||
if (ptr->get_input_scale_strategy() ==
|
||||
W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR) {
|
||||
TORCH_CHECK_EQ(a_scales.numel(), 1);
|
||||
}
|
||||
|
||||
W8A8MatMulPrimitiveHandler::ExecArgs exec_args;
|
||||
exec_args.a_ptr = a.data_ptr<int8_t>();
|
||||
exec_args.a_m_size = a.size(0);
|
||||
exec_args.bias_ptr = nullptr;
|
||||
exec_args.use_bias = false;
|
||||
exec_args.a_scales_ptr = nullptr;
|
||||
exec_args.a_zero_points_ptr = nullptr;
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "onednn_scaled_mm", [&] {
|
||||
if (ptr->get_input_scale_strategy() ==
|
||||
W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR) {
|
||||
if (bias.has_value()) {
|
||||
exec_args.bias_ptr = bias->data_ptr<scalar_t>();
|
||||
exec_args.bias_type = get_dnnl_type<scalar_t>();
|
||||
exec_args.use_bias = true;
|
||||
}
|
||||
exec_args.a_scales_ptr = a_scales.data_ptr<float>();
|
||||
exec_args.a_zero_points_ptr = azp_ptr;
|
||||
exec_args.c_ptr = c.data_ptr<scalar_t>();
|
||||
ptr->execute(exec_args);
|
||||
} else if (ptr->get_input_scale_strategy() ==
|
||||
W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TOKEN) {
|
||||
torch::Tensor tmp_fp32_out =
|
||||
torch::empty_like(c, ::at::ScalarType::Float);
|
||||
exec_args.c_ptr = tmp_fp32_out.data_ptr<float>();
|
||||
ptr->execute(exec_args);
|
||||
if (bias.has_value()) {
|
||||
if (azp.has_value()) {
|
||||
dynamic_quant_epilogue<true, true>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), azp_ptr, azp_adj->data_ptr<float>(),
|
||||
bias->data_ptr<scalar_t>(), c.size(0), c.size(1));
|
||||
} else {
|
||||
dynamic_quant_epilogue<false, true>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), azp_ptr, nullptr,
|
||||
bias->data_ptr<scalar_t>(), c.size(0), c.size(1));
|
||||
}
|
||||
} else {
|
||||
if (azp.has_value()) {
|
||||
dynamic_quant_epilogue<true, false>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), azp_ptr, azp_adj->data_ptr<float>(),
|
||||
(scalar_t*)nullptr, c.size(0), c.size(1));
|
||||
} else {
|
||||
dynamic_quant_epilogue<false, false>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), azp_ptr, nullptr, (scalar_t*)nullptr,
|
||||
c.size(0), c.size(1));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "invalid act quant type.");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// static-per-tensor quantization.
|
||||
void static_scaled_int8_quant(
|
||||
torch::Tensor& out, // [batch, hidden_size]
|
||||
const torch::Tensor& input, // [batch, hidden_size]
|
||||
const torch::Tensor& scale, std::optional<torch::Tensor> const& azp) {
|
||||
CPU_KERNEL_GUARD_IN(static_scaled_int8_quant)
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
TORCH_CHECK_EQ(input.dim(), 2);
|
||||
TORCH_CHECK_EQ(input.stride(1), 1);
|
||||
TORCH_CHECK(scale.numel() == 1);
|
||||
TORCH_CHECK(!azp.has_value() || azp->numel() == 1);
|
||||
|
||||
const int64_t stride = input.stride(0);
|
||||
const int64_t hidden_size = input.size(1);
|
||||
const int64_t num_tokens = input.size(0);
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "static_scaled_int8_quant_impl", [&] {
|
||||
if (azp.has_value()) {
|
||||
static_scaled_int8_quant_impl<true>(
|
||||
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
||||
scale.data_ptr<float>(), azp->data_ptr<int32_t>(), num_tokens,
|
||||
stride, hidden_size);
|
||||
} else {
|
||||
static_scaled_int8_quant_impl<false>(input.data_ptr<scalar_t>(),
|
||||
out.data_ptr<int8_t>(),
|
||||
scale.data_ptr<float>(), nullptr,
|
||||
num_tokens, stride, hidden_size);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// dynamic-per-token quantization.
|
||||
void dynamic_scaled_int8_quant(
|
||||
torch::Tensor& out, // [batch, hidden_size]
|
||||
const torch::Tensor& input, // [batch, hidden_size]
|
||||
torch::Tensor& scale, // [batch, 1]
|
||||
std::optional<torch::Tensor> const& azp) {
|
||||
CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant)
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
TORCH_CHECK_EQ(input.dim(), 2);
|
||||
TORCH_CHECK_EQ(input.stride(1), 1);
|
||||
|
||||
const int64_t hidden_size = input.size(1);
|
||||
const int64_t num_tokens = input.size(0);
|
||||
const int64_t stride = input.stride(0);
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "dynamic_scaled_int8_quant_impl", [&] {
|
||||
if (azp.has_value()) {
|
||||
dynamic_scaled_int8_quant_impl<true>(
|
||||
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
||||
scale.data_ptr<float>(), azp->data_ptr<int32_t>(), num_tokens,
|
||||
stride, hidden_size);
|
||||
} else {
|
||||
dynamic_scaled_int8_quant_impl<false>(
|
||||
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
||||
scale.data_ptr<float>(), nullptr, num_tokens, stride,
|
||||
hidden_size);
|
||||
}
|
||||
});
|
||||
}
|
||||
@ -1,951 +0,0 @@
|
||||
#include "cpu_types.hpp"
|
||||
#include "dnnl_helper.hpp"
|
||||
|
||||
namespace {
|
||||
template <typename scalar_t>
|
||||
struct KernelVecType {
|
||||
using load_vec_type = void;
|
||||
using azp_adj_load_vec_type = void;
|
||||
using cvt_vec_type = void;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KernelVecType<float> {
|
||||
using load_vec_type = vec_op::FP32Vec16;
|
||||
using azp_adj_load_vec_type = vec_op::INT32Vec16;
|
||||
using cvt_vec_type = vec_op::FP32Vec16;
|
||||
};
|
||||
|
||||
#if !defined(__aarch64__) || defined(ARM_BF16_SUPPORT)
|
||||
template <>
|
||||
struct KernelVecType<c10::BFloat16> {
|
||||
using load_vec_type = vec_op::BF16Vec16;
|
||||
using azp_adj_load_vec_type = vec_op::INT32Vec16;
|
||||
using cvt_vec_type = vec_op::FP32Vec16;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <>
|
||||
struct KernelVecType<c10::Half> {
|
||||
#if defined(__powerpc64__) || defined(__s390x__)
|
||||
// Power architecture-specific vector type
|
||||
using load_vec_type = vec_op::FP32Vec16;
|
||||
#else
|
||||
// Fallback for other architectures
|
||||
using load_vec_type = vec_op::FP16Vec16;
|
||||
#endif
|
||||
using azp_adj_load_vec_type = vec_op::INT32Vec16;
|
||||
using cvt_vec_type = vec_op::FP32Vec16;
|
||||
};
|
||||
|
||||
#if defined(__AVX512F__) || defined(__aarch64__)
|
||||
template <bool AZP, typename scalar_t>
|
||||
void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
const float* scale, const int32_t* azp,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
constexpr float i8_min =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::min());
|
||||
constexpr float i8_max =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::max());
|
||||
const cvt_vec_t inv_scale(1.0 / *scale);
|
||||
const cvt_vec_t i8_min_vec(i8_min);
|
||||
const cvt_vec_t i8_max_vec(i8_max);
|
||||
|
||||
cvt_vec_t zp_vec;
|
||||
if constexpr (AZP) {
|
||||
zp_vec = cvt_vec_t(static_cast<float>(*azp));
|
||||
}
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
int j = 0;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = elems_fp32 * inv_scale;
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + zp_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output + i * hidden_size + j);
|
||||
}
|
||||
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = elems_fp32 * inv_scale;
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + zp_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output + i * hidden_size + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool AZP, typename scalar_t>
|
||||
void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
float* scale, int32_t* azp,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
constexpr float i8_min =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::min());
|
||||
constexpr float i8_max =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::max());
|
||||
const cvt_vec_t i8_min_vec(i8_min);
|
||||
const cvt_vec_t i8_max_vec(i8_max);
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
cvt_vec_t max_value(std::numeric_limits<float>::lowest());
|
||||
cvt_vec_t min_value(std::numeric_limits<float>::max());
|
||||
{
|
||||
int j = 0;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
if constexpr (AZP) {
|
||||
max_value = max_value.max(elems_fp32);
|
||||
min_value = min_value.min(elems_fp32);
|
||||
} else {
|
||||
max_value = max_value.max(elems_fp32.abs());
|
||||
}
|
||||
}
|
||||
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
|
||||
if (j + vec_elem_num == hidden_size) {
|
||||
if constexpr (AZP) {
|
||||
max_value = max_value.max(elems_fp32);
|
||||
min_value = min_value.min(elems_fp32);
|
||||
} else {
|
||||
max_value = max_value.max(elems_fp32.abs());
|
||||
}
|
||||
} else {
|
||||
if constexpr (AZP) {
|
||||
max_value = max_value.max(elems_fp32, hidden_size - j);
|
||||
min_value = min_value.min(elems_fp32, hidden_size - j);
|
||||
} else {
|
||||
max_value = max_value.max(elems_fp32.abs(), hidden_size - j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float scale_val, azp_val;
|
||||
if constexpr (AZP) {
|
||||
float max_scalar = max_value.reduce_max();
|
||||
float min_scalar = min_value.reduce_min();
|
||||
scale_val = (max_scalar - min_scalar) / 255.0f;
|
||||
azp_val = std::nearbyint(-128.0f - min_scalar / scale_val);
|
||||
azp[i] = static_cast<int32_t>(azp_val);
|
||||
scale[i] = scale_val;
|
||||
} else {
|
||||
scale_val = max_value.reduce_max() / 127.0f;
|
||||
scale[i] = scale_val;
|
||||
}
|
||||
|
||||
const cvt_vec_t inv_scale(1.0 / scale_val);
|
||||
const cvt_vec_t azp_vec(azp_val);
|
||||
|
||||
{
|
||||
int j = 0;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = (elems_fp32 * inv_scale);
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + azp_vec;
|
||||
}
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output + i * hidden_size + j);
|
||||
}
|
||||
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = (elems_fp32 * inv_scale);
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + azp_vec;
|
||||
}
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output + i * hidden_size + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool PerChannel, typename scalar_t>
|
||||
void static_quant_epilogue(const float* input, scalar_t* output,
|
||||
const float a_scale, const float* b_scale,
|
||||
const int32_t* azp_with_adj, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl)
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using azp_adj_load_vec_t =
|
||||
typename KernelVecType<scalar_t>::azp_adj_load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
cvt_vec_t a_scale_vec(a_scale);
|
||||
cvt_vec_t b_scale_vec(*b_scale);
|
||||
cvt_vec_t scale_vec = a_scale_vec * b_scale_vec;
|
||||
|
||||
int j = 0;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
cvt_vec_t elems_fp32(input + i * hidden_size + j);
|
||||
azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j);
|
||||
cvt_vec_t azp_adj_fp32(azp_adj_vec);
|
||||
|
||||
if constexpr (PerChannel) {
|
||||
b_scale_vec = cvt_vec_t(b_scale + j);
|
||||
scale_vec = b_scale_vec * a_scale_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32;
|
||||
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output + i * hidden_size + j);
|
||||
}
|
||||
|
||||
cvt_vec_t elems_fp32(input + i * hidden_size + j);
|
||||
azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j);
|
||||
cvt_vec_t azp_adj_fp32(azp_adj_vec);
|
||||
|
||||
if constexpr (PerChannel) {
|
||||
b_scale_vec = cvt_vec_t(b_scale + j);
|
||||
scale_vec = b_scale_vec * a_scale_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32;
|
||||
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output + i * hidden_size + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool AZP, bool PerChannel, bool Bias, typename scalar_t>
|
||||
void dynamic_quant_epilogue(const float* input, scalar_t* output,
|
||||
const float* a_scale, const float* b_scale,
|
||||
const int32_t* azp, const int32_t* azp_adj,
|
||||
const scalar_t* bias, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue)
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using azp_adj_load_vec_t =
|
||||
typename KernelVecType<scalar_t>::azp_adj_load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
int j = 0;
|
||||
cvt_vec_t token_scale_vec(a_scale[i]);
|
||||
cvt_vec_t token_zp_scale_vec;
|
||||
if constexpr (AZP) {
|
||||
float zp_scale_val = a_scale[i] * static_cast<float>(azp[i]);
|
||||
if constexpr (!PerChannel) {
|
||||
zp_scale_val *= *b_scale;
|
||||
}
|
||||
token_zp_scale_vec = cvt_vec_t(zp_scale_val);
|
||||
}
|
||||
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
cvt_vec_t elems_fp32(input + i * hidden_size + j);
|
||||
elems_fp32 = elems_fp32 * token_scale_vec;
|
||||
|
||||
if constexpr (AZP) {
|
||||
azp_adj_load_vec_t azp_adj_vec(azp_adj + j);
|
||||
cvt_vec_t azp_adj_fp32(azp_adj_vec);
|
||||
azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec;
|
||||
|
||||
if constexpr (PerChannel) {
|
||||
cvt_vec_t b_scale_vec(b_scale + j);
|
||||
azp_adj_fp32 = azp_adj_fp32 * b_scale_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32 - azp_adj_fp32;
|
||||
}
|
||||
|
||||
if constexpr (Bias) {
|
||||
load_vec_t bias_vec(bias + j);
|
||||
cvt_vec_t bias_vec_fp32(bias_vec);
|
||||
elems_fp32 = elems_fp32 + bias_vec_fp32;
|
||||
}
|
||||
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output + i * hidden_size + j);
|
||||
}
|
||||
|
||||
cvt_vec_t elems_fp32(input + i * hidden_size + j);
|
||||
elems_fp32 = elems_fp32 * token_scale_vec;
|
||||
|
||||
if constexpr (AZP) {
|
||||
azp_adj_load_vec_t azp_adj_vec(azp_adj + j);
|
||||
cvt_vec_t azp_adj_fp32(azp_adj_vec);
|
||||
azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec;
|
||||
|
||||
if constexpr (PerChannel) {
|
||||
cvt_vec_t b_scale_vec(b_scale + j);
|
||||
azp_adj_fp32 = azp_adj_fp32 * b_scale_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32 - azp_adj_fp32;
|
||||
}
|
||||
|
||||
if constexpr (Bias) {
|
||||
load_vec_t bias_vec(bias + j);
|
||||
cvt_vec_t bias_vec_fp32(bias_vec);
|
||||
elems_fp32 = elems_fp32 + bias_vec_fp32;
|
||||
}
|
||||
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output + i * hidden_size + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
#elif defined(__powerpc64__)
|
||||
template <bool AZP, typename scalar_t>
|
||||
void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
const float* scale, const int32_t* azp,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
constexpr float i8_min =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::min());
|
||||
constexpr float i8_max =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::max());
|
||||
|
||||
const cvt_vec_t inv_scale(1.0 / *scale);
|
||||
const cvt_vec_t i8_min_vec(i8_min);
|
||||
const cvt_vec_t i8_max_vec(i8_max);
|
||||
|
||||
cvt_vec_t zp_vec;
|
||||
if constexpr (AZP) {
|
||||
zp_vec = cvt_vec_t(static_cast<float>(*azp));
|
||||
}
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
int j = 0;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = elems_fp32 * inv_scale;
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + zp_vec;
|
||||
}
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output + i * hidden_size + j);
|
||||
}
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = elems_fp32 * inv_scale;
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + zp_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output + i * hidden_size + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
template <bool AZP, typename scalar_t>
|
||||
void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
float* scale, int32_t* azp,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
constexpr float i8_min =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::min());
|
||||
constexpr float i8_max =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::max());
|
||||
const cvt_vec_t i8_min_vec(i8_min);
|
||||
const cvt_vec_t i8_max_vec(i8_max);
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
cvt_vec_t max_value(std::numeric_limits<float>::lowest());
|
||||
cvt_vec_t min_value(std::numeric_limits<float>::max());
|
||||
{
|
||||
int j = 0;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
if constexpr (AZP) {
|
||||
max_value = max_value.max(elems_fp32);
|
||||
min_value = min_value.min(elems_fp32);
|
||||
} else {
|
||||
max_value = max_value.max(elems_fp32.abs());
|
||||
}
|
||||
}
|
||||
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
|
||||
if (j + vec_elem_num == hidden_size) {
|
||||
if constexpr (AZP) {
|
||||
max_value = max_value.max(elems_fp32);
|
||||
min_value = min_value.min(elems_fp32);
|
||||
} else {
|
||||
max_value = max_value.max(elems_fp32.abs());
|
||||
}
|
||||
} else {
|
||||
if constexpr (AZP) {
|
||||
max_value = max_value.max(elems_fp32, hidden_size - j);
|
||||
min_value = min_value.min(elems_fp32, hidden_size - j);
|
||||
} else {
|
||||
max_value = max_value.max(elems_fp32.abs(), hidden_size - j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float scale_val, azp_val;
|
||||
if constexpr (AZP) {
|
||||
float max_scalar = max_value.reduce_max();
|
||||
float min_scalar = min_value.reduce_min();
|
||||
scale_val = (max_scalar - min_scalar) / 255.0f;
|
||||
azp_val = std::nearbyint(-128.0f - min_scalar / scale_val);
|
||||
azp[i] = static_cast<int32_t>(azp_val);
|
||||
scale[i] = scale_val;
|
||||
} else {
|
||||
scale_val = max_value.reduce_max() / 127.0f;
|
||||
scale[i] = scale_val;
|
||||
}
|
||||
|
||||
const cvt_vec_t inv_scale(1.0 / scale_val);
|
||||
const cvt_vec_t azp_vec(azp_val);
|
||||
|
||||
{
|
||||
int j = 0;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = (elems_fp32 * inv_scale);
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + azp_vec;
|
||||
}
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output + i * hidden_size + j);
|
||||
}
|
||||
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = (elems_fp32 * inv_scale);
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + azp_vec;
|
||||
}
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output + i * hidden_size + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
}
|
||||
template <bool PerChannel, typename scalar_t>
|
||||
void static_quant_epilogue(const float* input, scalar_t* output,
|
||||
const float a_scale, const float* b_scale,
|
||||
const int32_t* azp_with_adj, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl)
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using azp_adj_load_vec_t =
|
||||
typename KernelVecType<scalar_t>::azp_adj_load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
cvt_vec_t a_scale_vec(a_scale);
|
||||
cvt_vec_t b_scale_vec(*b_scale);
|
||||
cvt_vec_t scale_vec = a_scale_vec * b_scale_vec;
|
||||
|
||||
int j = 0;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
cvt_vec_t elems_fp32(input + i * hidden_size + j);
|
||||
azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j);
|
||||
cvt_vec_t azp_adj_fp32(azp_adj_vec);
|
||||
|
||||
if constexpr (PerChannel) {
|
||||
b_scale_vec = cvt_vec_t(b_scale + j);
|
||||
scale_vec = b_scale_vec * a_scale_vec;
|
||||
}
|
||||
elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32;
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output + i * hidden_size + j);
|
||||
}
|
||||
|
||||
cvt_vec_t elems_fp32(input + i * hidden_size + j);
|
||||
azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j);
|
||||
cvt_vec_t azp_adj_fp32(azp_adj_vec);
|
||||
|
||||
if constexpr (PerChannel) {
|
||||
b_scale_vec = cvt_vec_t(b_scale + j);
|
||||
scale_vec = b_scale_vec * a_scale_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32;
|
||||
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output + i * hidden_size + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
template <bool AZP, bool PerChannel, bool Bias, typename scalar_t>
|
||||
void dynamic_quant_epilogue(const float* input, scalar_t* output,
|
||||
const float* a_scale, const float* b_scale,
|
||||
const int32_t* azp, const int32_t* azp_adj,
|
||||
const scalar_t* bias, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue)
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using azp_adj_load_vec_t =
|
||||
typename KernelVecType<scalar_t>::azp_adj_load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
int j = 0;
|
||||
cvt_vec_t token_scale_vec(a_scale[i]);
|
||||
cvt_vec_t token_zp_scale_vec;
|
||||
if constexpr (AZP) {
|
||||
float zp_scale_val = a_scale[i] * static_cast<float>(azp[i]);
|
||||
if constexpr (!PerChannel) {
|
||||
zp_scale_val *= *b_scale;
|
||||
}
|
||||
token_zp_scale_vec = cvt_vec_t(zp_scale_val);
|
||||
}
|
||||
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
cvt_vec_t elems_fp32(input + i * hidden_size + j);
|
||||
elems_fp32 = elems_fp32 * token_scale_vec;
|
||||
|
||||
if constexpr (AZP) {
|
||||
azp_adj_load_vec_t azp_adj_vec(azp_adj + j);
|
||||
cvt_vec_t azp_adj_fp32(azp_adj_vec);
|
||||
azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec;
|
||||
|
||||
if constexpr (PerChannel) {
|
||||
cvt_vec_t b_scale_vec(b_scale + j);
|
||||
azp_adj_fp32 = azp_adj_fp32 * b_scale_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32 - azp_adj_fp32;
|
||||
}
|
||||
|
||||
if constexpr (Bias) {
|
||||
load_vec_t bias_vec(bias + j);
|
||||
cvt_vec_t bias_vec_fp32(bias_vec);
|
||||
elems_fp32 = elems_fp32 + bias_vec_fp32;
|
||||
}
|
||||
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output + i * hidden_size + j);
|
||||
}
|
||||
|
||||
cvt_vec_t elems_fp32(input + i * hidden_size + j);
|
||||
elems_fp32 = elems_fp32 * token_scale_vec;
|
||||
|
||||
if constexpr (AZP) {
|
||||
azp_adj_load_vec_t azp_adj_vec(azp_adj + j);
|
||||
cvt_vec_t azp_adj_fp32(azp_adj_vec);
|
||||
azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec;
|
||||
|
||||
if constexpr (PerChannel) {
|
||||
cvt_vec_t b_scale_vec(b_scale + j);
|
||||
azp_adj_fp32 = azp_adj_fp32 * b_scale_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32 - azp_adj_fp32;
|
||||
}
|
||||
|
||||
if constexpr (Bias) {
|
||||
load_vec_t bias_vec(bias + j);
|
||||
cvt_vec_t bias_vec_fp32(bias_vec);
|
||||
elems_fp32 = elems_fp32 + bias_vec_fp32;
|
||||
}
|
||||
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output + i * hidden_size + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
#else
|
||||
template <typename scalar_t>
|
||||
void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
const float* scale, const int32_t* azp,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
TORCH_CHECK(false,
|
||||
"static_scaled_int8_quant_impl requires AVX512/powerpc64/AArch64 "
|
||||
"support.")
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
float* scale, int32_t* azp,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
TORCH_CHECK(false,
|
||||
"dynamic_scaled_int8_quant_impl requires "
|
||||
"AVX512/powerpc64/AArch64 support.")
|
||||
}
|
||||
|
||||
template <bool PerChannel, typename scalar_t>
|
||||
void static_quant_epilogue(const float* input, scalar_t* output,
|
||||
const float a_scale, const float* b_scale,
|
||||
const int32_t* azp_with_adj, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
TORCH_CHECK(
|
||||
false, "static_quant_epilogue requires AVX512/powerpc64/AArch64 support.")
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void dynamic_quant_epilogue(const float* input, scalar_t* output,
|
||||
const float* a_scale, const float* b_scale,
|
||||
const int32_t* azp, const int32_t* azp_with_adj,
|
||||
const scalar_t* bias, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"dynamic_quant_epilogue requires AVX512/powerpc64/AArch64 support.")
|
||||
}
|
||||
#endif
|
||||
} // namespace
|
||||
|
||||
void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major
|
||||
const torch::Tensor& a, // [M, IC], row-major
|
||||
const torch::Tensor& b, // [IC, OC], column-major
|
||||
const torch::Tensor& a_scales, // [1] or [M]
|
||||
const torch::Tensor& b_scales, // [1] or [OC]
|
||||
const std::optional<torch::Tensor>& bias // [OC]
|
||||
) {
|
||||
CPU_KERNEL_GUARD_IN(cutlass_scaled_mm)
|
||||
// Checks for conformality
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8,
|
||||
"int8_scaled_mm only supports INT8 inputs.")
|
||||
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||
b.size(1) == c.size(1));
|
||||
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
|
||||
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
|
||||
|
||||
// Check for strides and alignment
|
||||
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
||||
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||
TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
||||
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
|
||||
bias->dim() == 1);
|
||||
}
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm", [&] {
|
||||
if (a_scales.numel() != 1) {
|
||||
// per-token
|
||||
// Note: oneDNN doesn't support per-token activation quantization
|
||||
// Ideally we want to fuse the GEMM and the scale procedure with oneDNN
|
||||
// JIT, the intermediate data is cached in registers or L1. But for now
|
||||
// the oneDNN GEMM code generation only supports two quantization
|
||||
// patterns: per-tensor or per-output-channel of weight.
|
||||
// So we have to apply the per-token scale with a 'epilogue'. In C=s_a *
|
||||
// s_b * (A@B) + bias, the C_inter = s_b * (A@B) is computed by oneDNN
|
||||
// GEMM, then the per-token scale (and bias) is applied with the epilogue
|
||||
// C=s_a * C_inter + bias.
|
||||
torch::Tensor tmp_fp32_out =
|
||||
torch::empty_like(c, ::at::ScalarType::Float);
|
||||
// Compute C_inter=s_b * (A@B)
|
||||
DNNLPrimitiveHelper<true>::gemm_s8s8_jit<float, void>(
|
||||
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(),
|
||||
tmp_fp32_out.data_ptr<float>(), nullptr, a.size(0), b.size(1),
|
||||
a.size(1), nullptr, b_scales.data_ptr<float>(), 0, b_scales.numel());
|
||||
if (bias.has_value()) {
|
||||
// Compute C=s_a * C_inter + bias
|
||||
dynamic_quant_epilogue<false, true, true>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), nullptr, nullptr, nullptr,
|
||||
bias->data_ptr<scalar_t>(), c.size(0), c.size(1));
|
||||
} else {
|
||||
// Compute C=s_a * C_inter
|
||||
dynamic_quant_epilogue<false, true, false, scalar_t>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), nullptr, nullptr, nullptr, nullptr,
|
||||
c.size(0), c.size(1));
|
||||
}
|
||||
} else {
|
||||
// per-tensor
|
||||
if (bias.has_value()) {
|
||||
// Compute C=s_a * s_b * (A@B) + bias
|
||||
DNNLPrimitiveHelper<false>::gemm_s8s8_jit(
|
||||
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(), c.data_ptr<scalar_t>(),
|
||||
bias->data_ptr<scalar_t>(), a.size(0), b.size(1), a.size(1),
|
||||
a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
|
||||
a_scales.numel(), b_scales.numel());
|
||||
} else {
|
||||
// Compute C=s_a * s_b * (A@B)
|
||||
DNNLPrimitiveHelper<false>::gemm_s8s8_jit<scalar_t, void>(
|
||||
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(), c.data_ptr<scalar_t>(),
|
||||
nullptr, a.size(0), b.size(1), a.size(1),
|
||||
a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
|
||||
a_scales.numel(), b_scales.numel());
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void int8_scaled_mm_azp(torch::Tensor& c, // [M, OC], row-major
|
||||
const torch::Tensor& a, // [M, IC], row-major
|
||||
const torch::Tensor& b, // [IC, OC], column-major
|
||||
const torch::Tensor& a_scales, // [1] or [M]
|
||||
const torch::Tensor& b_scales, // [1] or [OC]
|
||||
const torch::Tensor& azp_adj, // [OC]
|
||||
const std::optional<torch::Tensor>& azp, // [1] or [M]
|
||||
const std::optional<torch::Tensor>& bias // [OC]
|
||||
) {
|
||||
CPU_KERNEL_GUARD_IN(cutlass_scaled_mm_azp)
|
||||
// Checks for conformality
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8,
|
||||
"int8_scaled_mm_azp only supports INT8 inputs.")
|
||||
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||
b.size(1) == c.size(1));
|
||||
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
|
||||
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
|
||||
|
||||
// Check for strides and alignment
|
||||
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
||||
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||
TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
||||
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous());
|
||||
}
|
||||
if (azp) {
|
||||
TORCH_CHECK(azp->numel() == a.size(0) && azp->is_contiguous());
|
||||
}
|
||||
TORCH_CHECK(azp_adj.numel() == b.size(1) && azp_adj.is_contiguous());
|
||||
|
||||
// azp & bias types
|
||||
TORCH_CHECK(azp_adj.dtype() == torch::kInt32);
|
||||
TORCH_CHECK(!azp || azp->dtype() == torch::kInt32);
|
||||
TORCH_CHECK(!bias || bias->dtype() == c.dtype(),
|
||||
"currently bias dtype must match output dtype ", c.dtype());
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm_azp", [&] {
|
||||
torch::Tensor tmp_fp32_out = torch::empty_like(c, ::at::ScalarType::Float);
|
||||
if (a_scales.numel() != 1) {
|
||||
// per-token
|
||||
// Note: oneDNN doesn't support per-token activation quantization
|
||||
// Compute C_inter=s_b * (A@B)
|
||||
DNNLPrimitiveHelper<true>::gemm_s8s8_jit<float, void>(
|
||||
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(),
|
||||
tmp_fp32_out.data_ptr<float>(), nullptr, a.size(0), b.size(1),
|
||||
a.size(1), nullptr, b_scales.data_ptr<float>(), 0, b_scales.numel());
|
||||
if (bias.has_value()) {
|
||||
// Compute C=s_a * C_inter - s_a * s_b * azp * azp_adj + bias
|
||||
if (b_scales.numel() != 1) {
|
||||
// Per-Channel
|
||||
dynamic_quant_epilogue<true, true, true>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
|
||||
azp->data_ptr<int32_t>(), azp_adj.data_ptr<int32_t>(),
|
||||
bias->data_ptr<scalar_t>(), c.size(0), c.size(1));
|
||||
} else {
|
||||
// Per-Tensor
|
||||
dynamic_quant_epilogue<true, false, true>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
|
||||
azp->data_ptr<int32_t>(), azp_adj.data_ptr<int32_t>(),
|
||||
bias->data_ptr<scalar_t>(), c.size(0), c.size(1));
|
||||
}
|
||||
} else {
|
||||
// Compute C=s_a * C_inter - s_a * s_b * azp * azp_adj
|
||||
if (b_scales.numel() != 1) {
|
||||
// Per-Channel
|
||||
dynamic_quant_epilogue<true, true, false, scalar_t>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
|
||||
azp->data_ptr<int32_t>(), azp_adj.data_ptr<int32_t>(), nullptr,
|
||||
c.size(0), c.size(1));
|
||||
} else {
|
||||
// Per-Tensor
|
||||
dynamic_quant_epilogue<true, false, false, scalar_t>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
|
||||
azp->data_ptr<int32_t>(), azp_adj.data_ptr<int32_t>(), nullptr,
|
||||
c.size(0), c.size(1));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// per-tensor
|
||||
if (bias.has_value()) {
|
||||
// Compute C_inter=s_a * s_b * (A@B) + bias
|
||||
DNNLPrimitiveHelper<false>::gemm_s8s8_jit(
|
||||
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(),
|
||||
tmp_fp32_out.data_ptr<float>(), bias->data_ptr<scalar_t>(),
|
||||
a.size(0), b.size(1), a.size(1), a_scales.data_ptr<float>(),
|
||||
b_scales.data_ptr<float>(), a_scales.numel(), b_scales.numel());
|
||||
} else {
|
||||
// Compute C_inter=s_a * s_b * (A@B)
|
||||
DNNLPrimitiveHelper<false>::gemm_s8s8_jit<float, void>(
|
||||
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(),
|
||||
tmp_fp32_out.data_ptr<float>(), nullptr, a.size(0), b.size(1),
|
||||
a.size(1), a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
|
||||
a_scales.numel(), b_scales.numel());
|
||||
}
|
||||
|
||||
// Compute C=C_inter - s_a * s_b * azp_adj
|
||||
if (b_scales.numel() != 1) {
|
||||
// Per-Channel
|
||||
static_quant_epilogue<true>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
*a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
|
||||
azp_adj.data_ptr<int32_t>(), a.size(0), b.size(1));
|
||||
} else {
|
||||
// Per-Tensor
|
||||
static_quant_epilogue<false>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
*a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
|
||||
azp_adj.data_ptr<int32_t>(), a.size(0), b.size(1));
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// static-per-tensor quantization.
|
||||
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||
const torch::Tensor& input, // [..., hidden_size]
|
||||
const torch::Tensor& scale,
|
||||
std::optional<torch::Tensor> const& azp) {
|
||||
CPU_KERNEL_GUARD_IN(static_scaled_int8_quant)
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
TORCH_CHECK(scale.numel() == 1);
|
||||
TORCH_CHECK(!azp.has_value() || azp->numel() == 1);
|
||||
|
||||
const int hidden_size = input.size(-1);
|
||||
const int num_tokens = input.numel() / hidden_size;
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "static_scaled_int8_quant_impl", [&] {
|
||||
if (azp.has_value()) {
|
||||
static_scaled_int8_quant_impl<true>(
|
||||
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
||||
scale.data_ptr<float>(), azp->data_ptr<int32_t>(), num_tokens,
|
||||
hidden_size);
|
||||
} else {
|
||||
static_scaled_int8_quant_impl<false>(
|
||||
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
||||
scale.data_ptr<float>(), nullptr, num_tokens, hidden_size);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// dynamic-per-token quantization.
|
||||
void dynamic_scaled_int8_quant(
|
||||
torch::Tensor& out, // [..., hidden_size]
|
||||
const torch::Tensor& input, // [..., hidden_size]
|
||||
torch::Tensor& scale, // [..., 1]
|
||||
std::optional<torch::Tensor> const& azp) {
|
||||
CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant)
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
|
||||
int const hidden_size = input.size(-1);
|
||||
int const num_tokens = input.numel() / hidden_size;
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "dynamic_scaled_int8_quant_impl", [&] {
|
||||
if (azp.has_value()) {
|
||||
dynamic_scaled_int8_quant_impl<true>(
|
||||
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
||||
scale.data_ptr<float>(), azp->data_ptr<int32_t>(), num_tokens,
|
||||
hidden_size);
|
||||
} else {
|
||||
dynamic_scaled_int8_quant_impl<false>(
|
||||
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
||||
scale.data_ptr<float>(), nullptr, num_tokens, hidden_size);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#if defined(__powerpc64__)
|
||||
void int8_scaled_mm_ppc64le(torch::Tensor& c, // [M, OC], row-major
|
||||
const torch::Tensor& a, // [M, IC], row-major
|
||||
const torch::Tensor& b, // [IC, OC], column-major
|
||||
const torch::Tensor& a_scales,
|
||||
const torch::Tensor& b_scales,
|
||||
const std::optional<torch::Tensor>& bias // [OC]
|
||||
) {
|
||||
CPU_KERNEL_GUARD_IN(cutlass_scaled_mm)
|
||||
// Checks for conformality
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8,
|
||||
"int8_scaled_mm_ppc64le only supports INT8 inputs.");
|
||||
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||
b.size(1) == c.size(1));
|
||||
// We dont need this
|
||||
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
|
||||
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
|
||||
|
||||
// Check for strides and alignment
|
||||
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
||||
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||
TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
||||
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
|
||||
bias->dim() == 1);
|
||||
}
|
||||
VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm_ppc64le", [&] {
|
||||
torch::Tensor tmp_fp32_out = torch::empty_like(c, ::at::ScalarType::Float);
|
||||
// Compute C_inter=s_b * (A@B)
|
||||
DNNLPrimitiveHelper<true>::gemm_s8s8_jit<float, void>(
|
||||
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(),
|
||||
tmp_fp32_out.data_ptr<float>(), nullptr, a.size(0), b.size(1),
|
||||
a.size(1), nullptr, b_scales.data_ptr<float>(), 0, b_scales.numel());
|
||||
if (bias.has_value()) {
|
||||
// Compute C=s_a * C_inter + bias
|
||||
dynamic_quant_epilogue<false, true, true>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), nullptr, nullptr, nullptr,
|
||||
bias->data_ptr<scalar_t>(), c.size(0), c.size(1));
|
||||
} else {
|
||||
// Compute C=s_a * C_inter
|
||||
dynamic_quant_epilogue<false, true, false, scalar_t>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), nullptr, nullptr, nullptr, nullptr,
|
||||
c.size(0), c.size(1));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#endif
|
||||
@ -6,25 +6,20 @@
|
||||
|
||||
std::string init_cpu_threads_env(const std::string& cpu_ids);
|
||||
|
||||
void int8_scaled_mm(torch::Tensor& c, const torch::Tensor& a,
|
||||
const torch::Tensor& b, const torch::Tensor& a_scales,
|
||||
const torch::Tensor& b_scales,
|
||||
const std::optional<torch::Tensor>& bias);
|
||||
void release_dnnl_matmul_handler(int64_t handler);
|
||||
|
||||
void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a,
|
||||
const torch::Tensor& b, const torch::Tensor& a_scales,
|
||||
const torch::Tensor& b_scales,
|
||||
const torch::Tensor& azp_adj,
|
||||
const std::optional<torch::Tensor>& azp,
|
||||
const std::optional<torch::Tensor>& bias);
|
||||
int64_t create_onednn_scaled_mm_handler(const torch::Tensor& b,
|
||||
const torch::Tensor& b_scales,
|
||||
at::ScalarType output_type,
|
||||
bool dynamic_act_quant, bool use_azp,
|
||||
int64_t primitive_cache_size);
|
||||
|
||||
#if defined(__powerpc64__)
|
||||
void int8_scaled_mm_ppc64le(torch::Tensor& c, const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& a_scales,
|
||||
const torch::Tensor& b_scales,
|
||||
const std::optional<torch::Tensor>& bias);
|
||||
#endif
|
||||
void onednn_scaled_mm(torch::Tensor& c, const torch::Tensor& a,
|
||||
const torch::Tensor& a_scales,
|
||||
const std::optional<torch::Tensor>& azp,
|
||||
const std::optional<torch::Tensor>& azp_adj,
|
||||
const std::optional<torch::Tensor>& bias,
|
||||
int64_t handler);
|
||||
|
||||
void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,
|
||||
torch::Tensor& kv_cache, double scale,
|
||||
@ -151,8 +146,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding);
|
||||
|
||||
// Quantization
|
||||
#if defined(__AVX512F__) || (defined(__aarch64__) && !defined(__APPLE__))
|
||||
#if defined(__AVX512F__) || (defined(__aarch64__) && !defined(__APPLE__)) || \
|
||||
defined(__powerpc64__)
|
||||
at::Tag stride_tag = at::Tag::needs_fixed_stride_order;
|
||||
// Helper function to release oneDNN handlers
|
||||
ops.def("release_dnnl_matmul_handler(int handler) -> ()",
|
||||
&release_dnnl_matmul_handler);
|
||||
|
||||
// Create oneDNN W8A8 handler
|
||||
ops.def(
|
||||
"create_onednn_scaled_mm_handler(Tensor b, Tensor b_scales, ScalarType "
|
||||
"output_type, bool dynamic_act_quant, bool use_azp, int "
|
||||
"primitive_cache_size) -> int",
|
||||
&create_onednn_scaled_mm_handler);
|
||||
|
||||
// oneDNN scaled_mm for W8A8 with static per-tensor activation quantization
|
||||
ops.def(
|
||||
"onednn_scaled_mm(Tensor! c, Tensor a, Tensor a_scales, Tensor? azp, "
|
||||
"Tensor? azp_adj, Tensor? bias, int handler) -> ()");
|
||||
ops.impl("onednn_scaled_mm", torch::kCPU, &onednn_scaled_mm);
|
||||
|
||||
// Compute int8 quantized tensor for given scaling factor.
|
||||
ops.def(
|
||||
@ -168,50 +180,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
{stride_tag});
|
||||
ops.impl("dynamic_scaled_int8_quant", torch::kCPU,
|
||||
&dynamic_scaled_int8_quant);
|
||||
// W8A8 GEMM, supporting symmetric per-tensor or per-row/column
|
||||
// quantization.
|
||||
ops.def(
|
||||
"cutlass_scaled_mm(Tensor! out, Tensor a,"
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor? bias) -> ()",
|
||||
{stride_tag});
|
||||
ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm);
|
||||
// w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
|
||||
// quantization.
|
||||
ops.def(
|
||||
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor azp_adj,"
|
||||
" Tensor? azp, Tensor? bias) -> ()",
|
||||
{stride_tag});
|
||||
ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp);
|
||||
#elif defined(__powerpc64__)
|
||||
// Compute int8 quantized tensor for given scaling factor.
|
||||
ops.def(
|
||||
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
|
||||
"Tensor? azp) -> ()");
|
||||
ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant);
|
||||
|
||||
// Compute int8 quantized tensor and scaling factor
|
||||
ops.def(
|
||||
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, "
|
||||
"Tensor!? azp) -> ()");
|
||||
ops.impl("dynamic_scaled_int8_quant", torch::kCPU,
|
||||
&dynamic_scaled_int8_quant);
|
||||
// W8A8 GEMM, supporting symmetric quantization.
|
||||
ops.def(
|
||||
"cutlass_scaled_mm(Tensor! out, Tensor a,"
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor? bias) -> ()");
|
||||
ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm_ppc64le);
|
||||
// w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
|
||||
// quantization.
|
||||
ops.def(
|
||||
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor azp_adj,"
|
||||
" Tensor? azp, Tensor? bias) -> ()");
|
||||
ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp);
|
||||
#endif
|
||||
|
||||
// SHM CCL
|
||||
|
||||
@ -60,3 +60,13 @@ struct enable_sm100_only : Kernel {
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Kernel>
|
||||
struct enable_sm120_only : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE void operator()(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 1200
|
||||
Kernel::operator()(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@ -45,6 +45,9 @@ struct SSMParamsBase {
|
||||
index_t out_d_stride;
|
||||
index_t out_z_batch_stride;
|
||||
index_t out_z_d_stride;
|
||||
index_t ssm_states_batch_stride;
|
||||
index_t ssm_states_dim_stride;
|
||||
index_t ssm_states_dstate_stride;
|
||||
|
||||
// Common data pointers.
|
||||
void *__restrict__ A_ptr;
|
||||
|
||||
@ -132,8 +132,10 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
||||
input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + sequence_start_index * params.B_batch_stride + group_id * params.B_group_stride;
|
||||
weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride;
|
||||
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + sequence_start_index * params.C_batch_stride + group_id * params.C_group_stride;
|
||||
input_t *ssm_states = reinterpret_cast<input_t *>(params.ssm_states_ptr) + (cache_index * params.dim + dim_id * kNRows) * params.dstate;
|
||||
|
||||
input_t *ssm_states = reinterpret_cast<input_t *>(params.ssm_states_ptr) +
|
||||
cache_index * params.ssm_states_batch_stride +
|
||||
dim_id * kNRows * params.ssm_states_dim_stride;
|
||||
|
||||
float D_val[kNRows] = {0};
|
||||
if (params.D_ptr != nullptr) {
|
||||
#pragma unroll
|
||||
@ -248,7 +250,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
||||
}
|
||||
// Initialize running total
|
||||
|
||||
scan_t running_prefix = chunk > 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.0, has_initial_state ? float(ssm_states[state_idx]): 0.0);
|
||||
scan_t running_prefix = chunk > 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.0, has_initial_state ? float(ssm_states[state_idx * params.ssm_states_dstate_stride]): 0.0);
|
||||
|
||||
SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
|
||||
typename Ktraits::BlockScanT(smem_scan).InclusiveScan(
|
||||
@ -259,7 +261,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
||||
if (threadIdx.x == 0) {
|
||||
smem_running_prefix[state_idx] = prefix_op.running_prefix;
|
||||
if (chunk == n_chunks - 1) {
|
||||
ssm_states[state_idx] = input_t(prefix_op.running_prefix.y);
|
||||
ssm_states[state_idx * params.ssm_states_dstate_stride] = input_t(prefix_op.running_prefix.y);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
@ -481,6 +483,10 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
||||
params.out_batch_stride = out.stride(1);
|
||||
params.out_d_stride = out.stride(0);
|
||||
|
||||
params.ssm_states_batch_stride = ssm_states.stride(0);
|
||||
params.ssm_states_dim_stride = ssm_states.stride(1);
|
||||
params.ssm_states_dstate_stride = ssm_states.stride(2);
|
||||
|
||||
}
|
||||
else{
|
||||
if (!is_variable_B) {
|
||||
@ -509,6 +515,10 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
||||
}
|
||||
params.out_batch_stride = out.stride(0);
|
||||
params.out_d_stride = out.stride(1);
|
||||
|
||||
params.ssm_states_batch_stride = ssm_states.stride(0);
|
||||
params.ssm_states_dim_stride = ssm_states.stride(1);
|
||||
params.ssm_states_dstate_stride = ssm_states.stride(2);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -20,6 +20,7 @@ namespace MARLIN_NAMESPACE_NAME {
|
||||
TEMPLATE = ("template __global__ void Marlin<"
|
||||
"{{scalar_t}}, "
|
||||
"{{w_type_id}}, "
|
||||
"{{s_type_id}}, "
|
||||
"{{threads}}, "
|
||||
"{{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, "
|
||||
@ -77,6 +78,7 @@ def generate_new_kernels():
|
||||
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
|
||||
continue
|
||||
# nvfp4 only supports group_size == 16
|
||||
# mxfp4 only supports group_size == 32
|
||||
if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]:
|
||||
continue
|
||||
# other quantization methods don't support group_size = 16
|
||||
@ -89,9 +91,22 @@ def generate_new_kernels():
|
||||
|
||||
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
|
||||
|
||||
if scalar_type == "vllm::kFE2M1f" and group_blocks == 1:
|
||||
s_type = "vllm::kFE4M3fn"
|
||||
elif scalar_type == "vllm::kFE2M1f" and group_blocks == 2:
|
||||
s_type = "vllm::kFE8M0fnu"
|
||||
if dtype == "fp16":
|
||||
# we cannot safely dequantize e8m0 to fp16, so skip this
|
||||
continue
|
||||
elif dtype == "fp16":
|
||||
s_type = "vllm::kFloat16"
|
||||
elif dtype == "bf16":
|
||||
s_type = "vllm::kBFloat16"
|
||||
|
||||
template_str = jinja2.Template(TEMPLATE).render(
|
||||
scalar_t=c_dtype,
|
||||
w_type_id=scalar_type + ".id()",
|
||||
s_type_id=s_type + ".id()",
|
||||
threads=threads,
|
||||
thread_m_blocks=max(m_blocks, 1),
|
||||
thread_n_blocks=n_blocks,
|
||||
|
||||
@ -7,23 +7,25 @@
|
||||
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
#define MARLIN_KERNEL_PARAMS \
|
||||
const int4 *__restrict__ A, const int4 *__restrict__ B, \
|
||||
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
|
||||
const int4 *__restrict__ scales_ptr, \
|
||||
const uint16_t *__restrict__ scale2_ptr, \
|
||||
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
|
||||
const int32_t *__restrict__ sorted_token_ids_ptr, \
|
||||
const int32_t *__restrict__ expert_ids_ptr, \
|
||||
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
|
||||
const float *__restrict__ topk_weights_ptr, int top_k, \
|
||||
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
|
||||
int prob_n, int prob_k, int *locks, bool use_atomic_add, \
|
||||
#define MARLIN_KERNEL_PARAMS \
|
||||
const int4 *__restrict__ A, const int4 *__restrict__ B, \
|
||||
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
|
||||
const int4 *__restrict__ b_bias_ptr, \
|
||||
const int4 *__restrict__ scales_ptr, \
|
||||
const uint16_t *__restrict__ scale2_ptr, \
|
||||
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
|
||||
const int32_t *__restrict__ sorted_token_ids_ptr, \
|
||||
const int32_t *__restrict__ expert_ids_ptr, \
|
||||
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
|
||||
const float *__restrict__ topk_weights_ptr, int top_k, \
|
||||
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
|
||||
int prob_n, int prob_k, int *locks, bool has_bias, bool use_atomic_add, \
|
||||
bool use_fp32_reduce, int max_shared_mem
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
||||
const vllm::ScalarTypeId s_type_id, // weight scale ScalarType id
|
||||
const int threads, // number of threads in a threadblock
|
||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||
// dimension (batchsize) of the
|
||||
|
||||
@ -280,6 +280,7 @@ __device__ inline void wait_negative_and_add(int* lock) {
|
||||
|
||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
||||
const vllm::ScalarTypeId s_type_id, // weight scale ScalarType id
|
||||
const int threads, // number of threads in a threadblock
|
||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||
// dimension (batchsize) of the
|
||||
@ -299,6 +300,7 @@ __global__ void Marlin(
|
||||
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
||||
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
||||
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
|
||||
const int4* __restrict__ b_bias_ptr,
|
||||
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
||||
// (k/groupsize)xn
|
||||
const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4
|
||||
@ -318,8 +320,9 @@ __global__ void Marlin(
|
||||
int prob_n, // output dimension n
|
||||
int prob_k, // reduction dimension k
|
||||
int* locks, // extra global storage for barrier synchronization
|
||||
bool use_atomic_add, // whether to use atomic add to reduce
|
||||
bool use_fp32_reduce, // whether to use fp32 global reduce
|
||||
bool has_bias,
|
||||
bool use_atomic_add, // whether to use atomic add to reduce
|
||||
bool use_fp32_reduce, // whether to use fp32 global reduce
|
||||
int max_shared_mem) {
|
||||
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
|
||||
// same size, which might involve multiple column "slices" (of width 16 *
|
||||
@ -342,12 +345,23 @@ __global__ void Marlin(
|
||||
|
||||
extern __shared__ int4 sh[];
|
||||
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
|
||||
static constexpr auto s_type = vllm::ScalarType::from_id(s_type_id);
|
||||
if constexpr (w_type == vllm::kFE2M1f) {
|
||||
static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 ||
|
||||
s_type == vllm::kFE8M0fnu && group_blocks == 2);
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
static_assert(s_type == vllm::kBFloat16);
|
||||
} else if constexpr (std::is_same<scalar_t, half>::value) {
|
||||
static_assert(s_type == vllm::kFloat16);
|
||||
}
|
||||
|
||||
constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8;
|
||||
constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 ||
|
||||
w_type == vllm::kU4B8 || w_type == vllm::kU8B128;
|
||||
// see comments of dequant.h for more details
|
||||
constexpr bool dequant_skip_flop =
|
||||
!is_int_type ||
|
||||
w_type == vllm::kFE4M3fn ||
|
||||
w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn ||
|
||||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
|
||||
has_zp && !is_zp_float && !(w_type == vllm::kU8);
|
||||
|
||||
@ -365,6 +379,7 @@ __global__ void Marlin(
|
||||
const int zp_expert_stride =
|
||||
is_zp_float ? prob_n * prob_k / group_size / 8
|
||||
: prob_n * prob_k / group_size / (pack_factor * 4);
|
||||
const int b_bias_expert_stride = prob_n / 8;
|
||||
|
||||
// parallel: num valid moe blocks
|
||||
int num_tokens_past_padded = num_tokens_past_padded_ptr[0];
|
||||
@ -475,7 +490,7 @@ __global__ void Marlin(
|
||||
for (int i = 0; i < 4; i++) {
|
||||
int idx = tid4 * 4 + i;
|
||||
idx = idx < block_num_valid_tokens ? idx : 0;
|
||||
if constexpr (w_type == vllm::kFE2M1f) {
|
||||
if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
|
||||
sh_block_topk_weights[idx] = __hmul2(
|
||||
global_scale, Dtype::num2num2(Dtype::float2num(
|
||||
topk_weights_ptr[sh_block_sorted_ids[idx]])));
|
||||
@ -513,7 +528,7 @@ __global__ void Marlin(
|
||||
expert_id = expert_ids_ptr[block_id];
|
||||
}
|
||||
|
||||
if constexpr (w_type == vllm::kFE2M1f) {
|
||||
if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
|
||||
uint16_t val = scale2_ptr[expert_id];
|
||||
global_scale = Dtype::num2num2(*reinterpret_cast<scalar_t*>(&val));
|
||||
}
|
||||
@ -526,6 +541,9 @@ __global__ void Marlin(
|
||||
if constexpr (has_act_order) {
|
||||
g_idx += (expert_id - old_expert_id) * prob_k;
|
||||
}
|
||||
if (has_bias) {
|
||||
b_bias_ptr += (expert_id - old_expert_id) * b_bias_expert_stride;
|
||||
}
|
||||
|
||||
read_moe_block_data(block_id);
|
||||
};
|
||||
@ -721,7 +739,7 @@ __global__ void Marlin(
|
||||
|
||||
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
(threadIdx.x % 32) / 4;
|
||||
s_sh_rd = s_sh_rd * 2 + warp_row % 2;
|
||||
s_sh_rd = s_sh_rd * 2 + (warp_row / group_blocks) % 2;
|
||||
|
||||
} else if constexpr (group_blocks != -1)
|
||||
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
@ -734,6 +752,18 @@ __global__ void Marlin(
|
||||
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
(threadIdx.x % 32) % 4;
|
||||
|
||||
int bias_sh_rd;
|
||||
if constexpr (m_block_size_8) {
|
||||
bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
(threadIdx.x % 32) / 8;
|
||||
} else {
|
||||
bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
(threadIdx.x % 32) % 4;
|
||||
}
|
||||
|
||||
int bias_sh_wr = threadIdx.x;
|
||||
int bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x;
|
||||
|
||||
// Zero-points have the same read layout as the scales
|
||||
// (without column-wise case)
|
||||
constexpr int num_col_threads = 8;
|
||||
@ -793,7 +823,19 @@ __global__ void Marlin(
|
||||
constexpr int sh_b_size = stages * b_sh_stage;
|
||||
int4* sh_b = sh_new;
|
||||
int4* sh_red = sh_new;
|
||||
int4* sh_g_idx = sh_b + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
|
||||
|
||||
constexpr int sh_size_b_red_min =
|
||||
(sh_red_size < sh_b_size ? sh_red_size : sh_b_size);
|
||||
constexpr int sh_size_b_red_max =
|
||||
(sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
|
||||
constexpr int sh_bias_size = (thread_n_blocks * 16 / 8);
|
||||
constexpr int sh_b_red_bias_size =
|
||||
sh_size_b_red_max > (sh_size_b_red_min + sh_bias_size)
|
||||
? sh_size_b_red_max
|
||||
: (sh_size_b_red_min + sh_bias_size);
|
||||
|
||||
int4* sh_bias = sh_new + sh_size_b_red_min;
|
||||
int4* sh_g_idx = sh_new + sh_b_red_bias_size;
|
||||
int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
|
||||
constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride)
|
||||
: (stages * s_sh_stage);
|
||||
@ -803,9 +845,9 @@ __global__ void Marlin(
|
||||
static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <=
|
||||
stages * b_sh_stage);
|
||||
int4* sh_a = sh_s + sh_s_size;
|
||||
constexpr int shm_size_used =
|
||||
moe_block_size + stages * (g_idx_stage + zp_sh_stage) + sh_s_size +
|
||||
(sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
|
||||
constexpr int shm_size_used = moe_block_size +
|
||||
stages * (g_idx_stage + zp_sh_stage) +
|
||||
sh_s_size + sh_b_red_bias_size;
|
||||
|
||||
// all remaining shared memory is used to cache A (input)
|
||||
// sh_a_max_row is at least ` stages * 16 * thread_m_blocks `
|
||||
@ -816,7 +858,8 @@ __global__ void Marlin(
|
||||
FragA frag_a[2][thread_m_blocks];
|
||||
I4 frag_b_quant[2][b_thread_vecs];
|
||||
FragC frag_c[thread_m_blocks][4][2];
|
||||
FragS frag_s[2][4]; // No act-order
|
||||
FragS frag_s[2][4]; // No act-order
|
||||
FragS frag_bias[2][4];
|
||||
FragS act_frag_s[2][4][4]; // For act-order
|
||||
int frag_qzp[2][num_ints_per_thread]; // Zero-points
|
||||
FragZP frag_zp; // Zero-points in fp16
|
||||
@ -1065,10 +1108,15 @@ __global__ void Marlin(
|
||||
if constexpr (w_type_id != vllm::kFE2M1f.id()) {
|
||||
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
|
||||
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
|
||||
} else {
|
||||
} else if constexpr (group_blocks == 1 || thread_k_blocks > 4) {
|
||||
reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
|
||||
reinterpret_cast<int2*>(
|
||||
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
|
||||
} else {
|
||||
reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
|
||||
reinterpret_cast<int2*>(
|
||||
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride) +
|
||||
k % 2];
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1281,9 +1329,9 @@ __global__ void Marlin(
|
||||
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
|
||||
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];
|
||||
|
||||
dequant_fp8_scales<scalar_t2>(s_quant_0,
|
||||
reinterpret_cast<scalar_t2*>(&frag_s[k2]));
|
||||
dequant_fp8_scales<scalar_t2>(
|
||||
dequant_fp8_scales<scalar_t2, s_type_id>(
|
||||
s_quant_0, reinterpret_cast<scalar_t2*>(&frag_s[k2]));
|
||||
dequant_fp8_scales<scalar_t2, s_type_id>(
|
||||
s_quant_1, reinterpret_cast<scalar_t2*>(&frag_s[k2]) + 2);
|
||||
}
|
||||
|
||||
@ -1566,7 +1614,7 @@ __global__ void Marlin(
|
||||
// Write out the reduce final result in the correct layout. We only actually
|
||||
// reshuffle matrix fragments in this step, the reduction above is performed
|
||||
// in fragment layout.
|
||||
auto write_result = [&]() {
|
||||
auto write_result = [&](bool last) {
|
||||
int c_gl_stride = prob_n / 8;
|
||||
constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
|
||||
int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
|
||||
@ -1592,7 +1640,7 @@ __global__ void Marlin(
|
||||
|
||||
// We first reorder in shared memory to guarantee the most efficient final
|
||||
// global write patterns
|
||||
auto write = [&](int idx, float c0, float c1, FragS& s) {
|
||||
auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) {
|
||||
scalar_t2 res =
|
||||
Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));
|
||||
|
||||
@ -1601,14 +1649,27 @@ __global__ void Marlin(
|
||||
if constexpr (!has_act_order && group_blocks == -1 &&
|
||||
w_type.size_bits() == 4 &&
|
||||
(has_zp && dequant_skip_flop || !has_zp)) {
|
||||
res = __hmul2(res, s[0]);
|
||||
scalar_t2 tmp_scale = s[0];
|
||||
if constexpr (m_block_size_8) {
|
||||
tmp_scale = Dtype::num2num2(
|
||||
reinterpret_cast<scalar_t*>(&s[0])[(threadIdx.x % 8) / 4]);
|
||||
}
|
||||
res = __hmul2(res, tmp_scale);
|
||||
}
|
||||
|
||||
if constexpr (w_type == vllm::kFE2M1f) {
|
||||
if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
|
||||
if (!mul_topk_weights) {
|
||||
res = __hmul2(res, global_scale);
|
||||
}
|
||||
}
|
||||
if (has_bias && last) {
|
||||
scalar_t2 tmp_bias = b_bias[0];
|
||||
if constexpr (m_block_size_8) {
|
||||
tmp_bias = Dtype::num2num2(
|
||||
reinterpret_cast<scalar_t*>(&b_bias[0])[(threadIdx.x % 8) / 4]);
|
||||
}
|
||||
res = __hadd2(res, tmp_bias);
|
||||
}
|
||||
|
||||
if constexpr (m_block_size_8) {
|
||||
((scalar_t*)sh_red)[idx] = res.x;
|
||||
@ -1626,19 +1687,25 @@ __global__ void Marlin(
|
||||
if constexpr (m_block_size_8) {
|
||||
int wr = c_sh_wr + 16 * j;
|
||||
write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1],
|
||||
frag_s[j / 2][2 * (j % 2) + 0]);
|
||||
frag_s[j / 2][2 * (j % 2) + 0],
|
||||
frag_bias[j / 2][2 * (j % 2) + 0]);
|
||||
write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3],
|
||||
frag_s[j / 2][2 * (j % 2) + 1]);
|
||||
frag_s[j / 2][2 * (j % 2) + 1],
|
||||
frag_bias[j / 2][2 * (j % 2) + 1]);
|
||||
} else {
|
||||
int wr = c_sh_wr + 8 * j;
|
||||
write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
|
||||
frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
|
||||
frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0],
|
||||
frag_bias[j / 2][2 * (j % 2) + 0]);
|
||||
write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],
|
||||
frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);
|
||||
frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0],
|
||||
frag_bias[j / 2][2 * (j % 2) + 0]);
|
||||
write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],
|
||||
frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);
|
||||
frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1],
|
||||
frag_bias[j / 2][2 * (j % 2) + 1]);
|
||||
write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],
|
||||
frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);
|
||||
frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1],
|
||||
frag_bias[j / 2][2 * (j % 2) + 1]);
|
||||
}
|
||||
}
|
||||
c_sh_wr += 16 * (4 * c_sh_stride);
|
||||
@ -1805,6 +1872,14 @@ __global__ void Marlin(
|
||||
}
|
||||
|
||||
thread_block_reduce();
|
||||
|
||||
if (has_bias && last) {
|
||||
__syncthreads();
|
||||
cp_async4_pred(&sh_bias[bias_sh_wr], &b_bias_ptr[bias_gl_rd],
|
||||
threadIdx.x < 16 * thread_n_blocks / 8);
|
||||
cp_async_fence();
|
||||
}
|
||||
|
||||
if constexpr (!has_act_order && group_blocks == -1 &&
|
||||
(has_zp && dequant_skip_flop || !has_zp)) {
|
||||
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
|
||||
@ -1867,11 +1942,20 @@ __global__ void Marlin(
|
||||
}
|
||||
barrier_release(&locks[locks_off], last);
|
||||
}
|
||||
|
||||
if (has_bias && last) {
|
||||
cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
reinterpret_cast<int4*>(&frag_bias)[0] = sh_bias[bias_sh_rd];
|
||||
reinterpret_cast<int4*>(&frag_bias)[1] = sh_bias[bias_sh_rd + 4];
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (use_atomic_add && slice_count > 1 && slice_idx != 0)
|
||||
wait_negative_and_add(&locks[locks_off]);
|
||||
if (last || use_atomic_add)
|
||||
// only the last block in a slice actually writes the result
|
||||
write_result();
|
||||
write_result(last);
|
||||
int old_slice_row = slice_row;
|
||||
slice_row = 0;
|
||||
slice_col_par++;
|
||||
@ -1904,6 +1988,7 @@ __global__ void Marlin(
|
||||
for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
|
||||
}
|
||||
|
||||
bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x;
|
||||
// Update slice k/n for scales loading
|
||||
if constexpr (has_act_order) {
|
||||
slice_k_start = tb_k * slice_row;
|
||||
|
||||
@ -51,8 +51,9 @@ __global__ void permute_cols_kernel(
|
||||
} // namespace marlin
|
||||
|
||||
torch::Tensor moe_wna16_marlin_gemm(
|
||||
torch::Tensor& a, std::optional<torch::Tensor> const& c_or_none,
|
||||
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
|
||||
torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
|
||||
torch::Tensor& b_q_weight,
|
||||
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
|
||||
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
|
||||
@ -212,7 +213,7 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
|
||||
// Get B size
|
||||
int tb_k = th_config.thread_k;
|
||||
int tb_n = th_config.thread_n;
|
||||
int tb_m = thread_m_blocks * (m_block_size_8 ? 8 : 16);
|
||||
int tb_m = thread_m_blocks * 16;
|
||||
|
||||
// shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights
|
||||
// both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)
|
||||
@ -220,6 +221,11 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
|
||||
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2;
|
||||
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
|
||||
int sh_red_size = tb_m * (tb_n + 8) * 2;
|
||||
int sh_bias_size = tb_n * 2;
|
||||
int tmp_size =
|
||||
(sh_b_size > sh_red_size ? sh_red_size : sh_b_size) + sh_bias_size;
|
||||
tmp_size = max(max(sh_b_size, sh_red_size), tmp_size);
|
||||
|
||||
int sh_s_size =
|
||||
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
|
||||
group_size, has_act_order, is_k_full);
|
||||
@ -234,8 +240,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
|
||||
sh_zp_size = sh_s_size / 2;
|
||||
}
|
||||
|
||||
int total_size = max(sh_b_size, sh_red_size) + sh_a_size + sh_s_size +
|
||||
sh_zp_size + sh_g_idx_size + sh_block_meta_size;
|
||||
int total_size = tmp_size + sh_a_size + sh_s_size + sh_zp_size +
|
||||
sh_g_idx_size + sh_block_meta_size;
|
||||
|
||||
return total_size;
|
||||
}
|
||||
@ -270,20 +276,25 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
|
||||
int cache_size = get_kernel_cache_size(
|
||||
th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k,
|
||||
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float);
|
||||
return cache_size <= max_shared_mem;
|
||||
return cache_size + 512 <= max_shared_mem;
|
||||
}
|
||||
|
||||
#define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
||||
M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \
|
||||
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
|
||||
thread_n_blocks == THREAD_N_BLOCKS && \
|
||||
thread_k_blocks == THREAD_K_BLOCKS && \
|
||||
m_block_size_8 == M_BLOCK_SIZE_8 && \
|
||||
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
|
||||
is_zp_float == IS_ZP_FLOAT) { \
|
||||
kernel = Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
|
||||
THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, \
|
||||
pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
|
||||
#define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
||||
M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \
|
||||
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
|
||||
thread_n_blocks == THREAD_N_BLOCKS && \
|
||||
thread_k_blocks == THREAD_K_BLOCKS && \
|
||||
m_block_size_8 == M_BLOCK_SIZE_8 && \
|
||||
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
|
||||
is_zp_float == IS_ZP_FLOAT) { \
|
||||
constexpr auto S_TYPE = \
|
||||
W_TYPE == vllm::kFE2M1f \
|
||||
? (GROUP_BLOCKS == 1 ? vllm::kFE4M3fn : vllm::kFE8M0fnu) \
|
||||
: (std::is_same<scalar_t, half>::value ? vllm::kFloat16 \
|
||||
: vllm::kBFloat16); \
|
||||
kernel = Marlin<scalar_t, W_TYPE.id(), S_TYPE.id(), NUM_THREADS, \
|
||||
THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
||||
M_BLOCK_SIZE_8, pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
|
||||
}
|
||||
|
||||
// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false)
|
||||
@ -335,31 +346,45 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
|
||||
\
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
|
||||
|
||||
#define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
|
||||
|
||||
#define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
|
||||
|
||||
#define FP4_GET_IF(W_TYPE) \
|
||||
FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||
FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||
FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||
FP4_GET_IF_M234(W_TYPE, 8, 4, 128)
|
||||
|
||||
#define BIGGROUP_GET_IF(W_TYPE) \
|
||||
BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||
BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||
BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||
BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128)
|
||||
|
||||
#define NVFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
|
||||
|
||||
#define NVFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
|
||||
|
||||
#define NVFP4_GET_IF(W_TYPE) \
|
||||
NVFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||
NVFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||
NVFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||
NVFP4_GET_IF_M234(W_TYPE, 8, 4, 128)
|
||||
|
||||
#define MXFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
|
||||
|
||||
#define MXFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
|
||||
|
||||
#define MXFP4_GET_IF(W_TYPE) \
|
||||
MXFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||
MXFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||
MXFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||
MXFP4_GET_IF_M234(W_TYPE, 8, 4, 128)
|
||||
|
||||
// We currently have 4-bit models only with group_blocks == 4
|
||||
#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \
|
||||
@ -408,12 +433,17 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
|
||||
COMMON_GET_IF(vllm::kU4B8)
|
||||
COMMON_GET_IF(vllm::kU8B128)
|
||||
|
||||
BIGGROUP_GET_IF(vllm::kFE4M3fn)
|
||||
NVFP4_GET_IF(vllm::kFE2M1f)
|
||||
|
||||
FP4_GET_IF(vllm::kFE2M1f)
|
||||
BIGGROUP_GET_IF(vllm::kFE4M3fn)
|
||||
|
||||
ACT_GET_IF(vllm::kU4B8)
|
||||
ACT_GET_IF(vllm::kU8B128)
|
||||
if (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
if (false) {
|
||||
}
|
||||
MXFP4_GET_IF(vllm::kFE2M1f)
|
||||
}
|
||||
|
||||
return kernel;
|
||||
}
|
||||
@ -482,16 +512,16 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
void* s2, void* zp, void* g_idx, void* perm, void* a_tmp,
|
||||
void* sorted_token_ids, void* expert_ids,
|
||||
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
void* s, void* s2, void* zp, void* g_idx, void* perm,
|
||||
void* a_tmp, void* sorted_token_ids, void* expert_ids,
|
||||
void* num_tokens_past_padded, void* topk_weights,
|
||||
int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep,
|
||||
int prob_m, int prob_n, int prob_k, void* workspace,
|
||||
vllm::ScalarType const& q_type, bool has_act_order,
|
||||
bool is_k_full, bool has_zp, int num_groups, int group_size,
|
||||
int dev, cudaStream_t stream, int thread_k, int thread_n,
|
||||
int sms, bool use_atomic_add, bool use_fp32_reduce,
|
||||
vllm::ScalarType const& q_type, bool has_bias,
|
||||
bool has_act_order, bool is_k_full, bool has_zp, int num_groups,
|
||||
int group_size, int dev, cudaStream_t stream, int thread_k,
|
||||
int thread_n, int sms, bool use_atomic_add, bool use_fp32_reduce,
|
||||
bool is_zp_float) {
|
||||
int thread_m_blocks = div_ceil(moe_block_size, 16);
|
||||
bool m_block_size_8 = moe_block_size == 8;
|
||||
@ -538,6 +568,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
const int4* B_ptr = (const int4*)B;
|
||||
int4* C_ptr = (int4*)C;
|
||||
int4* C_tmp_ptr = (int4*)C_tmp;
|
||||
const int4* bias_ptr = (const int4*)b_bias;
|
||||
const int4* s_ptr = (const int4*)s;
|
||||
const uint16_t* s2_ptr = (const uint16_t*)s2;
|
||||
const int4* zp_ptr = (const int4*)zp;
|
||||
@ -648,10 +679,10 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
// avoid ">>>" being formatted to "> > >"
|
||||
// clang-format off
|
||||
kernel<<<blocks, num_threads, max_shared_mem, stream>>>(
|
||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr,
|
||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr,
|
||||
sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,
|
||||
topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m,
|
||||
prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce, max_shared_mem);
|
||||
prob_n, prob_k, locks, has_bias, use_atomic_add, use_fp32_reduce, max_shared_mem);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@ -659,7 +690,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
|
||||
torch::Tensor moe_wna16_marlin_gemm(
|
||||
torch::Tensor& a, std::optional<torch::Tensor> const& c_or_none,
|
||||
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
|
||||
torch::Tensor& b_q_weight,
|
||||
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
|
||||
std::optional<torch::Tensor> const& global_scale_or_none,
|
||||
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||
@ -766,7 +798,6 @@ torch::Tensor moe_wna16_marlin_gemm(
|
||||
num_groups = b_scales.size(1);
|
||||
|
||||
torch::Tensor g_idx, perm, a_tmp;
|
||||
;
|
||||
if (g_idx_or_none.has_value() && perm_or_none.has_value()) {
|
||||
g_idx = g_idx_or_none.value();
|
||||
perm = perm_or_none.value();
|
||||
@ -815,12 +846,24 @@ torch::Tensor moe_wna16_marlin_gemm(
|
||||
torch::Tensor global_scale;
|
||||
if (global_scale_or_none.has_value()) {
|
||||
global_scale = global_scale_or_none.value();
|
||||
TORCH_CHECK(b_q_type == vllm::kFE2M1f,
|
||||
"global_scale can only be used for float4_e2m1f.");
|
||||
TORCH_CHECK(b_q_type == vllm::kFE2M1f && group_size == 16,
|
||||
"global_scale can only be used for nvfp4 format.");
|
||||
} else {
|
||||
global_scale = torch::empty({0}, options);
|
||||
TORCH_CHECK(!(b_q_type == vllm::kFE2M1f),
|
||||
"the global_scale parameter must be passed for float4_e2m1f.");
|
||||
TORCH_CHECK(!(b_q_type == vllm::kFE2M1f && group_size == 16),
|
||||
"the global_scale parameter must be passed for nvfp4 format.");
|
||||
}
|
||||
|
||||
bool has_bias = b_bias_or_none.has_value();
|
||||
torch::Tensor b_bias;
|
||||
if (has_bias) {
|
||||
b_bias = b_bias_or_none.value();
|
||||
TORCH_CHECK(b_bias.device().is_cuda(), "b_bias is not on GPU");
|
||||
TORCH_CHECK(b_bias.is_contiguous(), "b_bias is not contiguous");
|
||||
TORCH_CHECK(b_bias.size(1) == size_n, "b_bias.size(0) != size_n");
|
||||
TORCH_CHECK(b_bias.stride(1) == 1, "b_bias.stride(1) != 1");
|
||||
} else {
|
||||
b_bias = torch::empty({0}, options);
|
||||
}
|
||||
|
||||
torch::Tensor b_zeros;
|
||||
@ -832,7 +875,6 @@ torch::Tensor moe_wna16_marlin_gemm(
|
||||
b_zeros = torch::empty({0}, options);
|
||||
}
|
||||
bool has_zp = b_zeros.size(-1) > 0;
|
||||
|
||||
if (has_zp) {
|
||||
TORCH_CHECK(
|
||||
b_q_type == vllm::kU4 || b_q_type == vllm::kU8,
|
||||
@ -890,41 +932,58 @@ torch::Tensor moe_wna16_marlin_gemm(
|
||||
if (a.scalar_type() == at::ScalarType::Half) {
|
||||
void* scales_ptr;
|
||||
if (b_q_type == vllm::kFE2M1f) {
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
|
||||
if (group_size == 16)
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
|
||||
else if (group_size == 32)
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
|
||||
else
|
||||
TORCH_CHECK(false,
|
||||
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
|
||||
"and group_size == 32 (MXFP4)");
|
||||
} else {
|
||||
scales_ptr = b_scales.data_ptr<at::Half>();
|
||||
}
|
||||
|
||||
MARLIN_NAMESPACE_NAME::marlin_mm<half>(
|
||||
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
|
||||
c_tmp.data_ptr<float>(), scales_ptr, global_scale.data_ptr<at::Half>(),
|
||||
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
|
||||
a_tmp.data_ptr<at::Half>(), sorted_token_ids.data_ptr(),
|
||||
expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(),
|
||||
topk_weights.data_ptr(), moe_block_size, top_k, mul_topk_weights, is_ep,
|
||||
size_m, size_n, size_k, workspace.data_ptr(), b_q_type, has_act_order,
|
||||
is_k_full, has_zp, num_groups, group_size, dev,
|
||||
c_tmp.data_ptr<float>(), b_bias.data_ptr<at::Half>(), scales_ptr,
|
||||
global_scale.data_ptr<at::Half>(), b_zeros.data_ptr(), g_idx.data_ptr(),
|
||||
perm.data_ptr(), a_tmp.data_ptr<at::Half>(),
|
||||
sorted_token_ids.data_ptr(), expert_ids.data_ptr(),
|
||||
num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(),
|
||||
moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k,
|
||||
workspace.data_ptr(), b_q_type, has_bias, has_act_order, is_k_full,
|
||||
has_zp, num_groups, group_size, dev,
|
||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||
use_atomic_add, use_fp32_reduce, is_zp_float);
|
||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||
void* scales_ptr;
|
||||
if (b_q_type == vllm::kFE2M1f) {
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
|
||||
if (group_size == 16)
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
|
||||
else if (group_size == 32)
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
|
||||
else
|
||||
TORCH_CHECK(false,
|
||||
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
|
||||
"and group_size == 32 (MXFP4)");
|
||||
} else {
|
||||
scales_ptr = b_scales.data_ptr<at::BFloat16>();
|
||||
}
|
||||
|
||||
MARLIN_NAMESPACE_NAME::marlin_mm<nv_bfloat16>(
|
||||
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
||||
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(), scales_ptr,
|
||||
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
|
||||
b_bias.data_ptr<at::BFloat16>(), scales_ptr,
|
||||
global_scale.data_ptr<at::BFloat16>(), b_zeros.data_ptr(),
|
||||
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
|
||||
sorted_token_ids.data_ptr(), expert_ids.data_ptr(),
|
||||
num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(),
|
||||
moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k,
|
||||
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
|
||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float);
|
||||
workspace.data_ptr(), b_q_type, has_bias, has_act_order, is_k_full,
|
||||
has_zp, num_groups, group_size, dev,
|
||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||
use_atomic_add, use_fp32_reduce, is_zp_float);
|
||||
} else {
|
||||
TORCH_CHECK(false,
|
||||
"moe_wna16_marlin_gemm only supports bfloat16 and float16");
|
||||
|
||||
@ -45,8 +45,6 @@ void moe_permute(
|
||||
auto copy_topk_ids = topk_ids.clone(); // copy topk_ids for preprocess
|
||||
auto permuted_experts_id = torch::empty_like(topk_ids);
|
||||
auto sorted_row_idx = torch::empty_like(inv_permuted_idx);
|
||||
auto align_expert_first_token_offset =
|
||||
torch::zeros_like(expert_first_token_offset);
|
||||
|
||||
CubKeyValueSorter sorter{};
|
||||
int64_t* valid_num_ptr = nullptr;
|
||||
@ -85,12 +83,14 @@ void moe_permute(
|
||||
});
|
||||
|
||||
// get m_indices and update expert_first_token_offset with align block
|
||||
getMIndices(get_ptr<int64_t>(expert_first_token_offset),
|
||||
get_ptr<int64_t>(align_expert_first_token_offset),
|
||||
get_ptr<int>(m_indices), n_local_expert, align_block_size_value,
|
||||
stream);
|
||||
// this is only required for DeepGemm and not required for CUTLASS group gemm
|
||||
if (align_block_size.has_value()) {
|
||||
// update align_expert_first_token_offset
|
||||
auto align_expert_first_token_offset =
|
||||
torch::zeros_like(expert_first_token_offset);
|
||||
getMIndices(get_ptr<int64_t>(expert_first_token_offset),
|
||||
get_ptr<int64_t>(align_expert_first_token_offset),
|
||||
get_ptr<int>(m_indices), n_local_expert, align_block_size_value,
|
||||
stream);
|
||||
expert_first_token_offset.copy_(align_expert_first_token_offset);
|
||||
}
|
||||
}
|
||||
@ -195,19 +195,14 @@ void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights,
|
||||
torch::Tensor& expert_first_token_offset,
|
||||
torch::Tensor& src_row_id2dst_row_id_map,
|
||||
torch::Tensor& m_indices) {
|
||||
TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0");
|
||||
TORCH_CHECK(false, "moe_permute is not supported on CUDA < 12.0");
|
||||
}
|
||||
|
||||
void moe_unpermute(const torch::Tensor& input,
|
||||
const torch::Tensor& topk_weights, torch::Tensor& topk_ids,
|
||||
const torch::Tensor& token_expert_indices,
|
||||
const std::optional<torch::Tensor>& expert_map,
|
||||
int64_t n_expert, int64_t n_local_expert, int64_t topk,
|
||||
const std::optional<int64_t>& align_block_size,
|
||||
torch::Tensor& permuted_input,
|
||||
torch::Tensor& expert_first_token_offset,
|
||||
torch::Tensor& src_row_id2dst_row_id_map,
|
||||
torch::Tensor& m_indices) {
|
||||
void moe_unpermute(
|
||||
const torch::Tensor& permuted_hidden_states,
|
||||
const torch::Tensor& topk_weights, const torch::Tensor& inv_permuted_idx,
|
||||
const std::optional<torch::Tensor>& expert_first_token_offset, int64_t topk,
|
||||
torch::Tensor& hidden_states) {
|
||||
TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0");
|
||||
}
|
||||
|
||||
@ -224,4 +219,4 @@ bool moe_permute_unpermute_supported() {
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("moe_permute", &moe_permute);
|
||||
m.impl("moe_unpermute", &moe_unpermute);
|
||||
}
|
||||
}
|
||||
@ -188,7 +188,9 @@ __launch_bounds__(TPB) __global__ void moeTopK(
|
||||
It fuses the softmax, max and argmax into a single kernel.
|
||||
|
||||
Limitations:
|
||||
1) This implementation is intended for when the number of experts is a small power of 2.
|
||||
1) This implementation is optimized for when the number of experts is a small power of 2.
|
||||
Additionally it also supports when number of experts is multiple of 64 which is still
|
||||
faster than the computing softmax and topK separately (only tested on CUDA yet).
|
||||
2) This implementation assumes k is small, but will work for any k.
|
||||
*/
|
||||
|
||||
@ -198,8 +200,6 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
|
||||
int* source_rows, const int k, const int start_expert, const int end_expert)
|
||||
{
|
||||
// We begin by enforcing compile time assertions and setting up compile time constants.
|
||||
static_assert(VPT == (VPT & -VPT), "VPT must be power of 2");
|
||||
static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2");
|
||||
static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2");
|
||||
static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");
|
||||
|
||||
@ -407,12 +407,10 @@ struct TopkConstants
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <int EXPERTS, int WARPS_PER_TB, int WARP_SIZE_PARAM, typename IndType>
|
||||
template <int EXPERTS, int WARPS_PER_TB, int WARP_SIZE_PARAM, int MAX_BYTES_PER_LDG, typename IndType>
|
||||
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, IndType* indices,
|
||||
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
|
||||
{
|
||||
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
|
||||
|
||||
static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
|
||||
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG, WARP_SIZE_PARAM>;
|
||||
static constexpr int VPT = Constants::VPT;
|
||||
@ -425,21 +423,27 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
|
||||
input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert);
|
||||
}
|
||||
|
||||
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \
|
||||
switch (warpSize) { \
|
||||
case 32: \
|
||||
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 32>( \
|
||||
gating_output, nullptr, topk_weights, topk_indices, \
|
||||
token_expert_indices, num_tokens, topk, 0, num_experts, stream); \
|
||||
break; \
|
||||
case 64: \
|
||||
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 64>( \
|
||||
gating_output, nullptr, topk_weights, topk_indices, \
|
||||
token_expert_indices, num_tokens, topk, 0, num_experts, stream); \
|
||||
break; \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported warp size: ", warpSize); \
|
||||
#ifndef USE_ROCM
|
||||
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \
|
||||
static_assert(WARP_SIZE == 32, \
|
||||
"Unsupported warp size. Only 32 is supported for CUDA"); \
|
||||
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, WARP_SIZE, MAX_BYTES>( \
|
||||
gating_output, nullptr, topk_weights, topk_indices, \
|
||||
token_expert_indices, num_tokens, topk, 0, num_experts, stream);
|
||||
#else
|
||||
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \
|
||||
if (WARP_SIZE == 64) { \
|
||||
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 64, MAX_BYTES>( \
|
||||
gating_output, nullptr, topk_weights, topk_indices, \
|
||||
token_expert_indices, num_tokens, topk, 0, num_experts, stream); \
|
||||
} else if (WARP_SIZE == 32) { \
|
||||
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 32, MAX_BYTES>( \
|
||||
gating_output, nullptr, topk_weights, topk_indices, \
|
||||
token_expert_indices, num_tokens, topk, 0, num_experts, stream); \
|
||||
} else { \
|
||||
assert(false && "Unsupported warp size. Only 32 and 64 are supported for ROCm"); \
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename IndType>
|
||||
void topkGatingSoftmaxKernelLauncher(
|
||||
@ -453,38 +457,64 @@ void topkGatingSoftmaxKernelLauncher(
|
||||
const int topk,
|
||||
cudaStream_t stream) {
|
||||
static constexpr int WARPS_PER_TB = 4;
|
||||
auto warpSize = WARP_SIZE;
|
||||
static constexpr int BYTES_PER_LDG_POWER_OF_2 = 16;
|
||||
#ifndef USE_ROCM
|
||||
static constexpr int BYTES_PER_LDG_MULTIPLE_64 = 8;
|
||||
#endif
|
||||
switch (num_experts) {
|
||||
case 1:
|
||||
LAUNCH_SOFTMAX(1, WARPS_PER_TB);
|
||||
LAUNCH_SOFTMAX(1, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 2:
|
||||
LAUNCH_SOFTMAX(2, WARPS_PER_TB);
|
||||
LAUNCH_SOFTMAX(2, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 4:
|
||||
LAUNCH_SOFTMAX(4, WARPS_PER_TB);
|
||||
LAUNCH_SOFTMAX(4, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 8:
|
||||
LAUNCH_SOFTMAX(8, WARPS_PER_TB);
|
||||
LAUNCH_SOFTMAX(8, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 16:
|
||||
LAUNCH_SOFTMAX(16, WARPS_PER_TB);
|
||||
LAUNCH_SOFTMAX(16, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 32:
|
||||
LAUNCH_SOFTMAX(32, WARPS_PER_TB);
|
||||
LAUNCH_SOFTMAX(32, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 64:
|
||||
LAUNCH_SOFTMAX(64, WARPS_PER_TB);
|
||||
LAUNCH_SOFTMAX(64, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 128:
|
||||
LAUNCH_SOFTMAX(128, WARPS_PER_TB);
|
||||
LAUNCH_SOFTMAX(128, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 256:
|
||||
LAUNCH_SOFTMAX(256, WARPS_PER_TB);
|
||||
LAUNCH_SOFTMAX(256, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 512:
|
||||
LAUNCH_SOFTMAX(512, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
// (CUDA only) support multiples of 64 when num_experts is not power of 2.
|
||||
// ROCm uses WARP_SIZE 64 so 8 bytes loading won't fit for some of num_experts,
|
||||
// alternatively we can test 4 bytes loading and enable it in future.
|
||||
#ifndef USE_ROCM
|
||||
case 192:
|
||||
LAUNCH_SOFTMAX(192, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
|
||||
break;
|
||||
case 320:
|
||||
LAUNCH_SOFTMAX(320, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
|
||||
break;
|
||||
case 384:
|
||||
LAUNCH_SOFTMAX(384, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
|
||||
break;
|
||||
case 448:
|
||||
LAUNCH_SOFTMAX(448, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
|
||||
break;
|
||||
case 576:
|
||||
LAUNCH_SOFTMAX(576, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
|
||||
break;
|
||||
#endif
|
||||
default: {
|
||||
TORCH_CHECK(softmax_workspace != nullptr,
|
||||
"softmax_workspace must be provided for num_experts that are not a power of 2.");
|
||||
"softmax_workspace must be provided for num_experts that are not a power of 2 or multiple of 64.");
|
||||
static constexpr int TPB = 256;
|
||||
moeSoftmax<TPB><<<num_tokens, TPB, 0, stream>>>(
|
||||
gating_output, nullptr, softmax_workspace, num_experts);
|
||||
|
||||
@ -35,7 +35,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
|
||||
m.def(
|
||||
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
|
||||
"Tensor! b_q_weight, Tensor! b_scales, Tensor? global_scale, Tensor? "
|
||||
"Tensor! b_q_weight, Tensor? b_bias_or_none,"
|
||||
"Tensor! b_scales, Tensor? global_scale, Tensor? "
|
||||
"b_zeros_or_none,"
|
||||
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
|
||||
"Tensor sorted_token_ids,"
|
||||
|
||||
32
csrc/ops.h
32
csrc/ops.h
@ -138,6 +138,8 @@ void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
void fatrelu_and_mul(torch::Tensor& out, torch::Tensor& input,
|
||||
double threshold);
|
||||
void swigluoai_and_mul(torch::Tensor& out, torch::Tensor& input,
|
||||
double alpha = 1.702, double limit = 7.0);
|
||||
|
||||
void gelu_new(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
@ -145,22 +147,6 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
void gelu_quick(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
void advance_step_flashattn(int64_t num_seqs, int64_t num_queries,
|
||||
int64_t block_size, torch::Tensor& input_tokens,
|
||||
torch::Tensor& sampled_token_ids,
|
||||
torch::Tensor& input_positions,
|
||||
torch::Tensor& seq_lens,
|
||||
torch::Tensor& slot_mapping,
|
||||
torch::Tensor& block_tables);
|
||||
|
||||
void advance_step_flashinfer(
|
||||
int64_t num_seqs, int64_t num_queries, int64_t block_size,
|
||||
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
|
||||
torch::Tensor& input_positions, torch::Tensor& seq_lens,
|
||||
torch::Tensor& slot_mapping, torch::Tensor& block_tables,
|
||||
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
|
||||
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds);
|
||||
|
||||
void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
|
||||
torch::Tensor const& q_pe,
|
||||
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||
@ -170,15 +156,6 @@ void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
|
||||
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
const torch::Tensor& scales,
|
||||
const std::vector<int64_t>& codebook_partition_sizes,
|
||||
const std::optional<torch::Tensor>& bias);
|
||||
|
||||
torch::Tensor aqlm_dequant(
|
||||
const torch::Tensor& codes, const torch::Tensor& codebooks,
|
||||
const std::vector<int64_t>& codebook_partition_sizes);
|
||||
|
||||
torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors, torch::Tensor _zeros,
|
||||
@ -252,6 +229,11 @@ void get_cutlass_moe_mm_data(
|
||||
const int64_t num_experts, const int64_t n, const int64_t k,
|
||||
const std::optional<torch::Tensor>& blockscale_offsets);
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
|
||||
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets);
|
||||
|
||||
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
|
||||
@ -1,336 +0,0 @@
|
||||
/*
|
||||
* The goal of this GPU kernel is to advance input tensors on the GPU directly
|
||||
* PR: https://github.com/vllm-project/vllm/pull/6338
|
||||
* Current restrictions:
|
||||
* 1. Specialized for DraftModelRunner
|
||||
* 2. Supports flash_attn only
|
||||
*/
|
||||
|
||||
#include "advance_step.cuh"
|
||||
|
||||
namespace prepare_inputs {
|
||||
|
||||
//
|
||||
template <int const num_threads>
|
||||
__global__ void advance_step_flashattn_kernel(
|
||||
int num_seqs, int num_queries, int block_size, long* input_tokens_ptr,
|
||||
long const* sampled_token_ids_ptr, long* input_positions_ptr,
|
||||
int* seq_lens_ptr, long* slot_mapping_ptr, int const* block_tables_ptr,
|
||||
int64_t const block_tables_stride) {
|
||||
int const n_pad = num_seqs - num_queries;
|
||||
if (n_pad && blockIdx.x == 0) {
|
||||
// Handle cuda graph padding
|
||||
int const offset = num_queries;
|
||||
for (int i = threadIdx.x; i < n_pad; i += blockDim.x) {
|
||||
input_tokens_ptr[offset + i] = 0;
|
||||
input_positions_ptr[offset + i] = 0;
|
||||
slot_mapping_ptr[offset + i] = -1;
|
||||
}
|
||||
}
|
||||
|
||||
int num_query_blocks = div_ceil(num_queries, num_threads);
|
||||
|
||||
if (blockIdx.x >= num_query_blocks) {
|
||||
return;
|
||||
}
|
||||
|
||||
int cur_query_id = blockIdx.x * num_threads + threadIdx.x;
|
||||
|
||||
if (cur_query_id >= num_queries) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Update input_tokens
|
||||
input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id];
|
||||
|
||||
int seq_len = seq_lens_ptr[cur_query_id];
|
||||
int next_seq_len = seq_len + 1;
|
||||
int next_input_pos = next_seq_len - 1;
|
||||
|
||||
// Update seq_lens
|
||||
seq_lens_ptr[cur_query_id] = next_seq_len;
|
||||
// Update input_positions
|
||||
input_positions_ptr[cur_query_id] = next_input_pos;
|
||||
|
||||
int const* seq_block_tables_ptr =
|
||||
block_tables_ptr + block_tables_stride * cur_query_id;
|
||||
|
||||
int block_index = next_input_pos / block_size;
|
||||
int block_offset = next_input_pos % block_size;
|
||||
|
||||
int slot_num = seq_block_tables_ptr[block_index] * block_size + block_offset;
|
||||
// Update slot_mapping
|
||||
slot_mapping_ptr[cur_query_id] = slot_num;
|
||||
}
|
||||
|
||||
inline void verify_tensor(std::string const& name, torch::Tensor const& t,
|
||||
int64_t const size_0, int64_t const size_1,
|
||||
c10::ScalarType const type) {
|
||||
bool size_0_cond = true;
|
||||
if (size_0 != -1) {
|
||||
size_0_cond = t.size(0) == size_0;
|
||||
}
|
||||
|
||||
bool size_1_cond = true;
|
||||
if (size_1 != -1) {
|
||||
size_1_cond = t.size(1) == size_1;
|
||||
}
|
||||
|
||||
bool is_contiguous = t.is_contiguous();
|
||||
bool same_type = t.dtype() == type;
|
||||
|
||||
bool pass = size_0_cond && size_1_cond && is_contiguous && same_type;
|
||||
if (!pass) {
|
||||
TORCH_CHECK(false, "tensor: name = ", name, ", shape = ", t.sizes(),
|
||||
" is_cont = ", t.is_contiguous(), ", type = ", t.dtype(),
|
||||
" is not as expected: shape = [", size_0, ", ", size_1,
|
||||
"], type = ", type);
|
||||
}
|
||||
}
|
||||
|
||||
/// each thread processes a block per query
|
||||
__global__ void advance_step_flashinfer_kernel(
|
||||
int num_threads, int num_seqs, int num_queries, int block_size,
|
||||
long* input_tokens_ptr, long const* sampled_token_ids_ptr,
|
||||
long* input_positions_ptr, int* seq_lens_ptr, long* slot_mapping_ptr,
|
||||
int const* block_tables_ptr, int64_t const block_tables_stride,
|
||||
int* paged_kv_last_page_len_ptr, int* block_table_bound_ptr) {
|
||||
int const n_pad = num_seqs - num_queries;
|
||||
if (n_pad && blockIdx.x == 0) {
|
||||
// Handle cuda graph padding
|
||||
int const offset = num_queries;
|
||||
for (int i = threadIdx.x; i < n_pad; i += blockDim.x) {
|
||||
input_tokens_ptr[offset + i] = 0;
|
||||
input_positions_ptr[offset + i] = 0;
|
||||
slot_mapping_ptr[offset + i] = -1;
|
||||
}
|
||||
}
|
||||
int num_query_blocks = div_ceil(num_queries, num_threads);
|
||||
|
||||
if (blockIdx.x < num_query_blocks) {
|
||||
int cur_query_id = blockIdx.x * num_threads + threadIdx.x;
|
||||
|
||||
if (cur_query_id < num_queries) {
|
||||
// Update input_tokens
|
||||
input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id];
|
||||
|
||||
int seq_len = seq_lens_ptr[cur_query_id];
|
||||
int next_seq_len = seq_len + 1;
|
||||
int next_input_pos = next_seq_len - 1;
|
||||
|
||||
// Update seq_lens
|
||||
seq_lens_ptr[cur_query_id] = next_seq_len;
|
||||
// Update input_positions
|
||||
input_positions_ptr[cur_query_id] = next_input_pos;
|
||||
|
||||
int const* seq_block_tables_ptr =
|
||||
block_tables_ptr + block_tables_stride * cur_query_id;
|
||||
|
||||
int block_index = next_input_pos / block_size;
|
||||
int block_offset = next_input_pos % block_size;
|
||||
|
||||
// Update paged_kv_last_page_len
|
||||
paged_kv_last_page_len_ptr[cur_query_id] = block_offset + 1;
|
||||
|
||||
int slot_num =
|
||||
seq_block_tables_ptr[block_index] * block_size + block_offset;
|
||||
// Update slot_mapping
|
||||
slot_mapping_ptr[cur_query_id] = slot_num;
|
||||
block_table_bound_ptr[cur_query_id] = div_ceil(next_seq_len, block_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void advance_step_flashinfer_indptr_kernel(
|
||||
int num_threads, int num_seqs, int num_queries, int* paged_kv_indptr_ptr,
|
||||
int* block_table_bound_ptr) {
|
||||
int idx = blockIdx.x * num_threads + threadIdx.x;
|
||||
// Update paged_kv_indptr
|
||||
if (idx == 0) {
|
||||
paged_kv_indptr_ptr[idx] = 0;
|
||||
}
|
||||
if (idx < num_queries) {
|
||||
int sum = 0;
|
||||
for (int i = 0; i <= idx; ++i) {
|
||||
sum += block_table_bound_ptr[i];
|
||||
}
|
||||
paged_kv_indptr_ptr[idx + 1] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void advance_step_flashinfer_indices_kernel(
|
||||
int num_seqs, int num_queries, int const* block_tables_ptr,
|
||||
int64_t const max_num_blocks_per_seq, int* paged_kv_indices_ptr,
|
||||
int* paged_kv_indptr_ptr, int* block_table_bound_ptr) {
|
||||
// note: max_num_blocks_per_seq = block_tables.stride(0)
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
// when cuda graphs are enabled, paged_kv_indptr tensor
|
||||
// has to be updated for the padded queries
|
||||
// tid represents a query# for paged_kv_indptr tensor
|
||||
if (num_queries < tid && tid <= num_seqs) {
|
||||
paged_kv_indptr_ptr[tid] = paged_kv_indptr_ptr[num_queries];
|
||||
}
|
||||
|
||||
// each thread processes a block_ptr in block_tables
|
||||
// block_tables shape: [num_queries, max_num_blocks_per_seq]
|
||||
// paged_kv_indices is flattened block_tables.
|
||||
for (int idx = tid; idx < (num_seqs * max_num_blocks_per_seq);
|
||||
idx += (gridDim.x * blockDim.x)) {
|
||||
// block_tables-row = paged_kv_indptr[queryNum]
|
||||
int queryNum = idx / max_num_blocks_per_seq;
|
||||
int col = idx % max_num_blocks_per_seq;
|
||||
if (queryNum < num_queries && col < block_table_bound_ptr[queryNum]) {
|
||||
int indices_arr_idx = paged_kv_indptr_ptr[queryNum] + col;
|
||||
int block_tables_idx = queryNum * max_num_blocks_per_seq + col;
|
||||
paged_kv_indices_ptr[indices_arr_idx] =
|
||||
block_tables_ptr[block_tables_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void advance_step_flashattn(int num_seqs, int num_queries, int block_size,
|
||||
torch::Tensor& input_tokens, // type: long
|
||||
torch::Tensor& sampled_token_ids, // type: long
|
||||
torch::Tensor& input_positions, // type: long
|
||||
torch::Tensor& seq_lens, // type: int
|
||||
torch::Tensor& slot_mapping, // type: long
|
||||
torch::Tensor& block_tables) { // type: int
|
||||
|
||||
if (logging) {
|
||||
printf("advance_step_flashattn:\n");
|
||||
printf(" num_seqs = %d\n", num_seqs);
|
||||
printf(" num_queries = %d\n", num_queries);
|
||||
printf(" block_size = %d\n", block_size);
|
||||
}
|
||||
// Verify all tensors
|
||||
verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong);
|
||||
verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1,
|
||||
at::kLong);
|
||||
verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong);
|
||||
verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt);
|
||||
verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong);
|
||||
verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt);
|
||||
|
||||
int dev = sampled_token_ids.get_device();
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
|
||||
|
||||
int blocks;
|
||||
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
||||
|
||||
advance_step_flashattn_kernel<max_threads>
|
||||
<<<blocks, max_threads, 0, stream>>>(
|
||||
num_seqs, num_queries, block_size,
|
||||
reinterpret_cast<long*>(input_tokens.data_ptr()),
|
||||
reinterpret_cast<long const*>(sampled_token_ids.data_ptr()),
|
||||
reinterpret_cast<long*>(input_positions.data_ptr()),
|
||||
reinterpret_cast<int*>(seq_lens.data_ptr()),
|
||||
reinterpret_cast<long*>(slot_mapping.data_ptr()),
|
||||
reinterpret_cast<int const*>(block_tables.data_ptr()),
|
||||
block_tables.stride(0));
|
||||
}
|
||||
|
||||
void advance_step_flashinfer(
|
||||
int num_seqs, int num_queries, int block_size,
|
||||
torch::Tensor& input_tokens, // type: long
|
||||
torch::Tensor& sampled_token_ids, // type: long
|
||||
torch::Tensor& input_positions, // type: long
|
||||
torch::Tensor& seq_lens, // type: int
|
||||
torch::Tensor& slot_mapping, // type: long
|
||||
torch::Tensor& block_tables, // type: int
|
||||
torch::Tensor& paged_kv_indices, // type: int
|
||||
torch::Tensor& paged_kv_indptr, // type: int
|
||||
torch::Tensor& paged_kv_last_page_len, // type: int
|
||||
torch::Tensor& block_table_bound) { // type: int
|
||||
|
||||
if (logging) {
|
||||
printf("advance_step_flashinfer:\n");
|
||||
printf(" num_seqs = %d\n", num_seqs);
|
||||
printf(" num_queries = %d\n", num_queries);
|
||||
printf(" block_size = %d\n", block_size);
|
||||
printf(" block_tables.stride(0) = %zu\n", block_tables.stride(0));
|
||||
}
|
||||
// Verify all tensors
|
||||
verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong);
|
||||
// verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1,
|
||||
// at::kLong);
|
||||
verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong);
|
||||
verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt);
|
||||
verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong);
|
||||
verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt);
|
||||
|
||||
verify_tensor("paged_kv_indices", paged_kv_indices, -1, -1, at::kInt);
|
||||
verify_tensor("paged_kv_indptr", paged_kv_indptr, num_seqs + 1, -1, at::kInt);
|
||||
verify_tensor("paged_kv_last_page_len", paged_kv_last_page_len, num_seqs, -1,
|
||||
at::kInt);
|
||||
|
||||
verify_tensor("block_table_bound", block_table_bound, num_seqs, -1, at::kInt);
|
||||
|
||||
int dev = sampled_token_ids.get_device();
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
|
||||
|
||||
int blocks;
|
||||
int threads;
|
||||
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
||||
cudaDeviceGetAttribute(&threads, cudaDevAttrMaxThreadsPerBlock, dev);
|
||||
|
||||
TORCH_CHECK((blocks * threads > num_queries),
|
||||
"multi-step: not enough threads to map to num_queries = ",
|
||||
num_queries, " block_tables.stride(0) = ", block_tables.stride(0),
|
||||
" blocks = ", blocks, " max_threads = ", threads);
|
||||
if (logging) {
|
||||
printf("launching kernels with %d blocks and %d threads\n", blocks,
|
||||
threads);
|
||||
}
|
||||
advance_step_flashinfer_kernel<<<blocks, threads, 0, stream>>>(
|
||||
threads, num_seqs, num_queries, block_size,
|
||||
reinterpret_cast<long*>(input_tokens.data_ptr()),
|
||||
reinterpret_cast<long const*>(sampled_token_ids.data_ptr()),
|
||||
reinterpret_cast<long*>(input_positions.data_ptr()),
|
||||
reinterpret_cast<int*>(seq_lens.data_ptr()),
|
||||
reinterpret_cast<long*>(slot_mapping.data_ptr()),
|
||||
reinterpret_cast<int const*>(block_tables.data_ptr()),
|
||||
block_tables.stride(0),
|
||||
reinterpret_cast<int*>(paged_kv_last_page_len.data_ptr()),
|
||||
reinterpret_cast<int*>(block_table_bound.data_ptr()));
|
||||
|
||||
advance_step_flashinfer_indptr_kernel<<<blocks, threads, 0, stream>>>(
|
||||
threads, num_seqs, num_queries,
|
||||
reinterpret_cast<int*>(paged_kv_indptr.data_ptr()),
|
||||
reinterpret_cast<int*>(block_table_bound.data_ptr()));
|
||||
|
||||
advance_step_flashinfer_indices_kernel<<<blocks, threads, 0, stream>>>(
|
||||
num_seqs, num_queries,
|
||||
reinterpret_cast<int const*>(block_tables.data_ptr()),
|
||||
block_tables.stride(0),
|
||||
reinterpret_cast<int*>(paged_kv_indices.data_ptr()),
|
||||
reinterpret_cast<int*>(paged_kv_indptr.data_ptr()),
|
||||
reinterpret_cast<int*>(block_table_bound.data_ptr()));
|
||||
}
|
||||
|
||||
} // namespace prepare_inputs
|
||||
|
||||
void advance_step_flashattn(int64_t num_seqs, int64_t num_queries,
|
||||
int64_t block_size, torch::Tensor& input_tokens,
|
||||
torch::Tensor& sampled_token_ids,
|
||||
torch::Tensor& input_positions,
|
||||
torch::Tensor& seq_lens,
|
||||
torch::Tensor& slot_mapping,
|
||||
torch::Tensor& block_tables) {
|
||||
prepare_inputs::advance_step_flashattn(
|
||||
num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
|
||||
input_positions, seq_lens, slot_mapping, block_tables);
|
||||
}
|
||||
|
||||
void advance_step_flashinfer(
|
||||
int64_t num_seqs, int64_t num_queries, int64_t block_size,
|
||||
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
|
||||
torch::Tensor& input_positions, torch::Tensor& seq_lens,
|
||||
torch::Tensor& slot_mapping, torch::Tensor& block_tables,
|
||||
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
|
||||
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bound) {
|
||||
prepare_inputs::advance_step_flashinfer(
|
||||
num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
|
||||
input_positions, seq_lens, slot_mapping, block_tables, paged_kv_indices,
|
||||
paged_kv_indptr, paged_kv_last_page_len, block_table_bound);
|
||||
}
|
||||
@ -1,19 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace prepare_inputs {
|
||||
|
||||
static constexpr int max_threads = 256;
|
||||
static constexpr bool logging = false;
|
||||
|
||||
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
|
||||
|
||||
} // namespace prepare_inputs
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user