Compare commits
260 Commits
gemma3n-mm
...
gpu_ids2
| Author | SHA1 | Date | |
|---|---|---|---|
| ab153be252 | |||
| 5e53c89a74 | |||
| c66e38ea4c | |||
| 251595368f | |||
| 4bed167768 | |||
| b140416abf | |||
| 5b8366b61a | |||
| c7753a9809 | |||
| 4b9a9435bb | |||
| 3482fd7e4e | |||
| 77f77a951e | |||
| 1a4f35e2ea | |||
| be1e128dfb | |||
| 65393ee064 | |||
| dc221ad72d | |||
| 7571a4a7e5 | |||
| f67d986dd1 | |||
| cc876d0f29 | |||
| fdfd409f8f | |||
| ffbcc9e757 | |||
| 59389c927b | |||
| 8f2720def9 | |||
| ad6c2e1a0b | |||
| 49e8c7ea25 | |||
| 805d62ca88 | |||
| b7d9e9416f | |||
| 7c12a765aa | |||
| cd587c93ef | |||
| 332d4cb17b | |||
| bf03ff3575 | |||
| 47043eb678 | |||
| 31b96d1c64 | |||
| e59ba9e142 | |||
| 403b481573 | |||
| 138709f8d1 | |||
| 0bbac1c1b4 | |||
| a3e4e85ece | |||
| eb58f5953d | |||
| 4ac9c33f78 | |||
| efe73d0575 | |||
| 853487bc1b | |||
| 9ff2af6d2b | |||
| 70ca5484f5 | |||
| 5358cce5ff | |||
| 2155e95ef1 | |||
| f95570a52d | |||
| b6e7e3d58f | |||
| e760fcef22 | |||
| 6bbf1795b7 | |||
| 9e0ef888f0 | |||
| 97abeb1daa | |||
| 34dad19e7b | |||
| 6db31e7a27 | |||
| 977180c912 | |||
| c40784c794 | |||
| baed180aa0 | |||
| 0b407479ef | |||
| 5eaf570050 | |||
| d8ee5a2ca4 | |||
| b9fca83256 | |||
| 32dffc2772 | |||
| c438183e99 | |||
| baba0389f7 | |||
| c6c22f16d3 | |||
| dd382e0fe3 | |||
| 849590a2a7 | |||
| a4c23314c0 | |||
| b942c094e3 | |||
| b4bab81660 | |||
| b91cb3fa5c | |||
| 71d1d75b7a | |||
| 72d14d0eed | |||
| e34d130c16 | |||
| 7721ef1786 | |||
| 8369b7c2a9 | |||
| 3eb4ad53f3 | |||
| 90a2769f20 | |||
| e60d422f19 | |||
| 0d914c81a2 | |||
| 6e428cdd7a | |||
| 93b9d9f499 | |||
| af107d5a0e | |||
| 31c5d0a1b7 | |||
| afb7cff1b9 | |||
| d2e841a10a | |||
| 14601f5fba | |||
| 042d131f39 | |||
| 8e807cdfa4 | |||
| e601efcb10 | |||
| 22dd9c2730 | |||
| a6d795d593 | |||
| a37d75bbec | |||
| edd270bc78 | |||
| 110df74332 | |||
| 1ad69e8375 | |||
| b8a498c9b2 | |||
| 923147b5e8 | |||
| 45877ef740 | |||
| 6e4bef1bea | |||
| 4ff79a136e | |||
| 448acad31e | |||
| eb0b2d2f08 | |||
| 3112271f6e | |||
| 1fd471e957 | |||
| 2c5ebec064 | |||
| 2e610deb72 | |||
| 6e2c19ce22 | |||
| 47db8c2c15 | |||
| 462b269280 | |||
| c18b3b8e8b | |||
| 9528e3a05e | |||
| 9fb52e523a | |||
| e202dd2736 | |||
| 43813e6361 | |||
| cede942b87 | |||
| fe1e924811 | |||
| 4548c03c50 | |||
| 40b86aa05e | |||
| 432870829d | |||
| f73d02aadc | |||
| c5ebe040ac | |||
| 8d763cb891 | |||
| cf4cd53982 | |||
| 32c9be2200 | |||
| 8aeaa910a2 | |||
| 906e05d840 | |||
| ef9a2990ae | |||
| 7e90870491 | |||
| d3f05c9248 | |||
| c108781c85 | |||
| 3d184b95b8 | |||
| 2f35a022e6 | |||
| ffe00ef77a | |||
| 5561681d04 | |||
| fbd62d8750 | |||
| 2e26f9156a | |||
| 9e5452ee34 | |||
| 0e3fe896e2 | |||
| 1caca5a589 | |||
| 783921d889 | |||
| 4a98edff1f | |||
| a7bab0c9e5 | |||
| 25950dca9b | |||
| a4113b035c | |||
| 7e1665b089 | |||
| 8d1096e7db | |||
| 8d775dd30a | |||
| 78fe77534b | |||
| 2f2fcb31b8 | |||
| 1dba2c4ebe | |||
| 71d6de3a26 | |||
| 536fd33003 | |||
| 619b9f5c7e | |||
| d1b689c445 | |||
| 9854dc9040 | |||
| ff5c60fad8 | |||
| 6f1229f91d | |||
| 1819fbda63 | |||
| 7f0367109e | |||
| fb14d53cf6 | |||
| b024a42e93 | |||
| cb97f2bfc5 | |||
| 359200f6ac | |||
| 220aee902a | |||
| 67d25eca05 | |||
| 363528de27 | |||
| 4ff61ababa | |||
| 0ec3779df7 | |||
| b616f6a53d | |||
| 2e25bb12a8 | |||
| 9965c47d0d | |||
| 059d4cdb49 | |||
| bdb84e26b0 | |||
| 3dd359147d | |||
| 657f2f301a | |||
| a1aafc827a | |||
| 139508a418 | |||
| d265414dbc | |||
| 48fb076cbc | |||
| c1909e7e8c | |||
| b95877509b | |||
| 706ff13224 | |||
| ccbfb1d1c9 | |||
| 9e5552aa13 | |||
| 0c600b9ab6 | |||
| e303dcf523 | |||
| ae9c4d416f | |||
| d853520b3e | |||
| ba51aea65e | |||
| 8452946c06 | |||
| 2e7cbf2d7d | |||
| 7da296be04 | |||
| b205e8467d | |||
| be0cfb2b68 | |||
| 1a03dd496b | |||
| 27b8017636 | |||
| 9ec1e3065a | |||
| 9dae7d46bf | |||
| 7058d7dd5d | |||
| a0389e0554 | |||
| 3be8d312a2 | |||
| 3abfe22154 | |||
| e81fbefe8a | |||
| 9290de5667 | |||
| 7f280d69c9 | |||
| 02cabff207 | |||
| 3d19d47d91 | |||
| 8acb4badee | |||
| 314af8617c | |||
| 0e96cc9b7e | |||
| ecad851cbd | |||
| ed70f3c64f | |||
| 650d5dbd04 | |||
| 9025a9a705 | |||
| c05596f1a3 | |||
| 787b13389e | |||
| 96453cfa83 | |||
| b1c1fe35a5 | |||
| 08d81f1014 | |||
| 6cc1e7d96d | |||
| 9909726d2a | |||
| 22e9d42040 | |||
| 86debab54c | |||
| be250bbc67 | |||
| 27949354fa | |||
| bd5038af07 | |||
| a2f14dc8f9 | |||
| 92ee7baaf9 | |||
| 7151f92241 | |||
| e28533a16f | |||
| 6d42ce8315 | |||
| ded1fb635b | |||
| 97d9524fe9 | |||
| d8cf819a9a | |||
| 551ef1631a | |||
| 2863befce3 | |||
| 2965c99c86 | |||
| 2062c0723d | |||
| 1c50e100a9 | |||
| 3ee56e26be | |||
| 8fe7fc8634 | |||
| e936e401de | |||
| f5dfa07531 | |||
| 022c58b80f | |||
| 19108ef311 | |||
| 5a52f389dd | |||
| 65b1cbb138 | |||
| 6c9837a761 | |||
| 6f2f53a82d | |||
| 7b1895e6ce | |||
| 4d36693687 | |||
| daec9dea6e | |||
| daceac57c7 | |||
| 8615d9776f | |||
| 7b460c25f9 | |||
| f719772281 | |||
| d45417b804 | |||
| a29e62ea34 | |||
| e53be6f00a | |||
| c329ceca6d |
@ -46,6 +46,6 @@ while getopts "m:b:l:f:t:" OPT; do
|
||||
done
|
||||
|
||||
lm_eval --model vllm \
|
||||
--model_args "pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,distributed_executor_backend=ray,trust_remote_code=true,max_model_len=4096" \
|
||||
--model_args "pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,add_bos_token=true,trust_remote_code=true,max_model_len=4096" \
|
||||
--tasks gsm8k --num_fewshot "$FEWSHOT" --limit "$LIMIT" \
|
||||
--batch_size "$BATCH_SIZE"
|
||||
|
||||
@ -18,12 +18,14 @@ RTOL = 0.08
|
||||
|
||||
def launch_lm_eval(eval_config, tp_size):
|
||||
trust_remote_code = eval_config.get("trust_remote_code", False)
|
||||
max_model_len = eval_config.get("max_model_len", 4096)
|
||||
model_args = (
|
||||
f"pretrained={eval_config['model_name']},"
|
||||
f"tensor_parallel_size={tp_size},"
|
||||
f"enforce_eager=true,"
|
||||
f"add_bos_token=true,"
|
||||
f"trust_remote_code={trust_remote_code}"
|
||||
f"trust_remote_code={trust_remote_code},"
|
||||
f"max_model_len={max_model_len}"
|
||||
)
|
||||
results = lm_eval.simple_evaluate(
|
||||
model="vllm",
|
||||
|
||||
@ -11,7 +11,7 @@ See [vLLM performance dashboard](https://perf.vllm.ai) for the latest performanc
|
||||
|
||||
## Performance benchmark quick overview
|
||||
|
||||
**Benchmarking Coverage**: latency, throughput and fix-qps serving on A100 (the support for FP8 benchmark on H100 is coming!), with different models.
|
||||
**Benchmarking Coverage**: latency, throughput and fix-qps serving on A100 (the support for FP8 benchmark on H100 is coming!) and Intel® Xeon® Processors, with different models.
|
||||
|
||||
**Benchmarking Duration**: about 1hr.
|
||||
|
||||
@ -31,13 +31,27 @@ Performance benchmark will be triggered when:
|
||||
- A PR being merged into vllm.
|
||||
- Every commit for those PRs with `perf-benchmarks` label AND `ready` label.
|
||||
|
||||
Manually Trigger the benchmark
|
||||
|
||||
```bash
|
||||
bash .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh
|
||||
```
|
||||
|
||||
Runtime environment variables:
|
||||
- `ON_CPU`: set the value to '1' on Intel® Xeon® Processors. Default value is 0.
|
||||
- `SERVING_JSON`: JSON file to use for the serving tests. Default value is empty string (use default file).
|
||||
- `LATENCY_JSON`: JSON file to use for the latency tests. Default value is empty string (use default file).
|
||||
- `THROUGHPUT_JSON`: JSON file to use for the throughout tests. Default value is empty string (use default file).
|
||||
- `REMOTE_HOST`: IP for the remote vLLM service to benchmark. Default value is empty string.
|
||||
- `REMOTE_PORT`: Port for the remote vLLM service to benchmark. Default value is empty string.
|
||||
|
||||
Nightly benchmark will be triggered when:
|
||||
- Every commit for those PRs with `perf-benchmarks` label and `nightly-benchmarks` label.
|
||||
|
||||
## Performance benchmark details
|
||||
|
||||
See [performance-benchmarks-descriptions.md](performance-benchmarks-descriptions.md) for detailed descriptions, and use `tests/latency-tests.json`, `tests/throughput-tests.json`, `tests/serving-tests.json` to configure the test cases.
|
||||
|
||||
> NOTE: For Intel® Xeon® Processors, use `tests/latency-tests-cpu.json`, `tests/throughput-tests-cpu.json`, `tests/serving-tests-cpu.json` instead.
|
||||
### Latency test
|
||||
|
||||
Here is an example of one test inside `latency-tests.json`:
|
||||
@ -119,6 +133,30 @@ If you do not see the table, please wait till the benchmark finish running.
|
||||
The json version of the table (together with the json version of the benchmark) will be also attached to the markdown file.
|
||||
The raw benchmarking results (in the format of json files) are in the `Artifacts` tab of the benchmarking.
|
||||
|
||||
The `compare-json-results.py` helps to compare benchmark results JSON files converted using `convert-results-json-to-markdown.py`.
|
||||
When run, benchmark script generates results under `benchmark/results` folder, along with the `benchmark_results.md` and `benchmark_results.json`.
|
||||
`compare-json-results.py` compares two `benchmark_results.json` files and provides performance ratio e.g. for Output Tput, Median TTFT and Median TPOT.
|
||||
|
||||
Here is an example using the script to compare result_a and result_b without detail test name.
|
||||
`python3 compare-json-results.py -f results_a/benchmark_results.json -f results_b/benchmark_results.json --ignore_test_name`
|
||||
|
||||
| | results_a/benchmark_results.json | results_b/benchmark_results.json | perf_ratio |
|
||||
|----|----------------------------------------|----------------------------------------|----------|
|
||||
| 0 | 142.633982 | 156.526018 | 1.097396 |
|
||||
| 1 | 241.620334 | 294.018783 | 1.216863 |
|
||||
| 2 | 218.298905 | 262.664916 | 1.203235 |
|
||||
| 3 | 242.743860 | 299.816190 | 1.235113 |
|
||||
|
||||
Here is an example using the script to compare result_a and result_b with detail test name.
|
||||
`python3 compare-json-results.py -f results_a/benchmark_results.json -f results_b/benchmark_results.json`
|
||||
| | results_a/benchmark_results.json_name | results_a/benchmark_results.json | results_b/benchmark_results.json_name | results_b/benchmark_results.json | perf_ratio |
|
||||
|---|---------------------------------------------|----------------------------------------|---------------------------------------------|----------------------------------------|----------|
|
||||
| 0 | serving_llama8B_tp1_sharegpt_qps_1 | 142.633982 | serving_llama8B_tp1_sharegpt_qps_1 | 156.526018 | 1.097396 |
|
||||
| 1 | serving_llama8B_tp1_sharegpt_qps_16 | 241.620334 | serving_llama8B_tp1_sharegpt_qps_16 | 294.018783 | 1.216863 |
|
||||
| 2 | serving_llama8B_tp1_sharegpt_qps_4 | 218.298905 | serving_llama8B_tp1_sharegpt_qps_4 | 262.664916 | 1.203235 |
|
||||
| 3 | serving_llama8B_tp1_sharegpt_qps_inf | 242.743860 | serving_llama8B_tp1_sharegpt_qps_inf | 299.816190 | 1.235113 |
|
||||
| 4 | serving_llama8B_tp2_random_1024_128_qps_1 | 96.613390 | serving_llama8B_tp4_random_1024_128_qps_1 | 108.404853 | 1.122048 |
|
||||
|
||||
## Nightly test details
|
||||
|
||||
See [nightly-descriptions.md](nightly-descriptions.md) for the detailed description on test workload, models and docker containers of benchmarking other llm engines.
|
||||
|
||||
@ -4,7 +4,8 @@
|
||||
- Input length: 32 tokens.
|
||||
- Output length: 128 tokens.
|
||||
- Batch size: fixed (8).
|
||||
- Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B.
|
||||
- GPU Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B.
|
||||
- CPU Models: llama-3.1 8B.
|
||||
- Evaluation metrics: end-to-end latency (mean, median, p99).
|
||||
|
||||
{latency_tests_markdown_table}
|
||||
@ -14,7 +15,8 @@
|
||||
- Input length: randomly sample 200 prompts from ShareGPT dataset (with fixed random seed).
|
||||
- Output length: the corresponding output length of these 200 prompts.
|
||||
- Batch size: dynamically determined by vllm to achieve maximum throughput.
|
||||
- Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B.
|
||||
- GPU Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B.
|
||||
- CPU Models: llama-3.1 8B.
|
||||
- Evaluation metrics: throughput.
|
||||
|
||||
{throughput_tests_markdown_table}
|
||||
@ -25,12 +27,18 @@
|
||||
- Output length: the corresponding output length of these 200 prompts.
|
||||
- Batch size: dynamically determined by vllm and the arrival pattern of the requests.
|
||||
- **Average QPS (query per second)**: 1, 4, 16 and inf. QPS = inf means all requests come at once. For other QPS values, the arrival time of each query is determined using a random Poisson process (with fixed random seed).
|
||||
- Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B.
|
||||
- We also added a speculative decoding test for llama-3 70B, under QPS 2
|
||||
- GPU Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B.
|
||||
- We also added a speculative decoding test for llama-3 70B on GPU, under QPS 2
|
||||
- CPU Models: llama-3.1 8B.
|
||||
- Evaluation metrics: throughput, TTFT (time to the first token, with mean, median and p99), ITL (inter-token latency, with mean, median and p99).
|
||||
- For CPU, we added random dataset tests to benchmark fixed input/output length with 100 prompts.
|
||||
|
||||
{serving_tests_markdown_table}
|
||||
|
||||
## Platform Information
|
||||
|
||||
{platform_markdown_table}
|
||||
|
||||
## json version of the benchmarking tables
|
||||
|
||||
This section contains the data of the markdown tables above in JSON format.
|
||||
|
||||
@ -0,0 +1,66 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def compare_data_columns(
|
||||
files, name_column, data_column, drop_column, ignore_test_name=False
|
||||
):
|
||||
print("\ncompare_data_column: " + data_column)
|
||||
frames = []
|
||||
compare_frames = []
|
||||
for file in files:
|
||||
data_df = pd.read_json(file)
|
||||
serving_df = data_df.dropna(subset=[drop_column], ignore_index=True)
|
||||
if ignore_test_name is False:
|
||||
serving_df = serving_df.rename(columns={name_column: file + "_name"})
|
||||
frames.append(serving_df[file + "_name"])
|
||||
serving_df = serving_df.rename(columns={data_column: file})
|
||||
frames.append(serving_df[file])
|
||||
compare_frames.append(serving_df[file])
|
||||
if len(compare_frames) >= 2:
|
||||
# Compare numbers among two files
|
||||
ratio_df = compare_frames[1] / compare_frames[0]
|
||||
frames.append(ratio_df)
|
||||
compare_frames.pop(1)
|
||||
|
||||
concat_df = pd.concat(frames, axis=1)
|
||||
return concat_df
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-f", "--file", action="append", type=str, help="input file name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ignore_test_name", action="store_true", help="ignore_test_name or not"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
files = args.file
|
||||
print("comparing : " + ", ".join(files))
|
||||
|
||||
drop_column = "P99"
|
||||
name_column = "Test name"
|
||||
data_cols_to_compare = ["Output Tput (tok/s)", "Median TTFT (ms)", "Median"]
|
||||
html_msgs_for_data_cols = [
|
||||
"Compare Output Tokens /n",
|
||||
"Median TTFT /n",
|
||||
"Median TPOT /n",
|
||||
]
|
||||
ignore_test_name = args.ignore_test_name
|
||||
with open("perf_comparison.html", "w") as text_file:
|
||||
for i in range(len(data_cols_to_compare)):
|
||||
output_df = compare_data_columns(
|
||||
files,
|
||||
name_column,
|
||||
data_cols_to_compare[i],
|
||||
drop_column,
|
||||
ignore_test_name=ignore_test_name,
|
||||
)
|
||||
print(output_df)
|
||||
html = output_df.to_html()
|
||||
text_file.write(html_msgs_for_data_cols[i])
|
||||
text_file.write(html)
|
||||
@ -3,9 +3,11 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from importlib import util
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import psutil
|
||||
from tabulate import tabulate
|
||||
|
||||
results_folder = Path("results/")
|
||||
@ -29,11 +31,11 @@ throughput_results = []
|
||||
throughput_results_column_mapping = {
|
||||
"test_name": "Test name",
|
||||
"gpu_type": "GPU",
|
||||
# "num_requests": "# of req.",
|
||||
# "total_num_tokens": "Total # of tokens",
|
||||
# "elapsed_time": "Elapsed time (s)",
|
||||
"num_requests": "# of req.",
|
||||
"total_num_tokens": "Total # of tokens",
|
||||
"elapsed_time": "Elapsed time (s)",
|
||||
"requests_per_second": "Tput (req/s)",
|
||||
# "tokens_per_second": "Tput (tok/s)",
|
||||
"tokens_per_second": "Tput (tok/s)",
|
||||
}
|
||||
|
||||
# serving results and the keys that will be printed into markdown
|
||||
@ -41,16 +43,18 @@ serving_results = []
|
||||
serving_column_mapping = {
|
||||
"test_name": "Test name",
|
||||
"gpu_type": "GPU",
|
||||
# "completed": "# of req.",
|
||||
"completed": "# of req.",
|
||||
"request_throughput": "Tput (req/s)",
|
||||
# "input_throughput": "Input Tput (tok/s)",
|
||||
# "output_throughput": "Output Tput (tok/s)",
|
||||
"total_token_throughput": "Total Token Tput (tok/s)",
|
||||
"output_throughput": "Output Tput (tok/s)",
|
||||
"total_input_tokens": "Total input tokens",
|
||||
"total_output_tokens": "Total output tokens",
|
||||
"mean_ttft_ms": "Mean TTFT (ms)",
|
||||
"median_ttft_ms": "Median TTFT (ms)",
|
||||
"p99_ttft_ms": "P99 TTFT (ms)",
|
||||
# "mean_tpot_ms": "Mean TPOT (ms)",
|
||||
# "median_tpot_ms": "Median",
|
||||
# "p99_tpot_ms": "P99",
|
||||
"mean_tpot_ms": "Mean TPOT (ms)",
|
||||
"median_tpot_ms": "Median",
|
||||
"p99_tpot_ms": "P99",
|
||||
"mean_itl_ms": "Mean ITL (ms)",
|
||||
"median_itl_ms": "Median ITL (ms)",
|
||||
"p99_itl_ms": "P99 ITL (ms)",
|
||||
@ -75,6 +79,20 @@ def results_to_json(latency, throughput, serving):
|
||||
)
|
||||
|
||||
|
||||
def get_size_with_unit(bytes, suffix="B"):
|
||||
"""
|
||||
Scale bytes to its proper format
|
||||
e.g:
|
||||
1253656 => '1.20MB'
|
||||
1253656678 => '1.17GB'
|
||||
"""
|
||||
factor = 1024
|
||||
for unit in ["", "K", "M", "G", "T", "P"]:
|
||||
if bytes < factor:
|
||||
return f"{bytes:.2f}{unit}{suffix}"
|
||||
bytes /= factor
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# collect results
|
||||
for test_file in results_folder.glob("*.json"):
|
||||
@ -155,6 +173,27 @@ if __name__ == "__main__":
|
||||
serving_results = pd.DataFrame.from_dict(serving_results)
|
||||
throughput_results = pd.DataFrame.from_dict(throughput_results)
|
||||
|
||||
svmem = psutil.virtual_memory()
|
||||
platform_data = {
|
||||
"Physical cores": [psutil.cpu_count(logical=False)],
|
||||
"Total cores": [psutil.cpu_count(logical=True)],
|
||||
"Total Memory": [get_size_with_unit(svmem.total)],
|
||||
}
|
||||
|
||||
if util.find_spec("numa") is not None:
|
||||
from numa import info
|
||||
|
||||
platform_data["Total NUMA nodes"] = [info.get_num_configured_nodes()]
|
||||
|
||||
if util.find_spec("cpuinfo") is not None:
|
||||
from cpuinfo import get_cpu_info
|
||||
|
||||
platform_data["CPU Brand"] = [get_cpu_info()["brand_raw"]]
|
||||
|
||||
platform_results = pd.DataFrame.from_dict(
|
||||
platform_data, orient="index", columns=["Platform Info"]
|
||||
)
|
||||
|
||||
raw_results_json = results_to_json(
|
||||
latency_results, throughput_results, serving_results
|
||||
)
|
||||
@ -200,6 +239,9 @@ if __name__ == "__main__":
|
||||
throughput_md_table = tabulate(
|
||||
throughput_results, headers="keys", tablefmt="pipe", showindex=False
|
||||
)
|
||||
platform_md_table = tabulate(
|
||||
platform_results, headers="keys", tablefmt="pipe", showindex=True
|
||||
)
|
||||
|
||||
# document the result
|
||||
with open(results_folder / "benchmark_results.md", "w") as f:
|
||||
@ -211,6 +253,7 @@ if __name__ == "__main__":
|
||||
latency_tests_markdown_table=latency_md_table,
|
||||
throughput_tests_markdown_table=throughput_md_table,
|
||||
serving_tests_markdown_table=serving_md_table,
|
||||
platform_markdown_table=platform_md_table,
|
||||
benchmarking_results_in_json_string=processed_results_json,
|
||||
)
|
||||
f.write(results)
|
||||
|
||||
@ -31,6 +31,20 @@ check_gpus() {
|
||||
echo "GPU type is $gpu_type"
|
||||
}
|
||||
|
||||
check_cpus() {
|
||||
# check the number of CPUs and NUMA Node and GPU type.
|
||||
declare -g numa_count=$(python3 -c "from numa import info;numa_size = info.get_num_configured_nodes(); print(numa_size)")
|
||||
if [[ $numa_count -gt 0 ]]; then
|
||||
echo "NUMA found."
|
||||
echo $numa_count
|
||||
else
|
||||
echo "Need at least 1 NUMA to run benchmarking."
|
||||
exit 1
|
||||
fi
|
||||
declare -g gpu_type="cpu"
|
||||
echo "GPU type is $gpu_type"
|
||||
}
|
||||
|
||||
check_hf_token() {
|
||||
# check if HF_TOKEN is available and valid
|
||||
if [[ -z "$HF_TOKEN" ]]; then
|
||||
@ -69,6 +83,22 @@ json2args() {
|
||||
echo "$args"
|
||||
}
|
||||
|
||||
json2envs() {
|
||||
# transforms the JSON string to environment variables.
|
||||
# example:
|
||||
# input: { "VLLM_CPU_KVCACHE_SPACE": 5 }
|
||||
# output: VLLM_CPU_KVCACHE_SPACE=5
|
||||
local json_string=$1
|
||||
local args=$(
|
||||
echo "$json_string" | jq -r '
|
||||
to_entries |
|
||||
map((.key ) + "=" + (.value | tostring)) |
|
||||
join(" ")
|
||||
'
|
||||
)
|
||||
echo "$args"
|
||||
}
|
||||
|
||||
wait_for_server() {
|
||||
# wait for vllm server to start
|
||||
# return 1 if vllm server crashes
|
||||
@ -158,15 +188,24 @@ run_latency_tests() {
|
||||
# get arguments
|
||||
latency_params=$(echo "$params" | jq -r '.parameters')
|
||||
latency_args=$(json2args "$latency_params")
|
||||
latency_environment_variables=$(echo "$params" | jq -r '.environment_variables')
|
||||
latency_envs=$(json2envs "$latency_environment_variables")
|
||||
|
||||
# check if there is enough GPU to run the test
|
||||
tp=$(echo "$latency_params" | jq -r '.tensor_parallel_size')
|
||||
if [[ $gpu_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name."
|
||||
continue
|
||||
if [ "$ON_CPU" == "1" ];then
|
||||
if [[ $numa_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $numa_count NUMA nodes found. Skip testcase $test_name."
|
||||
continue
|
||||
fi
|
||||
else
|
||||
if [[ $gpu_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name."
|
||||
continue
|
||||
fi
|
||||
fi
|
||||
|
||||
latency_command="python3 benchmark_latency.py \
|
||||
latency_command=" $latency_envs python3 benchmark_latency.py \
|
||||
--output-json $RESULTS_FOLDER/${test_name}.json \
|
||||
$latency_args"
|
||||
|
||||
@ -216,15 +255,24 @@ run_throughput_tests() {
|
||||
# get arguments
|
||||
throughput_params=$(echo "$params" | jq -r '.parameters')
|
||||
throughput_args=$(json2args "$throughput_params")
|
||||
throughput_environment_variables=$(echo "$params" | jq -r '.environment_variables')
|
||||
throughput_envs=$(json2envs "$throughput_environment_variables")
|
||||
|
||||
# check if there is enough GPU to run the test
|
||||
tp=$(echo "$throughput_params" | jq -r '.tensor_parallel_size')
|
||||
if [[ $gpu_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name."
|
||||
continue
|
||||
if [ "$ON_CPU" == "1" ];then
|
||||
if [[ $numa_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $numa_count NUMA nodes found. Skip testcase $test_name."
|
||||
continue
|
||||
fi
|
||||
else
|
||||
if [[ $gpu_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name."
|
||||
continue
|
||||
fi
|
||||
fi
|
||||
|
||||
throughput_command="python3 benchmark_throughput.py \
|
||||
throughput_command=" $throughput_envs python3 benchmark_throughput.py \
|
||||
--output-json $RESULTS_FOLDER/${test_name}.json \
|
||||
$throughput_args"
|
||||
|
||||
@ -272,18 +320,27 @@ run_serving_tests() {
|
||||
|
||||
# get client and server arguments
|
||||
server_params=$(echo "$params" | jq -r '.server_parameters')
|
||||
server_envs=$(echo "$params" | jq -r '.server_environment_variables')
|
||||
client_params=$(echo "$params" | jq -r '.client_parameters')
|
||||
server_args=$(json2args "$server_params")
|
||||
server_envs=$(json2envs "$server_envs")
|
||||
client_args=$(json2args "$client_params")
|
||||
qps_list=$(echo "$params" | jq -r '.qps_list')
|
||||
qps_list=$(echo "$qps_list" | jq -r '.[] | @sh')
|
||||
echo "Running over qps list $qps_list"
|
||||
|
||||
# check if there is enough GPU to run the test
|
||||
# check if there is enough resources to run the test
|
||||
tp=$(echo "$server_params" | jq -r '.tensor_parallel_size')
|
||||
if [[ $gpu_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name."
|
||||
continue
|
||||
if [ "$ON_CPU" == "1" ];then
|
||||
if [[ $numa_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $numa_count NUMA nodes found. Skip testcase $test_name."
|
||||
continue
|
||||
fi
|
||||
else
|
||||
if [[ $gpu_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name."
|
||||
continue
|
||||
fi
|
||||
fi
|
||||
|
||||
# check if server model and client model is aligned
|
||||
@ -294,23 +351,33 @@ run_serving_tests() {
|
||||
continue
|
||||
fi
|
||||
|
||||
server_command="python3 \
|
||||
server_command="$server_envs python3 \
|
||||
-m vllm.entrypoints.openai.api_server \
|
||||
$server_args"
|
||||
|
||||
# run the server
|
||||
echo "Running test case $test_name"
|
||||
echo "Server command: $server_command"
|
||||
bash -c "$server_command" &
|
||||
server_pid=$!
|
||||
|
||||
# wait until the server is alive
|
||||
if wait_for_server; then
|
||||
echo ""
|
||||
echo "vllm server is up and running."
|
||||
# support remote vllm server
|
||||
client_remote_args=""
|
||||
if [[ -z "${REMOTE_HOST}" ]]; then
|
||||
bash -c "$server_command" &
|
||||
server_pid=$!
|
||||
# wait until the server is alive
|
||||
if wait_for_server; then
|
||||
echo ""
|
||||
echo "vLLM server is up and running."
|
||||
else
|
||||
echo ""
|
||||
echo "vLLM failed to start within the timeout period."
|
||||
fi
|
||||
else
|
||||
echo ""
|
||||
echo "vllm failed to start within the timeout period."
|
||||
server_command="Using Remote Server $REMOTE_HOST $REMOTE_PORT"
|
||||
if [[ ${REMOTE_PORT} ]]; then
|
||||
client_remote_args=" --host=$REMOTE_HOST --port=$REMOTE_PORT "
|
||||
else
|
||||
client_remote_args=" --host=$REMOTE_HOST "
|
||||
fi
|
||||
fi
|
||||
|
||||
# iterate over different QPS
|
||||
@ -332,7 +399,7 @@ run_serving_tests() {
|
||||
--result-filename ${new_test_name}.json \
|
||||
--request-rate $qps \
|
||||
--metadata "tensor_parallel_size=$tp" \
|
||||
$client_args"
|
||||
$client_args $client_remote_args "
|
||||
|
||||
echo "Running test case $test_name with qps $qps"
|
||||
echo "Client command: $client_command"
|
||||
@ -360,7 +427,14 @@ run_serving_tests() {
|
||||
}
|
||||
|
||||
main() {
|
||||
check_gpus
|
||||
local ARCH
|
||||
ARCH=''
|
||||
if [ "$ON_CPU" == "1" ];then
|
||||
check_cpus
|
||||
ARCH='-cpu'
|
||||
else
|
||||
check_gpus
|
||||
fi
|
||||
check_hf_token
|
||||
|
||||
# Set to v1 to run v1 benchmark
|
||||
@ -386,9 +460,9 @@ main() {
|
||||
QUICK_BENCHMARK_ROOT=../.buildkite/nightly-benchmarks/
|
||||
|
||||
# benchmarking
|
||||
run_serving_tests $QUICK_BENCHMARK_ROOT/tests/serving-tests.json
|
||||
run_latency_tests $QUICK_BENCHMARK_ROOT/tests/latency-tests.json
|
||||
run_throughput_tests $QUICK_BENCHMARK_ROOT/tests/throughput-tests.json
|
||||
run_serving_tests $QUICK_BENCHMARK_ROOT/tests/"${SERVING_JSON:-serving-tests$ARCH.json}"
|
||||
run_latency_tests $QUICK_BENCHMARK_ROOT/tests/"${LATENCY_JSON:-latency-tests$ARCH.json}"
|
||||
run_throughput_tests $QUICK_BENCHMARK_ROOT/tests/"${THROUGHPUT_JSON:-throughput-tests$ARCH.json}"
|
||||
|
||||
# postprocess benchmarking results
|
||||
pip install tabulate pandas
|
||||
|
||||
30
.buildkite/nightly-benchmarks/tests/latency-tests-cpu.json
Normal file
30
.buildkite/nightly-benchmarks/tests/latency-tests-cpu.json
Normal file
@ -0,0 +1,30 @@
|
||||
[
|
||||
{
|
||||
"test_name": "latency_llama8B_tp1",
|
||||
"environment_variables": {
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 1,
|
||||
"load_format": "dummy",
|
||||
"num_iters_warmup": 5,
|
||||
"num_iters": 15
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "latency_llama8B_tp4",
|
||||
"environment_variables": {
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 4,
|
||||
"load_format": "dummy",
|
||||
"num_iters_warmup": 5,
|
||||
"num_iters": 15
|
||||
}
|
||||
}
|
||||
]
|
||||
158
.buildkite/nightly-benchmarks/tests/serving-tests-cpu.json
Normal file
158
.buildkite/nightly-benchmarks/tests/serving-tests-cpu.json
Normal file
@ -0,0 +1,158 @@
|
||||
[
|
||||
{
|
||||
"test_name": "serving_llama8B_tp1_sharegpt",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_ENGINE_ITERATION_TIMEOUT_S": 120,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 1,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
"block_size": 128,
|
||||
"trust_remote_code": "",
|
||||
"disable_log_stats": "",
|
||||
"disable_log_requests": "",
|
||||
"enforce_eager": "",
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"max_concurrency": 60,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_tp2_sharegpt",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_ENGINE_ITERATION_TIMEOUT_S": 120,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 2,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
"block_size": 128,
|
||||
"trust_remote_code": "",
|
||||
"disable_log_stats": "",
|
||||
"disable_log_requests": "",
|
||||
"enforce_eager": "",
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"max_concurrency": 60,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_tp4_sharegpt",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_ENGINE_ITERATION_TIMEOUT_S": 120,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 4,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
"block_size": 128,
|
||||
"trust_remote_code": "",
|
||||
"disable_log_stats": "",
|
||||
"disable_log_requests": "",
|
||||
"enforce_eager": "",
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"max_concurrency": 60,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_tp4_random_1024_128",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_ENGINE_ITERATION_TIMEOUT_S": 120,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 4,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
"block_size": 128,
|
||||
"trust_remote_code": "",
|
||||
"enable_chunked_prefill": "",
|
||||
"disable_log_stats": "",
|
||||
"disable_log_requests": "",
|
||||
"enforce_eager": "",
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "random",
|
||||
"random-input-len": 1024,
|
||||
"random-output-len": 128,
|
||||
"ignore-eos": "",
|
||||
"max_concurrency": 100,
|
||||
"num_prompts": 100
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_pp6_random_1024_128",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_ENGINE_ITERATION_TIMEOUT_S": 120,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"pipeline_parallel_size": 6,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
"block_size": 128,
|
||||
"trust_remote_code": "",
|
||||
"enable_chunked_prefill": "",
|
||||
"disable_log_stats": "",
|
||||
"disable_log_requests": "",
|
||||
"enforce_eager": "",
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "random",
|
||||
"random-input-len": 1024,
|
||||
"random-output-len": 128,
|
||||
"ignore-eos": "",
|
||||
"max_concurrency": 100,
|
||||
"num_prompts": 100
|
||||
}
|
||||
}
|
||||
]
|
||||
@ -0,0 +1,32 @@
|
||||
[
|
||||
{
|
||||
"test_name": "throughput_llama8B_tp1",
|
||||
"environment_variables": {
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 1,
|
||||
"load_format": "dummy",
|
||||
"dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"num_prompts": 200,
|
||||
"backend": "vllm"
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "throughput_llama8B_tp4",
|
||||
"environment_variables": {
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 4,
|
||||
"load_format": "dummy",
|
||||
"dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"num_prompts": 200,
|
||||
"backend": "vllm"
|
||||
}
|
||||
}
|
||||
]
|
||||
@ -52,7 +52,7 @@ steps:
|
||||
queue: cpu_queue_postmerge
|
||||
commands:
|
||||
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ."
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ."
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT"
|
||||
|
||||
- label: "Annotate release workflow"
|
||||
@ -101,7 +101,7 @@ steps:
|
||||
queue: cpu_queue_postmerge
|
||||
commands:
|
||||
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest --progress plain --target vllm-openai -f docker/Dockerfile.cpu ."
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --build-arg VLLM_CPU_AVX512BF16=true --build-arg VLLM_CPU_AVX512VNNI=true --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest --progress plain --target vllm-openai -f docker/Dockerfile.cpu ."
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest"
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version)"
|
||||
env:
|
||||
|
||||
@ -107,10 +107,9 @@ fi
|
||||
|
||||
if [[ $commands == *" kernels/attention"* ]]; then
|
||||
commands="${commands} \
|
||||
--ignore=kernels/attention/stest_attention_selector.py \
|
||||
--ignore=kernels/attention/test_attention_selector.py \
|
||||
--ignore=kernels/attention/test_blocksparse_attention.py \
|
||||
--ignore=kernels/attention/test_encoder_decoder_attn.py \
|
||||
--ignore=kernels/attention/test_attention_selector.py \
|
||||
--ignore=kernels/attention/test_flash_attn.py \
|
||||
--ignore=kernels/attention/test_flashinfer.py \
|
||||
--ignore=kernels/attention/test_prefix_prefill.py \
|
||||
|
||||
@ -48,9 +48,16 @@ function cpu_tests() {
|
||||
# Run basic model test
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
set -e
|
||||
pytest -v -s tests/kernels/attention/test_cache.py -m cpu_model
|
||||
pytest -v -s tests/kernels/attention/test_mla_decode_cpu.py -m cpu_model
|
||||
pytest -v -s tests/models/language/generation -m cpu_model
|
||||
# Note: disable until supports V1
|
||||
# pytest -v -s tests/kernels/attention/test_cache.py -m cpu_model
|
||||
# pytest -v -s tests/kernels/attention/test_mla_decode_cpu.py -m cpu_model
|
||||
|
||||
# Note: disable Bart until supports V1
|
||||
pytest -v -s tests/models/language/generation -m cpu_model \
|
||||
--ignore=tests/models/language/generation/test_bart.py
|
||||
VLLM_CPU_SGL_KERNEL=1 pytest -v -s tests/models/language/generation -m cpu_model \
|
||||
--ignore=tests/models/language/generation/test_bart.py
|
||||
|
||||
pytest -v -s tests/models/language/pooling -m cpu_model
|
||||
pytest -v -s tests/models/multimodal/generation \
|
||||
--ignore=tests/models/multimodal/generation/test_mllama.py \
|
||||
@ -61,20 +68,14 @@ function cpu_tests() {
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
set -e
|
||||
pytest -s -v \
|
||||
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_static_setup \
|
||||
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynamic_per_token"
|
||||
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_logprobs[False-10-32-neuralmagic/Llama-3.2-1B-quantized.w8a8]"
|
||||
|
||||
# Note: disable it until supports V1
|
||||
# Run AWQ test
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
set -e
|
||||
VLLM_USE_V1=0 pytest -s -v \
|
||||
tests/quantization/test_ipex_quant.py"
|
||||
|
||||
# Run chunked-prefill and prefix-cache test
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
set -e
|
||||
pytest -s -v -k cpu_model \
|
||||
tests/basic_correctness/test_chunked_prefill.py"
|
||||
# docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
# set -e
|
||||
# VLLM_USE_V1=0 pytest -s -v \
|
||||
# tests/quantization/test_ipex_quant.py"
|
||||
|
||||
# online serving
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
@ -98,4 +99,4 @@ function cpu_tests() {
|
||||
|
||||
# All of CPU tests are expected to be finished less than 40 mins.
|
||||
export -f cpu_tests
|
||||
timeout 1h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"
|
||||
timeout 1.5h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"
|
||||
|
||||
@ -2,10 +2,34 @@
|
||||
|
||||
# This script build the CPU docker image and run the offline inference inside the container.
|
||||
# It serves a sanity check for compilation and basic model usage.
|
||||
set -ex
|
||||
set -exuo pipefail
|
||||
|
||||
# Try building the docker image
|
||||
docker build -t hpu-test-env -f docker/Dockerfile.hpu .
|
||||
cat <<EOF | docker build -t hpu-plugin-v1-test-env -f - .
|
||||
FROM 1.22-413-pt2.7.1:latest
|
||||
|
||||
COPY ./ /workspace/vllm
|
||||
|
||||
WORKDIR /workspace/vllm
|
||||
|
||||
RUN pip install -v -r requirements/hpu.txt
|
||||
RUN pip install git+https://github.com/vllm-project/vllm-gaudi.git
|
||||
|
||||
ENV no_proxy=localhost,127.0.0.1
|
||||
ENV PT_HPU_ENABLE_LAZY_COLLECTIVES=true
|
||||
|
||||
RUN VLLM_TARGET_DEVICE=hpu python3 setup.py install
|
||||
|
||||
# install development dependencies (for testing)
|
||||
RUN python3 -m pip install -e tests/vllm_test_utils
|
||||
|
||||
WORKDIR /workspace/
|
||||
|
||||
RUN git clone https://github.com/vllm-project/vllm-gaudi.git
|
||||
|
||||
RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks
|
||||
|
||||
EOF
|
||||
|
||||
# Setup cleanup
|
||||
# certain versions of HPU software stack have a bug that can
|
||||
@ -14,13 +38,21 @@ docker build -t hpu-test-env -f docker/Dockerfile.hpu .
|
||||
# functions, while other platforms only need one remove_docker_container
|
||||
# function.
|
||||
EXITCODE=1
|
||||
remove_docker_containers() { docker rm -f hpu-test || true; docker rm -f hpu-test-tp2 || true; }
|
||||
remove_docker_containers_and_exit() { remove_docker_containers; exit $EXITCODE; }
|
||||
trap remove_docker_containers_and_exit EXIT
|
||||
remove_docker_containers() { docker rm -f hpu-plugin-v1-test || true; }
|
||||
trap 'remove_docker_containers; exit $EXITCODE;' EXIT
|
||||
remove_docker_containers
|
||||
|
||||
# Run the image and launch offline inference
|
||||
docker run --runtime=habana --name=hpu-test --network=host -e HABANA_VISIBLE_DEVICES=all -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m
|
||||
docker run --runtime=habana --name=hpu-test-tp2 --network=host -e HABANA_VISIBLE_DEVICES=all -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --tensor-parallel-size 2
|
||||
echo "Running HPU plugin v1 test"
|
||||
docker run --rm --runtime=habana --name=hpu-plugin-v1-test --network=host \
|
||||
-e HABANA_VISIBLE_DEVICES=all \
|
||||
hpu-plugin-v1-test-env \
|
||||
/bin/bash "/workspace/vllm-gaudi/tests/upstream_tests/ci_tests.sh"
|
||||
|
||||
EXITCODE=$?
|
||||
if [ $EXITCODE -eq 0 ]; then
|
||||
echo "Test with basic model passed"
|
||||
else
|
||||
echo "Test with basic model FAILED with exit code: $EXITCODE" >&2
|
||||
fi
|
||||
|
||||
# The trap will handle the container removal and final exit.
|
||||
@ -11,8 +11,8 @@ container_name="xpu_${BUILDKITE_COMMIT}_$(tr -dc A-Za-z0-9 < /dev/urandom | head
|
||||
docker build -t ${image_name} -f docker/Dockerfile.xpu .
|
||||
|
||||
# Setup cleanup
|
||||
remove_docker_container() {
|
||||
docker rm -f "${container_name}" || true;
|
||||
remove_docker_container() {
|
||||
docker rm -f "${container_name}" || true;
|
||||
docker image rm -f "${image_name}" || true;
|
||||
docker system prune -f || true;
|
||||
}
|
||||
@ -26,7 +26,9 @@ docker run \
|
||||
--name "${container_name}" \
|
||||
"${image_name}" \
|
||||
sh -c '
|
||||
VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m
|
||||
VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m -tp 2
|
||||
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
|
||||
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
|
||||
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
|
||||
cd tests
|
||||
pytest -v -s v1/core
|
||||
'
|
||||
|
||||
@ -22,16 +22,6 @@ trap remove_docker_container EXIT
|
||||
# Remove the container that might not be cleaned up in the previous run.
|
||||
remove_docker_container
|
||||
|
||||
# Build docker image.
|
||||
# TODO: build the image outside the script and share the image with other
|
||||
# tpu test if building time is too long.
|
||||
DOCKER_BUILDKIT=1 docker build \
|
||||
--build-arg max_jobs=16 \
|
||||
--build-arg USE_SCCACHE=1 \
|
||||
--build-arg GIT_REPO_CHECK=0 \
|
||||
--tag vllm/vllm-tpu-bm \
|
||||
--progress plain -f docker/Dockerfile.tpu .
|
||||
|
||||
LOG_ROOT=$(mktemp -d)
|
||||
# If mktemp fails, set -e will cause the script to exit.
|
||||
echo "Results will be stored in: $LOG_ROOT"
|
||||
|
||||
14
.buildkite/scripts/tpu/quantized_v6e_1.env
Normal file
14
.buildkite/scripts/tpu/quantized_v6e_1.env
Normal file
@ -0,0 +1,14 @@
|
||||
# Environment config
|
||||
TEST_NAME=llama8bw8a8
|
||||
CONTAINER_NAME=vllm-tpu
|
||||
|
||||
# vllm config
|
||||
MODEL=RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8
|
||||
MAX_NUM_SEQS=128
|
||||
MAX_NUM_BATCHED_TOKENS=1024
|
||||
TENSOR_PARALLEL_SIZE=1
|
||||
MAX_MODEL_LEN=2048
|
||||
DOWNLOAD_DIR=/mnt/disks/persist
|
||||
EXPECTED_THROUGHPUT=10.0
|
||||
INPUT_LEN=1800
|
||||
OUTPUT_LEN=128
|
||||
@ -155,6 +155,7 @@ steps:
|
||||
- examples/offline_inference/rlhf_colocate.py
|
||||
- tests/examples/offline_inference/data_parallel.py
|
||||
- tests/v1/test_async_llm_dp.py
|
||||
- tests/v1/test_external_lb_dp.py
|
||||
- tests/v1/engine/test_engine_core_client.py
|
||||
commands:
|
||||
# test with tp=2 and external_dp=2
|
||||
@ -163,8 +164,9 @@ steps:
|
||||
# test with tp=2 and pp=2
|
||||
- PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
|
||||
# test with internal dp
|
||||
- python3 ../examples/offline_inference/data_parallel.py
|
||||
- python3 ../examples/offline_inference/data_parallel.py --enforce-eager
|
||||
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
||||
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py
|
||||
- pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp
|
||||
- pytest -v -s distributed/test_utils.py
|
||||
- pytest -v -s compile/test_basic_correctness.py
|
||||
@ -215,7 +217,7 @@ steps:
|
||||
##### 1 GPU test #####
|
||||
|
||||
- label: Regression Test # 5min
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/test_regression
|
||||
@ -225,7 +227,7 @@ steps:
|
||||
working_dir: "/vllm-workspace/tests" # optional
|
||||
|
||||
- label: Engine Test # 10min
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/engine
|
||||
@ -280,7 +282,7 @@ steps:
|
||||
- python3 offline_inference/llm_engine_example.py
|
||||
- python3 offline_inference/audio_language.py --seed 0
|
||||
- python3 offline_inference/vision_language.py --seed 0
|
||||
- python3 offline_inference/vision_language_embedding.py --seed 0
|
||||
- python3 offline_inference/vision_language_pooling.py --seed 0
|
||||
- python3 offline_inference/vision_language_multi_image.py --seed 0
|
||||
- VLLM_USE_V1=0 python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
|
||||
- python3 offline_inference/encoder_decoder.py
|
||||
@ -338,7 +340,7 @@ steps:
|
||||
parallelism: 4
|
||||
|
||||
- label: PyTorch Compilation Unit Tests
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -420,7 +422,7 @@ steps:
|
||||
- pytest -v -s kernels/mamba
|
||||
|
||||
- label: Tensorizer Test # 11min
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
soft_fail: true
|
||||
source_file_dependencies:
|
||||
- vllm/model_executor/model_loader
|
||||
@ -512,7 +514,7 @@ steps:
|
||||
##### models test #####
|
||||
|
||||
- label: Basic Models Test # 24min
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -536,6 +538,17 @@ steps:
|
||||
- pip freeze | grep -E 'torch'
|
||||
- pytest -v -s models/language -m core_model
|
||||
|
||||
- label: Language Models Test (Hybrid) # 35 min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/models/language/generation
|
||||
commands:
|
||||
# Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
|
||||
- pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8'
|
||||
- pytest -v -s models/language/generation -m hybrid_model
|
||||
|
||||
- label: Language Models Test (Extended Generation) # 1hr20min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
optional: true
|
||||
@ -545,7 +558,7 @@ steps:
|
||||
commands:
|
||||
# Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
|
||||
- pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8'
|
||||
- pytest -v -s models/language/generation -m 'not core_model'
|
||||
- pytest -v -s models/language/generation -m '(not core_model) and (not hybrid_model)'
|
||||
|
||||
- label: Language Models Test (Extended Pooling) # 36min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
@ -590,7 +603,7 @@ steps:
|
||||
- pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=0) and not core_model'
|
||||
|
||||
- label: Multi-Modal Models Test (Extended) 3
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
optional: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -671,10 +684,12 @@ steps:
|
||||
- vllm/worker/model_runner.py
|
||||
- entrypoints/llm/test_collective_rpc.py
|
||||
- tests/v1/test_async_llm_dp.py
|
||||
- tests/v1/test_external_lb_dp.py
|
||||
- tests/v1/entrypoints/openai/test_multi_api_servers.py
|
||||
- vllm/v1/engine/
|
||||
commands:
|
||||
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
||||
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py
|
||||
- DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py
|
||||
- pytest -v -s entrypoints/llm/test_collective_rpc.py
|
||||
- pytest -v -s ./compile/test_basic_correctness.py
|
||||
|
||||
2
.github/CODEOWNERS
vendored
2
.github/CODEOWNERS
vendored
@ -16,7 +16,7 @@
|
||||
/vllm/lora @jeejeelee
|
||||
/vllm/reasoning @aarnphm
|
||||
/vllm/entrypoints @aarnphm
|
||||
CMakeLists.txt @tlrmchlsmth
|
||||
CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
||||
|
||||
# Any change to the VllmConfig changes can have a large user-facing impact,
|
||||
# so spam a lot of people
|
||||
|
||||
35
.github/mergify.yml
vendored
35
.github/mergify.yml
vendored
@ -27,6 +27,22 @@ pull_request_rules:
|
||||
add:
|
||||
- ci/build
|
||||
|
||||
- name: label-deepseek
|
||||
description: Automatically apply deepseek label
|
||||
conditions:
|
||||
- or:
|
||||
- files~=^examples/.*deepseek.*\.py
|
||||
- files~=^tests/.*deepseek.*\.py
|
||||
- files~=^vllm/entrypoints/openai/tool_parsers/.*deepseek.*\.py
|
||||
- files~=^vllm/model_executor/models/.*deepseek.*\.py
|
||||
- files~=^vllm/reasoning/.*deepseek.*\.py
|
||||
- files~=^vllm/transformers_utils/.*deepseek.*\.py
|
||||
- title~=(?i)DeepSeek
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
- deepseek
|
||||
|
||||
- name: label-frontend
|
||||
description: Automatically apply frontend label
|
||||
conditions:
|
||||
@ -58,14 +74,23 @@ pull_request_rules:
|
||||
- files~=^vllm/multimodal/
|
||||
- files~=^tests/multimodal/
|
||||
- files~=^tests/models/multimodal/
|
||||
- files~=^tests/models/*/audio_language/
|
||||
- files~=^tests/models/*/vision_language/
|
||||
- files=tests/models/test_vision.py
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
- multi-modality
|
||||
|
||||
- name: label-new-model
|
||||
description: Automatically apply new-model label
|
||||
conditions:
|
||||
- and:
|
||||
- files~=^vllm/model_executor/models/
|
||||
- files=vllm/model_executor/models/registry.py
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
- new-model
|
||||
|
||||
- name: label-performance
|
||||
description: Automatically apply performance label
|
||||
conditions:
|
||||
@ -140,8 +165,14 @@ pull_request_rules:
|
||||
conditions:
|
||||
- or:
|
||||
- files~=^vllm/spec_decode/
|
||||
- files~=^vllm/v1/spec_decode/
|
||||
- files=vllm/model_executor/layers/spec_decode_base_sampler.py
|
||||
- files~=^tests/spec_decode/
|
||||
- files~=^tests/v1/spec_decode/
|
||||
- files~=^examples/.*(spec_decode|mlpspeculator|eagle|speculation).*\.py
|
||||
- files~=^vllm/model_executor/models/.*eagle.*\.py
|
||||
- files=vllm/model_executor/models/mlp_speculator.py
|
||||
- files~=^vllm/transformers_utils/configs/(eagle|medusa|mlp_speculator)\.py
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
|
||||
2
.github/workflows/lint-and-deploy.yaml
vendored
2
.github/workflows/lint-and-deploy.yaml
vendored
@ -68,7 +68,7 @@ jobs:
|
||||
export AWS_ACCESS_KEY_ID=minioadmin
|
||||
export AWS_SECRET_ACCESS_KEY=minioadmin
|
||||
sleep 30 && kubectl -n ns-vllm logs -f "$(kubectl -n ns-vllm get pods | awk '/deployment/ {print $1;exit}')" &
|
||||
helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env"
|
||||
helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set image.env[2].name=VLLM_CPU_CI_ENV --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string image.env[2].value="1" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env"
|
||||
|
||||
- name: curl test
|
||||
run: |
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@ -146,6 +146,7 @@ venv.bak/
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
docs/argparse
|
||||
docs/examples
|
||||
|
||||
# mypy
|
||||
|
||||
@ -160,10 +160,17 @@ repos:
|
||||
types: [python]
|
||||
pass_filenames: false
|
||||
additional_dependencies: [pathspec, regex]
|
||||
- id: validate-config
|
||||
name: Validate configuration has default values and that each field has a docstring
|
||||
entry: python tools/validate_config.py
|
||||
language: python
|
||||
types: [python]
|
||||
pass_filenames: true
|
||||
files: vllm/config.py|tests/test_config.py
|
||||
# Keep `suggestion` last
|
||||
- id: suggestion
|
||||
name: Suggestion
|
||||
entry: bash -c 'echo "To bypass pre-commit hooks, add --no-verify to git commit."'
|
||||
entry: bash -c 'echo "To bypass all the pre-commit hooks, add --no-verify to git commit. To skip a specific hook, prefix the commit command with SKIP=<hook-id>."'
|
||||
language: system
|
||||
verbose: true
|
||||
pass_filenames: false
|
||||
|
||||
@ -171,6 +171,15 @@ if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}")
|
||||
endif()
|
||||
|
||||
#
|
||||
# Set nvcc fatbin compression.
|
||||
#
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
list(APPEND VLLM_GPU_FLAGS "-Xfatbin" "-compress-all" "-compress-mode=size")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
||||
#
|
||||
# Use FetchContent for C++ dependencies that are compiled as part of vLLM's build process.
|
||||
@ -232,7 +241,6 @@ endif()
|
||||
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/mamba/mamba_ssm/selective_scan_fwd.cu"
|
||||
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
|
||||
"csrc/cache_kernels.cu"
|
||||
"csrc/attention/paged_attention_v1.cu"
|
||||
"csrc/attention/paged_attention_v2.cu"
|
||||
@ -259,7 +267,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
|
||||
|
||||
# Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building.
|
||||
set(CUTLASS_REVISION "v3.9.2" CACHE STRING "CUTLASS revision to use")
|
||||
set(CUTLASS_REVISION "v4.0.0" CACHE STRING "CUTLASS revision to use")
|
||||
|
||||
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
|
||||
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
|
||||
@ -393,7 +401,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
|
||||
# CUDA 12.0 or later
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_ARCHS)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu"
|
||||
@ -409,7 +417,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
|
||||
message(STATUS "Building scaled_mm_c3x_sm90 for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_ARCHS)
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c3x_sm90 as CUDA Compiler version is "
|
||||
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
|
||||
"later if you intend on running FP8 quantized models on "
|
||||
@ -420,10 +428,40 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
||||
# The cutlass_scaled_mm kernels for Geforce Blackwell SM120 (c3x, i.e. CUTLASS 3.x) require
|
||||
# CUDA 12.8 or later
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0;12.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu"
|
||||
)
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM120=1")
|
||||
# Let scaled_mm_c2x know it doesn't need to build these arches
|
||||
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
|
||||
message(STATUS "Building scaled_mm_c3x_sm120 for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c3x_sm120 as CUDA Compiler version is "
|
||||
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
|
||||
"later if you intend on running FP8 quantized models on "
|
||||
"Blackwell.")
|
||||
else()
|
||||
message(STATUS "Not building scaled_mm_c3x_120 as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
||||
# The cutlass_scaled_mm kernels for Blackwell SM100 (c3x, i.e. CUTLASS 3.x)
|
||||
# require CUDA 12.8 or later
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu"
|
||||
@ -438,7 +476,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
|
||||
message(STATUS "Building scaled_mm_c3x_sm100 for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS)
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c3x_sm100 as CUDA Compiler version is "
|
||||
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
|
||||
"later if you intend on running FP8 quantized models on "
|
||||
@ -481,7 +519,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
|
||||
# require CUDA 12.2 or later (and only work on Hopper).
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_ARCHS)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.2 AND SCALED_MM_ARCHS)
|
||||
set(SRCS "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
@ -490,7 +528,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SPARSE_SCALED_MM_C3X=1")
|
||||
message(STATUS "Building sparse_scaled_mm_c3x for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_ARCHS)
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.2 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building sparse_scaled_mm_c3x kernels as CUDA Compiler version is "
|
||||
"not >= 12.2, we recommend upgrading to CUDA 12.2 or later "
|
||||
"if you intend on running FP8 sparse quantized models on Hopper.")
|
||||
@ -502,7 +540,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
# FP4 Archs and flags
|
||||
cuda_archs_loose_intersection(FP4_ARCHS "10.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND FP4_ARCHS)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
|
||||
"csrc/quantization/fp4/nvfp4_experts_quant.cu"
|
||||
@ -523,7 +561,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
# CUTLASS MLA Archs and flags
|
||||
cuda_archs_loose_intersection(MLA_ARCHS "10.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND MLA_ARCHS)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND MLA_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/attention/mla/cutlass_mla_kernels.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
@ -562,7 +600,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"if you intend on running FP8 quantized MoE models on Hopper.")
|
||||
else()
|
||||
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
"in CUDA target architectures.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@ -574,7 +612,37 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${CUTLASS_MOE_DATA_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
endif()
|
||||
message(STATUS "Building moe_data for archs: ${CUTLASS_MOE_DATA_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
|
||||
message(STATUS "Not building moe_data as CUDA Compiler version is "
|
||||
"not >= 12.3, we recommend upgrading to CUDA 12.3 or later "
|
||||
"if you intend on running FP8 quantized MoE models on Hopper or Blackwell.")
|
||||
else()
|
||||
message(STATUS "Not building moe_data as no compatible archs found "
|
||||
"in CUDA target architectures.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
set(SRCS "csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
|
||||
message(STATUS "Building blockwise_scaled_group_mm_sm100 for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building blockwise_scaled_group_mm_sm100 kernels as CUDA Compiler version is "
|
||||
"not >= 12.8, we recommend upgrading to CUDA 12.8 or later "
|
||||
"if you intend on running FP8 quantized MoE models on Blackwell.")
|
||||
else()
|
||||
message(STATUS "Not building blockwise_scaled_group_mm_sm100 as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
#
|
||||
# Machete kernels
|
||||
@ -582,7 +650,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# The machete kernels only work on hopper and require CUDA 12.0 or later.
|
||||
# Only build Machete kernels if we are building for something compatible with sm90a
|
||||
cuda_archs_loose_intersection(MACHETE_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND MACHETE_ARCHS)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND MACHETE_ARCHS)
|
||||
#
|
||||
# For the Machete kernels we automatically generate sources for various
|
||||
# preselected input type pairs and schedules.
|
||||
@ -634,7 +702,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
message(STATUS "Building Machete kernels for archs: ${MACHETE_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0
|
||||
AND MACHETE_ARCHS)
|
||||
message(STATUS "Not building Machete kernels as CUDA Compiler version is "
|
||||
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
|
||||
|
||||
@ -701,6 +701,7 @@ class HuggingFaceDataset(BenchmarkDataset):
|
||||
self,
|
||||
dataset_path: str,
|
||||
dataset_split: str,
|
||||
no_stream: bool = False,
|
||||
dataset_subset: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
@ -708,6 +709,7 @@ class HuggingFaceDataset(BenchmarkDataset):
|
||||
|
||||
self.dataset_split = dataset_split
|
||||
self.dataset_subset = dataset_subset
|
||||
self.load_stream = not no_stream
|
||||
self.load_data()
|
||||
|
||||
def load_data(self) -> None:
|
||||
@ -716,7 +718,7 @@ class HuggingFaceDataset(BenchmarkDataset):
|
||||
self.dataset_path,
|
||||
name=self.dataset_subset,
|
||||
split=self.dataset_split,
|
||||
streaming=True,
|
||||
streaming=self.load_stream,
|
||||
)
|
||||
self.data = self.data.shuffle(seed=self.random_seed)
|
||||
|
||||
|
||||
@ -551,7 +551,7 @@ async def benchmark(
|
||||
"total_input_tokens": metrics.total_input,
|
||||
"total_output_tokens": metrics.total_output,
|
||||
"request_throughput": metrics.request_throughput,
|
||||
"request_goodput:": metrics.request_goodput if goodput_config_dict else None,
|
||||
"request_goodput": metrics.request_goodput if goodput_config_dict else None,
|
||||
"output_throughput": metrics.output_throughput,
|
||||
"total_token_throughput": metrics.total_token_throughput,
|
||||
"input_lens": [output.prompt_len for output in outputs],
|
||||
@ -825,6 +825,7 @@ def main(args: argparse.Namespace):
|
||||
dataset_subset=args.hf_subset,
|
||||
dataset_split=args.hf_split,
|
||||
random_seed=args.seed,
|
||||
no_stream=args.no_stream,
|
||||
).sample(
|
||||
num_requests=args.num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
@ -1033,6 +1034,11 @@ def create_argument_parser():
|
||||
help="Path to the sharegpt/sonnet dataset. "
|
||||
"Or the huggingface dataset ID if using HF dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-stream",
|
||||
action="store_true",
|
||||
help="Do not load the dataset in streaming mode.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-concurrency",
|
||||
type=int,
|
||||
|
||||
@ -356,6 +356,7 @@ def get_requests(args, tokenizer):
|
||||
elif args.dataset_name == "burstgpt":
|
||||
dataset_cls = BurstGPTDataset
|
||||
elif args.dataset_name == "hf":
|
||||
common_kwargs["no_stream"] = args.no_stream
|
||||
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = VisionArenaDataset
|
||||
common_kwargs["dataset_subset"] = None
|
||||
@ -610,6 +611,11 @@ def create_argument_parser():
|
||||
help="Name of the dataset to benchmark on.",
|
||||
default="sharegpt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-stream",
|
||||
action="store_true",
|
||||
help="Do not load the dataset in streaming mode.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
|
||||
141
benchmarks/kernels/bench_nvfp4_gemm.py
Normal file
141
benchmarks/kernels/bench_nvfp4_gemm.py
Normal file
@ -0,0 +1,141 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
from weight_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
if not current_platform.has_device_capability(100):
|
||||
raise RuntimeError("NVFP4 requires compute capability of 10.0 (Blackwell)")
|
||||
|
||||
|
||||
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
||||
PROVIDER_CFGS = {
|
||||
"torch-bf16": dict(enabled=True),
|
||||
"nvfp4": dict(no_a_quant=False, enabled=True),
|
||||
"nvfp4-noquant": dict(no_a_quant=True, enabled=True),
|
||||
}
|
||||
|
||||
_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]]
|
||||
|
||||
|
||||
def _quant_weight_nvfp4(b: torch.Tensor, device: str):
|
||||
# Compute global scale for weight
|
||||
b_amax = torch.abs(b).max().to(torch.float32)
|
||||
b_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
|
||||
b_fp4, scale_b_fp4 = ops.scaled_fp4_quant(b, b_global_scale)
|
||||
return b_fp4, scale_b_fp4, b_global_scale
|
||||
|
||||
|
||||
def build_nvfp4_runner(cfg, a, b, dtype, device):
|
||||
b_fp4, scale_b_fp4, b_global_scale = _quant_weight_nvfp4(b, device)
|
||||
|
||||
# Compute global scale for activation
|
||||
# NOTE: This is generally provided ahead-of-time by the model checkpoint.
|
||||
a_amax = torch.abs(a).max().to(torch.float32)
|
||||
a_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax
|
||||
|
||||
# Alpha for the GEMM operation
|
||||
alpha = 1.0 / (a_global_scale * b_global_scale)
|
||||
|
||||
if cfg["no_a_quant"]:
|
||||
# Pre-quantize activation
|
||||
a_fp4, scale_a_fp4 = ops.scaled_fp4_quant(a, a_global_scale)
|
||||
|
||||
def run():
|
||||
return ops.cutlass_scaled_fp4_mm(
|
||||
a_fp4, b_fp4, scale_a_fp4, scale_b_fp4, alpha, dtype
|
||||
)
|
||||
|
||||
return run
|
||||
|
||||
# Quantize activation on-the-fly
|
||||
def run():
|
||||
a_fp4, scale_a_fp4 = ops.scaled_fp4_quant(a, a_global_scale)
|
||||
return ops.cutlass_scaled_fp4_mm(
|
||||
a_fp4, b_fp4, scale_a_fp4, scale_b_fp4, alpha, dtype
|
||||
)
|
||||
|
||||
return run
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384],
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=_enabled,
|
||||
line_names=_enabled,
|
||||
ylabel="TFLOP/s (larger is better)",
|
||||
plot_name="BF16 vs NVFP4 GEMMs",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider, N, K):
|
||||
M = batch_size
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
a = torch.randn((M, K), device=device, dtype=dtype)
|
||||
b = torch.randn((N, K), device=device, dtype=dtype)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "torch-bf16":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: torch.nn.functional.linear(a, b), quantiles=quantiles
|
||||
)
|
||||
else:
|
||||
cfg = PROVIDER_CFGS[provider]
|
||||
run_quant = build_nvfp4_runner(cfg, a, b, dtype, device)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: run_quant(), quantiles=quantiles
|
||||
)
|
||||
|
||||
to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3)
|
||||
return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms)
|
||||
|
||||
|
||||
def prepare_shapes(args):
|
||||
out = []
|
||||
for model, tp_size in itertools.product(args.models, args.tp_sizes):
|
||||
for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
|
||||
KN[tp_dim] //= tp_size
|
||||
KN.append(model)
|
||||
out.append(KN)
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=["meta-llama/Llama-3.1-8B-Instruct"],
|
||||
choices=list(WEIGHT_SHAPES.keys()),
|
||||
)
|
||||
parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1])
|
||||
args = parser.parse_args()
|
||||
|
||||
for K, N, model in prepare_shapes(args):
|
||||
print(f"{model}, N={N} K={K}, BF16 vs NVFP4 GEMMs TFLOP/s:")
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
show_plots=True,
|
||||
save_path=f"bench_nvfp4_res_n{N}_k{K}",
|
||||
N=N,
|
||||
K=K,
|
||||
)
|
||||
|
||||
print("Benchmark finished!")
|
||||
@ -113,6 +113,7 @@ def bench_run(
|
||||
w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
per_act_token: bool,
|
||||
num_repeats: int,
|
||||
):
|
||||
for _ in range(num_repeats):
|
||||
@ -124,7 +125,8 @@ def bench_run(
|
||||
topk_ids,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
a1_scale=a_scale,
|
||||
per_act_token,
|
||||
a1_scale=None,
|
||||
)
|
||||
|
||||
def run_cutlass_from_graph(
|
||||
@ -148,7 +150,8 @@ def bench_run(
|
||||
topk_ids,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
a1_scale=a_scale,
|
||||
per_act_token,
|
||||
a1_scale=None,
|
||||
)
|
||||
|
||||
def run_triton_from_graph(
|
||||
@ -227,6 +230,7 @@ def bench_run(
|
||||
"w2_q": w2_q,
|
||||
"w1_scale": w1_scale,
|
||||
"w2_scale": w2_scale,
|
||||
"per_act_token": per_act_token,
|
||||
# cuda graph params
|
||||
"cutlass_graph": cutlass_graph,
|
||||
"triton_graph": triton_graph,
|
||||
@ -287,12 +291,13 @@ def bench_run(
|
||||
w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
per_act_token,
|
||||
num_warmup,
|
||||
)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, num_runs)", # noqa: E501
|
||||
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
|
||||
@ -234,8 +234,10 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
|
||||
|
||||
fn = lambda: ops.gptq_marlin_gemm(
|
||||
a=bt.a,
|
||||
c=None,
|
||||
b_q_weight=w_q,
|
||||
b_scales=w_s,
|
||||
global_scale=None,
|
||||
b_zeros=w_zp,
|
||||
g_idx=g_idx,
|
||||
perm=sort_indices,
|
||||
|
||||
@ -620,7 +620,7 @@ def main(args: argparse.Namespace):
|
||||
4096,
|
||||
]
|
||||
else:
|
||||
batch_sizes = [args.batch_size]
|
||||
batch_sizes = args.batch_size
|
||||
|
||||
use_deep_gemm = bool(args.use_deep_gemm)
|
||||
|
||||
@ -728,7 +728,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
parser.add_argument("--use-deep-gemm", action="store_true")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--batch-size", type=int, required=False)
|
||||
parser.add_argument("--batch-size", type=int, nargs="+", required=False)
|
||||
parser.add_argument("--tune", action="store_true")
|
||||
parser.add_argument("--trust-remote-code", action="store_true")
|
||||
parser.add_argument("--model-prefix", type=str, required=False)
|
||||
|
||||
@ -12,9 +12,8 @@ endif()
|
||||
#
|
||||
# Define environment variables for special configurations
|
||||
#
|
||||
if(DEFINED ENV{VLLM_CPU_AVX512BF16})
|
||||
set(ENABLE_AVX512BF16 ON)
|
||||
endif()
|
||||
set(ENABLE_AVX512BF16 $ENV{VLLM_CPU_AVX512BF16})
|
||||
set(ENABLE_AVX512VNNI $ENV{VLLM_CPU_AVX512VNNI})
|
||||
|
||||
include_directories("${CMAKE_SOURCE_DIR}/csrc")
|
||||
|
||||
@ -96,12 +95,30 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
|
||||
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
|
||||
list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16")
|
||||
set(ENABLE_AVX512BF16 ON)
|
||||
else()
|
||||
set(ENABLE_AVX512BF16 OFF)
|
||||
message(WARNING "Disable AVX512-BF16 ISA support, requires gcc/g++ >= 12.3")
|
||||
endif()
|
||||
else()
|
||||
set(ENABLE_AVX512BF16 OFF)
|
||||
message(WARNING "Disable AVX512-BF16 ISA support, no avx512_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512BF16=1.")
|
||||
endif()
|
||||
|
||||
find_isa(${CPUINFO} "avx512_vnni" AVX512VNNI_FOUND)
|
||||
if (AVX512VNNI_FOUND OR ENABLE_AVX512VNNI)
|
||||
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
|
||||
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
|
||||
list(APPEND CXX_COMPILE_FLAGS "-mavx512vnni")
|
||||
set(ENABLE_AVX512VNNI ON)
|
||||
else()
|
||||
set(ENABLE_AVX512VNNI OFF)
|
||||
message(WARNING "Disable AVX512-VNNI ISA support, requires gcc/g++ >= 12.3")
|
||||
endif()
|
||||
else()
|
||||
set(ENABLE_AVX512VNNI OFF)
|
||||
message(WARNING "Disable AVX512-VNNI ISA support, no avx512_vnni found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512VNNI=1.")
|
||||
endif()
|
||||
|
||||
elseif (AVX2_FOUND)
|
||||
list(APPEND CXX_COMPILE_FLAGS "-mavx2")
|
||||
@ -148,17 +165,32 @@ else()
|
||||
endif()
|
||||
|
||||
#
|
||||
# Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 platforms)
|
||||
#
|
||||
if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||
# Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 /ARM platforms)
|
||||
# Flag to enable ACL kernels for AARCH64 platforms
|
||||
if ( VLLM_BUILD_ACL STREQUAL "ON")
|
||||
set(USE_ACL ON)
|
||||
else()
|
||||
set(USE_ACL OFF)
|
||||
endif()
|
||||
|
||||
if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND)
|
||||
FetchContent_Declare(
|
||||
oneDNN
|
||||
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
|
||||
GIT_TAG v3.7.1
|
||||
GIT_TAG v3.8.1
|
||||
GIT_PROGRESS TRUE
|
||||
GIT_SHALLOW TRUE
|
||||
)
|
||||
|
||||
if(USE_ACL)
|
||||
find_library(ARM_COMPUTE_LIBRARY NAMES arm_compute PATHS $ENV{ACL_ROOT_DIR}/build/)
|
||||
if(NOT ARM_COMPUTE_LIBRARY)
|
||||
message(FATAL_ERROR "Could not find ARM Compute Library: please set ACL_ROOT_DIR")
|
||||
endif()
|
||||
set(ONEDNN_AARCH64_USE_ACL "ON")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ENV{ACL_ROOT_DIR}/build/")
|
||||
endif()
|
||||
|
||||
set(ONEDNN_LIBRARY_TYPE "STATIC")
|
||||
set(ONEDNN_BUILD_DOC "OFF")
|
||||
set(ONEDNN_BUILD_EXAMPLES "OFF")
|
||||
@ -231,11 +263,29 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||
"csrc/cpu/quant.cpp"
|
||||
"csrc/cpu/shm.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI)
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cpu/sgl-kernels/gemm.cpp"
|
||||
"csrc/cpu/sgl-kernels/gemm_int8.cpp"
|
||||
"csrc/cpu/sgl-kernels/gemm_fp8.cpp"
|
||||
"csrc/cpu/sgl-kernels/moe.cpp"
|
||||
"csrc/cpu/sgl-kernels/moe_int8.cpp"
|
||||
"csrc/cpu/sgl-kernels/moe_fp8.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
add_compile_definitions(-DCPU_CAPABILITY_AVX512)
|
||||
endif()
|
||||
elseif(POWER10_FOUND)
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cpu/quant.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
endif()
|
||||
if (ASIMD_FOUND)
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cpu/quant.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
endif()
|
||||
|
||||
message(STATUS "CPU extension source files: ${VLLM_EXT_SRC}")
|
||||
|
||||
#
|
||||
# Define extension targets
|
||||
|
||||
@ -38,7 +38,7 @@ else()
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn
|
||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||
GIT_TAG 5f3644181c7a15345ce20bfc65af117d3601b524
|
||||
GIT_TAG 1c2624e53c078854e0637ee566c72fe2107e75f4
|
||||
GIT_PROGRESS TRUE
|
||||
# Don't share the vllm-flash-attn build between build types
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
|
||||
@ -265,8 +265,8 @@ macro(set_gencode_flags_for_srcs)
|
||||
endmacro()
|
||||
|
||||
#
|
||||
# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
|
||||
# `<major>.<minor>[letter]` compute the "loose intersection" with the
|
||||
# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
|
||||
# `<major>.<minor>[letter]` compute the "loose intersection" with the
|
||||
# `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in
|
||||
# `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there
|
||||
# is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the
|
||||
@ -278,7 +278,7 @@ endmacro()
|
||||
# in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`.
|
||||
# We have special handling for x.0a, if x.0a is in `SRC_CUDA_ARCHS` and x.0 is
|
||||
# in `TGT_CUDA_ARCHS` then we should remove x.0a from `SRC_CUDA_ARCHS` and add
|
||||
# x.0a to the result (and remove x.0 from TGT_CUDA_ARCHS).
|
||||
# x.0a to the result (and remove x.0 from TGT_CUDA_ARCHS).
|
||||
# The result is stored in `OUT_CUDA_ARCHS`.
|
||||
#
|
||||
# Example:
|
||||
@ -313,21 +313,16 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
|
||||
# if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
|
||||
# remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS
|
||||
set(_CUDA_ARCHS)
|
||||
if ("9.0a" IN_LIST _SRC_CUDA_ARCHS)
|
||||
list(REMOVE_ITEM _SRC_CUDA_ARCHS "9.0a")
|
||||
if ("9.0" IN_LIST TGT_CUDA_ARCHS)
|
||||
list(REMOVE_ITEM _TGT_CUDA_ARCHS "9.0")
|
||||
set(_CUDA_ARCHS "9.0a")
|
||||
foreach(_arch ${_SRC_CUDA_ARCHS})
|
||||
if(_arch MATCHES "\\a$")
|
||||
list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}")
|
||||
string(REPLACE "a" "" _base "${_arch}")
|
||||
if ("${_base}" IN_LIST TGT_CUDA_ARCHS)
|
||||
list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_base}")
|
||||
list(APPEND _CUDA_ARCHS "${_arch}")
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if ("10.0a" IN_LIST _SRC_CUDA_ARCHS)
|
||||
list(REMOVE_ITEM _SRC_CUDA_ARCHS "10.0a")
|
||||
if ("10.0" IN_LIST TGT_CUDA_ARCHS)
|
||||
list(REMOVE_ITEM _TGT_CUDA_ARCHS "10.0")
|
||||
set(_CUDA_ARCHS "10.0a")
|
||||
endif()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
list(SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
|
||||
|
||||
@ -359,7 +354,7 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
|
||||
endforeach()
|
||||
|
||||
list(REMOVE_DUPLICATES _CUDA_ARCHS)
|
||||
|
||||
|
||||
# reapply +PTX suffix to architectures that requested PTX
|
||||
set(_FINAL_ARCHS)
|
||||
foreach(_arch ${_CUDA_ARCHS})
|
||||
@ -370,7 +365,7 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
|
||||
endif()
|
||||
endforeach()
|
||||
set(_CUDA_ARCHS ${_FINAL_ARCHS})
|
||||
|
||||
|
||||
set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
|
||||
@ -33,6 +33,8 @@ namespace vec_op {
|
||||
#endif
|
||||
|
||||
#define FORCE_INLINE __attribute__((always_inline)) inline
|
||||
// Number of elements in single ASIMD vector of given Datatype
|
||||
#define NUM_ELEMENTS_REG(vec) (sizeof(vec) / sizeof(vec[0]))
|
||||
|
||||
namespace {
|
||||
template <typename T, T... indexes, typename F>
|
||||
@ -86,8 +88,8 @@ struct FP16Vec16 : public Vec<FP16Vec16> {
|
||||
}
|
||||
|
||||
void save(void* ptr, const int elem_num) const {
|
||||
int full_blocks = elem_num / 8;
|
||||
int remainder = elem_num % 8;
|
||||
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
|
||||
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
|
||||
|
||||
if (full_blocks > 0) {
|
||||
vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]);
|
||||
@ -197,6 +199,25 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||
vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[2]), v.val[3])}) {};
|
||||
|
||||
void save(void* ptr) const { *reinterpret_cast<bfloat16x8x2_t*>(ptr) = reg; };
|
||||
void save(void* ptr, const int elem_num) const {
|
||||
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
|
||||
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
|
||||
for (int i = 0; i < full_blocks; i++)
|
||||
vst1q_bf16(
|
||||
reinterpret_cast<__bf16*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i,
|
||||
reg.val[i]);
|
||||
if (remainder > 0) {
|
||||
bfloat16x8_t temp = reg.val[full_blocks];
|
||||
bfloat16_t* base = reinterpret_cast<bfloat16_t*>(ptr) + full_blocks * 8;
|
||||
if (remainder > 0) base[0] = vgetq_lane_bf16(temp, 0);
|
||||
if (remainder > 1) base[1] = vgetq_lane_bf16(temp, 1);
|
||||
if (remainder > 2) base[2] = vgetq_lane_bf16(temp, 2);
|
||||
if (remainder > 3) base[3] = vgetq_lane_bf16(temp, 3);
|
||||
if (remainder > 4) base[4] = vgetq_lane_bf16(temp, 4);
|
||||
if (remainder > 5) base[5] = vgetq_lane_bf16(temp, 5);
|
||||
if (remainder > 6) base[6] = vgetq_lane_bf16(temp, 6);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
@ -213,6 +234,25 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
: reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {};
|
||||
|
||||
void save(void* ptr) const { *reinterpret_cast<bfloat16x8x4_t*>(ptr) = reg; };
|
||||
void save(void* ptr, const int elem_num) const {
|
||||
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
|
||||
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
|
||||
for (int i = 0; i < full_blocks; i++)
|
||||
vst1q_bf16(
|
||||
reinterpret_cast<__bf16*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i,
|
||||
reg.val[i]);
|
||||
if (remainder > 0) {
|
||||
bfloat16x8_t temp = reg.val[full_blocks];
|
||||
bfloat16_t* base = reinterpret_cast<bfloat16_t*>(ptr) + full_blocks * 8;
|
||||
base[0] = vgetq_lane_bf16(temp, 0);
|
||||
if (remainder > 1) base[1] = vgetq_lane_bf16(temp, 1);
|
||||
if (remainder > 2) base[2] = vgetq_lane_bf16(temp, 2);
|
||||
if (remainder > 3) base[3] = vgetq_lane_bf16(temp, 3);
|
||||
if (remainder > 4) base[4] = vgetq_lane_bf16(temp, 4);
|
||||
if (remainder > 5) base[5] = vgetq_lane_bf16(temp, 5);
|
||||
if (remainder > 6) base[6] = vgetq_lane_bf16(temp, 6);
|
||||
}
|
||||
};
|
||||
};
|
||||
#endif
|
||||
|
||||
@ -372,6 +412,48 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
}
|
||||
};
|
||||
|
||||
struct INT32Vec16 : public Vec<INT32Vec16> {
|
||||
constexpr static int VEC_ELEM_NUM = 16;
|
||||
union AliasReg {
|
||||
int32x4x4_t reg;
|
||||
int32_t values[VEC_ELEM_NUM];
|
||||
};
|
||||
int32x4x4_t reg;
|
||||
|
||||
explicit INT32Vec16(const void* ptr) {
|
||||
reg.val[0] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr));
|
||||
reg.val[1] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr) + 4);
|
||||
reg.val[2] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr) + 8);
|
||||
reg.val[3] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr) + 12);
|
||||
}
|
||||
|
||||
void save(int32_t* ptr) const {
|
||||
vst1q_s32(ptr, reg.val[0]);
|
||||
vst1q_s32(ptr + 4, reg.val[1]);
|
||||
vst1q_s32(ptr + 8, reg.val[2]);
|
||||
vst1q_s32(ptr + 12, reg.val[3]);
|
||||
};
|
||||
|
||||
void save(int32_t* ptr, const int elem_num) const {
|
||||
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
|
||||
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
|
||||
|
||||
for (int i = 0; i < full_blocks; i++)
|
||||
vst1q_s32(
|
||||
reinterpret_cast<__int32_t*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i,
|
||||
reg.val[i]);
|
||||
|
||||
if (remainder > 0) {
|
||||
int32x4_t temp = reg.val[full_blocks];
|
||||
int32_t* base = reinterpret_cast<int32_t*>(ptr) + full_blocks * 4;
|
||||
if (remainder > 0) base[0] = vgetq_lane_s32(temp, 0);
|
||||
if (remainder > 1) base[1] = vgetq_lane_s32(temp, 1);
|
||||
if (remainder > 2) base[2] = vgetq_lane_s32(temp, 2);
|
||||
if (remainder > 3) base[3] = vgetq_lane_s32(temp, 3);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
constexpr static int VEC_ELEM_NUM = 16;
|
||||
union AliasReg {
|
||||
@ -434,7 +516,12 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
reg.val[2] = vcvt_f32_f16(vget_low_f16(v.reg.val[1]));
|
||||
reg.val[3] = vcvt_f32_f16(vget_high_f16(v.reg.val[1]));
|
||||
};
|
||||
|
||||
explicit FP32Vec16(const INT32Vec16& v) {
|
||||
reg.val[0] = vcvtq_f32_s32(v.reg.val[0]);
|
||||
reg.val[1] = vcvtq_f32_s32(v.reg.val[1]);
|
||||
reg.val[2] = vcvtq_f32_s32(v.reg.val[2]);
|
||||
reg.val[3] = vcvtq_f32_s32(v.reg.val[3]);
|
||||
};
|
||||
FP32Vec16 operator+(const FP32Vec16& b) const {
|
||||
return FP32Vec16(float32x4x4_t({vaddq_f32(reg.val[0], b.reg.val[0]),
|
||||
vaddq_f32(reg.val[1], b.reg.val[1]),
|
||||
@ -463,6 +550,85 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
vdivq_f32(reg.val[3], b.reg.val[3])}));
|
||||
};
|
||||
|
||||
FP32Vec16 clamp(const FP32Vec16& min, const FP32Vec16& max) const {
|
||||
return FP32Vec16(float32x4x4_t(
|
||||
{vminq_f32(max.reg.val[0], vmaxq_f32(min.reg.val[0], reg.val[0])),
|
||||
vminq_f32(max.reg.val[1], vmaxq_f32(min.reg.val[1], reg.val[1])),
|
||||
vminq_f32(max.reg.val[2], vmaxq_f32(min.reg.val[2], reg.val[2])),
|
||||
vminq_f32(max.reg.val[3], vmaxq_f32(min.reg.val[3], reg.val[3]))}));
|
||||
};
|
||||
|
||||
FP32Vec16 max(const FP32Vec16& b) const {
|
||||
return FP32Vec16(float32x4x4_t({vmaxq_f32(b.reg.val[0], reg.val[0]),
|
||||
vmaxq_f32(b.reg.val[1], reg.val[1]),
|
||||
vmaxq_f32(b.reg.val[2], reg.val[2]),
|
||||
vmaxq_f32(b.reg.val[3], reg.val[3])}));
|
||||
};
|
||||
|
||||
FP32Vec16 max(const FP32Vec16& b, const int elem_num) const {
|
||||
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
|
||||
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
|
||||
float32x4x4_t temp;
|
||||
|
||||
for (int i = 0; i < full_blocks; i++)
|
||||
temp.val[i] = vmaxq_f32(b.reg.val[i], reg.val[i]);
|
||||
|
||||
if (remainder > 0) {
|
||||
float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 0),
|
||||
vgetq_lane_f32(b.reg.val[full_blocks], 0));
|
||||
temp.val[full_blocks] = vsetq_lane_f32(max_v, temp.val[full_blocks], 0);
|
||||
}
|
||||
if (remainder > 1) {
|
||||
float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 1),
|
||||
vgetq_lane_f32(b.reg.val[full_blocks], 1));
|
||||
temp.val[full_blocks] = vsetq_lane_f32(max_v, temp.val[full_blocks], 1);
|
||||
}
|
||||
if (remainder > 2) {
|
||||
float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 2),
|
||||
vgetq_lane_f32(b.reg.val[full_blocks], 2));
|
||||
temp.val[full_blocks] = vsetq_lane_f32(max_v, temp.val[full_blocks], 2);
|
||||
}
|
||||
return FP32Vec16(temp);
|
||||
};
|
||||
|
||||
FP32Vec16 min(const FP32Vec16& b) const {
|
||||
return FP32Vec16(float32x4x4_t({
|
||||
vminq_f32(b.reg.val[0], reg.val[0]),
|
||||
vminq_f32(b.reg.val[1], reg.val[1]),
|
||||
vminq_f32(b.reg.val[2], reg.val[2]),
|
||||
vminq_f32(b.reg.val[3], reg.val[3]),
|
||||
}));
|
||||
};
|
||||
FP32Vec16 min(const FP32Vec16& b, const int elem_num) const {
|
||||
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
|
||||
const int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
|
||||
float32x4x4_t temp;
|
||||
for (int i = 0; i < full_blocks; i++)
|
||||
temp.val[i] = vminq_f32(b.reg.val[i], reg.val[i]);
|
||||
|
||||
if (remainder > 0) {
|
||||
float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 0),
|
||||
vgetq_lane_f32(b.reg.val[full_blocks], 0));
|
||||
temp.val[full_blocks] = vsetq_lane_f32(min_v, temp.val[full_blocks], 0);
|
||||
}
|
||||
if (remainder > 1) {
|
||||
float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 1),
|
||||
vgetq_lane_f32(b.reg.val[full_blocks], 1));
|
||||
temp.val[full_blocks] = vsetq_lane_f32(min_v, temp.val[full_blocks], 1);
|
||||
}
|
||||
if (remainder > 2) {
|
||||
float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 2),
|
||||
vgetq_lane_f32(b.reg.val[full_blocks], 2));
|
||||
temp.val[full_blocks] = vsetq_lane_f32(min_v, temp.val[full_blocks], 2);
|
||||
}
|
||||
|
||||
return FP32Vec16(temp);
|
||||
};
|
||||
FP32Vec16 abs() const {
|
||||
return FP32Vec16(
|
||||
float32x4x4_t({vabsq_f32(reg.val[0]), vabsq_f32(reg.val[1]),
|
||||
vabsq_f32(reg.val[2]), vabsq_f32(reg.val[3])}));
|
||||
}
|
||||
float reduce_sum() const {
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
@ -473,6 +639,24 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
return answer;
|
||||
};
|
||||
|
||||
float reduce_max() const {
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
float max_v = std::numeric_limits<float>::lowest();
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&max_v, &ar](int i) { max_v = std::max(max_v, ar.values[i]); });
|
||||
return max_v;
|
||||
}
|
||||
|
||||
float reduce_min() const {
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
float min_v = std::numeric_limits<float>::max();
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&min_v, &ar](int i) { min_v = std::min(min_v, ar.values[i]); });
|
||||
return min_v;
|
||||
}
|
||||
|
||||
template <int group_size>
|
||||
float reduce_sub_sum(int idx) {
|
||||
static_assert(VEC_ELEM_NUM % group_size == 0);
|
||||
@ -493,6 +677,83 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
vst1q_f32(ptr + 8, reg.val[2]);
|
||||
vst1q_f32(ptr + 12, reg.val[3]);
|
||||
};
|
||||
|
||||
void save(float* ptr, const int elem_num) const {
|
||||
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
|
||||
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
|
||||
|
||||
for (int i = 0; i < full_blocks; i++)
|
||||
vst1q_f32(
|
||||
reinterpret_cast<float32_t*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i,
|
||||
reg.val[i]);
|
||||
|
||||
if (remainder > 0) {
|
||||
float32x4_t temp = reg.val[full_blocks];
|
||||
float* base = reinterpret_cast<float32_t*>(ptr) +
|
||||
full_blocks * NUM_ELEMENTS_REG(reg.val[0]);
|
||||
if (remainder > 0) base[0] = vgetq_lane_f32(temp, 0);
|
||||
if (remainder > 1) base[1] = vgetq_lane_f32(temp, 1);
|
||||
if (remainder > 2) base[2] = vgetq_lane_f32(temp, 2);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct INT8Vec16 : public Vec<INT8Vec16> {
|
||||
constexpr static int VEC_ELEM_NUM = 16;
|
||||
union AliasReg {
|
||||
int8x16_t reg;
|
||||
int8_t values[VEC_ELEM_NUM];
|
||||
};
|
||||
int8x16_t reg;
|
||||
|
||||
explicit INT8Vec16(const FP32Vec16& vec) {
|
||||
// Convert each 128-bit float32 vector to int32
|
||||
int32x4_t part0 =
|
||||
vcvtq_s32_f32(vec.reg.val[0]); // Convert first 128-bit block
|
||||
int32x4_t part1 =
|
||||
vcvtq_s32_f32(vec.reg.val[1]); // Convert second 128-bit block
|
||||
int32x4_t part2 =
|
||||
vcvtq_s32_f32(vec.reg.val[2]); // Convert third 128-bit block
|
||||
int32x4_t part3 =
|
||||
vcvtq_s32_f32(vec.reg.val[3]); // Convert fourth 128-bit block
|
||||
|
||||
// Narrow each 32-bit vector to 8 bits and combine
|
||||
int8x8_t lower =
|
||||
vqmovn_s16(vcombine_s16(vqmovn_s32(part0), vqmovn_s32(part1)));
|
||||
int8x8_t upper =
|
||||
vqmovn_s16(vcombine_s16(vqmovn_s32(part2), vqmovn_s32(part3)));
|
||||
reg = vcombine_s8(lower, upper); // Combine to form a single 128-bit vector
|
||||
}
|
||||
|
||||
void save(int8_t* ptr) const { vst1q_s8(ptr, reg); };
|
||||
|
||||
void save(int8_t* ptr, const int elem_num) const {
|
||||
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg);
|
||||
int remainder = elem_num % NUM_ELEMENTS_REG(reg);
|
||||
|
||||
for (int i = 0; i < full_blocks; i++)
|
||||
vst1q_s8(reinterpret_cast<int8_t*>(ptr) + NUM_ELEMENTS_REG(reg) * i, reg);
|
||||
if (remainder > 0) {
|
||||
int8x16_t temp = reg;
|
||||
int8_t* base =
|
||||
reinterpret_cast<int8_t*>(ptr) + full_blocks * NUM_ELEMENTS_REG(reg);
|
||||
if (remainder > 0) base[0] = vgetq_lane_s8(temp, 0);
|
||||
if (remainder > 1) base[1] = vgetq_lane_s8(temp, 1);
|
||||
if (remainder > 2) base[2] = vgetq_lane_s8(temp, 2);
|
||||
if (remainder > 3) base[3] = vgetq_lane_s8(temp, 3);
|
||||
if (remainder > 4) base[4] = vgetq_lane_s8(temp, 4);
|
||||
if (remainder > 5) base[5] = vgetq_lane_s8(temp, 5);
|
||||
if (remainder > 6) base[6] = vgetq_lane_s8(temp, 6);
|
||||
if (remainder > 7) base[7] = vgetq_lane_s8(temp, 7);
|
||||
if (remainder > 8) base[8] = vgetq_lane_s8(temp, 8);
|
||||
if (remainder > 9) base[9] = vgetq_lane_s8(temp, 9);
|
||||
if (remainder > 10) base[10] = vgetq_lane_s8(temp, 10);
|
||||
if (remainder > 11) base[11] = vgetq_lane_s8(temp, 11);
|
||||
if (remainder > 12) base[12] = vgetq_lane_s8(temp, 12);
|
||||
if (remainder > 13) base[13] = vgetq_lane_s8(temp, 13);
|
||||
if (remainder > 14) base[14] = vgetq_lane_s8(temp, 14);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
||||
@ -57,6 +57,7 @@ class DNNLPrimitiveHelper {
|
||||
// Note: Due to the limitation of oneDNN
|
||||
// (https://github.com/oneapi-src/oneDNN/issues/1636), the quantized bias is
|
||||
// not supported.
|
||||
|
||||
template <typename OutputT, typename BiasT>
|
||||
static void gemm_s8s8_jit(const int8_t* a, const int8_t* b, OutputT* c,
|
||||
const BiasT* bias, dnnl_dim_t M, dnnl_dim_t N,
|
||||
@ -90,6 +91,27 @@ class DNNLPrimitiveHelper {
|
||||
}
|
||||
|
||||
dnnl::matmul::primitive_desc matmul_pd;
|
||||
// Create memory descriptors with format_tag::any for the primitive. This
|
||||
// enables the matmul primitive to choose memory layouts for an
|
||||
// optimized primitive implementation, and these layouts may differ from the
|
||||
// ones provided by the user.
|
||||
#ifdef __aarch64__
|
||||
auto mat_src_md = dnnl::memory::desc({M, K}, dnnl::memory::data_type::s8,
|
||||
dnnl::memory::format_tag::any);
|
||||
auto mat_weights_md = dnnl::memory::desc(
|
||||
{K, N}, dnnl::memory::data_type::s8, dnnl::memory::format_tag::any);
|
||||
auto mat_dst_md =
|
||||
dnnl::memory::desc({M, N}, OutputType, dnnl::memory::format_tag::any);
|
||||
if (bias) {
|
||||
dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1});
|
||||
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), mat_src_md,
|
||||
mat_weights_md, bias_md,
|
||||
mat_dst_md, attr);
|
||||
} else {
|
||||
matmul_pd = dnnl::matmul::primitive_desc(
|
||||
default_engine(), mat_src_md, mat_weights_md, mat_dst_md, attr);
|
||||
}
|
||||
#else
|
||||
if (bias) {
|
||||
dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1});
|
||||
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
|
||||
@ -98,6 +120,7 @@ class DNNLPrimitiveHelper {
|
||||
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
|
||||
c_md, attr);
|
||||
}
|
||||
#endif
|
||||
dnnl::matmul matmul(matmul_pd);
|
||||
|
||||
auto& engine = default_engine();
|
||||
@ -111,24 +134,34 @@ class DNNLPrimitiveHelper {
|
||||
(void*)b_scales);
|
||||
|
||||
auto& stream = default_stream();
|
||||
|
||||
auto mat_src_mem = a_m;
|
||||
auto mat_weights_mem = b_m;
|
||||
auto mat_dst_mem = c_m;
|
||||
#ifdef __aarch64__
|
||||
if (matmul_pd.weights_desc() != b_m.get_desc()) {
|
||||
mat_weights_mem = dnnl::memory(matmul_pd.weights_desc(), engine);
|
||||
dnnl::reorder(b_m, mat_weights_mem).execute(stream, b_m, mat_weights_mem);
|
||||
}
|
||||
#endif
|
||||
if constexpr (InputNoScale) {
|
||||
if (bias) {
|
||||
dnnl::memory::desc bias_md({N}, BiasType, {1});
|
||||
dnnl::memory bias_m(bias_md, engine, (void*)bias);
|
||||
matmul.execute(
|
||||
stream, {
|
||||
{DNNL_ARG_SRC, a_m},
|
||||
{DNNL_ARG_WEIGHTS, b_m},
|
||||
{DNNL_ARG_SRC, mat_src_mem},
|
||||
{DNNL_ARG_WEIGHTS, mat_weights_mem},
|
||||
{DNNL_ARG_BIAS, bias_m},
|
||||
{DNNL_ARG_DST, c_m},
|
||||
{DNNL_ARG_DST, mat_dst_mem},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||
});
|
||||
} else {
|
||||
matmul.execute(
|
||||
stream, {
|
||||
{DNNL_ARG_SRC, a_m},
|
||||
{DNNL_ARG_WEIGHTS, b_m},
|
||||
{DNNL_ARG_DST, c_m},
|
||||
{DNNL_ARG_SRC, mat_src_mem},
|
||||
{DNNL_ARG_WEIGHTS, mat_weights_mem},
|
||||
{DNNL_ARG_DST, mat_dst_mem},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||
});
|
||||
}
|
||||
@ -138,19 +171,19 @@ class DNNLPrimitiveHelper {
|
||||
dnnl::memory bias_m(bias_md, engine, (void*)bias);
|
||||
matmul.execute(
|
||||
stream, {
|
||||
{DNNL_ARG_SRC, a_m},
|
||||
{DNNL_ARG_WEIGHTS, b_m},
|
||||
{DNNL_ARG_SRC, mat_src_mem},
|
||||
{DNNL_ARG_WEIGHTS, mat_weights_mem},
|
||||
{DNNL_ARG_BIAS, bias_m},
|
||||
{DNNL_ARG_DST, c_m},
|
||||
{DNNL_ARG_DST, mat_dst_mem},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||
});
|
||||
} else {
|
||||
matmul.execute(
|
||||
stream, {
|
||||
{DNNL_ARG_SRC, a_m},
|
||||
{DNNL_ARG_WEIGHTS, b_m},
|
||||
{DNNL_ARG_DST, c_m},
|
||||
{DNNL_ARG_SRC, mat_src_mem},
|
||||
{DNNL_ARG_WEIGHTS, mat_weights_mem},
|
||||
{DNNL_ARG_DST, mat_dst_mem},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||
});
|
||||
@ -170,5 +203,4 @@ class DNNLPrimitiveHelper {
|
||||
return stream;
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
@ -36,7 +36,7 @@ struct KernelVecType<c10::Half> {
|
||||
using cvt_vec_type = vec_op::FP32Vec16;
|
||||
};
|
||||
|
||||
#ifdef __AVX512F__
|
||||
#if defined(__AVX512F__) || defined(__aarch64__)
|
||||
template <bool AZP, typename scalar_t>
|
||||
void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
const float* scale, const int32_t* azp,
|
||||
@ -598,8 +598,9 @@ void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
const float* scale, const int32_t* azp,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
TORCH_CHECK(
|
||||
false, "static_scaled_int8_quant_impl requires AVX512/powerpc64 support.")
|
||||
TORCH_CHECK(false,
|
||||
"static_scaled_int8_quant_impl requires AVX512/powerpc64/AArch64 "
|
||||
"support.")
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
@ -607,9 +608,9 @@ void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
float* scale, int32_t* azp,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"dynamic_scaled_int8_quant_impl requires AVX512/powerpc64 support.")
|
||||
TORCH_CHECK(false,
|
||||
"dynamic_scaled_int8_quant_impl requires "
|
||||
"AVX512/powerpc64/AArch64 support.")
|
||||
}
|
||||
|
||||
template <bool PerChannel, typename scalar_t>
|
||||
@ -617,7 +618,8 @@ void static_quant_epilogue(const float* input, scalar_t* output,
|
||||
const float a_scale, const float* b_scale,
|
||||
const int32_t* azp_with_adj, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
TORCH_CHECK(false, "static_quant_epilogue requires AVX512/powerpc64 support.")
|
||||
TORCH_CHECK(
|
||||
false, "static_quant_epilogue requires AVX512/powerpc64/AArch64 support.")
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
@ -626,8 +628,9 @@ void dynamic_quant_epilogue(const float* input, scalar_t* output,
|
||||
const int32_t* azp, const int32_t* azp_with_adj,
|
||||
const scalar_t* bias, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
TORCH_CHECK(false,
|
||||
"dynamic_quant_epilogue requires AVX512/powerpc64 support.")
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"dynamic_quant_epilogue requires AVX512/powerpc64/AArch64 support.")
|
||||
}
|
||||
#endif
|
||||
} // namespace
|
||||
|
||||
238
csrc/cpu/sgl-kernels/common.h
Normal file
238
csrc/cpu/sgl-kernels/common.h
Normal file
@ -0,0 +1,238 @@
|
||||
// Adapted from
|
||||
// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/record_function.h>
|
||||
|
||||
// clang-format off
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
|
||||
// dispatch bool
|
||||
#define AT_DISPATCH_BOOL(BOOL_V, BOOL_NAME, ...) \
|
||||
[&] { \
|
||||
if (BOOL_V) { \
|
||||
constexpr bool BOOL_NAME = true; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
constexpr bool BOOL_NAME = false; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
||||
|
||||
// dispatch: bfloat16, float16, int8_t, fp8_e4m3
|
||||
#define CPU_DISPATCH_PACKED_TYPES(TYPE, ...) \
|
||||
[&] { \
|
||||
switch (TYPE) { \
|
||||
case at::ScalarType::BFloat16 : { \
|
||||
using packed_t = at::BFloat16; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
case at::ScalarType::Half: { \
|
||||
using packed_t = at::Half; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
case at::ScalarType::Char : { \
|
||||
using packed_t = int8_t; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
case at::ScalarType::Float8_e4m3fn : { \
|
||||
using packed_t = at::Float8_e4m3fn; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported floating data type.\n"); \
|
||||
} \
|
||||
}()
|
||||
|
||||
#define UNUSED(x) (void)(x)
|
||||
|
||||
#define CHECK_CPU(x) TORCH_CHECK(x.device().type() == at::kCPU, #x " must be a CPU tensor")
|
||||
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_LAST_DIM_CONTIGUOUS(x) \
|
||||
TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimention")
|
||||
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CPU(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \
|
||||
CHECK_CPU(x); \
|
||||
CHECK_LAST_DIM_CONTIGUOUS(x)
|
||||
|
||||
#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
|
||||
|
||||
#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
|
||||
|
||||
// parallel routines
|
||||
constexpr int GRAIN_SIZE = 1024;
|
||||
|
||||
template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
|
||||
inline T div_up(T x, T y) { return (x + y - 1) / y; }
|
||||
|
||||
template <typename T>
|
||||
inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) {
|
||||
#if 0
|
||||
// onednn partition pattern
|
||||
T& n_my = n_end;
|
||||
if (nth <= 1 || n == 0) {
|
||||
n_start = 0;
|
||||
n_my = n;
|
||||
} else {
|
||||
T n1 = div_up(n, nth);
|
||||
T n2 = n1 - 1;
|
||||
T T1 = n - n2 * nth;
|
||||
n_my = ith < T1 ? n1 : n2;
|
||||
n_start = ith <= T1 ? ith*n1 : T1 * n1 + (ith - T1) * n2;
|
||||
}
|
||||
n_end += n_start;
|
||||
#else
|
||||
// pytorch aten partition pattern
|
||||
T n_my = div_up(n, nth);
|
||||
n_start = ith * n_my;
|
||||
n_end = std::min(n_start + n_my, n);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename func_t>
|
||||
inline void parallel_for(int n, const func_t& f) {
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp parallel
|
||||
{
|
||||
int nth = omp_get_num_threads();
|
||||
int ith = omp_get_thread_num();
|
||||
int tbegin, tend;
|
||||
balance211(n, nth, ith, tbegin, tend);
|
||||
f(tbegin, tend);
|
||||
}
|
||||
#else
|
||||
f(0, n);
|
||||
#endif
|
||||
}
|
||||
|
||||
// for 1d parallel, use `actual_nth`
|
||||
// for 2d parallel, use even nths, e.g. 43->42
|
||||
int inline adjust_num_threads(int m) {
|
||||
int actual_nth = at::get_num_threads();
|
||||
if (m == 1) {
|
||||
return actual_nth;
|
||||
}
|
||||
return std::max(1, (actual_nth >> 1) * 2);
|
||||
}
|
||||
|
||||
template <typename func_t>
|
||||
inline void parallel_2d(int m, int n, const func_t& f) {
|
||||
|
||||
// make sure we have even num_threads
|
||||
int nth = adjust_num_threads(m);
|
||||
|
||||
// [NOTE] thread blocking:
|
||||
//
|
||||
// 1) prefer square block per thread
|
||||
// 2) use even number of CPU cores
|
||||
// 3) use all `num_threads` cores
|
||||
//
|
||||
// we have:
|
||||
// TM * TN = T
|
||||
// BM / TM = BN / TN
|
||||
// then:
|
||||
// TM = ((BM / BN) * T) ^ 0.5
|
||||
//
|
||||
float r = float(m) / n;
|
||||
int nth_m = std::ceil(std::sqrt(r * nth));
|
||||
int nth_n = 1;
|
||||
for (; nth_m > 0; --nth_m) {
|
||||
nth_n = nth / nth_m;
|
||||
if (nth_m * nth_n == nth) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp parallel num_threads(nth)
|
||||
{
|
||||
int ith = omp_get_thread_num();
|
||||
int ith_m = ith / nth_n;
|
||||
int ith_n = ith % nth_n;
|
||||
|
||||
int thread_block_m = div_up(m, nth_m);
|
||||
int thread_block_n = div_up(n, nth_n);
|
||||
|
||||
int begin_m = ith_m * thread_block_m;
|
||||
int end_m = std::min(m, begin_m + thread_block_m);
|
||||
int begin_n = ith_n * thread_block_n;
|
||||
int end_n = std::min(n, begin_n + thread_block_n);
|
||||
|
||||
f(begin_m, end_m, begin_n, end_n);
|
||||
}
|
||||
#else
|
||||
f(0, m, 0, n);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int get_cache_blocks(int BLOCK_SIZE, int K) {
|
||||
// L2 2MB and ratio of 50%
|
||||
const int L2_size = 2048 * 1024 >> 1;
|
||||
return std::max(1, int(L2_size / (BLOCK_SIZE * K * sizeof(T))));
|
||||
}
|
||||
|
||||
// data indexing for dimension collapse
|
||||
template <typename T>
|
||||
inline T data_index_init(T offset) {
|
||||
return offset;
|
||||
}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
inline T data_index_init(T offset, T& x, const T& X, Args&&... args) {
|
||||
offset = data_index_init(offset, std::forward<Args>(args)...);
|
||||
x = offset % X;
|
||||
return offset / X;
|
||||
}
|
||||
|
||||
inline bool data_index_step() {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
inline bool data_index_step(T& x, const T& X, Args&&... args) {
|
||||
if (data_index_step(std::forward<Args>(args)...)) {
|
||||
x = ((x + 1) == X) ? 0 : (x + 1);
|
||||
return x == 0;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// forced unroll for perf critical path
|
||||
|
||||
#if __has_attribute(always_inline)
|
||||
#define ALWAYS_INLINE __attribute__((__always_inline__)) inline
|
||||
#else
|
||||
#define ALWAYS_INLINE inline
|
||||
#endif
|
||||
|
||||
template <int n>
|
||||
struct Unroll {
|
||||
template <typename Func, typename... Args>
|
||||
ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
|
||||
Unroll<n - 1>{}(f, args...);
|
||||
f(std::integral_constant<int, n - 1>{}, args...);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Unroll<1> {
|
||||
template <typename Func, typename... Args>
|
||||
ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
|
||||
f(std::integral_constant<int, 0>{}, args...);
|
||||
}
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
464
csrc/cpu/sgl-kernels/gemm.cpp
Normal file
464
csrc/cpu/sgl-kernels/gemm.cpp
Normal file
@ -0,0 +1,464 @@
|
||||
// Adapted from
|
||||
// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu
|
||||
|
||||
#include "common.h"
|
||||
#include "vec.h"
|
||||
#include "gemm.h"
|
||||
|
||||
// clang-format off
|
||||
|
||||
namespace {
|
||||
|
||||
// packed layout:
|
||||
// quants {N, K} int8_t
|
||||
// comp {N} int32_t
|
||||
template <int BLOCK_N>
|
||||
inline void s8s8_compensation(int8_t* __restrict__ packed, int K) {
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
__m512i vcomp[COLS];
|
||||
|
||||
for (int col = 0; col < COLS; ++col) {
|
||||
vcomp[col] = _mm512_setzero_si512();
|
||||
}
|
||||
|
||||
const int64_t offset = BLOCK_N * K;
|
||||
const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80));
|
||||
for (int k = 0; k < K / 4; ++k) {
|
||||
for (int col = 0; col < COLS; ++col) {
|
||||
__m512i vb = _mm512_loadu_si512((const __m512i *)(packed + k * BLOCK_N * 4 + col * 64));
|
||||
vcomp[col] = _mm512_dpbusd_epi32(vcomp[col], off, vb);
|
||||
}
|
||||
}
|
||||
|
||||
for (int col = 0; col < COLS; ++col) {
|
||||
_mm512_storeu_si512((__m512i *)(packed + offset + col * 64), vcomp[col]);
|
||||
}
|
||||
#else
|
||||
TORCH_CHECK(false, "s8s8_compensation not implemented!");
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert to vnni format
|
||||
// from [N, K] to [K/2, N, 2] for bfloat16 and float16
|
||||
template <typename packed_t>
|
||||
inline void pack_vnni(packed_t* __restrict__ packed, const packed_t* __restrict__ weight, int N, int K) {
|
||||
const int VNNI_BLK = 2;
|
||||
for (int n = 0; n < N; ++n) {
|
||||
for (int k = 0; k < K / VNNI_BLK; ++k) {
|
||||
for (int d = 0; d < VNNI_BLK; ++d) {
|
||||
packed[k * N * VNNI_BLK + n * VNNI_BLK + d] = weight[n * K + k * VNNI_BLK + d];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void pack_vnni<int8_t>(int8_t* __restrict__ packed, const int8_t* __restrict__ weight, int N, int K) {
|
||||
constexpr int BLOCK_N = block_size_n();
|
||||
TORCH_CHECK(N == BLOCK_N);
|
||||
|
||||
const int VNNI_BLK = 4;
|
||||
for (int n = 0; n < N; ++n) {
|
||||
for (int k = 0; k < K / VNNI_BLK; ++k) {
|
||||
for (int d = 0; d < VNNI_BLK; ++d) {
|
||||
packed[k * N * VNNI_BLK + n * VNNI_BLK + d] = weight[n * K + k * VNNI_BLK + d];
|
||||
}
|
||||
}
|
||||
}
|
||||
s8s8_compensation<BLOCK_N>(packed, K);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
fVec data0 = fVec::loadu(input + d);
|
||||
fVec data1 = fVec::loadu(input + d + fVec::size());
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_add_stub(scalar_t* __restrict__ out, const float* __restrict__ input, const float* __restrict__ bias, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
fVec data0 = fVec::loadu(input + d) + fVec::loadu(bias + d);
|
||||
fVec data1 = fVec::loadu(input + d + fVec::size()) + fVec::loadu(bias + d + fVec::size());
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d] + bias[d]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, bool has_bias, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_nn {
|
||||
static inline void apply(
|
||||
const scalar_t* __restrict__ A, const scalar_t* __restrict__ B, scalar_t* __restrict__ C,
|
||||
const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
|
||||
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
template <bool has_bias, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_nn<at::BFloat16, has_bias, BLOCK_M, BLOCK_N> {
|
||||
static inline void apply(
|
||||
const at::BFloat16* __restrict__ A, const at::BFloat16* __restrict__ B, at::BFloat16* __restrict__ C,
|
||||
const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
|
||||
|
||||
constexpr int ROWS = BLOCK_M;
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
|
||||
// prefetch distance
|
||||
constexpr int PREFETCH_SIZE_K = 0;
|
||||
|
||||
__m512bh va;
|
||||
__m512bh vb[COLS];
|
||||
__m512 vc[ROWS * COLS];
|
||||
|
||||
auto loadc = [&](auto i) {
|
||||
constexpr int col = i % COLS;
|
||||
if constexpr (has_bias) {
|
||||
vc[i] = _mm512_loadu_ps(bias + col * 16);
|
||||
} else {
|
||||
vc[i] = _mm512_set1_ps(0.f);
|
||||
}
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(loadc);
|
||||
|
||||
const int64_t K2 = K >> 1;
|
||||
const int64_t lda2 = lda >> 1;
|
||||
const int64_t ldb2 = ldb; // ldb * 2 >> 1;
|
||||
const float* a_ptr = reinterpret_cast<const float*>(A);
|
||||
const float* b_ptr = reinterpret_cast<const float*>(B);
|
||||
|
||||
auto compute = [&](auto i, int64_t k) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
if constexpr (col == 0) {
|
||||
va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k]));
|
||||
}
|
||||
if constexpr (row == 0) {
|
||||
vb[col] = (__m512bh)(_mm512_loadu_si512(b_ptr + k * ldb2 + col * 16));
|
||||
if constexpr (PREFETCH_SIZE_K > 0) {
|
||||
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0);
|
||||
}
|
||||
}
|
||||
vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]);
|
||||
};
|
||||
for (int64_t k = 0; k < K2; ++k) {
|
||||
Unroll<ROWS * COLS>{}(compute, k);
|
||||
}
|
||||
|
||||
auto storec = [&](auto i) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
// for COLS = 2, 4 use 512bit store
|
||||
// for COLS = 1, 3 use 256bit store
|
||||
if constexpr (COLS % 2 == 0) {
|
||||
if constexpr (col % 2 == 0) {
|
||||
_mm512_storeu_si512(
|
||||
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)),
|
||||
(__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col])));
|
||||
}
|
||||
} else {
|
||||
_mm256_storeu_si256(
|
||||
reinterpret_cast<__m256i*>(C + row * ldc + col * 16),
|
||||
(__m256i)(_mm512_cvtneps_pbh(vc[i])));
|
||||
}
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(storec);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
|
||||
tinygemm_kernel_nn<scalar_t, has_bias, MB_SIZE, NB_SIZE>::apply( \
|
||||
A + mb_start * lda, B + nb_start * 2, C + mb_start * ldc + nb_start, \
|
||||
has_bias ? bias + nb_start : nullptr, K, lda, ldb, ldc);
|
||||
|
||||
template <typename scalar_t, bool has_bias>
|
||||
struct brgemm {
|
||||
static inline void apply(
|
||||
const scalar_t* __restrict__ A, const scalar_t* __restrict__ B, scalar_t* __restrict__ C,
|
||||
float* __restrict__ Ctmp, const float* __restrict__ bias,
|
||||
int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
|
||||
|
||||
constexpr int BLOCK_N = block_size_n();
|
||||
at::native::cpublas::brgemm(
|
||||
M, N, K, lda, ldb, BLOCK_N, /* add_C */false,
|
||||
A, B, Ctmp);
|
||||
|
||||
// copy from Ctmp to C
|
||||
for (int64_t m = 0; m < M; ++m) {
|
||||
if constexpr (has_bias) {
|
||||
copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N);
|
||||
} else {
|
||||
copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename scalar_t, bool has_bias>
|
||||
void tinygemm_kernel(
|
||||
const scalar_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ bias,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg) {
|
||||
|
||||
if (brg) {
|
||||
brgemm<scalar_t, has_bias>::apply(
|
||||
A, B, C, Ctmp, bias,
|
||||
M, N, K, lda, ldb, ldc);
|
||||
return;
|
||||
}
|
||||
|
||||
// pattern: 1-4-16
|
||||
constexpr int64_t BLOCK_M = 4;
|
||||
constexpr int64_t BLOCK_N = 64;
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
for (int mb = 0; mb < MB; ++mb) {
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
for (int64_t nb = 0; nb < NB; ++nb) {
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
|
||||
|
||||
switch(mb_size << 4 | nb_size >> 4) {
|
||||
// mb_size = 1
|
||||
case 0x12: LAUNCH_TINYGEMM_KERNEL_NN(1, 32); break;
|
||||
case 0x14: LAUNCH_TINYGEMM_KERNEL_NN(1, 64); break;
|
||||
// mb_size = 2
|
||||
case 0x22: LAUNCH_TINYGEMM_KERNEL_NN(2, 32); break;
|
||||
case 0x24: LAUNCH_TINYGEMM_KERNEL_NN(2, 64); break;
|
||||
// mb_size = 3
|
||||
case 0x32: LAUNCH_TINYGEMM_KERNEL_NN(3, 32); break;
|
||||
case 0x34: LAUNCH_TINYGEMM_KERNEL_NN(3, 64); break;
|
||||
// mb_size = 4
|
||||
case 0x42: LAUNCH_TINYGEMM_KERNEL_NN(4, 32); break;
|
||||
case 0x44: LAUNCH_TINYGEMM_KERNEL_NN(4, 64); break;
|
||||
default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void weight_packed_linear_kernel_impl(
|
||||
scalar_t* __restrict__ out,
|
||||
const scalar_t* __restrict__ mat1,
|
||||
const scalar_t* __restrict__ mat2,
|
||||
const float* __restrict__ bias,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t mat1_strideM,
|
||||
int64_t out_strideM) {
|
||||
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
|
||||
// use avx512-bf16 when a) M is small; b) dtype is bfloat16, otherwise use amx
|
||||
const bool use_brgemm = (M > 4) || (!std::is_same_v<scalar_t, at::BFloat16>);
|
||||
|
||||
// l2 cache block for n
|
||||
int64_t cache_blocks_nb = get_cache_blocks<scalar_t>(BLOCK_N, K);
|
||||
|
||||
// parallel on [MB, NB]
|
||||
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
|
||||
parallel_2d(MB, NB, [&](int64_t begin_mb, int64_t end_mb, int64_t begin_nb, int64_t end_nb) {
|
||||
|
||||
// for brgemm, use float32 for accumulate
|
||||
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
|
||||
|
||||
for (int64_t nbb = begin_nb; nbb < end_nb; nbb += cache_blocks_nb) {
|
||||
for (int64_t mb = begin_mb; mb < end_mb; ++mb) {
|
||||
for (int64_t nb = nbb; nb < std::min(nbb + cache_blocks_nb, end_nb); ++nb) {
|
||||
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(M - mb_start, BLOCK_M);
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(N - nb_start, BLOCK_N);
|
||||
|
||||
tinygemm_kernel<scalar_t, has_bias>(
|
||||
/* A */ mat1 + mb_start * mat1_strideM,
|
||||
/* B */ mat2 + nb_start * K /* nb * BLOCK_N * K */,
|
||||
/* C */ out + mb_start * out_strideM + nb_start,
|
||||
/* Ctmp*/ Ctmp,
|
||||
/* bias*/ bias + nb_start,
|
||||
/* M */ mb_size,
|
||||
/* N */ nb_size,
|
||||
/* K */ K,
|
||||
/* lda */ mat1_strideM,
|
||||
/* ldb */ nb_size,
|
||||
/* ldc */ out_strideM,
|
||||
/* brg */ use_brgemm);
|
||||
}}}
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// tinygemm interface
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(const scalar_t* __restrict__ A, const scalar_t* __restrict__ B, scalar_t* __restrict__ C,
|
||||
float* __restrict__ Ctmp, int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool brg) {
|
||||
tinygemm_kernel<scalar_t, false>(A, B, C, Ctmp, nullptr, M, N, K, lda, ldb, ldc, brg);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
|
||||
template void tinygemm_kernel<TYPE>( \
|
||||
const TYPE* __restrict__ A, const TYPE* __restrict__ B, TYPE* __restrict__ C, \
|
||||
float* __restrict__ Ctmp, int64_t M, int64_t N, int64_t K, int64_t lda, \
|
||||
int64_t ldb, int64_t ldc, bool brg)
|
||||
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
|
||||
|
||||
at::Tensor convert_weight_packed(at::Tensor& weight) {
|
||||
// for 3d moe weights
|
||||
// weight : [E, OC, IC]
|
||||
// w1 : [E, 2N, K]
|
||||
// w2 : [E, K, N]
|
||||
CHECK_INPUT(weight);
|
||||
|
||||
const int64_t ndim = weight.ndimension();
|
||||
TORCH_CHECK(ndim == 2 || ndim == 3, "expect weight to be 2d or 3d, got ", ndim, "d tensor.");
|
||||
const auto st = weight.scalar_type();
|
||||
const int64_t E = ndim == 3 ? weight.size(0) : 1;
|
||||
const int64_t OC = ndim == 3 ? weight.size(1) : weight.size(0);
|
||||
const int64_t IC = ndim == 3 ? weight.size(2) : weight.size(1);
|
||||
|
||||
// we handle 2 TILE_N at a time.
|
||||
TORCH_CHECK(OC % TILE_N == 0, "invalid weight out features ", OC);
|
||||
TORCH_CHECK(IC % TILE_K == 0, "invalid weight input features ", IC);
|
||||
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
const int64_t NB = div_up(OC, BLOCK_N);
|
||||
|
||||
// use phony sizes here [E, OC, IC], for each [E], [OC, IC] -> [IC / 2, OC, 2]
|
||||
auto packed_weight = at::empty({}, weight.options());
|
||||
const int64_t stride = OC * IC;
|
||||
|
||||
TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf || st == at::kChar || st == at::kFloat8_e4m3fn,
|
||||
"expect weight to be bfloat16, float16, int8 or fp8_e4m3.");
|
||||
|
||||
CPU_DISPATCH_PACKED_TYPES(st, [&] {
|
||||
// adjust most inner dimension size
|
||||
const int packed_row_size = get_row_size<packed_t>(IC);
|
||||
auto sizes = weight.sizes().vec();
|
||||
sizes[ndim - 1] = packed_row_size;
|
||||
packed_weight.resize_(sizes);
|
||||
|
||||
const packed_t* w_data = weight.data_ptr<packed_t>();
|
||||
packed_t* packed_data = packed_weight.data_ptr<packed_t>();
|
||||
|
||||
// parallel on {E, NB}
|
||||
at::parallel_for(0, E * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
int64_t e{0}, nb{0};
|
||||
data_index_init(begin, e, E, nb, NB);
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
UNUSED(i);
|
||||
|
||||
int64_t n = nb * BLOCK_N;
|
||||
int64_t n_size = std::min(BLOCK_N, OC - n);
|
||||
pack_vnni<packed_t>(
|
||||
packed_data + e * OC * packed_row_size + n * packed_row_size,
|
||||
w_data + e * stride + n * IC,
|
||||
n_size,
|
||||
IC);
|
||||
|
||||
// move to the next index
|
||||
data_index_step(e, E, nb, NB);
|
||||
}
|
||||
});
|
||||
});
|
||||
return packed_weight;
|
||||
}
|
||||
|
||||
// mat1 : [M, K]
|
||||
// mat2 : [N, K]
|
||||
// bias : [N]
|
||||
// out : [M, N]
|
||||
//
|
||||
at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2,
|
||||
const std::optional<at::Tensor>& bias, bool is_vnni) {
|
||||
RECORD_FUNCTION(
|
||||
"sgl-kernel::weight_packed_linear", std::vector<c10::IValue>({mat1, mat2, bias}));
|
||||
|
||||
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
|
||||
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
|
||||
CHECK_INPUT(mat2);
|
||||
|
||||
int64_t M = mat1.size(0);
|
||||
int64_t N = mat2.size(0);
|
||||
int64_t K = mat2.size(1);
|
||||
CHECK_EQ(mat1.size(1), K);
|
||||
CHECK_DIM(2, mat1);
|
||||
CHECK_DIM(2, mat2);
|
||||
|
||||
auto out = at::empty({M, N}, mat1.options());
|
||||
|
||||
// strides
|
||||
int64_t mat1_strideM = mat1.stride(0);
|
||||
int64_t out_strideM = out.stride(0);
|
||||
|
||||
const bool has_bias = bias.has_value();
|
||||
const float* bias_data = nullptr;
|
||||
if (has_bias) {
|
||||
CHECK_EQ(bias.value().size(0), N);
|
||||
bias_data = bias.value().data_ptr<float>();
|
||||
}
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(mat1.scalar_type(), "weight_packed_linear_kernel_impl", [&] {
|
||||
weight_packed_linear_kernel_impl<scalar_t>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
mat1.data_ptr<scalar_t>(),
|
||||
packed_w.data_ptr<scalar_t>(),
|
||||
bias_data,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
mat1_strideM,
|
||||
out_strideM);
|
||||
});
|
||||
|
||||
return out;
|
||||
}
|
||||
266
csrc/cpu/sgl-kernels/gemm.h
Normal file
266
csrc/cpu/sgl-kernels/gemm.h
Normal file
@ -0,0 +1,266 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/native/CPUBlas.h>
|
||||
|
||||
// clang-format off
|
||||
|
||||
// amx-bf16
|
||||
#define TILE_M 16
|
||||
#define TILE_N 16
|
||||
#define TILE_K 32
|
||||
|
||||
// block size for AMX gemm
|
||||
constexpr int block_size_m() { return 2 * TILE_M; }
|
||||
constexpr int block_size_n() { return 2 * TILE_N; }
|
||||
|
||||
// define threshold using brgemm (intel AMX)
|
||||
template <typename T> inline bool can_use_brgemm(int M);
|
||||
template <> inline bool can_use_brgemm<at::BFloat16>(int M) { return M > 4; }
|
||||
template <> inline bool can_use_brgemm<at::Half>(int M) { return true; }
|
||||
// TODO: add u8s8 brgemm, this requires PyTorch 2.7
|
||||
template <> inline bool can_use_brgemm<int8_t>(int M) { return false; }
|
||||
template <> inline bool can_use_brgemm<at::Float8_e4m3fn>(int M) { return M > 4; }
|
||||
template <> inline bool can_use_brgemm<at::quint4x2>(int M) { return M > 4; }
|
||||
|
||||
// work around compiler internal error
|
||||
#define BLOCK_K 128 // 4 * TILE_K
|
||||
|
||||
// adjust leading dimension size for K
|
||||
template <typename T>
|
||||
inline int64_t get_row_size(int64_t K) {
|
||||
return K;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline int64_t get_row_size<int8_t>(int64_t K) {
|
||||
return K + sizeof(int32_t);
|
||||
}
|
||||
|
||||
inline int64_t get_row_size(int64_t K, bool use_int8_w8a8) {
|
||||
return use_int8_w8a8 ? K + sizeof(int32_t) : K;
|
||||
}
|
||||
|
||||
// pack weight to vnni format
|
||||
at::Tensor convert_weight_packed(at::Tensor& weight);
|
||||
|
||||
// moe implementations for int8 w8a8
|
||||
template <typename scalar_t>
|
||||
void fused_experts_int8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic1,
|
||||
scalar_t* __restrict__ ic2,
|
||||
uint8_t* __restrict__ A_tmp,
|
||||
float* __restrict__ C_tmp,
|
||||
uint8_t* __restrict__ Aq_tmp,
|
||||
float* __restrict__ As_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const int8_t* __restrict__ packed_w1,
|
||||
const int8_t* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
const float* __restrict__ topk_weights,
|
||||
const int32_t* __restrict__ sorted_ids,
|
||||
const int32_t* __restrict__ expert_ids,
|
||||
const int32_t* __restrict__ offsets,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t E,
|
||||
int64_t topk,
|
||||
int64_t num_tokens_post_pad);
|
||||
|
||||
// moe implementations for fp8 w8a16
|
||||
template <typename scalar_t>
|
||||
void fused_experts_fp8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic0,
|
||||
scalar_t* __restrict__ ic1,
|
||||
scalar_t* __restrict__ ic2,
|
||||
scalar_t* __restrict__ A_tmp,
|
||||
scalar_t* __restrict__ B_tmp,
|
||||
float* __restrict__ C_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w1,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
int64_t block_size_N,
|
||||
int64_t block_size_K,
|
||||
const float* __restrict__ topk_weights,
|
||||
const int32_t* __restrict__ sorted_ids,
|
||||
const int32_t* __restrict__ expert_ids,
|
||||
const int32_t* __restrict__ offsets,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t E,
|
||||
int64_t topk,
|
||||
int64_t num_tokens_post_pad);
|
||||
|
||||
// moe implementations for int4 w4a16
|
||||
template <typename scalar_t>
|
||||
void fused_experts_int4_w4a16_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic0,
|
||||
scalar_t* __restrict__ ic1,
|
||||
scalar_t* __restrict__ ic2,
|
||||
scalar_t* __restrict__ A_tmp,
|
||||
scalar_t* __restrict__ B_tmp,
|
||||
float* __restrict__ C_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const at::quint4x2* __restrict__ packed_w1,
|
||||
const at::quint4x2* __restrict__ packed_w2,
|
||||
const uint8_t* __restrict__ w1z,
|
||||
const uint8_t* __restrict__ w2z,
|
||||
const scalar_t* __restrict__ w1s,
|
||||
const scalar_t* __restrict__ w2s,
|
||||
int group_size,
|
||||
const float* __restrict__ topk_weights,
|
||||
const int32_t* __restrict__ sorted_ids,
|
||||
const int32_t* __restrict__ expert_ids,
|
||||
const int32_t* __restrict__ offsets,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t E,
|
||||
int64_t topk,
|
||||
int64_t num_tokens_post_pad);
|
||||
|
||||
// shared expert implememntation for int8 w8a8
|
||||
template <typename scalar_t>
|
||||
void shared_expert_int8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic1,
|
||||
float* __restrict__ C_tmp,
|
||||
uint8_t* __restrict__ Aq_tmp,
|
||||
float* __restrict__ As_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const int8_t* __restrict__ packed_w1,
|
||||
const int8_t* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
const scalar_t* __restrict__ fused_experts_out,
|
||||
float routed_scaling_factor,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K);
|
||||
|
||||
template <typename scalar_t>
|
||||
void shared_expert_fp8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic0,
|
||||
scalar_t* __restrict__ ic1,
|
||||
scalar_t* __restrict__ B_tmp,
|
||||
float* __restrict__ C_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w1,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
int64_t block_size_N,
|
||||
int64_t block_size_K,
|
||||
const scalar_t* __restrict__ fused_experts_out,
|
||||
float routed_scaling_factor,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K);
|
||||
|
||||
// tinygemm interface
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const scalar_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
float* __restrict__ Ctmp,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg);
|
||||
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const uint8_t* __restrict__ A,
|
||||
const int8_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
int32_t* __restrict__ Ctmp,
|
||||
const float* __restrict__ As,
|
||||
const float* __restrict__ Bs,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg);
|
||||
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const scalar_t* __restrict__ A,
|
||||
const at::Float8_e4m3fn* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
scalar_t* __restrict__ Btmp,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ scale,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg,
|
||||
int64_t block_size_K);
|
||||
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const scalar_t* __restrict__ A,
|
||||
const at::quint4x2* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
const uint8_t* __restrict__ Bz,
|
||||
const scalar_t* __restrict__ Bs,
|
||||
scalar_t* __restrict__ Btmp,
|
||||
float* __restrict__ Ctmp,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int group_size,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
int64_t strideBz,
|
||||
int64_t strideBs,
|
||||
bool brg);
|
||||
|
||||
// TODO: debug print, remove me later
|
||||
inline void print_16x32i(const __m512i x) {
|
||||
int32_t a[16];
|
||||
_mm512_storeu_si512((__m512i *)a, x);
|
||||
|
||||
for (int i = 0; i < 16; i++){
|
||||
std::cout << a[i] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
inline void print_16x32(const __m512 x) {
|
||||
float a[16];
|
||||
_mm512_storeu_ps((__m512 *)a, x);
|
||||
|
||||
for (int i = 0; i < 16; i++){
|
||||
std::cout << a[i] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
|
||||
inline void print_32x8u(const __m256i x) {
|
||||
uint8_t a[32];
|
||||
_mm256_storeu_si256((__m256i *)a, x);
|
||||
|
||||
for (int i = 0; i < 32; ++i) {
|
||||
std::cout << int32_t(a[i]) << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
530
csrc/cpu/sgl-kernels/gemm_fp8.cpp
Normal file
530
csrc/cpu/sgl-kernels/gemm_fp8.cpp
Normal file
@ -0,0 +1,530 @@
|
||||
// Adapted from
|
||||
// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu
|
||||
|
||||
#include "common.h"
|
||||
#include "vec.h"
|
||||
#include "gemm.h"
|
||||
|
||||
// clang-format off
|
||||
|
||||
// we use 4x32 for BLOCK_M
|
||||
#define BLOCK_SIZE_M_SCALE 4
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
fVec data0 = fVec::loadu(input + d);
|
||||
fVec data1 = fVec::loadu(input + d + fVec::size());
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_add_stub(scalar_t* __restrict__ out, const float* __restrict__ input, const float* __restrict__ bias, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
fVec data0 = fVec::loadu(input + d) + fVec::loadu(bias + d);
|
||||
fVec data1 = fVec::loadu(input + d + fVec::size()) + fVec::loadu(bias + d + fVec::size());
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d] + bias[d]);
|
||||
}
|
||||
}
|
||||
|
||||
inline void unpack_B(
|
||||
at::BFloat16* __restrict__ Btmp,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_B,
|
||||
int N,
|
||||
int K,
|
||||
int ldb,
|
||||
int ldb_tmp,
|
||||
float scale) {
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
// [K/2, N, 2]
|
||||
const int K2 = K >> 1;
|
||||
const int ldb2 = ldb; // ldb * 2 >> 1;
|
||||
const uint16_t* b_ptr = reinterpret_cast<const uint16_t*>(packed_B);
|
||||
const __m512 vd = _mm512_set1_ps(scale);
|
||||
|
||||
constexpr int BLOCK_N = block_size_n();
|
||||
static_assert(BLOCK_N == 32);
|
||||
|
||||
// prefetch distance
|
||||
constexpr int PREFETCH_SIZE_K = 64;
|
||||
|
||||
#pragma GCC unroll 4
|
||||
for (int k = 0; k < K2; ++k) {
|
||||
__m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2);
|
||||
if constexpr (PREFETCH_SIZE_K > 0) {
|
||||
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2, _MM_HINT_T0);
|
||||
}
|
||||
|
||||
__m256i b8_0 = _mm512_extracti32x8_epi32(b8, 0);
|
||||
__m256i b8_1 = _mm512_extracti32x8_epi32(b8, 1);
|
||||
|
||||
__m512bh bf16_0 = CVT_FP8_TO_BF16(b8_0);
|
||||
__m512bh bf16_1 = CVT_FP8_TO_BF16(b8_1);
|
||||
|
||||
// Apply scale
|
||||
__m512 f0_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 0));
|
||||
__m512 f0_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 1));
|
||||
__m512 f1_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 0));
|
||||
__m512 f1_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 1));
|
||||
|
||||
f0_lo = _mm512_mul_ps(f0_lo, vd);
|
||||
f0_hi = _mm512_mul_ps(f0_hi, vd);
|
||||
f1_lo = _mm512_mul_ps(f1_lo, vd);
|
||||
f1_hi = _mm512_mul_ps(f1_hi, vd);
|
||||
|
||||
bf16_0 = _mm512_cvtne2ps_pbh(f0_hi, f0_lo);
|
||||
bf16_1 = _mm512_cvtne2ps_pbh(f1_hi, f1_lo);
|
||||
|
||||
_mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + 0, (__m512i)bf16_0);
|
||||
_mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + 32, (__m512i)bf16_1);
|
||||
}
|
||||
#else
|
||||
TORCH_CHECK(false, "unpack_B: scalar path not implemented!");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename packed_t, bool has_bias, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_nn {
|
||||
static inline void apply(
|
||||
const scalar_t* __restrict__ A, const packed_t* __restrict__ B, scalar_t* __restrict__ C,
|
||||
const float* __restrict__ bias, const float* __restrict__ scale, int K, int lda, int ldb, int ldc, int64_t block_size_K) {
|
||||
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
template <bool has_bias, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_nn<at::BFloat16, at::Float8_e4m3fn, has_bias, BLOCK_M, BLOCK_N> {
|
||||
static inline void apply(
|
||||
const at::BFloat16* __restrict__ A, const at::Float8_e4m3fn* __restrict__ B, at::BFloat16* __restrict__ C,
|
||||
const float* __restrict__ bias, const float* __restrict__ scale, int K, int lda, int ldb, int ldc, int64_t block_size_K) {
|
||||
|
||||
constexpr int ROWS = BLOCK_M;
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
|
||||
const int KB = div_up(K, BLOCK_K);
|
||||
|
||||
// prefetch distance
|
||||
constexpr int PREFETCH_SIZE_K = 64;
|
||||
constexpr int PREFETCH_SIZE_KB = 1;
|
||||
|
||||
__m512bh va;
|
||||
__m512bh vb[COLS];
|
||||
__m512 vc[ROWS * COLS];
|
||||
__m512 vsum[ROWS * COLS];
|
||||
|
||||
// block quant scale
|
||||
__m512 vscale;
|
||||
|
||||
auto loadc = [&](auto i) {
|
||||
constexpr int col = i % COLS;
|
||||
if constexpr (has_bias) {
|
||||
vc[i] = _mm512_loadu_ps(bias + col * 16);
|
||||
} else {
|
||||
vc[i] = _mm512_setzero_ps();
|
||||
}
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(loadc);
|
||||
|
||||
const int lda2 = lda >> 1;
|
||||
const int ldb2 = ldb; // ldb * 2 >> 1;
|
||||
const float* a_ptr = reinterpret_cast<const float*>(A);
|
||||
const uint16_t* b_ptr = reinterpret_cast<const uint16_t*>(B);
|
||||
|
||||
auto compute = [&](auto i, int k) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
if constexpr (col == 0) {
|
||||
va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k]));
|
||||
if constexpr (PREFETCH_SIZE_K > 0) {
|
||||
_mm_prefetch(a_ptr + row * lda2 + k + PREFETCH_SIZE_K, _MM_HINT_T0);
|
||||
}
|
||||
}
|
||||
if constexpr (row == 0) {
|
||||
if constexpr (col % 2 == 0) {
|
||||
__m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2 + col * 16);
|
||||
if constexpr (PREFETCH_SIZE_K > 0) {
|
||||
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0);
|
||||
}
|
||||
vb[col + 0] = CVT_FP8_TO_BF16(_mm512_extracti32x8_epi32(b8, 0));
|
||||
vb[col + 1] = CVT_FP8_TO_BF16(_mm512_extracti32x8_epi32(b8, 1));
|
||||
}
|
||||
}
|
||||
vsum[i] = _mm512_dpbf16_ps(vsum[i], va, vb[col]);
|
||||
};
|
||||
|
||||
constexpr int BLOCK_K2 = BLOCK_K >> 1;
|
||||
for (int kb = 0; kb < KB; ++kb) {
|
||||
int kb_start = kb * BLOCK_K2;
|
||||
int kb_end = std::min(K, kb_start + BLOCK_K2);
|
||||
// 1. load scale vector
|
||||
vscale = _mm512_set1_ps(scale[kb]);
|
||||
if constexpr (PREFETCH_SIZE_KB > 0) {
|
||||
_mm_prefetch(scale + kb + PREFETCH_SIZE_KB, _MM_HINT_T0);
|
||||
}
|
||||
// 2. zero vsum for each block
|
||||
Unroll<ROWS * COLS>{}([&](auto i) {
|
||||
vsum[i] = _mm512_setzero_ps();
|
||||
});
|
||||
// 3. accumulate across each block
|
||||
for (int k = kb_start; k < kb_end; ++k) {
|
||||
Unroll<ROWS * COLS>{}(compute, k);
|
||||
}
|
||||
// 4. apply scale
|
||||
Unroll<ROWS * COLS>{}([&](auto i) {
|
||||
vc[i] = _mm512_fmadd_ps(vsum[i], vscale, vc[i]);
|
||||
});
|
||||
}
|
||||
|
||||
auto storec = [&](auto i) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
// for COLS = 2,4 use 512bit store
|
||||
if constexpr (col % 2 == 0) {
|
||||
_mm512_storeu_si512(
|
||||
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)),
|
||||
(__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col])));
|
||||
}
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(storec);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
|
||||
tinygemm_kernel_nn<scalar_t, at::Float8_e4m3fn, has_bias, MB_SIZE, NB_SIZE>::apply( \
|
||||
A + mb_start * lda, B + nb_start * 2, C + mb_start * ldc + nb_start, \
|
||||
has_bias ? bias + nb_start : nullptr, scale, K, lda, ldb, ldc, block_size_K);
|
||||
|
||||
template <typename scalar_t, typename packed_t, bool has_bias>
|
||||
struct brgemm {
|
||||
static inline void apply(
|
||||
const scalar_t* __restrict__ A,
|
||||
const packed_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
scalar_t* __restrict__ Btmp,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ bias,
|
||||
const float* __restrict__ scale,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldc) {
|
||||
TORCH_CHECK(false, "struct brgemm: primary template not implemented!");
|
||||
}
|
||||
};
|
||||
|
||||
template <bool has_bias>
|
||||
struct brgemm<at::BFloat16, at::Float8_e4m3fn, has_bias> {
|
||||
static inline void apply(
|
||||
const at::BFloat16* __restrict__ A,
|
||||
const at::Float8_e4m3fn* __restrict__ B,
|
||||
at::BFloat16* __restrict__ C,
|
||||
at::BFloat16* __restrict__ Btmp,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ bias,
|
||||
const float* __restrict__ scale,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldc) {
|
||||
|
||||
constexpr int BLOCK_N = block_size_n();
|
||||
|
||||
// [K, BLOCK_N] -> [K / 2, BLOCK_N * 2]
|
||||
const int ldb_tmp = BLOCK_N;
|
||||
|
||||
for (int k = 0; k < K; k += BLOCK_K) {
|
||||
int kb_size = std::min(BLOCK_K, K - k);
|
||||
|
||||
int idx = k >> 7; // k / BLOCK_K where BLOCK_K = 128
|
||||
unpack_B(Btmp + k * ldb_tmp, B + k * ldb, N, kb_size, ldb, ldb_tmp, scale[idx]);
|
||||
}
|
||||
|
||||
at::native::cpublas::brgemm(
|
||||
M, N, K, lda, ldb_tmp, BLOCK_N, /* add_C */ false, A, Btmp, Ctmp);
|
||||
|
||||
// copy from Ctmp to C
|
||||
for (int m = 0; m < M; ++m) {
|
||||
if constexpr (has_bias) {
|
||||
copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N);
|
||||
} else {
|
||||
copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename scalar_t, bool has_bias>
|
||||
void tinygemm_kernel(
|
||||
const scalar_t* __restrict__ A,
|
||||
const at::Float8_e4m3fn* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
scalar_t* __restrict__ Btmp,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ scale,
|
||||
const float* __restrict__ bias,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg,
|
||||
int64_t block_size_K) {
|
||||
|
||||
if (brg) {
|
||||
brgemm<scalar_t, at::Float8_e4m3fn, has_bias>::apply(
|
||||
A, B, C, Btmp, Ctmp, bias, scale, M, N, K, lda, ldb, ldc);
|
||||
return;
|
||||
}
|
||||
|
||||
// pattern: 1-4-16
|
||||
constexpr int64_t BLOCK_M = 4;
|
||||
constexpr int64_t BLOCK_N = 64;
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
for (int mb = 0; mb < MB; ++mb) {
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
for (int64_t nb = 0; nb < NB; ++nb) {
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
|
||||
|
||||
switch(mb_size << 4 | nb_size >> 4) {
|
||||
case 0x12: LAUNCH_TINYGEMM_KERNEL_NN(1, 32); break;
|
||||
case 0x22: LAUNCH_TINYGEMM_KERNEL_NN(2, 32); break;
|
||||
case 0x32: LAUNCH_TINYGEMM_KERNEL_NN(3, 32); break;
|
||||
case 0x42: LAUNCH_TINYGEMM_KERNEL_NN(4, 32); break;
|
||||
default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void fp8_scaled_mm_kernel_impl(
|
||||
scalar_t* __restrict__ out,
|
||||
const scalar_t* __restrict__ mat1,
|
||||
const at::Float8_e4m3fn* __restrict__ mat2,
|
||||
const float* __restrict__ scales2,
|
||||
const float* __restrict__ bias,
|
||||
scalar_t* __restrict__ buffer,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t mat1_strideM,
|
||||
int64_t out_strideM,
|
||||
int64_t block_size_N,
|
||||
int64_t block_size_K,
|
||||
int64_t buffer_size_per_thread) {
|
||||
|
||||
constexpr int64_t BLOCK_M = block_size_m() * BLOCK_SIZE_M_SCALE;
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
|
||||
const int64_t scale_size_K = div_up(K, block_size_K);
|
||||
const int64_t blocks_n_per_group = block_size_N / BLOCK_N;
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(M);
|
||||
|
||||
// parallel on [MB, NB]
|
||||
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
int64_t mb{0}, nb{0};
|
||||
data_index_init(begin, mb, MB, nb, NB);
|
||||
|
||||
int tid = at::get_thread_num();
|
||||
scalar_t* __restrict__ Btmp = buffer + tid * buffer_size_per_thread;
|
||||
float* __restrict__ Ctmp = (float*)((void*)(Btmp + BLOCK_N * K));
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
UNUSED(i);
|
||||
const float* scale_ptr = scales2 + (nb / blocks_n_per_group) * scale_size_K;
|
||||
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(M - mb_start, BLOCK_M);
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(N - nb_start, BLOCK_N);
|
||||
|
||||
tinygemm_kernel<scalar_t, has_bias>(
|
||||
/* A */ mat1 + mb_start * mat1_strideM,
|
||||
/* B */ mat2 + nb_start * K, // nb * BLOCK_N * K
|
||||
/* C */ out + mb_start * out_strideM + nb_start,
|
||||
/* Btmp */ Btmp,
|
||||
/* Ctmp */ Ctmp,
|
||||
/* scale */ scale_ptr,
|
||||
/* bias */ bias + nb_start,
|
||||
/* M */ mb_size,
|
||||
/* N */ nb_size,
|
||||
/* K */ K,
|
||||
/* lda */ mat1_strideM,
|
||||
/* ldb */ nb_size,
|
||||
/* ldc */ out_strideM,
|
||||
/* brg */ use_brgemm,
|
||||
/* block_size_K */ block_size_K);
|
||||
|
||||
// move to the next index
|
||||
data_index_step(mb, MB, nb, NB);
|
||||
}
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// tinygemm interface
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const scalar_t* __restrict__ A,
|
||||
const at::Float8_e4m3fn* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
scalar_t* __restrict__ Btmp,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ scale,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg,
|
||||
int64_t block_size_K) {
|
||||
tinygemm_kernel<scalar_t, false>(A, B, C, Btmp, Ctmp, scale, nullptr, M, N, K, lda, ldb, ldc, brg, block_size_K);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
|
||||
template void tinygemm_kernel<TYPE>( \
|
||||
const TYPE* __restrict__ A, \
|
||||
const at::Float8_e4m3fn* __restrict__ B, \
|
||||
TYPE* __restrict__ C, \
|
||||
TYPE* __restrict__ Btmp, \
|
||||
float* __restrict__ Ctmp, \
|
||||
const float* __restrict__ scale, \
|
||||
int64_t M, \
|
||||
int64_t N, \
|
||||
int64_t K, \
|
||||
int64_t lda, \
|
||||
int64_t ldb, \
|
||||
int64_t ldc, \
|
||||
bool brg, \
|
||||
int64_t block_size_K)
|
||||
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
|
||||
|
||||
at::Tensor fp8_scaled_mm_cpu(at::Tensor& mat1, at::Tensor& mat2, at::Tensor& scales2,
|
||||
std::vector<int64_t> block_size, std::optional<at::Tensor>& bias,
|
||||
at::ScalarType out_dtype, bool is_vnni) {
|
||||
RECORD_FUNCTION("sgl-kernel::fp8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales2, block_size, bias}));
|
||||
|
||||
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
|
||||
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
|
||||
CHECK_INPUT(mat2);
|
||||
CHECK_INPUT(scales2);
|
||||
TORCH_CHECK(scales2.scalar_type() == at::kFloat,
|
||||
"fp8_scaled_mm_cpu: expect scales2 to be float32.");
|
||||
|
||||
int64_t M = mat1.size(0);
|
||||
int64_t N = mat2.size(0);
|
||||
int64_t K = mat2.size(1);
|
||||
|
||||
CHECK_EQ(mat1.size(1), K);
|
||||
CHECK_DIM(2, mat1);
|
||||
CHECK_DIM(2, mat2);
|
||||
|
||||
TORCH_CHECK(block_size.size() == 2,
|
||||
"fp8_scaled_mm_cpu: expect block_size.size() to be 2.");
|
||||
|
||||
int64_t block_size_N = block_size[0];
|
||||
int64_t block_size_K = block_size[1];
|
||||
|
||||
constexpr int64_t BLOCK_M = block_size_m() * BLOCK_SIZE_M_SCALE;
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
TORCH_CHECK(block_size_N % BLOCK_N == 0, "fp8_scaled_mm_cpu: expect block_size_N to be multiples of BLOCK_N");
|
||||
TORCH_CHECK(block_size_K == BLOCK_K, "fp8_scaled_mm_cpu: expect block_size_K equals to BLOCK_K");
|
||||
CHECK_EQ(scales2.size(0), div_up(N, block_size_N));
|
||||
CHECK_EQ(scales2.size(1), div_up(K, block_size_K));
|
||||
|
||||
const auto st = mat1.scalar_type();
|
||||
TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf,
|
||||
"fp8_scaled_mm_cpu: expect A to be bfloat16 or half.");
|
||||
TORCH_CHECK(st == out_dtype,
|
||||
"fp8_scaled_mm_cpu: expect A has same dtype with out_dtype.");
|
||||
TORCH_CHECK(mat2.scalar_type() == at::kFloat8_e4m3fn,
|
||||
"fp8_scaled_mm_cpu: expect mat2 to be fp8_e4m3.");
|
||||
TORCH_CHECK(scales2.scalar_type() == at::kFloat,
|
||||
"fp8_scaled_mm_cpu: expect scales to be float32.");
|
||||
auto out = at::empty({M, N}, mat1.options().dtype(out_dtype));
|
||||
|
||||
// strides
|
||||
int64_t mat1_strideM = mat1.stride(0);
|
||||
int64_t out_strideM = out.stride(0);
|
||||
|
||||
const bool has_bias = bias.has_value();
|
||||
const float* bias_data = nullptr;
|
||||
if (has_bias) {
|
||||
CHECK_EQ(bias.value().size(0), N);
|
||||
bias_data = bias.value().data_ptr<float>();
|
||||
}
|
||||
|
||||
// Btmp : [T, BLOCK_N * K]
|
||||
// Ctmp : [T, BLOCK_M * BLOCK_N]
|
||||
int num_threads = at::get_num_threads();
|
||||
int64_t size_per_thread = BLOCK_N * K + BLOCK_M * BLOCK_N * 2;
|
||||
auto buffer = at::empty({num_threads, size_per_thread}, mat1.options());
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "fp8_scaled_mm_kernel_impl", [&] {
|
||||
fp8_scaled_mm_kernel_impl<scalar_t>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
mat1.data_ptr<scalar_t>(),
|
||||
packed_w.data_ptr<at::Float8_e4m3fn>(),
|
||||
scales2.data_ptr<float>(),
|
||||
bias_data,
|
||||
buffer.data_ptr<scalar_t>(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
mat1_strideM,
|
||||
out_strideM,
|
||||
block_size_N,
|
||||
block_size_K,
|
||||
size_per_thread);
|
||||
});
|
||||
|
||||
return out;
|
||||
}
|
||||
440
csrc/cpu/sgl-kernels/gemm_int8.cpp
Normal file
440
csrc/cpu/sgl-kernels/gemm_int8.cpp
Normal file
@ -0,0 +1,440 @@
|
||||
// Adapted from
|
||||
// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu
|
||||
|
||||
#include "common.h"
|
||||
#include "vec.h"
|
||||
#include "gemm.h"
|
||||
|
||||
// clang-format off
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t, bool has_bias, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_nn {
|
||||
static inline void apply(
|
||||
const uint8_t* __restrict__ A, const int8_t* __restrict__ B, scalar_t* __restrict__ C,
|
||||
const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp,
|
||||
const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
|
||||
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
template <bool has_bias, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_nn<at::BFloat16, has_bias, BLOCK_M, BLOCK_N> {
|
||||
static inline void apply(
|
||||
const uint8_t* __restrict__ A, const int8_t* __restrict__ B, at::BFloat16* __restrict__ C,
|
||||
const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp,
|
||||
const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
|
||||
|
||||
constexpr int ROWS = BLOCK_M;
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
static_assert(COLS % 2 == 0);
|
||||
|
||||
// prefetch distance
|
||||
constexpr int PREFETCH_SIZE_K = 0;
|
||||
|
||||
__m512i va;
|
||||
__m512i vb[COLS];
|
||||
__m512i vc[ROWS * COLS];
|
||||
__m512i vcomp[COLS];
|
||||
__m512 vd0;
|
||||
__m512 vd1[COLS];
|
||||
|
||||
// oops! 4x4 spills but luckly we use 4x2
|
||||
__m512 vbias[COLS];
|
||||
|
||||
// [NOTE]: s8s8 igemm compensation in avx512-vnni
|
||||
//
|
||||
// avx512-vnni has no s8s8, so we need to change s8s8 to u8s8 with compensate:
|
||||
//
|
||||
// a * b = (a + 128) * b - 128 * b
|
||||
// s s u s u s
|
||||
//
|
||||
// 1) 128 * b is pre-computed when packing B to vnni formats
|
||||
// 2) a + 128 is fused when dynamically quantize A
|
||||
//
|
||||
auto loadc = [&](auto i) {
|
||||
vc[i] = _mm512_set1_epi32(0);
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(loadc);
|
||||
|
||||
const int64_t K4 = K >> 2;
|
||||
const int64_t lda4 = lda >> 2;
|
||||
const int64_t ldb4 = ldb; // ldb * 4 >> 2;
|
||||
const int32_t* a_ptr = reinterpret_cast<const int32_t*>(A);
|
||||
const int32_t* b_ptr = reinterpret_cast<const int32_t*>(B);
|
||||
|
||||
auto compute = [&](auto i, int64_t k) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
if constexpr (col == 0) {
|
||||
va = _mm512_set1_epi32(a_ptr[row * lda4 + k]);
|
||||
}
|
||||
if constexpr (row == 0) {
|
||||
vb[col] = _mm512_loadu_si512(b_ptr + k * ldb4 + col * 16);
|
||||
if constexpr (PREFETCH_SIZE_K > 0) {
|
||||
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb4 + col * 16, _MM_HINT_T0);
|
||||
}
|
||||
}
|
||||
vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]);
|
||||
};
|
||||
for (int64_t k = 0; k < K4; ++k) {
|
||||
Unroll<ROWS * COLS>{}(compute, k);
|
||||
}
|
||||
|
||||
auto storec = [&](auto i) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
// load a scale
|
||||
if constexpr(col == 0) {
|
||||
vd0 = _mm512_set1_ps(As[row]);
|
||||
}
|
||||
// load b scale and vcomp per 2 vectors
|
||||
// also load bias if any
|
||||
if constexpr (row == 0) {
|
||||
if constexpr (col % 2 == 0) {
|
||||
vd1[col + 0] = _mm512_loadu_ps(Bs + col * 16);
|
||||
vd1[col + 1] = _mm512_loadu_ps(Bs + col * 16 + 16);
|
||||
vcomp[col + 0] = _mm512_loadu_si512(Bcomp + col * 16);
|
||||
vcomp[col + 1] = _mm512_loadu_si512(Bcomp + col * 16 + 16);
|
||||
if constexpr (has_bias) {
|
||||
vbias[col + 0] = _mm512_loadu_ps(bias + col * 16);
|
||||
vbias[col + 1] = _mm512_loadu_ps(bias + col * 16 + 16);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// for COLS = 2, 4 use 512bit store
|
||||
if constexpr (col % 2 == 0) {
|
||||
__m512 vc0 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[row * COLS + col + 0], vcomp[col + 0]));
|
||||
__m512 vc1 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[row * COLS + col + 1], vcomp[col + 1]));
|
||||
if constexpr (has_bias) {
|
||||
vc0 = _mm512_fmadd_ps(_mm512_mul_ps(vc0, vd0), vd1[col + 0], vbias[col + 0]);
|
||||
vc1 = _mm512_fmadd_ps(_mm512_mul_ps(vc1, vd0), vd1[col + 1], vbias[col + 1]);
|
||||
} else {
|
||||
vc0 = _mm512_mul_ps(_mm512_mul_ps(vc0, vd0), vd1[col + 0]);
|
||||
vc1 = _mm512_mul_ps(_mm512_mul_ps(vc1, vd0), vd1[col + 1]);
|
||||
}
|
||||
|
||||
_mm512_storeu_si512(
|
||||
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)),
|
||||
(__m512i)(_mm512_cvtne2ps_pbh(vc1, vc0)));
|
||||
}
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(storec);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
|
||||
tinygemm_kernel_nn<scalar_t, has_bias, MB_SIZE, NB_SIZE>::apply( \
|
||||
A + mb_start * lda, B + nb_start * 4, C + mb_start * ldc + nb_start, \
|
||||
As + mb_start, Bs + nb_start, Bcomp + nb_start, \
|
||||
has_bias ? bias + nb_start : nullptr, K, lda, ldb, ldc);
|
||||
|
||||
template <typename scalar_t, bool has_bias>
|
||||
void tinygemm_kernel(
|
||||
const uint8_t* __restrict__ A,
|
||||
const int8_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
int32_t* __restrict__ Ctmp,
|
||||
const float* __restrict__ As,
|
||||
const float* __restrict__ Bs,
|
||||
const float* __restrict__ bias,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg) {
|
||||
|
||||
// B compensation
|
||||
const int32_t* Bcomp = reinterpret_cast<const int32_t*>(B + block_size_n() * K);
|
||||
|
||||
// pattern: 1-4-16
|
||||
constexpr int64_t BLOCK_M = 4;
|
||||
constexpr int64_t BLOCK_N = 64;
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
for (int64_t mb = 0; mb < MB; ++mb) {
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
for (int64_t nb = 0; nb < NB; ++nb) {
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
|
||||
|
||||
switch(mb_size << 4 | nb_size >> 4) {
|
||||
// mb_size = 1
|
||||
case 0x12: LAUNCH_TINYGEMM_KERNEL_NN(1, 32); break;
|
||||
case 0x14: LAUNCH_TINYGEMM_KERNEL_NN(1, 64); break;
|
||||
// mb_size = 2
|
||||
case 0x22: LAUNCH_TINYGEMM_KERNEL_NN(2, 32); break;
|
||||
case 0x24: LAUNCH_TINYGEMM_KERNEL_NN(2, 64); break;
|
||||
// mb_size = 3
|
||||
case 0x32: LAUNCH_TINYGEMM_KERNEL_NN(3, 32); break;
|
||||
case 0x34: LAUNCH_TINYGEMM_KERNEL_NN(3, 64); break;
|
||||
// mb_size = 4
|
||||
case 0x42: LAUNCH_TINYGEMM_KERNEL_NN(4, 32); break;
|
||||
case 0x44: LAUNCH_TINYGEMM_KERNEL_NN(4, 64); break;
|
||||
default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
void int8_scaled_mm_kernel_impl(
|
||||
scalar_t* __restrict__ out,
|
||||
const uint8_t* __restrict__ mat1,
|
||||
const int8_t* __restrict__ mat2,
|
||||
const float* __restrict__ scales1,
|
||||
const float* __restrict__ scales2,
|
||||
const float* __restrict__ bias,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K) {
|
||||
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
|
||||
// TODO: brgemm u8s8 depends on PyTorch 2.7 release.
|
||||
const bool use_brgemm = false;
|
||||
|
||||
// K + 4 after compensation
|
||||
const int64_t packed_row_size = get_row_size<int8_t>(K);
|
||||
|
||||
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
int64_t mb{0}, nb{0};
|
||||
data_index_init(begin, mb, MB, nb, NB);
|
||||
|
||||
// for brgemm, use int32_t for accumulate
|
||||
alignas(64) int32_t Ctmp[BLOCK_M * BLOCK_N];
|
||||
|
||||
for (int i = begin; i < end; ++i) {
|
||||
UNUSED(i);
|
||||
int mb_start = mb * BLOCK_M;
|
||||
int mb_size = std::min(M - mb_start, BLOCK_M);
|
||||
int nb_start = nb * BLOCK_N;
|
||||
int nb_size = std::min(N - nb_start, BLOCK_N);
|
||||
|
||||
tinygemm_kernel<scalar_t, has_bias>(
|
||||
/* A */ mat1 + mb_start * K,
|
||||
/* B */ mat2 + nb_start * packed_row_size /* nb * BLOCK_N * (K + 4) */,
|
||||
/* C */ out + mb_start * N + nb_start,
|
||||
/* Ctmp*/ Ctmp,
|
||||
/* As */ scales1 + mb_start,
|
||||
/* Bs */ scales2 + nb_start,
|
||||
/* bias*/ bias + nb_start,
|
||||
/* M */ mb_size,
|
||||
/* N */ nb_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ nb_size,
|
||||
/* ldc */ N,
|
||||
/* brg */ use_brgemm);
|
||||
|
||||
// move to the next index
|
||||
data_index_step(mb, MB, nb, NB);
|
||||
}
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// tinygemm interface
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(const uint8_t* __restrict__ A, const int8_t* __restrict__ B, scalar_t* __restrict__ C,
|
||||
int32_t* __restrict__ Ctmp, const float* __restrict__ As, const float* __restrict__ Bs,
|
||||
int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool brg) {
|
||||
tinygemm_kernel<scalar_t, false>(A, B, C, Ctmp, As, Bs, nullptr, M, N, K, lda, ldb, ldc, brg);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
|
||||
template void tinygemm_kernel<TYPE>( \
|
||||
const uint8_t* __restrict__ A, const int8_t* __restrict__ B, TYPE* __restrict__ C, \
|
||||
int32_t* __restrict__ Ctmp, const float* __restrict__ As, const float* __restrict__ Bs, \
|
||||
int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool brg)
|
||||
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> per_token_quant_int8_cpu(at::Tensor& A) {
|
||||
RECORD_FUNCTION("sgl-kernel::per_token_quant_int8_cpu", std::vector<c10::IValue>({A}));
|
||||
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(A);
|
||||
CHECK_DIM(2, A);
|
||||
|
||||
int64_t M = A.size(0);
|
||||
int64_t K = A.size(1);
|
||||
int64_t lda = A.stride(0);
|
||||
|
||||
const auto st = A.scalar_type();
|
||||
TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf,
|
||||
"per_token_quant_int8: expect A to be bfloat16 or half.");
|
||||
|
||||
auto Aq = at::empty({M, K}, A.options().dtype(at::kByte));
|
||||
auto As = at::empty({M}, A.options().dtype(at::kFloat));
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "per_token_quant_int8", [&] {
|
||||
uint8_t* __restrict__ Aq_data = Aq.data_ptr<uint8_t>();
|
||||
float* __restrict__ As_data = As.data_ptr<float>();
|
||||
const scalar_t* __restrict__ A_data = A.data_ptr<scalar_t>();
|
||||
|
||||
at::parallel_for(0, M, 0, [&] (int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
quantize_row_int8<scalar_t>(
|
||||
Aq_data + m * K,
|
||||
As_data[m],
|
||||
A_data + m * lda,
|
||||
K);
|
||||
}
|
||||
});
|
||||
});
|
||||
return std::make_tuple(Aq, As);
|
||||
}
|
||||
|
||||
// weight : static, per-channel, symmetric
|
||||
// activation : dynamic, per-token, symmetric
|
||||
//
|
||||
// mat1 : [M, K]
|
||||
// mat2 : [N, K]
|
||||
// scales1 : [M]
|
||||
// scales2 : [N]
|
||||
// bias : [N]
|
||||
// out : [M, N]
|
||||
//
|
||||
at::Tensor int8_scaled_mm_cpu(at::Tensor& mat1, at::Tensor& mat2,
|
||||
at::Tensor& scales1, at::Tensor& scales2,
|
||||
std::optional<at::Tensor>& bias, at::ScalarType out_dtype, bool is_vnni) {
|
||||
RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales1, scales2, bias}));
|
||||
|
||||
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
|
||||
|
||||
CHECK_INPUT(mat1);
|
||||
CHECK_INPUT(mat2);
|
||||
CHECK_INPUT(scales1);
|
||||
CHECK_INPUT(scales2);
|
||||
CHECK_DIM(2, mat1);
|
||||
CHECK_DIM(2, mat2);
|
||||
|
||||
int64_t M = mat1.size(0);
|
||||
int64_t N = mat2.size(0);
|
||||
int64_t K = mat1.size(1);
|
||||
|
||||
// see [NOTE]: s8s8 igemm compensation in avx512-vnni
|
||||
CHECK_EQ(mat2.size(1), (int64_t)(is_vnni ? K + sizeof(int32_t) : K));
|
||||
CHECK_EQ(scales1.numel(), M);
|
||||
CHECK_EQ(scales2.numel(), N);
|
||||
|
||||
TORCH_CHECK(mat1.scalar_type() == at::kByte, "int8_scaled_mm: expect mat1 to be uint8.");
|
||||
TORCH_CHECK(mat2.scalar_type() == at::kChar, "int8_scaled_mm: expect mat2 to be int8.");
|
||||
TORCH_CHECK(scales1.scalar_type() == at::kFloat && scales2.scalar_type() == at::kFloat,
|
||||
"int8_scaled_mm: expect scales to be float32.");
|
||||
|
||||
auto out = at::empty({M, N}, mat1.options().dtype(out_dtype));
|
||||
|
||||
const bool has_bias = bias.has_value();
|
||||
const float* bias_data = nullptr;
|
||||
if (has_bias) {
|
||||
CHECK_EQ(bias.value().size(0), N);
|
||||
bias_data = bias.value().data_ptr<float>();
|
||||
}
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "int8_scaled_mm_kernel_impl", [&] {
|
||||
int8_scaled_mm_kernel_impl<scalar_t>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
mat1.data_ptr<uint8_t>(),
|
||||
packed_w.data_ptr<int8_t>(),
|
||||
scales1.data_ptr<float>(),
|
||||
scales2.data_ptr<float>(),
|
||||
bias_data,
|
||||
M,
|
||||
N,
|
||||
K);
|
||||
});
|
||||
return out;
|
||||
}
|
||||
|
||||
// fused `per_token_quant_int8_cpu` and `int8_scaled_mm_cpu`
|
||||
at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2, at::Tensor& scales2,
|
||||
const std::optional<at::Tensor>& bias, at::ScalarType out_dtype, bool is_vnni) {
|
||||
RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales2, bias}));
|
||||
|
||||
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
|
||||
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
|
||||
CHECK_INPUT(mat2);
|
||||
CHECK_INPUT(scales2);
|
||||
CHECK_DIM(2, mat1);
|
||||
CHECK_DIM(2, mat2);
|
||||
|
||||
int64_t M = mat1.size(0);
|
||||
int64_t N = mat2.size(0);
|
||||
int64_t K = mat1.size(1);
|
||||
int64_t lda = mat1.stride(0);
|
||||
|
||||
// see [NOTE]: s8s8 igemm compensation in avx512-vnni
|
||||
CHECK_EQ(mat2.size(1), (int64_t)(is_vnni ? K + sizeof(int32_t) : K));
|
||||
CHECK_EQ(scales2.numel(), N);
|
||||
|
||||
const auto st = mat1.scalar_type();
|
||||
TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf,
|
||||
"int8_scaled_mm_with_quant: expect A to be bfloat16 or half.");
|
||||
TORCH_CHECK(st == out_dtype,
|
||||
"int8_scaled_mm_with_quant: expect A has same dtype with out_dtype.");
|
||||
TORCH_CHECK(mat2.scalar_type() == at::kChar,
|
||||
"int8_scaled_mm_with_quant: expect mat2 to be int8.");
|
||||
TORCH_CHECK(scales2.scalar_type() == at::kFloat,
|
||||
"int8_scaled_mm_with_quant: expect scales to be float32.");
|
||||
|
||||
const int64_t buffer_size = M * K + M * sizeof(float);
|
||||
auto buffer = at::empty({buffer_size}, mat1.options().dtype(at::kByte));
|
||||
auto out = at::empty({M, N}, mat1.options().dtype(out_dtype));
|
||||
|
||||
const bool has_bias = bias.has_value();
|
||||
const float* bias_data = nullptr;
|
||||
if (has_bias) {
|
||||
CHECK_EQ(bias.value().size(0), N);
|
||||
bias_data = bias.value().data_ptr<float>();
|
||||
}
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "int8_scaled_mm_with_quant_kernel_impl", [&] {
|
||||
uint8_t* __restrict__ Aq_data = buffer.data_ptr<uint8_t>();
|
||||
float* __restrict__ As_data = (float*)((void*)(Aq_data + M * K));
|
||||
const scalar_t* __restrict__ A_data = mat1.data_ptr<scalar_t>();
|
||||
|
||||
at::parallel_for(0, M, 0, [&] (int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
quantize_row_int8<scalar_t>(
|
||||
Aq_data + m * K,
|
||||
As_data[m],
|
||||
A_data + m * lda,
|
||||
K);
|
||||
}
|
||||
});
|
||||
|
||||
int8_scaled_mm_kernel_impl<scalar_t>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
Aq_data,
|
||||
packed_w.data_ptr<int8_t>(),
|
||||
As_data,
|
||||
scales2.data_ptr<float>(),
|
||||
bias_data,
|
||||
M,
|
||||
N,
|
||||
K);
|
||||
});
|
||||
return out;
|
||||
}
|
||||
1330
csrc/cpu/sgl-kernels/moe.cpp
Normal file
1330
csrc/cpu/sgl-kernels/moe.cpp
Normal file
File diff suppressed because it is too large
Load Diff
502
csrc/cpu/sgl-kernels/moe_fp8.cpp
Normal file
502
csrc/cpu/sgl-kernels/moe_fp8.cpp
Normal file
@ -0,0 +1,502 @@
|
||||
// Adapted from
|
||||
// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu
|
||||
|
||||
#include "common.h"
|
||||
#include "gemm.h"
|
||||
#include "vec.h"
|
||||
|
||||
// clang-format off
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t size) {
|
||||
using Vec = at::vec::Vectorized<scalar_t>;
|
||||
// no remainder
|
||||
#pragma GCC unroll 4
|
||||
for (int64_t d = 0; d < size; d += Vec::size()) {
|
||||
Vec data = Vec::loadu(input + d);
|
||||
data.store(out + d);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_mul_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, float weight, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
const fVec weight_vec = fVec(weight);
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
bVec x = bVec::loadu(input + d);
|
||||
fVec x0, x1;
|
||||
std::tie(x0, x1) = at::vec::convert_to_float(x);
|
||||
x0 = x0 * weight_vec;
|
||||
x1 = x1 * weight_vec;
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d] * weight);
|
||||
}
|
||||
}
|
||||
|
||||
// acc from [topk, K] to [K]
|
||||
template <typename scalar_t>
|
||||
inline void sum_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t topk, int64_t K) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
if (topk == 1) {
|
||||
// do copy for topk = 1
|
||||
copy_stub(out, input, K);
|
||||
} else {
|
||||
// do sum for topk != 1
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= K - kVecSize; d += kVecSize) {
|
||||
fVec sum_fvec0 = fVec(0.f);
|
||||
fVec sum_fvec1 = fVec(0.f);
|
||||
for (int t = 0; t < topk; ++t) {
|
||||
bVec x_bvec = bVec::loadu(input + t * K + d);
|
||||
fVec x_fvec0, x_fvec1;
|
||||
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
||||
|
||||
sum_fvec0 += x_fvec0;
|
||||
sum_fvec1 += x_fvec1;
|
||||
}
|
||||
bVec out_bvec = convert_from_float_ext<scalar_t>(sum_fvec0, sum_fvec1);
|
||||
out_bvec.store(out + d);
|
||||
}
|
||||
for (; d < K; ++d) {
|
||||
float sum_val = 0.f;
|
||||
for (int t = 0; t < topk; ++t) {
|
||||
sum_val += static_cast<float>(input[t * K + d]);
|
||||
}
|
||||
out[d] = static_cast<scalar_t>(sum_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// out = input + input2 * scale
|
||||
template <typename scalar_t>
|
||||
inline void add_mul_stub(
|
||||
scalar_t* __restrict__ out,
|
||||
const scalar_t* __restrict__ input,
|
||||
const scalar_t* __restrict__ input2,
|
||||
float scale,
|
||||
int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
const fVec s_vec = fVec(scale);
|
||||
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
bVec x_bvec = bVec::loadu(input + d);
|
||||
fVec x0, x1;
|
||||
std::tie(x0, x1) = at::vec::convert_to_float(x_bvec);
|
||||
|
||||
bVec y_bvec = bVec::loadu(input2 + d);
|
||||
fVec y0, y1;
|
||||
std::tie(y0, y1) = at::vec::convert_to_float(y_bvec);
|
||||
|
||||
x0 = x0 + y0 * s_vec;
|
||||
x1 = x1 + y1 * s_vec;
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d] + float(input2[d]) * scale);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void silu_and_mul_stub(
|
||||
scalar_t* __restrict__ out,
|
||||
const scalar_t* __restrict__ input,
|
||||
const scalar_t* __restrict__ input2,
|
||||
int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
const fVec one = fVec(1.f);
|
||||
|
||||
// no remainder
|
||||
#pragma GCC unroll 4
|
||||
for (int64_t d = 0; d < size; d += bVec::size()) {
|
||||
bVec x = bVec::loadu(input + d);
|
||||
fVec x0, x1;
|
||||
std::tie(x0, x1) = at::vec::convert_to_float(x);
|
||||
bVec y = bVec::loadu(input2 + d);
|
||||
fVec y0, y1;
|
||||
std::tie(y0, y1) = at::vec::convert_to_float(y);
|
||||
x0 = x0 / (one + x0.neg().exp_u20());
|
||||
x1 = x1 / (one + x1.neg().exp_u20());
|
||||
x0 = x0 * y0;
|
||||
x1 = x1 * y1;
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
template <typename scalar_t>
|
||||
void fused_experts_fp8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic0,
|
||||
scalar_t* __restrict__ ic1,
|
||||
scalar_t* __restrict__ ic2,
|
||||
scalar_t* __restrict__ A_tmp,
|
||||
scalar_t* __restrict__ B_tmp,
|
||||
float* __restrict__ C_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w1,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
int64_t block_size_N,
|
||||
int64_t block_size_K,
|
||||
const float* __restrict__ topk_weights,
|
||||
const int32_t* __restrict__ sorted_ids,
|
||||
const int32_t* __restrict__ expert_ids,
|
||||
const int32_t* __restrict__ offsets,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t E,
|
||||
int64_t topk,
|
||||
int64_t num_tokens_post_pad) {
|
||||
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
|
||||
// stage 1: intermediate_cache0 = hidden_states @ w1
|
||||
const int64_t MB = div_up(num_tokens_post_pad, BLOCK_M);
|
||||
const int64_t NB = div_up(2 * N, BLOCK_N);
|
||||
int64_t scale_size_N = div_up(2 * N, block_size_N);
|
||||
int64_t scale_size_K = div_up(K, block_size_K);
|
||||
int64_t blocks_n_per_group = block_size_N / BLOCK_N;
|
||||
|
||||
const int64_t stride_e = 2 * N * K;
|
||||
const int64_t stride_n = K;
|
||||
|
||||
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
// get local pointers
|
||||
int tid = at::get_thread_num();
|
||||
scalar_t* __restrict__ A = A_tmp + tid * BLOCK_M * K;
|
||||
|
||||
bool is_brgemm_used = false;
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB;
|
||||
int64_t nb = i % NB;
|
||||
|
||||
int64_t n_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
// B shape [K, n_size] in vnni format
|
||||
int32_t expert_id = expert_ids[mb];
|
||||
const at::Float8_e4m3fn* __restrict__ B = packed_w1 + expert_id * stride_e + nb * BLOCK_N * stride_n;
|
||||
const float* __restrict__ Bs = w1s + expert_id * scale_size_N * scale_size_K + (nb / blocks_n_per_group) * scale_size_K;
|
||||
|
||||
// 1.a load A
|
||||
const int32_t* A_ids = sorted_ids + mb * BLOCK_M;
|
||||
int64_t m_size = offsets[mb + 1] - offsets[mb];
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(m_size);
|
||||
is_brgemm_used = is_brgemm_used || use_brgemm;
|
||||
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
int32_t index = A_ids[m] / topk;
|
||||
copy_stub(A + m * K, input + index * K, K);
|
||||
}
|
||||
|
||||
const int64_t offset = offsets[mb];
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ A,
|
||||
/* B */ B,
|
||||
/* C */ ic0 + offset * 2 * N + nb * BLOCK_N,
|
||||
/* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N),
|
||||
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
|
||||
/* scale */ Bs,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ 2 * N,
|
||||
/* brg */ use_brgemm,
|
||||
/* block_size_K */ block_size_K);
|
||||
}
|
||||
|
||||
if (is_brgemm_used) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
|
||||
// stage 1.5: intermediate_cache1 = silu(intermediate_cache0)
|
||||
at::parallel_for(0, M * topk, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
silu_and_mul_stub(
|
||||
ic1 + m * N,
|
||||
ic0 + m * 2 * N,
|
||||
ic0 + m * 2 * N + N,
|
||||
N);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 2: intermediate_cache2 = intermediate_cache1 @ w2
|
||||
// w2 : [E, K, N] as [E, OC, IC]
|
||||
const int64_t OC = K; // rename K as OC
|
||||
const int64_t IC = N; // rename N as IC
|
||||
const int64_t MB2 = MB;
|
||||
const int64_t NB2 = div_up(OC, BLOCK_N);
|
||||
scale_size_N = div_up(K, block_size_N);
|
||||
scale_size_K = div_up(N, block_size_K);
|
||||
const int64_t stride_e2 = OC * IC;
|
||||
const int64_t stride_oc = IC;
|
||||
|
||||
// parallel on [MB2, NB2]
|
||||
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
|
||||
int tid = at::get_thread_num();
|
||||
alignas(64) scalar_t C[BLOCK_M * BLOCK_K];
|
||||
|
||||
bool is_brgemm_used = false;
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB2;
|
||||
int64_t nb = i % NB2;
|
||||
|
||||
int64_t m_size = offsets[mb + 1] - offsets[mb];
|
||||
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(m_size);
|
||||
is_brgemm_used = is_brgemm_used || use_brgemm;
|
||||
|
||||
// A ptr from ic1 of [M * topk, N] in sorted order
|
||||
// so as to avoid copy A to tmp buffer again
|
||||
const scalar_t* __restrict__ A = ic1 + offsets[mb] * N;
|
||||
const int32_t* A_ids = sorted_ids + mb * BLOCK_M;
|
||||
|
||||
// B shape [IC, n_size] in vnni format
|
||||
int32_t expert_id = expert_ids[mb];
|
||||
const at::Float8_e4m3fn* __restrict__ B = packed_w2 + expert_id * stride_e2 + nb * BLOCK_N * stride_oc;
|
||||
const float* __restrict__ Bs = w2s + expert_id * scale_size_N * scale_size_K + (nb / blocks_n_per_group) * scale_size_K;
|
||||
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ A,
|
||||
/* B */ B,
|
||||
/* C */ C,
|
||||
/* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N),
|
||||
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
|
||||
/* scale */ Bs,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ IC,
|
||||
/* lda */ IC,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N,
|
||||
/* brg */ use_brgemm,
|
||||
/* block_size_K */ block_size_K);
|
||||
|
||||
// 2.b copy from C to ic2 in original order
|
||||
// and also mul topk_weights in float32
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
int32_t index = A_ids[m];
|
||||
float weight = topk_weights[index];
|
||||
copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size);
|
||||
}
|
||||
}
|
||||
|
||||
if (is_brgemm_used) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
|
||||
// stage 3: out = intermediate_cache2.sum(dim=1)
|
||||
// from [M, topk, K] to [M, K]
|
||||
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
sum_stub(output + m * K, ic2 + m * topk * K, topk, K);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#define INSTANTIATE_MOE_FP8_TEMPLATE(TYPE) \
|
||||
template void fused_experts_fp8_kernel_impl<TYPE>( \
|
||||
TYPE* __restrict__ output, \
|
||||
TYPE* __restrict__ ic0, \
|
||||
TYPE* __restrict__ ic1, \
|
||||
TYPE* __restrict__ ic2, \
|
||||
TYPE* __restrict__ A_tmp, \
|
||||
TYPE* __restrict__ B_tmp, \
|
||||
float* __restrict__ C_tmp, \
|
||||
const TYPE* __restrict__ input, \
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w1, \
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w2, \
|
||||
const float* __restrict__ w1s, \
|
||||
const float* __restrict__ w2s, \
|
||||
int64_t block_size_N, \
|
||||
int64_t block_size_K, \
|
||||
const float* __restrict__ topk_weights, \
|
||||
const int32_t* __restrict__ sorted_ids, \
|
||||
const int32_t* __restrict__ expert_ids, \
|
||||
const int32_t* __restrict__ offsets, \
|
||||
int64_t M, \
|
||||
int64_t N, \
|
||||
int64_t K, \
|
||||
int64_t E, \
|
||||
int64_t topk, \
|
||||
int64_t num_tokens_post_pad)
|
||||
|
||||
INSTANTIATE_MOE_FP8_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_MOE_FP8_TEMPLATE(at::Half);
|
||||
|
||||
template <typename scalar_t>
|
||||
void shared_expert_fp8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic0,
|
||||
scalar_t* __restrict__ ic1,
|
||||
scalar_t* __restrict__ B_tmp,
|
||||
float* __restrict__ C_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w1,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
int64_t block_size_N,
|
||||
int64_t block_size_K,
|
||||
const scalar_t* __restrict__ fused_experts_out,
|
||||
float routed_scaling_factor,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K) {
|
||||
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
|
||||
// stage 1: intermediate_cache0 = hidden_states @ w1
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(2 * N, BLOCK_N);
|
||||
int64_t scale_size_K = div_up(K, block_size_K);
|
||||
int64_t blocks_n_per_group = block_size_N / BLOCK_N;
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(M);
|
||||
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
int tid = at::get_thread_num();
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB;
|
||||
int64_t nb = i % NB;
|
||||
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
|
||||
int64_t n_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ input + mb * BLOCK_M * K,
|
||||
/* B */ packed_w1 + nb * BLOCK_N * K,
|
||||
/* C */ ic0 + mb * BLOCK_M * 2 * N + nb * BLOCK_N,
|
||||
/* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N),
|
||||
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
|
||||
/* scale */ w1s + (nb / blocks_n_per_group) * scale_size_K,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ 2 * N,
|
||||
/* brg */ use_brgemm,
|
||||
/* block_size_K */ block_size_K);
|
||||
}
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
|
||||
// stage 1.5: intermediate_cache1 = silu(intermediate_cache0)
|
||||
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
silu_and_mul_stub(
|
||||
ic1 + m * N,
|
||||
ic0 + m * 2 * N,
|
||||
ic0 + m * 2 * N + N,
|
||||
N);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 2: intermediate_cache2 = intermediate_cache1 @ w2
|
||||
// w2 : [K, N] as [OC, IC]
|
||||
const int64_t OC = K; // rename K as OC
|
||||
const int64_t IC = N; // rename N as IC
|
||||
const int64_t MB2 = MB;
|
||||
const int64_t NB2 = div_up(K, BLOCK_N);
|
||||
scale_size_K = div_up(N, block_size_K);
|
||||
|
||||
// parallel on [MB2, NB2]
|
||||
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
|
||||
int tid = at::get_thread_num();
|
||||
alignas(64) scalar_t C[BLOCK_M * BLOCK_K];
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB2;
|
||||
int64_t nb = i % NB2;
|
||||
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
|
||||
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
// 2.a gemm: C = A @ B
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ ic1 + mb * BLOCK_M * N,
|
||||
/* B */ packed_w2 + nb * BLOCK_N * N,
|
||||
/* C */ C,
|
||||
/* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N),
|
||||
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
|
||||
/* scale */ w2s + (nb / blocks_n_per_group) * scale_size_K,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ IC,
|
||||
/* lda */ IC,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N,
|
||||
/* brg */ use_brgemm,
|
||||
/* block_size_K */ block_size_K);
|
||||
|
||||
// 2.b copy from C to output and add fused_experts_out
|
||||
scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N;
|
||||
const scalar_t* __restrict__ fused_out = fused_experts_out + mb * BLOCK_M * K + nb * BLOCK_N;
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
}
|
||||
|
||||
#define INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(TYPE) \
|
||||
template void shared_expert_fp8_kernel_impl<TYPE>( \
|
||||
TYPE* __restrict__ output, \
|
||||
TYPE* __restrict__ ic0, \
|
||||
TYPE* __restrict__ ic1, \
|
||||
TYPE* __restrict__ B_tmp, \
|
||||
float* __restrict__ C_tmp, \
|
||||
const TYPE* __restrict__ input, \
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w1, \
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w2, \
|
||||
const float* __restrict__ w1s, \
|
||||
const float* __restrict__ w2s, \
|
||||
int64_t block_size_N, \
|
||||
int64_t block_size_K, \
|
||||
const TYPE* __restrict__ fused_experts_out, \
|
||||
float routed_scaling_factor, \
|
||||
int64_t M, \
|
||||
int64_t N, \
|
||||
int64_t K)
|
||||
|
||||
INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(at::Half);
|
||||
769
csrc/cpu/sgl-kernels/moe_int8.cpp
Normal file
769
csrc/cpu/sgl-kernels/moe_int8.cpp
Normal file
@ -0,0 +1,769 @@
|
||||
// Adapted from
|
||||
// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu
|
||||
|
||||
#include "common.h"
|
||||
#include "vec.h"
|
||||
#include "gemm.h"
|
||||
|
||||
// clang-format off
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t size) {
|
||||
using Vec = at::vec::Vectorized<scalar_t>;
|
||||
// no remainder
|
||||
#pragma GCC unroll 4
|
||||
for (int64_t d = 0; d < size; d += Vec::size()) {
|
||||
Vec data = Vec::loadu(input + d);
|
||||
data.store(out + d);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void copy_stub<uint8_t>(uint8_t* __restrict__ out, const uint8_t* __restrict__ input, int64_t size) {
|
||||
// size might be 64x + 32
|
||||
std::memcpy(out, input, size * sizeof(uint8_t));
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_mul_stub(scalar_t* __restrict__ out, const float* __restrict__ input, float weight, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
const fVec weight_vec = fVec(weight);
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
fVec data0 = fVec::loadu(input + d) * weight_vec;
|
||||
fVec data1 = fVec::loadu(input + d + fVec::size()) * weight_vec;
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d] * weight);
|
||||
}
|
||||
}
|
||||
|
||||
// acc from [topk, K] to [K]
|
||||
template <typename scalar_t>
|
||||
inline void sum_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t topk, int64_t K) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
if (topk == 1) {
|
||||
// do copy for topk = 1
|
||||
copy_stub(out, input, K);
|
||||
} else {
|
||||
// do sum for topk != 1
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= K - kVecSize; d += kVecSize) {
|
||||
fVec sum_fvec0 = fVec(0.f);
|
||||
fVec sum_fvec1 = fVec(0.f);
|
||||
for (int t = 0; t < topk; ++t) {
|
||||
bVec x_bvec = bVec::loadu(input + t * K + d);
|
||||
fVec x_fvec0, x_fvec1;
|
||||
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
||||
|
||||
sum_fvec0 += x_fvec0;
|
||||
sum_fvec1 += x_fvec1;
|
||||
}
|
||||
bVec out_bvec = convert_from_float_ext<scalar_t>(sum_fvec0, sum_fvec1);
|
||||
out_bvec.store(out + d);
|
||||
}
|
||||
for (; d < K; ++d) {
|
||||
float sum_val = 0.f;
|
||||
for (int t = 0; t < topk; ++t) {
|
||||
sum_val += static_cast<float>(input[t * K + d]);
|
||||
}
|
||||
out[d] = static_cast<scalar_t>(sum_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// out = input + input2 * scale
|
||||
template <typename scalar_t>
|
||||
inline void add_mul_stub(scalar_t* __restrict__ out, const float* __restrict__ input,
|
||||
const scalar_t* __restrict__ input2, float scale, int64_t size) {
|
||||
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
const fVec s_vec = fVec(scale);
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
fVec x0 = fVec::loadu(input + d);
|
||||
fVec x1 = fVec::loadu(input + d + fVec::size());
|
||||
|
||||
bVec y_bvec = bVec::loadu(input2 + d);
|
||||
fVec y0, y1;
|
||||
std::tie(y0, y1) = at::vec::convert_to_float(y_bvec);
|
||||
|
||||
x0 = x0 + y0 * s_vec;
|
||||
x1 = x1 + y1 * s_vec;
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d] + float(input2[d]) * scale);
|
||||
}
|
||||
}
|
||||
|
||||
/// gemm for w13
|
||||
template <typename scalar_t, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_vnni {
|
||||
static inline void apply(
|
||||
const uint8_t* __restrict__ A, const int8_t* __restrict__ B0, const int8_t* __restrict__ B1, scalar_t* __restrict__ C,
|
||||
const float* __restrict__ As, const float* __restrict__ Bs0, const float* __restrict__ Bs1,
|
||||
const int32_t* __restrict__ Bcomp0, const int32_t* __restrict__ Bcomp1,
|
||||
int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
|
||||
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
template <int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_vnni<at::BFloat16, BLOCK_M, BLOCK_N> {
|
||||
static inline void apply(
|
||||
const uint8_t* __restrict__ A, const int8_t* __restrict__ B0, const int8_t* __restrict__ B1, at::BFloat16* __restrict__ C,
|
||||
const float* __restrict__ As, const float* __restrict__ Bs0, const float* __restrict__ Bs1,
|
||||
const int32_t* __restrict__ Bcomp0, const int32_t* __restrict__ Bcomp1,
|
||||
int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
|
||||
|
||||
constexpr int ROWS = BLOCK_M;
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
static_assert(COLS % 2 == 0);
|
||||
|
||||
__m512i va;
|
||||
__m512i vb0[COLS];
|
||||
__m512i vb1[COLS];
|
||||
__m512i vc0[ROWS * COLS];
|
||||
__m512i vc1[ROWS * COLS];
|
||||
__m512i vcomp0[COLS];
|
||||
__m512i vcomp1[COLS];
|
||||
__m512 was;
|
||||
__m512 vbs0[COLS];
|
||||
__m512 vbs1[COLS];
|
||||
|
||||
auto loadc = [&](auto i) {
|
||||
vc0[i] = _mm512_set1_epi32(0);
|
||||
vc1[i] = _mm512_set1_epi32(0);
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(loadc);
|
||||
|
||||
const int64_t K4 = K >> 2;
|
||||
const int64_t lda4 = lda >> 2;
|
||||
const int64_t ldb4 = ldb; // ldb * 4 >> 2;
|
||||
const int32_t* a_ptr = reinterpret_cast<const int32_t*>(A);
|
||||
const int32_t* b0_ptr = reinterpret_cast<const int32_t*>(B0);
|
||||
const int32_t* b1_ptr = reinterpret_cast<const int32_t*>(B1);
|
||||
|
||||
auto compute = [&](auto i, int64_t k) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
if constexpr (col == 0) {
|
||||
va = _mm512_set1_epi32(a_ptr[row * lda4 + k]);
|
||||
}
|
||||
if constexpr (row == 0) {
|
||||
vb0[col] = _mm512_loadu_si512(b0_ptr + k * ldb4 + col * 16);
|
||||
vb1[col] = _mm512_loadu_si512(b1_ptr + k * ldb4 + col * 16);
|
||||
}
|
||||
vc0[i] = _mm512_dpbusd_epi32(vc0[i], va, vb0[col]);
|
||||
vc1[i] = _mm512_dpbusd_epi32(vc1[i], va, vb1[col]);
|
||||
};
|
||||
for (int64_t k = 0; k < K4; ++k) {
|
||||
Unroll<ROWS * COLS>{}(compute, k);
|
||||
}
|
||||
|
||||
auto scalec = [&](auto i) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
// load a scale
|
||||
if constexpr(col == 0) {
|
||||
was = _mm512_set1_ps(As[row]);
|
||||
}
|
||||
// load b scale and vcomp
|
||||
if constexpr (row == 0) {
|
||||
vbs0[col] = _mm512_loadu_ps(Bs0 + col * 16);
|
||||
vbs1[col] = _mm512_loadu_ps(Bs1 + col * 16);
|
||||
vcomp0[col] = _mm512_loadu_si512(Bcomp0 + col * 16);
|
||||
vcomp1[col] = _mm512_loadu_si512(Bcomp1 + col * 16);
|
||||
}
|
||||
__m512 c0 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc0[i], vcomp0[col]));
|
||||
__m512 c1 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc1[i], vcomp1[col]));
|
||||
vc0[i] = _mm512_castps_si512(_mm512_mul_ps(_mm512_mul_ps(c0, was), vbs0[col]));
|
||||
vc1[i] = _mm512_castps_si512(_mm512_mul_ps(_mm512_mul_ps(c1, was), vbs1[col]));
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(scalec);
|
||||
|
||||
using Vec = at::vec::Vectorized<float>;
|
||||
const Vec one = Vec(1.f);
|
||||
auto storec = [&](auto i) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
// for COLS = 2, 4 use 512bit store
|
||||
if constexpr (col % 2 == 0) {
|
||||
Vec x0 = _mm512_castsi512_ps(vc0[row * COLS + col + 0]);
|
||||
Vec x1 = _mm512_castsi512_ps(vc0[row * COLS + col + 1]);
|
||||
Vec y0 = _mm512_castsi512_ps(vc1[row * COLS + col + 0]);
|
||||
Vec y1 = _mm512_castsi512_ps(vc1[row * COLS + col + 1]);
|
||||
// silu
|
||||
x0 = x0 / (one + x0.neg().exp_u20());
|
||||
x1 = x1 / (one + x1.neg().exp_u20());
|
||||
// mul
|
||||
x0 = x0 * y0;
|
||||
x1 = x1 * y1;
|
||||
|
||||
_mm512_storeu_si512(
|
||||
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)),
|
||||
(__m512i)(_mm512_cvtne2ps_pbh(__m512(x1), __m512(x0))));
|
||||
}
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(storec);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#define LAUNCH_TINYGEMM_KERNEL_VNNI(MB_SIZE, NB_SIZE) \
|
||||
tinygemm_kernel_vnni<scalar_t, MB_SIZE, NB_SIZE>::apply( \
|
||||
A + mb_start * lda, B0 + nb_start * 4, B1 + nb_start * 4, \
|
||||
C + mb_start * ldc + nb_start, As + mb_start, \
|
||||
Bs0 + nb_start, Bs1 + nb_start, Bcomp0 + nb_start, Bcomp1 + nb_start,\
|
||||
K, lda, ldb, ldc);
|
||||
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const uint8_t* __restrict__ A,
|
||||
const int8_t* __restrict__ B0,
|
||||
const int8_t* __restrict__ B1,
|
||||
scalar_t* __restrict__ C,
|
||||
const float* __restrict__ As,
|
||||
const float* __restrict__ Bs0,
|
||||
const float* __restrict__ Bs1,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc) {
|
||||
|
||||
const int32_t* Bcomp0 = reinterpret_cast<const int32_t*>(B0 + block_size_n() * K);
|
||||
const int32_t* Bcomp1 = reinterpret_cast<const int32_t*>(B1 + block_size_n() * K);
|
||||
|
||||
// pattern: 1-(2+2)-(8+8)
|
||||
constexpr int64_t BLOCK_M = 4;
|
||||
constexpr int64_t BLOCK_N = 32;
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
for (int mb = 0; mb < MB; ++mb) {
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
for (int64_t nb = 0; nb < NB; ++nb) {
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
|
||||
|
||||
switch(mb_size << 4 | nb_size >> 4) {
|
||||
case 0x12: LAUNCH_TINYGEMM_KERNEL_VNNI(1, 32); break;
|
||||
case 0x22: LAUNCH_TINYGEMM_KERNEL_VNNI(2, 32); break;
|
||||
case 0x32: LAUNCH_TINYGEMM_KERNEL_VNNI(3, 32); break;
|
||||
case 0x42: LAUNCH_TINYGEMM_KERNEL_VNNI(4, 32); break;
|
||||
default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// gemm for w2
|
||||
template <typename scalar_t, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_vnni2 {
|
||||
static inline void apply(
|
||||
const uint8_t* __restrict__ A, const int8_t* __restrict__ B, float* __restrict__ C,
|
||||
const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp,
|
||||
int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
|
||||
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
template <int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_vnni2<at::BFloat16, BLOCK_M, BLOCK_N> {
|
||||
static inline void apply(
|
||||
const uint8_t* __restrict__ A, const int8_t* __restrict__ B, float* __restrict__ C,
|
||||
const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp,
|
||||
int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
|
||||
|
||||
constexpr int ROWS = BLOCK_M;
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
static_assert(COLS % 2 == 0);
|
||||
|
||||
__m512i va;
|
||||
__m512i vb[COLS];
|
||||
__m512i vc[ROWS * COLS];
|
||||
__m512i vcomp[COLS];
|
||||
__m512 was;
|
||||
__m512 vbs[COLS];
|
||||
|
||||
auto loadc = [&](auto i) {
|
||||
vc[i] = _mm512_set1_epi32(0);
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(loadc);
|
||||
|
||||
const int64_t K4 = K >> 2;
|
||||
const int64_t lda4 = lda >> 2;
|
||||
const int64_t ldb4 = ldb; // ldb * 4 >> 2;
|
||||
const int32_t* a_ptr = reinterpret_cast<const int32_t*>(A);
|
||||
const int32_t* b_ptr = reinterpret_cast<const int32_t*>(B);
|
||||
|
||||
auto compute = [&](auto i, int64_t k) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
if constexpr (col == 0) {
|
||||
va = _mm512_set1_epi32(a_ptr[row * lda4 + k]);
|
||||
}
|
||||
if constexpr (row == 0) {
|
||||
vb[col] = _mm512_loadu_si512(b_ptr + k * ldb4 + col * 16);
|
||||
}
|
||||
vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]);
|
||||
};
|
||||
for (int64_t k = 0; k < K4; ++k) {
|
||||
Unroll<ROWS * COLS>{}(compute, k);
|
||||
}
|
||||
|
||||
auto storec = [&](auto i) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
// load a scale
|
||||
if constexpr(col == 0) {
|
||||
was = _mm512_set1_ps(As[row]);
|
||||
}
|
||||
// load b scale and vcomp per 2 vectors
|
||||
// also load bias if any
|
||||
if constexpr (row == 0) {
|
||||
if constexpr (col % 2 == 0) {
|
||||
vbs[col + 0] = _mm512_loadu_ps(Bs + col * 16);
|
||||
vbs[col + 1] = _mm512_loadu_ps(Bs + col * 16 + 16);
|
||||
vcomp[col + 0] = _mm512_loadu_si512(Bcomp + col * 16);
|
||||
vcomp[col + 1] = _mm512_loadu_si512(Bcomp + col * 16 + 16);
|
||||
}
|
||||
}
|
||||
__m512 x = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[i], vcomp[col]));
|
||||
x = _mm512_mul_ps(_mm512_mul_ps(x, was), vbs[col]);
|
||||
_mm512_storeu_ps(reinterpret_cast<__m512*>(C + row * ldc + col * 16), x);
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(storec);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#define LAUNCH_TINYGEMM_KERNEL_VNNI2(MB_SIZE, NB_SIZE) \
|
||||
tinygemm_kernel_vnni2<scalar_t, MB_SIZE, NB_SIZE>::apply( \
|
||||
A + mb_start * lda, B + nb_start * 4, C + mb_start * ldc + nb_start, \
|
||||
As + mb_start, Bs + nb_start, Bcomp + nb_start, \
|
||||
K, lda, ldb, ldc);
|
||||
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const uint8_t* __restrict__ A,
|
||||
const int8_t* __restrict__ B,
|
||||
float* __restrict__ C,
|
||||
const float* __restrict__ As,
|
||||
const float* __restrict__ Bs,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc) {
|
||||
|
||||
// B compensation
|
||||
const int32_t* Bcomp = reinterpret_cast<const int32_t*>(B + block_size_n() * K);
|
||||
|
||||
// pattern: 1-4-16
|
||||
constexpr int64_t BLOCK_M = 4;
|
||||
constexpr int64_t BLOCK_N = 64;
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
for (int64_t mb = 0; mb < MB; ++mb) {
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
for (int64_t nb = 0; nb < NB; ++nb) {
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
|
||||
|
||||
switch(mb_size << 4 | nb_size >> 4) {
|
||||
case 0x12: LAUNCH_TINYGEMM_KERNEL_VNNI2(1, 32); break;
|
||||
case 0x22: LAUNCH_TINYGEMM_KERNEL_VNNI2(2, 32); break;
|
||||
case 0x32: LAUNCH_TINYGEMM_KERNEL_VNNI2(3, 32); break;
|
||||
case 0x42: LAUNCH_TINYGEMM_KERNEL_VNNI2(4, 32); break;
|
||||
default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
template <typename scalar_t>
|
||||
void fused_experts_int8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic1,
|
||||
scalar_t* __restrict__ ic2,
|
||||
uint8_t* __restrict__ A_tmp,
|
||||
float* __restrict__ C_tmp,
|
||||
uint8_t* __restrict__ Aq_tmp,
|
||||
float* __restrict__ As_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const int8_t* __restrict__ packed_w1,
|
||||
const int8_t* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
const float* __restrict__ topk_weights,
|
||||
const int32_t* __restrict__ sorted_ids,
|
||||
const int32_t* __restrict__ expert_ids,
|
||||
const int32_t* __restrict__ offsets,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t E,
|
||||
int64_t topk,
|
||||
int64_t num_tokens_post_pad) {
|
||||
|
||||
// handle 2 tiles per block
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
|
||||
// stage 0: quantize input to uint8, [M, K]
|
||||
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
quantize_row_int8<scalar_t>(
|
||||
Aq_tmp + m * K,
|
||||
As_tmp[m],
|
||||
input + m * K,
|
||||
K);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 1: intermediate_cache1 = silu(hidden_states @ w1)
|
||||
const int64_t MB = div_up(num_tokens_post_pad, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
|
||||
// strides for w1: [E, 2N, K]
|
||||
TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N);
|
||||
|
||||
// K and N are packed for int8
|
||||
const int64_t packed_K = get_row_size<int8_t>(K);
|
||||
const int64_t packed_N = get_row_size<int8_t>(N);
|
||||
|
||||
const int64_t stride_e = 2 * N * packed_K;
|
||||
const int64_t stride_n = packed_K;
|
||||
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
// get local pointers
|
||||
int tid = at::get_thread_num();
|
||||
uint8_t* __restrict__ A = A_tmp + tid * BLOCK_M * K;
|
||||
|
||||
alignas(64) float As[BLOCK_M];
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB;
|
||||
int64_t nb = i % NB;
|
||||
|
||||
// nb0 from top half and nb1 from bottom half
|
||||
int64_t nb0 = nb, nb1 = nb + NB;
|
||||
int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N);
|
||||
|
||||
// B shape [K, n_size] in vnni format
|
||||
int32_t expert_id = expert_ids[mb];
|
||||
const int8_t* __restrict__ B0 = packed_w1 + expert_id * stride_e + nb0 * BLOCK_N * stride_n;
|
||||
const int8_t* __restrict__ B1 = packed_w1 + expert_id * stride_e + nb1 * BLOCK_N * stride_n;
|
||||
const float* __restrict__ Bs0 = w1s + expert_id * 2 * N + nb0 * BLOCK_N;
|
||||
const float* __restrict__ Bs1 = w1s + expert_id * 2 * N + nb1 * BLOCK_N;
|
||||
|
||||
// 1.a load A
|
||||
const int32_t* A_ids = sorted_ids + mb * BLOCK_M;
|
||||
int64_t m_size = offsets[mb + 1] - offsets[mb];
|
||||
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
int32_t index = A_ids[m] / topk;
|
||||
copy_stub(A + m * K, Aq_tmp + index * K, K);
|
||||
As[m] = As_tmp[index];
|
||||
}
|
||||
|
||||
// fused 1.b: silu_and_mul(A @ B0, A @ B1)
|
||||
const int64_t offset = offsets[mb];
|
||||
tinygemm_kernel(
|
||||
/* A */ A,
|
||||
/* B0 */ B0,
|
||||
/* B1 */ B1,
|
||||
/* C */ ic1 + offset * N + nb * BLOCK_N,
|
||||
/* As */ As,
|
||||
/* Bs0 */ Bs0,
|
||||
/* Bs1 */ Bs1,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ N);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 1.5: quantize ic1 to uint8, [M * topk, N]
|
||||
at::parallel_for(0, M * topk, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
quantize_row_int8<scalar_t>(
|
||||
Aq_tmp + m * N,
|
||||
As_tmp[m],
|
||||
ic1 + m * N,
|
||||
N);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 2: intermediate_cache2 = intermediate_cache1 @ w2
|
||||
// w2 : [E, K, N] as [E, OC, IC]
|
||||
const int64_t OC = K; // rename K as OC
|
||||
const int64_t IC = N; // rename N as IC
|
||||
const int64_t MB2 = MB;
|
||||
const int64_t NB2 = div_up(OC, BLOCK_N);
|
||||
const int64_t stride_e2 = OC * packed_N;
|
||||
const int64_t stride_oc = packed_N;
|
||||
|
||||
// parallel on [MB2, NB2]
|
||||
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
|
||||
// get local pointers
|
||||
int tid = at::get_thread_num();
|
||||
// we won't be using C1 for gemm2
|
||||
float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N;
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB2;
|
||||
int64_t nb = i % NB2;
|
||||
|
||||
int64_t m_size = offsets[mb + 1] - offsets[mb];
|
||||
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
// A ptr from ic1 of [M * topk, N] in sorted order
|
||||
// so as to avoid copy A to tmp buffer again
|
||||
const uint8_t* __restrict__ A = Aq_tmp + offsets[mb] * N;
|
||||
const float* __restrict__ As = As_tmp + offsets[mb];
|
||||
const int32_t* A_ids = sorted_ids + mb * BLOCK_M;
|
||||
|
||||
// B shape [IC, n_size] in vnni format
|
||||
int32_t expert_id = expert_ids[mb];
|
||||
const int8_t* __restrict__ B = packed_w2 + expert_id * stride_e2 + nb * BLOCK_N * stride_oc;
|
||||
const float* __restrict__ Bs = w2s + expert_id * K + nb * BLOCK_N;
|
||||
|
||||
// 2.a gemm: C = A @ B
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ A,
|
||||
/* B */ B,
|
||||
/* C */ C,
|
||||
/* As */ As,
|
||||
/* Bs */ Bs,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ IC,
|
||||
/* lda */ IC,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N);
|
||||
|
||||
// 2.b copy from C to ic2 in original order
|
||||
// and also mul topk_weights in float32
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
int32_t index = A_ids[m];
|
||||
float weight = topk_weights[index];
|
||||
copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// stage 3: out = intermediate_cache2.sum(dim=1)
|
||||
// from [M, topk, K] to [M, K]
|
||||
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
sum_stub(output + m * K, ic2 + m * topk * K, topk, K);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#define INSTANTIATE_MOE_INT8_TEMPLATE(TYPE) \
|
||||
template void fused_experts_int8_kernel_impl<TYPE> ( \
|
||||
TYPE* __restrict__ output, TYPE* __restrict__ ic1, \
|
||||
TYPE* __restrict__ ic2, uint8_t* __restrict__ A_tmp, \
|
||||
float* __restrict__ C_tmp, uint8_t* __restrict__ Aq_tmp, \
|
||||
float* __restrict__ As_tmp, const TYPE* __restrict__ input, \
|
||||
const int8_t* __restrict__ packed_w1, const int8_t* __restrict__ packed_w2, \
|
||||
const float* __restrict__ w1s, const float* __restrict__ w2s, \
|
||||
const float* __restrict__ topk_weights, const int32_t* __restrict__ sorted_ids, \
|
||||
const int32_t* __restrict__ expert_ids, const int32_t* __restrict__ offsets, \
|
||||
int64_t M, int64_t N, int64_t K, int64_t E, int64_t topk, int64_t num_tokens_post_pad)
|
||||
|
||||
INSTANTIATE_MOE_INT8_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_MOE_INT8_TEMPLATE(at::Half);
|
||||
|
||||
template <typename scalar_t>
|
||||
void shared_expert_int8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic1,
|
||||
float* __restrict__ C_tmp,
|
||||
uint8_t* __restrict__ Aq_tmp,
|
||||
float* __restrict__ As_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const int8_t* __restrict__ packed_w1,
|
||||
const int8_t* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
const scalar_t* __restrict__ fused_experts_out,
|
||||
float routed_scaling_factor,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K) {
|
||||
|
||||
// handle 2 tiles per block
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
|
||||
// stage 0: quantize input to uint8, [M, K]
|
||||
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
quantize_row_int8<scalar_t>(
|
||||
Aq_tmp + m * K,
|
||||
As_tmp[m],
|
||||
input + m * K,
|
||||
K);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 1: intermediate_cache1 = silu(hidden_states @ w1)
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
|
||||
TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N);
|
||||
|
||||
// K and N are packed for int8
|
||||
const int64_t packed_K = get_row_size<int8_t>(K);
|
||||
const int64_t packed_N = get_row_size<int8_t>(N);
|
||||
const int64_t stride_n = packed_K;
|
||||
|
||||
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB;
|
||||
int64_t nb = i % NB;
|
||||
|
||||
// nb0 from top half and nb1 from bottom half
|
||||
int64_t nb0 = nb, nb1 = nb + NB;
|
||||
int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N);
|
||||
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
|
||||
|
||||
// A shape [m_size, K]
|
||||
const uint8_t* A = Aq_tmp + mb * BLOCK_M * K;
|
||||
const float* As = As_tmp + mb * BLOCK_M;
|
||||
|
||||
// B shape [K, n_size] in vnni format
|
||||
const int8_t* __restrict__ B0 = packed_w1 + nb0 * BLOCK_N * stride_n;
|
||||
const int8_t* __restrict__ B1 = packed_w1 + nb1 * BLOCK_N * stride_n;
|
||||
const float* __restrict__ Bs0 = w1s + nb0 * BLOCK_N;
|
||||
const float* __restrict__ Bs1 = w1s + nb1 * BLOCK_N;
|
||||
|
||||
// fused 1.b: silu_and_mul(A @ B0, A @ B1)
|
||||
tinygemm_kernel(
|
||||
/* A */ A,
|
||||
/* B0 */ B0,
|
||||
/* B1 */ B1,
|
||||
/* C */ ic1 + mb * BLOCK_M * N + nb * BLOCK_N,
|
||||
/* As */ As,
|
||||
/* Bs0 */ Bs0,
|
||||
/* Bs1 */ Bs1,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ N);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 1.5: quantize ic1 to uint8, [M * topk, N]
|
||||
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
quantize_row_int8<scalar_t>(
|
||||
Aq_tmp + m * N,
|
||||
As_tmp[m],
|
||||
ic1 + m * N,
|
||||
N);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 2: intermediate_cache2 = intermediate_cache1 @ w2
|
||||
// w2 : [K, N] as [OC, IC]
|
||||
const int64_t OC = K; // rename K as OC
|
||||
const int64_t IC = N; // rename N as IC
|
||||
const int64_t MB2 = MB;
|
||||
const int64_t NB2 = div_up(OC, BLOCK_N);
|
||||
const int64_t stride_oc = packed_N;
|
||||
|
||||
// parallel on [MB2, NB2]
|
||||
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
|
||||
// get local pointers
|
||||
int tid = at::get_thread_num();
|
||||
// we won't be using C1 for gemm2
|
||||
float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N;
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB2;
|
||||
int64_t nb = i % NB2;
|
||||
|
||||
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
|
||||
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
// A shape [m_size, IC]
|
||||
const uint8_t* __restrict__ A = Aq_tmp + mb * BLOCK_M * N;
|
||||
const float* __restrict__ As = As_tmp + mb * BLOCK_M;
|
||||
|
||||
// B shape [IC, n_size] in vnni format
|
||||
const int8_t* __restrict__ B = packed_w2 + nb * BLOCK_N * stride_oc;
|
||||
const float* __restrict__ Bs = w2s + nb * BLOCK_N;
|
||||
|
||||
// 2.a gemm: C = A @ B
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ A,
|
||||
/* B */ B,
|
||||
/* C */ C,
|
||||
/* As */ As,
|
||||
/* Bs */ Bs,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ IC,
|
||||
/* lda */ IC,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N);
|
||||
|
||||
// 2.b copy from C to output and add fused_experts_out
|
||||
scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N;
|
||||
const scalar_t* __restrict__ fused_out = fused_experts_out + mb * BLOCK_M * K + nb * BLOCK_N;
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#define INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(TYPE) \
|
||||
template void shared_expert_int8_kernel_impl<TYPE> ( \
|
||||
TYPE* __restrict__ output, TYPE* __restrict__ ic1, \
|
||||
float* __restrict__ C_tmp, uint8_t* __restrict__ Aq_tmp, \
|
||||
float* __restrict__ As_tmp, const TYPE* __restrict__ input, \
|
||||
const int8_t* __restrict__ packed_w1, const int8_t* __restrict__ packed_w2, \
|
||||
const float* __restrict__ w1s, const float* __restrict__ w2s, \
|
||||
const TYPE* __restrict__ fused_experts_out, float routed_scaling_factor, \
|
||||
int64_t M, int64_t N, int64_t K)
|
||||
|
||||
INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(at::Half);
|
||||
308
csrc/cpu/sgl-kernels/vec.h
Normal file
308
csrc/cpu/sgl-kernels/vec.h
Normal file
@ -0,0 +1,308 @@
|
||||
// Adapted from
|
||||
// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu
|
||||
|
||||
#pragma once
|
||||
|
||||
// clang-format off
|
||||
|
||||
#if defined(__AVX512F__) && defined(__AVX512BF16__) && defined(__AMX_BF16__)
|
||||
#define CPU_CAPABILITY_AVX512
|
||||
#endif
|
||||
|
||||
#include <ATen/cpu/vec/functional.h>
|
||||
#include <ATen/cpu/vec/vec.h>
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace at::vec;
|
||||
|
||||
template <typename scalar_t,
|
||||
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
||||
inline Vectorized<scalar_t> convert_from_float_ext(const Vectorized<float>& a, const Vectorized<float>& b) {
|
||||
return at::vec::convert_from_float<scalar_t>(a, b);
|
||||
}
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
|
||||
// `at::vec::convert_from_float<>` from PyTorch doesn't have avx512-bf16 intrinsics
|
||||
// use native instruction for bfloat16->float32 conversion
|
||||
template <>
|
||||
inline Vectorized<at::BFloat16> convert_from_float_ext<at::BFloat16>(const Vectorized<float>& a, const Vectorized<float>& b) {
|
||||
return (__m512i)(_mm512_cvtne2ps_pbh(__m512(b), __m512(a)));
|
||||
}
|
||||
|
||||
#define CVT_BF16_TO_FP32(a) \
|
||||
_mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16))
|
||||
|
||||
#define CVT_FP16_TO_FP32(a) \
|
||||
_mm512_cvtps_ph(a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC))
|
||||
|
||||
// this doesn't hanel NaN.
|
||||
inline __m512bh cvt_e4m3_bf16_intrinsic_no_nan(__m256i fp8_vec) {
|
||||
const __m512i x = _mm512_cvtepu8_epi16(fp8_vec);
|
||||
|
||||
const __m512i mant = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x07)), 4);
|
||||
const __m512i raw_exp = _mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x78)), 3);
|
||||
const __m512i exp = _mm512_slli_epi16(_mm512_add_epi16(raw_exp, _mm512_set1_epi16(120)), 7);
|
||||
const __m512i nonsign = _mm512_or_si512(exp, mant);
|
||||
|
||||
const __m512i sign = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x80)), 8);
|
||||
const __m512i combined = _mm512_or_si512(nonsign, sign);
|
||||
|
||||
const __mmask32 is_nonzero = _mm512_cmpneq_epi16_mask(x, _mm512_setzero_si512());
|
||||
return (__m512bh)_mm512_maskz_mov_epi16(is_nonzero, combined);
|
||||
}
|
||||
|
||||
inline __m512bh cvt_e4m3_bf16_intrinsic_without_denorm(__m256i fp8_vec) {
|
||||
// The following conversion is without denorm behavior, that is to say,
|
||||
// Max subnorm : S.0000.111 = 0.875 ∗ 2**(−6)
|
||||
// Min subnorm : S.0000.001 = 2**(−9)
|
||||
// 0.0019 ~ 0.0137 cannot be converted correctly.
|
||||
__m512i x = _mm512_cvtepu8_epi16(fp8_vec);
|
||||
auto mask = _mm512_cmpneq_epi16_mask(
|
||||
_mm512_and_si512(x, _mm512_set1_epi16(127)),
|
||||
_mm512_setzero_si512()); // mask = x & 0x7f
|
||||
auto mask_nan = _mm512_cmpneq_epi16_mask(
|
||||
_mm512_and_si512(x, _mm512_set1_epi16(127)),
|
||||
_mm512_set1_epi16(127)); // mask_nan = x & 0x7f
|
||||
auto mantissa = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(7)), 4); // mantissa = (x & 7) << 4
|
||||
auto exponent = _mm512_add_epi16(
|
||||
_mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(120)), 3),
|
||||
_mm512_set1_epi16(120)); // exponent = (((x >> 3) & 15) + 120)
|
||||
auto nonsign = _mm512_maskz_mov_epi16(mask, _mm512_or_si512(mantissa, _mm512_slli_epi16(exponent, 7)));
|
||||
nonsign = _mm512_mask_mov_epi16(_mm512_set1_epi16(0x7fff), mask_nan, nonsign); // deal with Nan
|
||||
return (__m512bh)(_mm512_or_si512(
|
||||
nonsign,
|
||||
_mm512_slli_epi16(
|
||||
_mm512_and_si512(x, _mm512_set1_epi16(128)),
|
||||
8))); // add sign (x & 128) << 8
|
||||
}
|
||||
|
||||
inline __m512bh cvt_e4m3_bf16_intrinsic_with_denorm(__m256i fp8_vec) {
|
||||
__m512i x = _mm512_cvtepu8_epi16(fp8_vec);
|
||||
__m512i lg2mant = _mm512_mask_mov_epi16(
|
||||
_mm512_mask_mov_epi16(
|
||||
_mm512_setzero_si512(), _mm512_test_epi16_mask(x, _mm512_set1_epi16(2)), _mm512_set1_epi16(1)),
|
||||
_mm512_test_epi16_mask(x, _mm512_set1_epi16(4)),
|
||||
_mm512_set1_epi16(2));
|
||||
return (__m512bh)(_mm512_or_si512(
|
||||
_mm512_maskz_mov_epi16(
|
||||
_mm512_cmpneq_epi16_mask(_mm512_and_si512(x, _mm512_set1_epi16(127)), _mm512_setzero_si512()),
|
||||
_mm512_mask_blend_epi16(
|
||||
_mm512_test_epi16_mask(x, _mm512_set1_epi16(120)),
|
||||
_mm512_or_si512(
|
||||
_mm512_and_si512(
|
||||
_mm512_sllv_epi16(
|
||||
_mm512_and_si512(x, _mm512_set1_epi16(3)), _mm512_sub_epi16(_mm512_set1_epi16(7), lg2mant)),
|
||||
_mm512_set1_epi16(0x007f)),
|
||||
_mm512_slli_epi16(_mm512_add_epi16(lg2mant, _mm512_set1_epi16(118)), 7)),
|
||||
_mm512_or_si512(
|
||||
_mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(7)), 4),
|
||||
_mm512_slli_epi16(
|
||||
_mm512_add_epi16(
|
||||
_mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(120)), 3), _mm512_set1_epi16(120)),
|
||||
7)))),
|
||||
_mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(128)), 8)));
|
||||
}
|
||||
|
||||
inline __m512bh CVT_FP8_TO_BF16(__m256i a) {
|
||||
#ifdef SGLANG_CPU_FP8_CVT_FTZ
|
||||
return cvt_e4m3_bf16_intrinsic_no_nan(a);
|
||||
#else
|
||||
return cvt_e4m3_bf16_intrinsic_with_denorm(a);
|
||||
#endif
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
// vector to scalar reduction
|
||||
#if defined(CPU_CAPABILITY_AVX512) && 0
|
||||
inline float vec_reduce_sum(const Vectorized<float>& a) {
|
||||
return _mm512_reduce_add_ps(__m512(a));
|
||||
}
|
||||
|
||||
inline float vec_reduce_max(const Vectorized<float>& a) {
|
||||
return _mm512_reduce_max_ps(__m512(a));
|
||||
}
|
||||
#else
|
||||
inline float vec_reduce_sum(const Vectorized<float>& a) {
|
||||
return vec_reduce_all([](Vectorized<float>& x, Vectorized<float>& y) { return x + y; }, a);
|
||||
}
|
||||
|
||||
inline float vec_reduce_max(const Vectorized<float>& a) {
|
||||
return vec_reduce_all([](Vectorized<float>& x, Vectorized<float>& y) { return maximum(x, y); }, a);
|
||||
}
|
||||
#endif
|
||||
|
||||
// https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282
|
||||
template <typename scalar_t>
|
||||
inline void quantize_row_int8(uint8_t* __restrict__ Aq, float& As,
|
||||
const scalar_t* __restrict__ A, int64_t K, float eps = 1e-7) {
|
||||
|
||||
float amax = 0.f; // absolute max
|
||||
for (int64_t k = 0; k < K; ++k) {
|
||||
const float val = static_cast<float>(A[k]);
|
||||
amax = std::max(amax, std::abs(val));
|
||||
}
|
||||
|
||||
amax = std::max(amax, eps);
|
||||
const float scale = amax / 127;
|
||||
const float inv_scale = 127 / amax;
|
||||
|
||||
for (int64_t k = 0; k < K; ++k) {
|
||||
const float val = static_cast<float>(A[k]) * inv_scale;
|
||||
Aq[k] = (uint8_t)(std::round(val)) + 128;
|
||||
}
|
||||
As = scale;
|
||||
}
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
template <>
|
||||
inline void quantize_row_int8<at::BFloat16>(uint8_t* __restrict__ Aq, float& As,
|
||||
const at::BFloat16* __restrict__ A, int64_t K, float eps) {
|
||||
|
||||
const __m512 signBit = _mm512_set1_ps(-0.0f);
|
||||
const __m512i off = _mm512_set1_epi32(128);
|
||||
|
||||
// K is 32x, no remainder
|
||||
float amax = 0.f;
|
||||
__m512 vamax0 = _mm512_set1_ps(0.f);
|
||||
__m512 vamax1 = _mm512_set1_ps(0.f);
|
||||
for (int64_t k = 0; k < K; k += 32) {
|
||||
__m512i va = _mm512_loadu_si512((void*)(A + k));
|
||||
__m512 va0 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 0));
|
||||
__m512 va1 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 1));
|
||||
vamax0 = _mm512_max_ps(vamax0, _mm512_andnot_ps(signBit, va0));
|
||||
vamax1 = _mm512_max_ps(vamax1, _mm512_andnot_ps(signBit, va1));
|
||||
}
|
||||
amax = _mm512_reduce_max_ps(_mm512_max_ps(vamax0, vamax1));
|
||||
amax = std::max(amax, eps);
|
||||
const float scale = amax / 127;
|
||||
const float inv_scale = 127 / amax;
|
||||
const __m512 vd = _mm512_set1_ps(inv_scale);
|
||||
|
||||
for (int64_t k = 0; k < K; k += 32) {
|
||||
__m512i va = _mm512_loadu_si512((void*)(A + k));
|
||||
__m512 va0 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 0));
|
||||
__m512 va1 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 1));
|
||||
va0 = _mm512_mul_ps(va0, vd);
|
||||
va1 = _mm512_mul_ps(va1, vd);
|
||||
va0 = _mm512_roundscale_ps(va0, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
|
||||
va1 = _mm512_roundscale_ps(va1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
|
||||
__m128i i0 = _mm512_cvtepi32_epi8(_mm512_add_epi32(_mm512_cvtps_epi32(va0), off));
|
||||
__m128i i1 = _mm512_cvtepi32_epi8(_mm512_add_epi32(_mm512_cvtps_epi32(va1), off));
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(Aq + k), _mm256_set_m128i(i1, i0));
|
||||
}
|
||||
As = scale;
|
||||
}
|
||||
#endif
|
||||
|
||||
// transpose utils
|
||||
// taken from my PR in ggml: https://github.com/ggml-org/llama.cpp/pull/8998
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
inline void transpose_16x16_32bit(__m512i * v) {
|
||||
__m512i v1[16];
|
||||
v1[0] = _mm512_unpacklo_epi32(v[0], v[1]);
|
||||
v1[1] = _mm512_unpackhi_epi32(v[0], v[1]);
|
||||
v1[2] = _mm512_unpacklo_epi32(v[2], v[3]);
|
||||
v1[3] = _mm512_unpackhi_epi32(v[2], v[3]);
|
||||
v1[4] = _mm512_unpacklo_epi32(v[4], v[5]);
|
||||
v1[5] = _mm512_unpackhi_epi32(v[4], v[5]);
|
||||
v1[6] = _mm512_unpacklo_epi32(v[6], v[7]);
|
||||
v1[7] = _mm512_unpackhi_epi32(v[6], v[7]);
|
||||
v1[8] = _mm512_unpacklo_epi32(v[8], v[9]);
|
||||
v1[9] = _mm512_unpackhi_epi32(v[8], v[9]);
|
||||
v1[10] = _mm512_unpacklo_epi32(v[10], v[11]);
|
||||
v1[11] = _mm512_unpackhi_epi32(v[10], v[11]);
|
||||
v1[12] = _mm512_unpacklo_epi32(v[12], v[13]);
|
||||
v1[13] = _mm512_unpackhi_epi32(v[12], v[13]);
|
||||
v1[14] = _mm512_unpacklo_epi32(v[14], v[15]);
|
||||
v1[15] = _mm512_unpackhi_epi32(v[14], v[15]);
|
||||
|
||||
v[0] = _mm512_unpacklo_epi64(v1[0], v1[2]);
|
||||
v[1] = _mm512_unpackhi_epi64(v1[0], v1[2]);
|
||||
v[2] = _mm512_unpacklo_epi64(v1[1], v1[3]);
|
||||
v[3] = _mm512_unpackhi_epi64(v1[1], v1[3]);
|
||||
v[4] = _mm512_unpacklo_epi64(v1[4], v1[6]);
|
||||
v[5] = _mm512_unpackhi_epi64(v1[4], v1[6]);
|
||||
v[6] = _mm512_unpacklo_epi64(v1[5], v1[7]);
|
||||
v[7] = _mm512_unpackhi_epi64(v1[5], v1[7]);
|
||||
v[8] = _mm512_unpacklo_epi64(v1[8], v1[10]);
|
||||
v[9] = _mm512_unpackhi_epi64(v1[8], v1[10]);
|
||||
v[10] = _mm512_unpacklo_epi64(v1[9], v1[11]);
|
||||
v[11] = _mm512_unpackhi_epi64(v1[9], v1[11]);
|
||||
v[12] = _mm512_unpacklo_epi64(v1[12], v1[14]);
|
||||
v[13] = _mm512_unpackhi_epi64(v1[12], v1[14]);
|
||||
v[14] = _mm512_unpacklo_epi64(v1[13], v1[15]);
|
||||
v[15] = _mm512_unpackhi_epi64(v1[13], v1[15]);
|
||||
|
||||
v1[0] = _mm512_shuffle_i32x4(v[0], v[4], 0x88);
|
||||
v1[1] = _mm512_shuffle_i32x4(v[1], v[5], 0x88);
|
||||
v1[2] = _mm512_shuffle_i32x4(v[2], v[6], 0x88);
|
||||
v1[3] = _mm512_shuffle_i32x4(v[3], v[7], 0x88);
|
||||
v1[4] = _mm512_shuffle_i32x4(v[0], v[4], 0xdd);
|
||||
v1[5] = _mm512_shuffle_i32x4(v[1], v[5], 0xdd);
|
||||
v1[6] = _mm512_shuffle_i32x4(v[2], v[6], 0xdd);
|
||||
v1[7] = _mm512_shuffle_i32x4(v[3], v[7], 0xdd);
|
||||
v1[8] = _mm512_shuffle_i32x4(v[8], v[12], 0x88);
|
||||
v1[9] = _mm512_shuffle_i32x4(v[9], v[13], 0x88);
|
||||
v1[10] = _mm512_shuffle_i32x4(v[10], v[14], 0x88);
|
||||
v1[11] = _mm512_shuffle_i32x4(v[11], v[15], 0x88);
|
||||
v1[12] = _mm512_shuffle_i32x4(v[8], v[12], 0xdd);
|
||||
v1[13] = _mm512_shuffle_i32x4(v[9], v[13], 0xdd);
|
||||
v1[14] = _mm512_shuffle_i32x4(v[10], v[14], 0xdd);
|
||||
v1[15] = _mm512_shuffle_i32x4(v[11], v[15], 0xdd);
|
||||
|
||||
v[0] = _mm512_shuffle_i32x4(v1[0], v1[8], 0x88);
|
||||
v[1] = _mm512_shuffle_i32x4(v1[1], v1[9], 0x88);
|
||||
v[2] = _mm512_shuffle_i32x4(v1[2], v1[10], 0x88);
|
||||
v[3] = _mm512_shuffle_i32x4(v1[3], v1[11], 0x88);
|
||||
v[4] = _mm512_shuffle_i32x4(v1[4], v1[12], 0x88);
|
||||
v[5] = _mm512_shuffle_i32x4(v1[5], v1[13], 0x88);
|
||||
v[6] = _mm512_shuffle_i32x4(v1[6], v1[14], 0x88);
|
||||
v[7] = _mm512_shuffle_i32x4(v1[7], v1[15], 0x88);
|
||||
v[8] = _mm512_shuffle_i32x4(v1[0], v1[8], 0xdd);
|
||||
v[9] = _mm512_shuffle_i32x4(v1[1], v1[9], 0xdd);
|
||||
v[10] = _mm512_shuffle_i32x4(v1[2], v1[10], 0xdd);
|
||||
v[11] = _mm512_shuffle_i32x4(v1[3], v1[11], 0xdd);
|
||||
v[12] = _mm512_shuffle_i32x4(v1[4], v1[12], 0xdd);
|
||||
v[13] = _mm512_shuffle_i32x4(v1[5], v1[13], 0xdd);
|
||||
v[14] = _mm512_shuffle_i32x4(v1[6], v1[14], 0xdd);
|
||||
v[15] = _mm512_shuffle_i32x4(v1[7], v1[15], 0xdd);
|
||||
}
|
||||
|
||||
// remove warning : ignoring attributes on template argument ‘__m512i’ [-Wignored-attributes]
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wignored-attributes"
|
||||
|
||||
// transpose from [2, 32] to [32, 2]
|
||||
inline std::tuple<__m512i, __m512i> transpose_2x32_16bit(__m512i r0, __m512i r1) {
|
||||
// r0: {a0, a1, ..., a31}
|
||||
// r1: {b0, b1, ..., b31}
|
||||
//
|
||||
// d0: {a0, b0, ..., a15, b15}
|
||||
// d1: {a16, b16, ..., a31, b31}
|
||||
//
|
||||
__m512i d0 = _mm512_unpacklo_epi16(r0, r1);
|
||||
__m512i d1 = _mm512_unpackhi_epi16(r0, r1);
|
||||
r0 = _mm512_shuffle_i32x4(d0, d1, 0x88);
|
||||
r1 = _mm512_shuffle_i32x4(d0, d1, 0xdd);
|
||||
d0 = _mm512_shuffle_i32x4(r0, r1, 0x88);
|
||||
d1 = _mm512_shuffle_i32x4(r0, r1, 0xdd);
|
||||
return std::make_tuple(d0, d1);
|
||||
}
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
#endif
|
||||
|
||||
// TODO: debug print, remove me later
|
||||
template<typename scalar_t>
|
||||
void print_array(scalar_t* ptr, int size) {
|
||||
for (int d = 0; d < size; ++d) {
|
||||
if (d % 16 == 0) { std::cout << std::endl; }
|
||||
std::cout << ptr[d] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
178
csrc/cpu/shm.cpp
178
csrc/cpu/shm.cpp
@ -7,9 +7,10 @@
|
||||
|
||||
namespace {
|
||||
#define MAX_SHM_RANK_NUM 8
|
||||
#define MAX_THREAD_NUM 12
|
||||
#define PER_THREAD_SHM_BUFFER_BYTES (4 * 1024 * 1024)
|
||||
#define MIN_THREAD_PROCESS_SIZE (8 * 1024)
|
||||
#define PER_THREAD_SHM_BUFFER_BYTES (2 * 1024 * 1024)
|
||||
static_assert(PER_THREAD_SHM_BUFFER_BYTES % 2 == 0);
|
||||
#define PER_THREAD_SHM_BUFFER_OFFSET (PER_THREAD_SHM_BUFFER_BYTES >> 1)
|
||||
#define MIN_THREAD_PROCESS_SIZE (256)
|
||||
#define MAX_P2P_SEND_TENSOR_NUM 8
|
||||
|
||||
template <typename scalar_t>
|
||||
@ -32,10 +33,10 @@ struct KernelVecType<c10::Half> {
|
||||
using scalar_vec_t = vec_op::FP16Vec16;
|
||||
};
|
||||
|
||||
enum class ThreadSHMStat : char { THREAD_READY = 0, SHM_DATA_READY, DONE };
|
||||
|
||||
struct ThreadSHMContext {
|
||||
volatile ThreadSHMStat thread_stats[MAX_SHM_RANK_NUM];
|
||||
volatile char _curr_thread_stamp;
|
||||
volatile char _ready_thread_stamp;
|
||||
char _padding1[6];
|
||||
int thread_id;
|
||||
int thread_num;
|
||||
int rank;
|
||||
@ -44,14 +45,19 @@ struct ThreadSHMContext {
|
||||
int swizzled_ranks[MAX_SHM_RANK_NUM];
|
||||
void* thread_shm_ptrs[MAX_SHM_RANK_NUM];
|
||||
ThreadSHMContext* shm_contexts[MAX_SHM_RANK_NUM];
|
||||
size_t _thread_buffer_mask;
|
||||
char _padding2[56];
|
||||
|
||||
ThreadSHMContext(const int thread_id, const int thread_num, const int rank,
|
||||
const int group_size, void* thread_shm_ptr)
|
||||
: thread_id(thread_id),
|
||||
: _curr_thread_stamp(1),
|
||||
_ready_thread_stamp(0),
|
||||
thread_id(thread_id),
|
||||
thread_num(thread_num),
|
||||
rank(rank),
|
||||
group_size(group_size),
|
||||
_spinning_count(0) {
|
||||
_spinning_count(0),
|
||||
_thread_buffer_mask(0) {
|
||||
static_assert(sizeof(ThreadSHMContext) % 64 == 0);
|
||||
TORCH_CHECK(group_size <= MAX_SHM_RANK_NUM);
|
||||
TORCH_CHECK((size_t)this % 64 == 0);
|
||||
@ -60,7 +66,6 @@ struct ThreadSHMContext {
|
||||
shm_contexts[i] = nullptr;
|
||||
thread_shm_ptrs[i] = nullptr;
|
||||
swizzled_ranks[i] = (i + rank) % group_size;
|
||||
thread_stats[i] = ThreadSHMStat::DONE;
|
||||
}
|
||||
set_context(rank, this, thread_shm_ptr);
|
||||
}
|
||||
@ -77,59 +82,66 @@ struct ThreadSHMContext {
|
||||
|
||||
template <typename T>
|
||||
T* get_thread_shm_ptr(int rank) {
|
||||
return reinterpret_cast<T*>(thread_shm_ptrs[rank]);
|
||||
return reinterpret_cast<T*>(
|
||||
reinterpret_cast<int8_t*>(thread_shm_ptrs[rank]) +
|
||||
(PER_THREAD_SHM_BUFFER_OFFSET & _thread_buffer_mask));
|
||||
}
|
||||
|
||||
void next_buffer() { _thread_buffer_mask ^= 0xFFFFFFFFFFFFFFFF; }
|
||||
|
||||
char get_curr_stamp() const { return _curr_thread_stamp; }
|
||||
|
||||
char get_ready_stamp() const { return _ready_thread_stamp; }
|
||||
|
||||
void next_stamp() {
|
||||
_mm_mfence();
|
||||
_curr_thread_stamp += 1;
|
||||
}
|
||||
|
||||
void commit_ready_stamp() {
|
||||
_mm_mfence();
|
||||
_ready_thread_stamp = _curr_thread_stamp;
|
||||
}
|
||||
|
||||
int get_swizzled_rank(int idx) { return swizzled_ranks[idx]; }
|
||||
|
||||
void wait_for_all(ThreadSHMStat prev_stat) {
|
||||
for (int idx = 0; idx < group_size; ++idx) {
|
||||
template <typename Cond>
|
||||
void wait_for_all(Cond&& cond) {
|
||||
for (int idx = 1; idx < group_size; ++idx) {
|
||||
int rank = get_swizzled_rank(idx);
|
||||
while (thread_stats[rank] == prev_stat) {
|
||||
++_spinning_count;
|
||||
_mm_pause();
|
||||
}
|
||||
wait_for_one(rank, std::forward<Cond>(cond));
|
||||
}
|
||||
vec_op::mem_barrier();
|
||||
}
|
||||
|
||||
void wait_for_one(int rank, ThreadSHMStat prev_stat) {
|
||||
while (thread_stats[rank] == prev_stat) {
|
||||
template <typename Cond>
|
||||
void wait_for_one(int rank, Cond&& cond) {
|
||||
ThreadSHMContext* rank_ctx = shm_contexts[rank];
|
||||
for (;;) {
|
||||
char local_curr_stamp = get_curr_stamp();
|
||||
char local_ready_stamp = get_ready_stamp();
|
||||
char rank_curr_stamp = rank_ctx->get_curr_stamp();
|
||||
char rank_ready_stamp = rank_ctx->get_ready_stamp();
|
||||
if (cond(local_curr_stamp, local_ready_stamp, rank_curr_stamp,
|
||||
rank_ready_stamp)) {
|
||||
break;
|
||||
}
|
||||
++_spinning_count;
|
||||
_mm_pause();
|
||||
}
|
||||
vec_op::mem_barrier();
|
||||
}
|
||||
|
||||
void set_thread_stat(ThreadSHMStat stat) {
|
||||
for (int idx = 0; idx < group_size; ++idx) {
|
||||
int rank = get_swizzled_rank(idx);
|
||||
shm_contexts[rank]->thread_stats[this->rank] = stat;
|
||||
}
|
||||
static bool check_no_buffer_conflict(char local_curr_stamp,
|
||||
char local_ready_stamp,
|
||||
char rank_curr_stamp,
|
||||
char rank_ready_stamp) {
|
||||
char temp = rank_curr_stamp + 2;
|
||||
return local_curr_stamp != temp;
|
||||
}
|
||||
|
||||
void set_thread_stat(int target_rank, ThreadSHMStat stat) {
|
||||
for (int idx = 0; idx < group_size; ++idx) {
|
||||
int rank = get_swizzled_rank(idx);
|
||||
shm_contexts[rank]->thread_stats[target_rank] = stat;
|
||||
}
|
||||
}
|
||||
|
||||
// barrier for all ranks in the group, used for all2all ops
|
||||
// DONE -> THREAD_READY -> SHM_DATA_READY -> DONE -> ...
|
||||
void barrier(ThreadSHMStat next_stat) {
|
||||
if (next_stat == ThreadSHMStat::THREAD_READY) {
|
||||
set_thread_stat(ThreadSHMStat::THREAD_READY);
|
||||
wait_for_all(ThreadSHMStat::DONE);
|
||||
} else if (next_stat == ThreadSHMStat::SHM_DATA_READY) {
|
||||
set_thread_stat(ThreadSHMStat::SHM_DATA_READY);
|
||||
wait_for_all(ThreadSHMStat::THREAD_READY);
|
||||
} else if (next_stat == ThreadSHMStat::DONE) {
|
||||
set_thread_stat(ThreadSHMStat::DONE);
|
||||
wait_for_all(ThreadSHMStat::SHM_DATA_READY);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Invalid next_stat to barrier.");
|
||||
}
|
||||
static bool check_stamp_ready(char local_curr_stamp, char local_ready_stamp,
|
||||
char rank_curr_stamp, char rank_ready_stamp) {
|
||||
char temp = local_curr_stamp + 1;
|
||||
return (local_curr_stamp == rank_ready_stamp) || (temp == rank_ready_stamp);
|
||||
}
|
||||
|
||||
std::string to_string() const {
|
||||
@ -164,7 +176,7 @@ class SHMManager {
|
||||
const int group_size)
|
||||
: _rank(rank),
|
||||
_group_size(group_size),
|
||||
_thread_num(std::min(torch::get_num_threads(), MAX_THREAD_NUM)),
|
||||
_thread_num(torch::get_num_threads()),
|
||||
_shm_names({""}),
|
||||
_shared_mem_ptrs({nullptr}),
|
||||
_shm_ctx(nullptr) {
|
||||
@ -326,7 +338,8 @@ void shm_cc_loop(ThreadSHMContext* ctx, int64_t elem_num, F&& inner_func) {
|
||||
(total_units_num + thread_num - 1) / thread_num;
|
||||
int64_t per_unit_elem_num = MIN_THREAD_PROCESS_SIZE / sizeof(scalar_t);
|
||||
int64_t max_per_thread_iteration_elem_num =
|
||||
PER_THREAD_SHM_BUFFER_BYTES / sizeof(scalar_t);
|
||||
(PER_THREAD_SHM_BUFFER_BYTES >> 1) /
|
||||
sizeof(scalar_t); // Note: double buffer
|
||||
int64_t per_thread_elem_num = per_unit_elem_num * per_thread_units_num;
|
||||
|
||||
#pragma omp parallel for schedule(static, 1)
|
||||
@ -336,10 +349,13 @@ void shm_cc_loop(ThreadSHMContext* ctx, int64_t elem_num, F&& inner_func) {
|
||||
int64_t curr_elem_num =
|
||||
std::min(max_per_thread_iteration_elem_num, end - offset);
|
||||
ThreadSHMContext* thread_ctx = ctx + i;
|
||||
bool fast_mode = ((end - offset) <= max_per_thread_iteration_elem_num);
|
||||
|
||||
while (curr_elem_num > 0) {
|
||||
inner_func(thread_ctx, offset, curr_elem_num);
|
||||
inner_func(thread_ctx, offset, curr_elem_num, fast_mode);
|
||||
|
||||
thread_ctx->next_stamp();
|
||||
thread_ctx->next_buffer();
|
||||
offset += max_per_thread_iteration_elem_num;
|
||||
curr_elem_num = std::min(max_per_thread_iteration_elem_num, end - offset);
|
||||
}
|
||||
@ -397,7 +413,7 @@ void all_reduce_sum_impl(ThreadSHMContext* ctx, scalar_t* data,
|
||||
shm_cc_ops::shm_cc_loop<scalar_t>(
|
||||
ctx, elem_num,
|
||||
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
|
||||
int64_t data_elem_num) {
|
||||
int64_t data_elem_num, bool fast_mode) {
|
||||
int rank = thread_ctx->rank;
|
||||
scalar_t* thread_shm_ptr =
|
||||
thread_ctx->get_thread_shm_ptr<scalar_t>(rank);
|
||||
@ -410,16 +426,17 @@ void all_reduce_sum_impl(ThreadSHMContext* ctx, scalar_t* data,
|
||||
thread_ctx->get_swizzled_rank(idx + 1));
|
||||
});
|
||||
|
||||
thread_ctx->barrier(ThreadSHMStat::THREAD_READY);
|
||||
if (!fast_mode) {
|
||||
thread_ctx->wait_for_all(ThreadSHMContext::check_no_buffer_conflict);
|
||||
}
|
||||
|
||||
shm_cc_ops::memcpy_to_shm(thread_shm_ptr, thread_data_ptr,
|
||||
thread_data_elem_num);
|
||||
|
||||
thread_ctx->barrier(ThreadSHMStat::SHM_DATA_READY);
|
||||
|
||||
thread_ctx->commit_ready_stamp();
|
||||
int64_t aligned_data_elem_num =
|
||||
(data_elem_num / vec_elem_num) * vec_elem_num;
|
||||
int64_t i = 0;
|
||||
thread_ctx->wait_for_all(ThreadSHMContext::check_stamp_ready);
|
||||
#pragma GCC unroll 4
|
||||
for (; i < aligned_data_elem_num; i += vec_elem_num) {
|
||||
vec_t local_data(thread_data_ptr + i); // load from cache
|
||||
@ -447,8 +464,6 @@ void all_reduce_sum_impl(ThreadSHMContext* ctx, scalar_t* data,
|
||||
reduced_data.save(thread_data_ptr + i,
|
||||
data_elem_num - aligned_data_elem_num);
|
||||
}
|
||||
|
||||
thread_ctx->barrier(ThreadSHMStat::DONE);
|
||||
});
|
||||
|
||||
return;
|
||||
@ -488,18 +503,18 @@ void shm_gather_impl(ThreadSHMContext* ctx, scalar_t* data, size_t elem_num,
|
||||
shm_cc_ops::shm_cc_loop<scalar_t>(
|
||||
ctx, elem_num,
|
||||
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
|
||||
int64_t data_elem_num) {
|
||||
int64_t data_elem_num, bool fast_mode) {
|
||||
int rank = thread_ctx->rank;
|
||||
scalar_t* thread_shm_ptr =
|
||||
thread_ctx->get_thread_shm_ptr<scalar_t>(rank);
|
||||
|
||||
thread_ctx->barrier(ThreadSHMStat::THREAD_READY);
|
||||
|
||||
shm_cc_ops::memcpy_to_shm(thread_shm_ptr, data + data_offset,
|
||||
data_elem_num * sizeof(scalar_t));
|
||||
|
||||
thread_ctx->barrier(ThreadSHMStat::SHM_DATA_READY);
|
||||
if (!fast_mode) {
|
||||
thread_ctx->wait_for_all(ThreadSHMContext::check_no_buffer_conflict);
|
||||
}
|
||||
|
||||
shm_cc_ops::memcpy(thread_shm_ptr, data + data_offset,
|
||||
data_elem_num * sizeof(scalar_t));
|
||||
thread_ctx->commit_ready_stamp();
|
||||
if (rank == dst) {
|
||||
shm_cc_ops::memcpy(outputs[rank] + data_offset, data + data_offset,
|
||||
data_elem_num * sizeof(scalar_t));
|
||||
@ -508,12 +523,12 @@ void shm_gather_impl(ThreadSHMContext* ctx, scalar_t* data, size_t elem_num,
|
||||
scalar_t* src_ptr =
|
||||
thread_ctx->get_thread_shm_ptr<scalar_t>(src_rank); // shm
|
||||
scalar_t* dst_ptr = outputs[src_rank] + data_offset;
|
||||
shm_cc_ops::memcpy_from_shm(dst_ptr, src_ptr,
|
||||
data_elem_num * sizeof(scalar_t));
|
||||
thread_ctx->wait_for_one(src_rank,
|
||||
ThreadSHMContext::check_stamp_ready);
|
||||
shm_cc_ops::memcpy(dst_ptr, src_ptr,
|
||||
data_elem_num * sizeof(scalar_t));
|
||||
}
|
||||
}
|
||||
|
||||
thread_ctx->barrier(ThreadSHMStat::DONE);
|
||||
});
|
||||
|
||||
return;
|
||||
@ -599,7 +614,7 @@ struct TensorListMeta {
|
||||
int8_t _padding[40];
|
||||
};
|
||||
|
||||
void shm_send_tensor_list_impl(ThreadSHMContext* ctx,
|
||||
void shm_send_tensor_list_impl(ThreadSHMContext* ctx, int64_t dst,
|
||||
const std::vector<torch::Tensor>& tensor_list) {
|
||||
CPU_KERNEL_GUARD_IN(shm_send_tensor_list_impl)
|
||||
std::vector<torch::Tensor> tensor_list_with_metadata;
|
||||
@ -620,12 +635,11 @@ void shm_send_tensor_list_impl(ThreadSHMContext* ctx,
|
||||
shm_cc_ops::shm_cc_loop<int8_t>(
|
||||
ctx, metadata->total_bytes,
|
||||
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
|
||||
int64_t data_elem_num) {
|
||||
int64_t data_elem_num, bool fast_mode) {
|
||||
int rank = thread_ctx->rank;
|
||||
// Wait until the receiver set the stat to DONE
|
||||
thread_ctx->wait_for_one(rank, ThreadSHMStat::SHM_DATA_READY);
|
||||
|
||||
int64_t curr_shm_offset = 0;
|
||||
thread_ctx->wait_for_one(dst,
|
||||
ThreadSHMContext::check_no_buffer_conflict);
|
||||
while (curr_shm_offset < data_elem_num) {
|
||||
MemPiece frag = metadata->get_data(data_offset + curr_shm_offset);
|
||||
frag.size = std::min(frag.size, data_elem_num - curr_shm_offset);
|
||||
@ -634,8 +648,7 @@ void shm_send_tensor_list_impl(ThreadSHMContext* ctx,
|
||||
frag.ptr, frag.size);
|
||||
curr_shm_offset += frag.size;
|
||||
}
|
||||
|
||||
thread_ctx->set_thread_stat(rank, ThreadSHMStat::SHM_DATA_READY);
|
||||
thread_ctx->commit_ready_stamp();
|
||||
});
|
||||
}
|
||||
|
||||
@ -646,8 +659,7 @@ std::vector<torch::Tensor> shm_recv_tensor_list_impl(ThreadSHMContext* ctx,
|
||||
torch::Tensor metadata_tensor =
|
||||
torch::empty({sizeof(TensorListMeta)}, options);
|
||||
|
||||
// Wait until the sender set the stat of the thread 0 to SHM_DATA_READY
|
||||
ctx->wait_for_one(src, ThreadSHMStat::DONE);
|
||||
ctx->wait_for_one(src, ThreadSHMContext::check_stamp_ready);
|
||||
shm_cc_ops::memcpy(metadata_tensor.data_ptr(),
|
||||
ctx->get_thread_shm_ptr<void>(src),
|
||||
sizeof(TensorListMeta));
|
||||
@ -664,9 +676,8 @@ std::vector<torch::Tensor> shm_recv_tensor_list_impl(ThreadSHMContext* ctx,
|
||||
shm_cc_ops::shm_cc_loop<int8_t>(
|
||||
ctx, metadata.total_bytes,
|
||||
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
|
||||
int64_t data_elem_num) {
|
||||
// Wait until the sender set the stat to SHM_DATA_READY
|
||||
thread_ctx->wait_for_one(src, ThreadSHMStat::DONE);
|
||||
int64_t data_elem_num, bool fast_mode) {
|
||||
ctx->wait_for_one(src, ThreadSHMContext::check_stamp_ready);
|
||||
int64_t curr_shm_offset = 0;
|
||||
while (curr_shm_offset < data_elem_num) {
|
||||
MemPiece frag = metadata.get_data(data_offset + curr_shm_offset);
|
||||
@ -677,8 +688,6 @@ std::vector<torch::Tensor> shm_recv_tensor_list_impl(ThreadSHMContext* ctx,
|
||||
frag.size);
|
||||
curr_shm_offset += frag.size;
|
||||
}
|
||||
|
||||
thread_ctx->set_thread_stat(src, ThreadSHMStat::DONE);
|
||||
});
|
||||
|
||||
std::vector<torch::Tensor> tensor_list;
|
||||
@ -756,7 +765,8 @@ void shm_send_tensor_list(int64_t handle,
|
||||
int64_t dst) {
|
||||
CPU_KERNEL_GUARD_IN(shm_send_tensor_list)
|
||||
shm_send_tensor_list_impl(
|
||||
SHMManager::get_singleton_instance(handle)->get_shm_ctx(), tensor_list);
|
||||
SHMManager::get_singleton_instance(handle)->get_shm_ctx(), dst,
|
||||
tensor_list);
|
||||
CPU_KERNEL_GUARD_OUT(shm_send_tensor_list)
|
||||
}
|
||||
|
||||
@ -778,4 +788,4 @@ std::string join_shm_manager(int64_t handle, const std::string& name) {
|
||||
TORCH_CHECK(shm_manager);
|
||||
shm_manager->join(name);
|
||||
return shm_manager->get_shm_ctx()->to_string();
|
||||
}
|
||||
}
|
||||
|
||||
@ -50,6 +50,27 @@ void shm_send_tensor_list(int64_t handle,
|
||||
|
||||
std::vector<torch::Tensor> shm_recv_tensor_list(int64_t handle, int64_t src);
|
||||
|
||||
at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
bool is_vnni);
|
||||
|
||||
at::Tensor convert_weight_packed(at::Tensor& weight);
|
||||
|
||||
at::Tensor fused_experts_cpu(
|
||||
at::Tensor& hidden_states, at::Tensor& w1, at::Tensor& w2,
|
||||
at::Tensor& topk_weights, at::Tensor& topk_ids, bool inplace,
|
||||
bool use_int8_w8a8, bool use_fp8_w8a16,
|
||||
const std::optional<at::Tensor>& w1_scale,
|
||||
const std::optional<at::Tensor>& w2_scale,
|
||||
const std::optional<std::vector<int64_t>> block_size,
|
||||
const std::optional<at::Tensor>& a1_scale,
|
||||
const std::optional<at::Tensor>& a2_scale, bool is_vnni);
|
||||
|
||||
at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2,
|
||||
at::Tensor& scales2,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
at::ScalarType out_dtype, bool is_vnni);
|
||||
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// vLLM custom ops
|
||||
|
||||
@ -130,8 +151,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding);
|
||||
|
||||
// Quantization
|
||||
#ifdef __AVX512F__
|
||||
#if defined(__AVX512F__) || defined(__aarch64__)
|
||||
at::Tag stride_tag = at::Tag::needs_fixed_stride_order;
|
||||
|
||||
// Compute int8 quantized tensor for given scaling factor.
|
||||
ops.def(
|
||||
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
|
||||
@ -214,6 +236,28 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def("shm_recv_tensor_list(int handle, int src) -> Tensor[](a)",
|
||||
&shm_recv_tensor_list);
|
||||
#endif
|
||||
|
||||
// sgl-kernels
|
||||
#if defined(__AVX512BF16__) && defined(__AVX512F__) && defined(__AVX512VNNI__)
|
||||
ops.def(
|
||||
"weight_packed_linear(Tensor(a0!) mat1, Tensor(a1!) mat2, Tensor(a2!)? "
|
||||
"bias, bool is_vnni) -> Tensor");
|
||||
ops.impl("weight_packed_linear", torch::kCPU, &weight_packed_linear);
|
||||
ops.def("convert_weight_packed(Tensor! weight) -> Tensor");
|
||||
ops.impl("convert_weight_packed", torch::kCPU, &convert_weight_packed);
|
||||
ops.def(
|
||||
"fused_experts_cpu(Tensor! hidden_states, Tensor w1, Tensor w2, Tensor "
|
||||
"topk_weights, Tensor topk_ids, bool inplace, bool use_int8_w8a8, bool "
|
||||
"use_fp8_w8a16, Tensor? w1_scale, Tensor? w2_scale, SymInt[]? "
|
||||
"block_size, Tensor? a1_scale, Tensor? a2_scale, bool is_vnni) -> "
|
||||
"Tensor");
|
||||
ops.impl("fused_experts_cpu", torch::kCPU, &fused_experts_cpu);
|
||||
ops.def(
|
||||
"int8_scaled_mm_with_quant(Tensor mat1, Tensor mat2, Tensor scales2, "
|
||||
"Tensor? bias, ScalarType out_dtype, bool is_vnni) -> Tensor");
|
||||
ops.impl("int8_scaled_mm_with_quant", torch::kCPU,
|
||||
&int8_scaled_mm_with_quant);
|
||||
#endif
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
|
||||
@ -153,7 +153,7 @@ struct ScaledEpilogueBias
|
||||
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::homogeneous_multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
@ -210,7 +210,7 @@ struct ScaledEpilogueBiasAzp
|
||||
EVTComputeAzp>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::homogeneous_multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
@ -288,7 +288,7 @@ struct ScaledEpilogueBiasAzpToken
|
||||
EVTComputeAcc>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::homogeneous_multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
|
||||
@ -195,7 +195,7 @@ struct ScaledEpilogueBias
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::homogeneous_multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
@ -238,7 +238,7 @@ struct ScaledEpilogueColumnBias
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::homogeneous_multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
@ -295,7 +295,7 @@ struct ScaledEpilogueBiasAzp
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAzp>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::homogeneous_multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
@ -371,7 +371,7 @@ struct ScaledEpilogueBiasAzpToken
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAcc>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::homogeneous_multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
|
||||
@ -45,7 +45,6 @@
|
||||
#include "cute/algorithm/functional.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cute/algorithm/gemm.hpp"
|
||||
#include "cute/tensor_predicate.hpp"
|
||||
#include "cute/numeric/arithmetic_tuple.hpp"
|
||||
|
||||
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
|
||||
|
||||
@ -1,656 +0,0 @@
|
||||
// clang-format off
|
||||
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_fwd.cu
|
||||
// and https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_update.cu
|
||||
#include <torch/all.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "causal_conv1d.h"
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
||||
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_store.cuh>
|
||||
|
||||
#ifdef USE_ROCM
|
||||
namespace cub = hipcub;
|
||||
#endif
|
||||
|
||||
#include "static_switch.h"
|
||||
|
||||
|
||||
|
||||
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
||||
|
||||
#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
|
||||
if (ITYPE == at::ScalarType::Half) { \
|
||||
using input_t = at::Half; \
|
||||
using weight_t = at::Half; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (ITYPE == at::ScalarType::BFloat16) { \
|
||||
using input_t = at::BFloat16; \
|
||||
using weight_t = at::BFloat16; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (ITYPE == at::ScalarType::Float) { \
|
||||
using input_t = float; \
|
||||
using weight_t = float; \
|
||||
__VA_ARGS__(); \
|
||||
} else { \
|
||||
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
|
||||
}
|
||||
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
|
||||
void set_conv_params_fwd(ConvParamsBase ¶ms,
|
||||
// sizes
|
||||
const size_t batch,
|
||||
const size_t dim,
|
||||
const size_t seqlen,
|
||||
const size_t width,
|
||||
// device pointers
|
||||
const at::Tensor x,
|
||||
const at::Tensor weight,
|
||||
const at::Tensor out,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
bool silu_activation,
|
||||
int64_t pad_slot_id,
|
||||
const std::optional<at::Tensor>& query_start_loc = std::nullopt,
|
||||
const std::optional<at::Tensor>& cache_indices = std::nullopt,
|
||||
const std::optional<at::Tensor>& has_initial_state = std::nullopt) {
|
||||
|
||||
// Reset the parameters
|
||||
memset(¶ms, 0, sizeof(params));
|
||||
|
||||
params.batch = batch;
|
||||
params.dim = dim;
|
||||
params.seqlen = seqlen;
|
||||
params.width = width;
|
||||
params.pad_slot_id = pad_slot_id;
|
||||
|
||||
params.silu_activation = silu_activation;
|
||||
|
||||
// Set the pointers and strides.
|
||||
params.x_ptr = x.data_ptr();
|
||||
params.weight_ptr = weight.data_ptr();
|
||||
params.bias_ptr = bias.has_value() ? bias.value().data_ptr() : nullptr;
|
||||
params.out_ptr = out.data_ptr();
|
||||
// All stride are in elements, not bytes.
|
||||
params.query_start_loc_ptr = query_start_loc.has_value() ? query_start_loc.value().data_ptr() : nullptr;
|
||||
params.cache_indices_ptr = cache_indices.has_value() ? cache_indices.value().data_ptr() : nullptr;
|
||||
params.has_initial_state_ptr = has_initial_state.has_value() ? has_initial_state.value().data_ptr() : nullptr;
|
||||
const bool varlen = params.query_start_loc_ptr != nullptr;
|
||||
params.x_batch_stride = x.stride(varlen ? 1 : 0);
|
||||
params.x_c_stride = x.stride(varlen ? 0 : 1);
|
||||
params.x_l_stride = x.stride(varlen ? 1 : -1);
|
||||
params.weight_c_stride = weight.stride(0);
|
||||
params.weight_width_stride = weight.stride(1);
|
||||
params.out_batch_stride = out.stride(varlen ? 1 : 0);
|
||||
params.out_c_stride = out.stride(varlen ? 0 : 1);
|
||||
params.out_l_stride = out.stride(varlen ? 1 : -1);
|
||||
}
|
||||
|
||||
|
||||
void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
|
||||
const std::optional<at::Tensor> &bias_,
|
||||
const std::optional<at::Tensor> &conv_states,
|
||||
const std::optional<at::Tensor> &query_start_loc,
|
||||
const std::optional<at::Tensor> &cache_indices,
|
||||
const std::optional<at::Tensor> &has_initial_state,
|
||||
bool silu_activation,
|
||||
// used to identify padding entries if cache_indices provided
|
||||
// in case of padding, the kernel will return early
|
||||
int64_t pad_slot_id) {
|
||||
auto input_type = x.scalar_type();
|
||||
auto weight_type = weight.scalar_type();
|
||||
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
||||
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
|
||||
|
||||
TORCH_CHECK(x.is_cuda());
|
||||
TORCH_CHECK(weight.is_cuda());
|
||||
|
||||
const bool varlen = query_start_loc.has_value() ? true : false;
|
||||
const auto sizes = x.sizes();
|
||||
const int batch_size = varlen ? query_start_loc.value().sizes()[0] - 1 : sizes[0];
|
||||
const int dim = varlen ? sizes[0] : sizes[1];
|
||||
const int seqlen = varlen ? sizes[1] : sizes[2];
|
||||
const int width = weight.size(-1);
|
||||
if (varlen){
|
||||
CHECK_SHAPE(x, dim, seqlen);
|
||||
}
|
||||
else {
|
||||
CHECK_SHAPE(x, batch_size, dim, seqlen);
|
||||
}
|
||||
CHECK_SHAPE(weight, dim, width);
|
||||
|
||||
|
||||
|
||||
if (bias_.has_value()) {
|
||||
auto bias = bias_.value();
|
||||
TORCH_CHECK(bias.scalar_type() == weight_type);
|
||||
TORCH_CHECK(bias.is_cuda());
|
||||
TORCH_CHECK(bias.stride(-1) == 1);
|
||||
CHECK_SHAPE(bias, dim);
|
||||
}
|
||||
|
||||
|
||||
if (has_initial_state.has_value()) {
|
||||
auto has_initial_state_ = has_initial_state.value();
|
||||
TORCH_CHECK(has_initial_state_.scalar_type() == at::ScalarType::Bool);
|
||||
TORCH_CHECK(has_initial_state_.is_cuda());
|
||||
CHECK_SHAPE(has_initial_state_, batch_size);
|
||||
}
|
||||
|
||||
|
||||
if (query_start_loc.has_value()) {
|
||||
auto query_start_loc_ = query_start_loc.value();
|
||||
TORCH_CHECK(query_start_loc_.scalar_type() == at::ScalarType::Int);
|
||||
TORCH_CHECK(query_start_loc_.is_cuda());
|
||||
}
|
||||
|
||||
|
||||
if (cache_indices.has_value()) {
|
||||
auto cache_indices_ = cache_indices.value();
|
||||
TORCH_CHECK(cache_indices_.scalar_type() == at::ScalarType::Int);
|
||||
TORCH_CHECK(cache_indices_.is_cuda());
|
||||
CHECK_SHAPE(cache_indices_, batch_size);
|
||||
}
|
||||
|
||||
at::Tensor out = x;
|
||||
|
||||
ConvParamsBase params;
|
||||
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
|
||||
bias_,
|
||||
silu_activation,
|
||||
pad_slot_id,
|
||||
query_start_loc,
|
||||
cache_indices,
|
||||
has_initial_state
|
||||
);
|
||||
|
||||
if (conv_states.has_value()) {
|
||||
auto conv_states_ = conv_states.value();
|
||||
TORCH_CHECK(conv_states_.scalar_type() == input_type);
|
||||
TORCH_CHECK(conv_states_.is_cuda());
|
||||
params.conv_states_ptr = conv_states_.data_ptr();
|
||||
params.conv_states_batch_stride = conv_states_.stride(0);
|
||||
params.conv_states_c_stride = conv_states_.stride(1);
|
||||
params.conv_states_l_stride = conv_states_.stride(2);
|
||||
} else {
|
||||
params.conv_states_ptr = nullptr;
|
||||
}
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] {
|
||||
causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
void causal_conv1d_update(const at::Tensor &x,
|
||||
const at::Tensor &conv_state,
|
||||
const at::Tensor &weight,
|
||||
const std::optional<at::Tensor> &bias_,
|
||||
bool silu_activation,
|
||||
const std::optional<at::Tensor> &cache_seqlens_,
|
||||
const std::optional<at::Tensor> &conv_state_indices_,
|
||||
// used to identify padding entries if cache_indices provided
|
||||
// in case of padding, the kernel will return early
|
||||
int64_t pad_slot_id) {
|
||||
auto input_type = x.scalar_type();
|
||||
auto weight_type = weight.scalar_type();
|
||||
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
||||
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
|
||||
TORCH_CHECK(weight_type == input_type, "weight type must equal to input type, other variations are disabled due to binary size limitations");
|
||||
TORCH_CHECK(conv_state.scalar_type() == input_type);
|
||||
|
||||
TORCH_CHECK(x.is_cuda());
|
||||
TORCH_CHECK(conv_state.is_cuda());
|
||||
TORCH_CHECK(weight.is_cuda());
|
||||
|
||||
const auto sizes = x.sizes();
|
||||
const int batch_size = sizes[0];
|
||||
const int dim = sizes[1];
|
||||
const int seqlen = sizes[2];
|
||||
const int width = weight.size(-1);
|
||||
const int conv_state_len = conv_state.size(2);
|
||||
TORCH_CHECK(conv_state_len >= width - 1);
|
||||
|
||||
CHECK_SHAPE(x, batch_size, dim, seqlen);
|
||||
CHECK_SHAPE(weight, dim, width);
|
||||
|
||||
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
|
||||
|
||||
if (bias_.has_value()) {
|
||||
auto bias = bias_.value();
|
||||
TORCH_CHECK(bias.scalar_type() == weight_type);
|
||||
TORCH_CHECK(bias.is_cuda());
|
||||
TORCH_CHECK(bias.stride(-1) == 1);
|
||||
CHECK_SHAPE(bias, dim);
|
||||
}
|
||||
|
||||
at::Tensor out = x;
|
||||
|
||||
ConvParamsBase params;
|
||||
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
|
||||
bias_,
|
||||
silu_activation,
|
||||
pad_slot_id);
|
||||
params.conv_state_ptr = conv_state.data_ptr();
|
||||
params.conv_state_len = conv_state_len;
|
||||
// All stride are in elements, not bytes.
|
||||
params.conv_state_batch_stride = conv_state.stride(0);
|
||||
params.conv_state_c_stride = conv_state.stride(1);
|
||||
params.conv_state_l_stride = conv_state.stride(2);
|
||||
|
||||
if (cache_seqlens_.has_value()) {
|
||||
auto cache_seqlens = cache_seqlens_.value();
|
||||
TORCH_CHECK(cache_seqlens.scalar_type() == torch::kInt32);
|
||||
TORCH_CHECK(cache_seqlens.is_cuda());
|
||||
TORCH_CHECK(cache_seqlens.stride(-1) == 1);
|
||||
CHECK_SHAPE(cache_seqlens, batch_size);
|
||||
params.cache_seqlens = cache_seqlens.data_ptr<int32_t>();
|
||||
} else {
|
||||
params.cache_seqlens = nullptr;
|
||||
}
|
||||
|
||||
if (conv_state_indices_.has_value()) {
|
||||
auto conv_state_indices = conv_state_indices_.value();
|
||||
TORCH_CHECK(conv_state_indices.scalar_type() == torch::kInt32)
|
||||
TORCH_CHECK(conv_state_indices.is_cuda());
|
||||
TORCH_CHECK(conv_state_indices.stride(0) == 1)
|
||||
CHECK_SHAPE(conv_state_indices, batch_size);
|
||||
|
||||
int conv_state_entries = conv_state.size(0);
|
||||
CHECK_SHAPE(conv_state, conv_state_entries, dim, conv_state_len);
|
||||
|
||||
params.conv_state_indices_ptr = conv_state_indices.data_ptr<int32_t>();
|
||||
} else {
|
||||
CHECK_SHAPE(conv_state, batch_size, dim, conv_state_len);
|
||||
params.conv_state_indices_ptr = nullptr;
|
||||
}
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] {
|
||||
causal_conv1d_update_cuda<input_t, weight_t>(params, stream);
|
||||
});
|
||||
}
|
||||
|
||||
template<int kNThreads_, int kWidth_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
|
||||
struct Causal_conv1d_fwd_kernel_traits {
|
||||
using input_t = input_t_;
|
||||
using weight_t = weight_t_;
|
||||
static constexpr int kNThreads = kNThreads_;
|
||||
static constexpr int kWidth = kWidth_;
|
||||
static constexpr int kNBytes = sizeof(input_t);
|
||||
static_assert(kNBytes == 2 || kNBytes == 4);
|
||||
static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
|
||||
static_assert(kWidth <= kNElts);
|
||||
static constexpr bool kIsVecLoad = kIsVecLoad_;
|
||||
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
||||
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
||||
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>;
|
||||
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
||||
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, 1, cub::BLOCK_STORE_DIRECT>;
|
||||
static constexpr int kSmemIOSize = kIsVecLoad
|
||||
? 0
|
||||
: custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)});
|
||||
static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts;
|
||||
static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize;
|
||||
};
|
||||
|
||||
template<typename Ktraits>
|
||||
__global__ __launch_bounds__(Ktraits::kNThreads)
|
||||
void causal_conv1d_fwd_kernel(ConvParamsBase params) {
|
||||
constexpr int kWidth = Ktraits::kWidth;
|
||||
constexpr int kNThreads = Ktraits::kNThreads;
|
||||
constexpr int kNElts = Ktraits::kNElts;
|
||||
constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
|
||||
using input_t = typename Ktraits::input_t;
|
||||
using vec_t = typename Ktraits::vec_t;
|
||||
using weight_t = typename Ktraits::weight_t;
|
||||
|
||||
// Shared memory.
|
||||
extern __shared__ char smem_[];
|
||||
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
|
||||
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_);
|
||||
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
|
||||
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
|
||||
vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize);
|
||||
|
||||
const bool kVarlen = params.query_start_loc_ptr != nullptr;
|
||||
const int tidx = threadIdx.x;
|
||||
const int batch_id = blockIdx.x;
|
||||
const int channel_id = blockIdx.y;
|
||||
const int *query_start_loc = kVarlen ? reinterpret_cast<int *>(params.query_start_loc_ptr) : nullptr;
|
||||
const int sequence_start_index = kVarlen ? query_start_loc[batch_id] : batch_id;
|
||||
const int seqlen = kVarlen ? query_start_loc[batch_id + 1] - sequence_start_index : params.seqlen;
|
||||
|
||||
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + sequence_start_index * params.x_batch_stride
|
||||
+ channel_id * params.x_c_stride;
|
||||
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
|
||||
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + sequence_start_index * params.out_batch_stride
|
||||
+ channel_id * params.out_c_stride;
|
||||
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
|
||||
|
||||
bool has_initial_state = params.has_initial_state_ptr == nullptr ? false
|
||||
: reinterpret_cast<bool *>(params.has_initial_state_ptr)[batch_id];
|
||||
|
||||
int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
|
||||
: reinterpret_cast<int *>(params.cache_indices_ptr);
|
||||
int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
|
||||
// cache_index == params.pad_slot_id is defined as padding, so we exit early
|
||||
if (cache_index == params.pad_slot_id){
|
||||
return;
|
||||
}
|
||||
input_t *conv_states = params.conv_states_ptr == nullptr ? nullptr
|
||||
: reinterpret_cast<input_t *>(params.conv_states_ptr) + cache_index * params.conv_states_batch_stride + channel_id * params.conv_states_c_stride;
|
||||
|
||||
// Thread 0 will load the last elements of the previous chunk, so we initialize those to 0.
|
||||
if (tidx == 0) {
|
||||
input_t initial_state[kNElts] = {0};
|
||||
if (has_initial_state) {
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth - 1; ++w){ initial_state[kNElts - 1 - (kWidth - 2) + w ] = conv_states[w]; }
|
||||
}
|
||||
smem_exchange[kNThreads - 1] = reinterpret_cast<vec_t *>(initial_state)[0];
|
||||
}
|
||||
|
||||
float weight_vals[kWidth];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
|
||||
|
||||
constexpr int kChunkSize = kNThreads * kNElts;
|
||||
const int n_chunks = (seqlen + kChunkSize - 1) / kChunkSize;
|
||||
for (int chunk = 0; chunk < n_chunks; ++chunk) {
|
||||
input_t x_vals_load[2 * kNElts] = {0};
|
||||
if constexpr(kIsVecLoad) {
|
||||
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (seqlen - chunk * kChunkSize) / kNElts);
|
||||
} else {
|
||||
__syncthreads();
|
||||
typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), seqlen - chunk * kChunkSize);
|
||||
}
|
||||
x += kChunkSize;
|
||||
__syncthreads();
|
||||
// Thread kNThreads - 1 don't write yet, so that thread 0 can read
|
||||
// the last elements of the previous chunk.
|
||||
if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
|
||||
__syncthreads();
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1];
|
||||
__syncthreads();
|
||||
// Now thread kNThreads - 1 can write the last elements of the current chunk.
|
||||
if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
|
||||
|
||||
float x_vals[2 * kNElts];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); }
|
||||
|
||||
float out_vals[kNElts];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNElts; ++i) {
|
||||
out_vals[i] = bias_val;
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth; ++w) {
|
||||
out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
|
||||
}
|
||||
}
|
||||
|
||||
if (params.silu_activation) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNElts; ++i) {
|
||||
out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i]));
|
||||
}
|
||||
}
|
||||
|
||||
input_t out_vals_store[kNElts];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; }
|
||||
if constexpr(kIsVecLoad) {
|
||||
typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(out), reinterpret_cast<vec_t (&)[1]>(out_vals_store), (seqlen - chunk * kChunkSize) / kNElts);
|
||||
} else {
|
||||
typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, seqlen - chunk * kChunkSize);
|
||||
}
|
||||
out += kChunkSize;
|
||||
|
||||
int final_state_position = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize);
|
||||
// in case the final state is separated between the last "smem_exchange" and
|
||||
// and the one before it (chunk = n_chunks - 1 and chunk = n_chunks - 2),
|
||||
// (which occurs when `final_state_position` is a non-positive index)
|
||||
// we load the correct data from smem_exchange from both chunks, the last chunk iteration and the one before it
|
||||
if (conv_states != nullptr && final_state_position < 0 && seqlen > kWidth){
|
||||
input_t vals_load[kNElts] = {0};
|
||||
if ((chunk == n_chunks - 2) && (tidx == kNThreads - 1)){
|
||||
// chunk = n_chunks - 2, a segment of the final state sits in the last index
|
||||
reinterpret_cast<vec_t *>(vals_load)[0] = smem_exchange[kNThreads - 1];
|
||||
#pragma unroll
|
||||
for (int w = 0; w < -final_state_position; ++w){
|
||||
conv_states[w] = vals_load[kNElts + final_state_position + w];
|
||||
}
|
||||
}
|
||||
if ((chunk == n_chunks - 1) && tidx == 0){
|
||||
// chunk = n_chunks - 1, the second segment of the final state first positions
|
||||
reinterpret_cast<vec_t *>(vals_load)[0] = smem_exchange[0];
|
||||
for (int w = -final_state_position; w < kWidth - 1; ++w){
|
||||
conv_states[w] = vals_load[w + final_state_position];
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Final state is stored in the smem_exchange last token slot,
|
||||
// in case seqlen < kWidth, we would need to take the final state from the
|
||||
// initial state which is stored in conv_states
|
||||
// in case seqlen > kWidth, we would need to load the last kWidth - 1 data
|
||||
// and load it into conv_state accordingly
|
||||
int last_thread = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize) / kNElts;
|
||||
if (conv_states != nullptr && tidx == last_thread) {
|
||||
input_t x_vals_load[kNElts * 2] = {0};
|
||||
// in case we are on the first kWidth tokens
|
||||
if (last_thread == 0 && seqlen < kWidth){
|
||||
// Need to take the initial state
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[0];
|
||||
const int offset = seqlen - (kWidth - 1);
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth - 1; ++w){
|
||||
// pad the existing state
|
||||
if ((w - seqlen) >= 0 && has_initial_state) { conv_states[w - seqlen] = conv_states[w]; }
|
||||
else if ((w - seqlen) >= 0 && !has_initial_state) { conv_states[w - seqlen] = input_t(0.0f); }
|
||||
}
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth - 1; ++w){
|
||||
if (offset + w >= 0)
|
||||
conv_states[w] = x_vals_load[offset + w ];
|
||||
}
|
||||
}
|
||||
else {
|
||||
// in case the final state is in between the threads data
|
||||
const int offset = ((seqlen - (kWidth - 1)) % (kNElts));
|
||||
if ((offset + kWidth - 2) >= kNElts && (last_thread + 1 < kNThreads)){
|
||||
// In case last_thread == kNThreads - 1, accessing last_thread + 1 will result in a
|
||||
// illegal access error on H100.
|
||||
// Therefore, we access last_thread + 1, only if the final state data sits there
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[1] = smem_exchange[last_thread + 1];
|
||||
}
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[last_thread];
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth - 1; ++w){
|
||||
conv_states[w] = x_vals_load[offset + w ];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
||||
void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
|
||||
const bool kVarlen = params.query_start_loc_ptr != nullptr;
|
||||
BOOL_SWITCH(params.seqlen % kNElts == 0 && !kVarlen, kIsVecLoad, [&] {
|
||||
using Ktraits = Causal_conv1d_fwd_kernel_traits<kNThreads, kWidth, kIsVecLoad, input_t, weight_t>;
|
||||
constexpr int kSmemSize = Ktraits::kSmemSize;
|
||||
dim3 grid(params.batch, params.dim);
|
||||
|
||||
auto kernel = &causal_conv1d_fwd_kernel<Ktraits>;
|
||||
|
||||
if (kSmemSize >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
|
||||
}
|
||||
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
||||
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
}
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
if (params.width == 2) {
|
||||
causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 3) {
|
||||
causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 4) {
|
||||
causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template void causal_conv1d_fwd_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_fwd_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
|
||||
|
||||
|
||||
|
||||
template<int kNThreads_, int kWidth_, typename input_t_, typename weight_t_>
|
||||
struct Causal_conv1d_update_kernel_traits {
|
||||
using input_t = input_t_;
|
||||
using weight_t = weight_t_;
|
||||
static constexpr int kNThreads = kNThreads_;
|
||||
static constexpr int kWidth = kWidth_;
|
||||
static constexpr int kNBytes = sizeof(input_t);
|
||||
static_assert(kNBytes == 2 || kNBytes == 4);
|
||||
};
|
||||
|
||||
template<typename Ktraits, bool kIsCircularBuffer>
|
||||
__global__ __launch_bounds__(Ktraits::kNThreads)
|
||||
void causal_conv1d_update_kernel(ConvParamsBase params) {
|
||||
constexpr int kWidth = Ktraits::kWidth;
|
||||
constexpr int kNThreads = Ktraits::kNThreads;
|
||||
using input_t = typename Ktraits::input_t;
|
||||
using weight_t = typename Ktraits::weight_t;
|
||||
|
||||
const int tidx = threadIdx.x;
|
||||
const int batch_id = blockIdx.x;
|
||||
const int channel_id = blockIdx.y * kNThreads + tidx;
|
||||
if (channel_id >= params.dim) return;
|
||||
|
||||
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
||||
+ channel_id * params.x_c_stride;
|
||||
|
||||
// If params.conv_state_batch_indices is set, then the conv state is gathered from the conv state tensor
|
||||
// along the batch axis. Otherwise, the conv state coordinate is the same as the batch id.
|
||||
const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr
|
||||
? batch_id
|
||||
: params.conv_state_indices_ptr[batch_id];
|
||||
// conv_state_batch_coord == params.pad_slot_id is defined as padding so we exit early
|
||||
if (conv_state_batch_coord == params.pad_slot_id){
|
||||
return;
|
||||
}
|
||||
input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr)
|
||||
+ conv_state_batch_coord * params.conv_state_batch_stride
|
||||
+ channel_id * params.conv_state_c_stride;
|
||||
|
||||
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
|
||||
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
||||
+ channel_id * params.out_c_stride;
|
||||
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
|
||||
|
||||
int state_len = params.conv_state_len;
|
||||
int advance_len = params.seqlen;
|
||||
int cache_seqlen = kIsCircularBuffer ? params.cache_seqlens[batch_id] % state_len : 0;
|
||||
int update_idx = cache_seqlen - (kWidth - 1);
|
||||
update_idx = update_idx < 0 ? update_idx + state_len : update_idx;
|
||||
|
||||
float weight_vals[kWidth] = {0};
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
|
||||
|
||||
float x_vals[kWidth] = {0};
|
||||
if constexpr (!kIsCircularBuffer) {
|
||||
#pragma unroll 2
|
||||
for (int i = 0; i < state_len - advance_len - (kWidth - 1); ++i) {
|
||||
conv_state[i * params.conv_state_l_stride] = conv_state[(i + advance_len) * params.conv_state_l_stride];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth - 1; ++i) {
|
||||
input_t state_val = conv_state[(state_len - (kWidth - 1) + i) * params.conv_state_l_stride];
|
||||
if (i < advance_len + (kWidth - 1) && state_len - advance_len - (kWidth - 1) + i >= 0) {
|
||||
conv_state[(state_len - advance_len - (kWidth - 1) + i) * params.conv_state_l_stride] = state_val;
|
||||
}
|
||||
x_vals[i] = float(state_val);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth - 1; ++i, update_idx = update_idx + 1 >= state_len ? update_idx + 1 - state_len : update_idx + 1) {
|
||||
input_t state_val = conv_state[update_idx * params.conv_state_l_stride];
|
||||
x_vals[i] = float(state_val);
|
||||
}
|
||||
}
|
||||
#pragma unroll 2
|
||||
for (int i = 0; i < params.seqlen; ++i) {
|
||||
input_t x_val = x[i * params.x_l_stride];
|
||||
if constexpr (!kIsCircularBuffer) {
|
||||
if (i < advance_len && state_len - advance_len + i >= 0) {
|
||||
conv_state[(state_len - advance_len + i) * params.conv_state_l_stride] = x_val;
|
||||
}
|
||||
} else {
|
||||
conv_state[update_idx * params.conv_state_l_stride] = x_val;
|
||||
++update_idx;
|
||||
update_idx = update_idx >= state_len ? update_idx - state_len : update_idx;
|
||||
}
|
||||
x_vals[kWidth - 1] = float(x_val);
|
||||
float out_val = bias_val;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < kWidth; ++j) { out_val += weight_vals[j] * x_vals[j]; }
|
||||
if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); }
|
||||
out[i * params.out_l_stride] = input_t(out_val);
|
||||
// Shift the input buffer by 1
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = x_vals[i + 1]; }
|
||||
}
|
||||
}
|
||||
|
||||
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
||||
void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
using Ktraits = Causal_conv1d_update_kernel_traits<kNThreads, kWidth, input_t, weight_t>;
|
||||
dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads);
|
||||
auto kernel = params.cache_seqlens == nullptr
|
||||
? &causal_conv1d_update_kernel<Ktraits, false>
|
||||
: &causal_conv1d_update_kernel<Ktraits, true>;
|
||||
kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
if (params.width == 2) {
|
||||
causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 3) {
|
||||
causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 4) {
|
||||
causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream);
|
||||
}
|
||||
}
|
||||
|
||||
template void causal_conv1d_update_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_update_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_update_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
@ -1,159 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
// clang-format off
|
||||
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d.h
|
||||
#pragma once
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct ConvParamsBase {
|
||||
using index_t = uint32_t;
|
||||
|
||||
int batch, dim, seqlen, width;
|
||||
int64_t pad_slot_id;
|
||||
bool silu_activation;
|
||||
|
||||
index_t x_batch_stride;
|
||||
index_t x_c_stride;
|
||||
index_t x_l_stride;
|
||||
index_t weight_c_stride;
|
||||
index_t weight_width_stride;
|
||||
index_t out_batch_stride;
|
||||
index_t out_c_stride;
|
||||
index_t out_l_stride;
|
||||
|
||||
int conv_state_len;
|
||||
index_t conv_state_batch_stride;
|
||||
index_t conv_state_c_stride;
|
||||
index_t conv_state_l_stride;
|
||||
|
||||
// Common data pointers.
|
||||
void *__restrict__ x_ptr;
|
||||
void *__restrict__ weight_ptr;
|
||||
void *__restrict__ bias_ptr;
|
||||
void *__restrict__ out_ptr;
|
||||
|
||||
void *__restrict__ conv_state_ptr;
|
||||
void *__restrict__ query_start_loc_ptr;
|
||||
void *__restrict__ has_initial_state_ptr;
|
||||
void *__restrict__ cache_indices_ptr;
|
||||
int32_t *__restrict__ cache_seqlens;
|
||||
|
||||
// For the continuous batching case. Makes it so that the mamba state for
|
||||
// the current batch doesn't need to be a contiguous tensor.
|
||||
int32_t *__restrict__ conv_state_indices_ptr;
|
||||
|
||||
void *__restrict__ seq_idx_ptr;
|
||||
|
||||
// No __restrict__ since initial_states could be the same as final_states.
|
||||
void * initial_states_ptr;
|
||||
index_t initial_states_batch_stride;
|
||||
index_t initial_states_l_stride;
|
||||
index_t initial_states_c_stride;
|
||||
|
||||
void * final_states_ptr;
|
||||
index_t final_states_batch_stride;
|
||||
index_t final_states_l_stride;
|
||||
index_t final_states_c_stride;
|
||||
|
||||
void * conv_states_ptr;
|
||||
index_t conv_states_batch_stride;
|
||||
index_t conv_states_l_stride;
|
||||
index_t conv_states_c_stride;
|
||||
};
|
||||
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
template<typename T>
|
||||
__device__ inline T shuffle_xor(T val, int offset) {
|
||||
return __shfl_xor_sync(uint32_t(-1), val, offset);
|
||||
}
|
||||
|
||||
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
||||
{
|
||||
return std::max(ilist);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr T constexpr_min(T a, T b) {
|
||||
return std::min(a, b);
|
||||
}
|
||||
|
||||
#else
|
||||
#include <hip/hip_bf16.h>
|
||||
|
||||
template<typename T>
|
||||
__device__ inline T shuffle_xor(T val, int offset) {
|
||||
return __shfl_xor(val, offset);
|
||||
}
|
||||
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
||||
{
|
||||
return *std::max_element(ilist.begin(), ilist.end());
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr T constexpr_min(T a, T b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<int BYTES> struct BytesToType {};
|
||||
|
||||
template<> struct BytesToType<16> {
|
||||
using Type = uint4;
|
||||
static_assert(sizeof(Type) == 16);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<8> {
|
||||
using Type = uint64_t;
|
||||
static_assert(sizeof(Type) == 8);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<4> {
|
||||
using Type = uint32_t;
|
||||
static_assert(sizeof(Type) == 4);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<2> {
|
||||
using Type = uint16_t;
|
||||
static_assert(sizeof(Type) == 2);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<1> {
|
||||
using Type = uint8_t;
|
||||
static_assert(sizeof(Type) == 1);
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
struct SumOp {
|
||||
__device__ inline T operator()(T const & x, T const & y) { return x + y; }
|
||||
};
|
||||
|
||||
template<int THREADS>
|
||||
struct Allreduce {
|
||||
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
|
||||
template<typename T, typename Operator>
|
||||
static __device__ inline T run(T x, Operator &op) {
|
||||
constexpr int OFFSET = THREADS / 2;
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
|
||||
return Allreduce<OFFSET>::run(x, op);
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct Allreduce<2> {
|
||||
template<typename T, typename Operator>
|
||||
static __device__ inline T run(T x, Operator &op) {
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
|
||||
return x;
|
||||
}
|
||||
};
|
||||
@ -1,28 +0,0 @@
|
||||
// Inspired by
|
||||
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
|
||||
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
|
||||
// clang-format off
|
||||
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/static_switch.h
|
||||
|
||||
#pragma once
|
||||
|
||||
/// @param COND - a boolean expression to switch by
|
||||
/// @param CONST_NAME - a name given for the constexpr bool variable.
|
||||
/// @param ... - code to execute for true and false
|
||||
///
|
||||
/// Usage:
|
||||
/// ```
|
||||
/// BOOL_SWITCH(flag, BoolConst, [&] {
|
||||
/// some_function<BoolConst>(...);
|
||||
/// });
|
||||
/// ```
|
||||
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
|
||||
[&] { \
|
||||
if (COND) { \
|
||||
static constexpr bool CONST_NAME = true; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
static constexpr bool CONST_NAME = false; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
||||
@ -1255,8 +1255,6 @@ __global__ void Marlin(
|
||||
if constexpr (has_zp && !is_zp_float) {
|
||||
if (is_new_zp) {
|
||||
if constexpr (group_blocks == -1) is_first_matmul_in_slice = false;
|
||||
FragB frag_zp_0;
|
||||
FragB frag_zp_1;
|
||||
int zp_quant_0, zp_quant_1;
|
||||
|
||||
if constexpr (w_type.size_bits() == 4) {
|
||||
|
||||
@ -239,7 +239,7 @@ void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
|
||||
torch::Tensor& output) // [num_tokens, hidden_size]
|
||||
{
|
||||
const int hidden_size = input.size(-1);
|
||||
const int num_tokens = output.numel() / hidden_size;
|
||||
const auto num_tokens = output.numel() / hidden_size;
|
||||
const int topk = input.size(1);
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
|
||||
@ -492,7 +492,7 @@ void topk_softmax(
|
||||
torch::Tensor& gating_output) // [num_tokens, num_experts]
|
||||
{
|
||||
const int num_experts = gating_output.size(-1);
|
||||
const int num_tokens = gating_output.numel() / num_experts;
|
||||
const auto num_tokens = gating_output.numel() / num_experts;
|
||||
const int topk = topk_weights.size(-1);
|
||||
|
||||
const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
|
||||
|
||||
16
csrc/ops.h
16
csrc/ops.h
@ -326,22 +326,6 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
|
||||
const std::optional<torch::Tensor>& has_initial_state,
|
||||
const torch::Tensor& ssm_states, int64_t pad_slot_id);
|
||||
|
||||
void causal_conv1d_update(const at::Tensor& x, const at::Tensor& conv_state,
|
||||
const at::Tensor& weight,
|
||||
const std::optional<at::Tensor>& bias_,
|
||||
bool silu_activation,
|
||||
const std::optional<at::Tensor>& cache_seqlens_,
|
||||
const std::optional<at::Tensor>& conv_state_indices_,
|
||||
int64_t pad_slot_id);
|
||||
|
||||
void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
|
||||
const std::optional<at::Tensor>& bias_,
|
||||
const std::optional<at::Tensor>& conv_states,
|
||||
const std::optional<at::Tensor>& query_start_loc,
|
||||
const std::optional<at::Tensor>& cache_indices,
|
||||
const std::optional<at::Tensor>& has_initial_state,
|
||||
bool silu_activation, int64_t pad_slot_id);
|
||||
|
||||
using fptr_t = int64_t;
|
||||
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
|
||||
torch::Tensor& rank_data, int64_t rank,
|
||||
|
||||
@ -162,10 +162,11 @@ __global__ void dynamic_scaled_int8_quant_kernel(
|
||||
|
||||
// calculate for absmax
|
||||
float thread_max = 0.f;
|
||||
for (int i = tid; i < hidden_size; i += stride) {
|
||||
const auto v = fabsf(static_cast<float>(row_in[i]));
|
||||
thread_max = fmaxf(thread_max, v);
|
||||
}
|
||||
vectorize_read_with_alignment<16>(
|
||||
row_in, hidden_size, tid, stride, [&] __device__(const scalar_t& src) {
|
||||
const float v = fabsf(static_cast<float>(src));
|
||||
thread_max = fmaxf(thread_max, v);
|
||||
});
|
||||
using BlockReduce = cub::BlockReduce<float, 256>;
|
||||
__shared__ typename BlockReduce::TempStorage tmp;
|
||||
float block_max = BlockReduce(tmp).Reduce(thread_max, cub::Max{}, blockDim.x);
|
||||
@ -232,9 +233,10 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
|
||||
|
||||
// 1. calculate min & max
|
||||
MinMax thread_mm;
|
||||
for (int i = tid; i < hidden_size; i += stride) {
|
||||
thread_mm += static_cast<float>(row_in[i]);
|
||||
}
|
||||
vectorize_read_with_alignment<16>(row_in, hidden_size, tid, stride,
|
||||
[&] __device__(const scalar_t& src) {
|
||||
thread_mm += static_cast<float>(src);
|
||||
});
|
||||
|
||||
using BlockReduce = cub::BlockReduce<MinMax, 256>;
|
||||
__shared__ typename BlockReduce::TempStorage tmp;
|
||||
|
||||
@ -51,7 +51,8 @@ struct cutlass_3x_gemm {
|
||||
// These are the minimum alignments needed for the kernels to compile
|
||||
static constexpr int AlignmentAB =
|
||||
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||
static constexpr int AlignmentCD = 4;
|
||||
static constexpr int AlignmentCD =
|
||||
128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
@ -144,4 +145,65 @@ struct cutlass_3x_gemm_sm100 {
|
||||
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
|
||||
};
|
||||
|
||||
template <typename ElementAB_, typename ElementD_,
|
||||
template <typename, typename, typename> typename Epilogue_,
|
||||
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
||||
typename EpilogueSchedule>
|
||||
struct cutlass_3x_gemm_sm120 {
|
||||
using ElementAB = ElementAB_;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentA =
|
||||
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
static constexpr int AlignmentB =
|
||||
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||
|
||||
using ElementC = void;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentC =
|
||||
128 / cutlass::sizeof_bits<ElementD_>::value;
|
||||
|
||||
using ElementD = ElementD_;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentD = AlignmentC;
|
||||
|
||||
using ElementAcc =
|
||||
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
||||
float>::type;
|
||||
using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>;
|
||||
|
||||
// MMA type
|
||||
using ElementAccumulator = float;
|
||||
|
||||
// Epilogue types
|
||||
using ElementBias = cutlass::half_t;
|
||||
using ElementCompute = float;
|
||||
using ElementAux = ElementD;
|
||||
using LayoutAux = LayoutD;
|
||||
using ElementAmax = float;
|
||||
|
||||
using EVTCompute = typename Epilogue::EVTCompute;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, TileShape,
|
||||
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC,
|
||||
ElementD, LayoutD, AlignmentD, EpilogueSchedule,
|
||||
EVTCompute>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, ElementAB,
|
||||
LayoutA, AlignmentA, ElementAB, LayoutB, AlignmentB,
|
||||
ElementAccumulator, TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
|
||||
};
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
@ -36,6 +36,12 @@ void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_sm120_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
|
||||
24
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu
Normal file
24
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu
Normal file
@ -0,0 +1,24 @@
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_sm120_fp8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_sm120_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_mm_sm120_fp8_epilogue<c3x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm120_fp8_epilogue<c3x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@ -0,0 +1,67 @@
|
||||
#pragma once
|
||||
|
||||
#include "scaled_mm.cuh"
|
||||
#include "cutlass_gemm_caller.cuh"
|
||||
|
||||
/**
|
||||
* This file defines Gemm kernel configurations for SM120 (fp8) based on the
|
||||
* Gemm shape.
|
||||
*/
|
||||
|
||||
namespace vllm {
|
||||
|
||||
using c3x::cutlass_gemm_caller;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm120_fp8_config_default {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
||||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_1, _1, _1>; // Only work with Shape<_1, _1, _1>
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm_sm120<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm120_fp8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
using Cutlass3xGemmDefault =
|
||||
typename sm120_fp8_config_default<InType, OutType,
|
||||
Epilogue>::Cutlass3xGemm;
|
||||
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
|
||||
template <template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm120_fp8_epilogue(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_gemm_sm120_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_gemm_sm120_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@ -0,0 +1,374 @@
|
||||
#include "core/registration.h"
|
||||
|
||||
#include <torch/all.h>
|
||||
#include <cutlass/arch/arch.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include <cassert>
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename ElementAB, typename ElementC, typename ElementAccumulator,
|
||||
typename LayoutSFA, typename LayoutSFB, typename ScaleConfig>
|
||||
__global__ void get_ggemm_starts(
|
||||
int32_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets,
|
||||
ElementC** out_offsets, ElementAccumulator** a_scale_offsets,
|
||||
ElementAccumulator** b_scale_offsets, ElementAB* a_base_as_int,
|
||||
ElementAB* b_base_as_int, ElementC* out_base_as_int,
|
||||
ElementAccumulator* a_scale_base_as_int,
|
||||
ElementAccumulator* b_scale_base_as_int, LayoutSFA* layout_sfa_base_as_int,
|
||||
LayoutSFB* layout_sfb_base_as_int, int* problem_sizes) {
|
||||
int expert_id = threadIdx.x;
|
||||
|
||||
if (expert_id >= gridDim.x * blockDim.x) {
|
||||
return;
|
||||
}
|
||||
|
||||
int m = problem_sizes[expert_id * 3];
|
||||
int n = problem_sizes[expert_id * 3 + 1];
|
||||
int k = problem_sizes[expert_id * 3 + 2];
|
||||
|
||||
int32_t expert_offset = expert_offsets[expert_id];
|
||||
int a_stride = expert_offset * k;
|
||||
int b_stride = expert_id * k * n;
|
||||
int a_scale_stride = expert_offset * k / 128;
|
||||
int b_scale_stride = expert_id * k * n / 128 / 128;
|
||||
|
||||
a_offsets[expert_id] = a_base_as_int + a_stride;
|
||||
b_offsets[expert_id] = b_base_as_int + b_stride;
|
||||
out_offsets[expert_id] = out_base_as_int + expert_offset * n;
|
||||
a_scale_offsets[expert_id] = a_scale_base_as_int + a_scale_stride;
|
||||
b_scale_offsets[expert_id] = b_scale_base_as_int + b_scale_stride;
|
||||
|
||||
LayoutSFA* layout_sfa_ptr = layout_sfa_base_as_int + expert_id;
|
||||
LayoutSFB* layout_sfb_ptr = layout_sfb_base_as_int + expert_id;
|
||||
|
||||
*layout_sfa_ptr =
|
||||
ScaleConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1));
|
||||
*layout_sfb_ptr =
|
||||
ScaleConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1));
|
||||
}
|
||||
|
||||
#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB, \
|
||||
ScaleConfig) \
|
||||
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
|
||||
get_ggemm_starts<cutlass::float_e4m3_t, C_TYPE, float, LayoutSFA, \
|
||||
LayoutSFB, ScaleConfig><<<1, num_experts, 0, stream>>>( \
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()), \
|
||||
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
|
||||
static_cast<float**>(a_scales_ptrs.data_ptr()), \
|
||||
static_cast<float**>(b_scales_ptrs.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t*>(b_tensors.data_ptr()), \
|
||||
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
|
||||
static_cast<float*>(a_scales.data_ptr()), \
|
||||
static_cast<float*>(b_scales.data_ptr()), \
|
||||
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()), \
|
||||
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr()), \
|
||||
static_cast<int*>(problem_sizes.data_ptr())); \
|
||||
}
|
||||
|
||||
template <typename LayoutSFA, typename LayoutSFB, typename ScaleConfig>
|
||||
void run_get_ggemm_starts(
|
||||
torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs,
|
||||
torch::Tensor& b_ptrs, torch::Tensor& out_ptrs,
|
||||
torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs,
|
||||
torch::Tensor const& a_tensors, torch::Tensor const& b_tensors,
|
||||
torch::Tensor out_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& layout_sfa,
|
||||
torch::Tensor const& layout_sfb, torch::Tensor const& problem_sizes) {
|
||||
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(out_tensors.size(1) % 128 == 0 or out_tensors.size(0) % 128 == 0);
|
||||
TORCH_CHECK(a_tensors.size(1) % 128 == 0 or a_tensors.size(0) % 128 == 0);
|
||||
|
||||
int num_experts = (int)expert_offsets.size(0);
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
|
||||
|
||||
if (false) {
|
||||
}
|
||||
__CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t, LayoutSFA,
|
||||
LayoutSFB, ScaleConfig)
|
||||
__CALL_GET_STARTS_KERNEL(torch::kFloat16, cutlass::half_t, LayoutSFA,
|
||||
LayoutSFB, ScaleConfig)
|
||||
else {
|
||||
TORCH_CHECK(false, "Unsupported output tensor type");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename OutType, typename ScheduleConfig, typename LayoutD>
|
||||
void run_blockwise_scaled_group_mm(
|
||||
torch::Tensor& out_ptrs, const torch::Tensor& a_ptrs,
|
||||
const torch::Tensor& b_ptrs, const torch::Tensor& a_scales_ptrs,
|
||||
const torch::Tensor& b_scales_ptrs, const torch::Tensor& stride_a,
|
||||
const torch::Tensor& stride_b, const torch::Tensor& stride_c,
|
||||
const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb,
|
||||
const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets) {
|
||||
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>;
|
||||
|
||||
// Types
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementC = OutType;
|
||||
using ElementD = ElementC;
|
||||
using ElementAccumulator = float;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = LayoutD;
|
||||
|
||||
// Alignments
|
||||
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
|
||||
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
||||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, typename ScheduleConfig::MmaTileShape,
|
||||
typename ScheduleConfig::ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
|
||||
ElementAccumulator, void, LayoutC*, AlignmentC, ElementD, LayoutC*,
|
||||
AlignmentC, typename ScheduleConfig::EpilogueSchedule>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, ElementA,
|
||||
cute::tuple<LayoutA*, typename ScheduleConfig::LayoutSFA*>,
|
||||
AlignmentA, ElementB,
|
||||
cute::tuple<LayoutB*, typename ScheduleConfig::LayoutSFB*>,
|
||||
AlignmentB, ElementAccumulator, typename ScheduleConfig::MmaTileShape,
|
||||
typename ScheduleConfig::ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
typename ScheduleConfig::KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel =
|
||||
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop,
|
||||
CollectiveEpilogue, void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
|
||||
|
||||
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
|
||||
int num_experts = (int)expert_offsets.size(0);
|
||||
|
||||
Gemm gemm_op;
|
||||
|
||||
// Mainloop Arguments
|
||||
typename GemmKernel::MainloopArguments mainloop_args{
|
||||
static_cast<const ElementA**>(a_ptrs.data_ptr()),
|
||||
static_cast<StrideA*>(stride_a.data_ptr()),
|
||||
static_cast<const ElementB**>(b_ptrs.data_ptr()),
|
||||
static_cast<StrideB*>(stride_b.data_ptr()),
|
||||
static_cast<const ElementAccumulator**>(a_scales_ptrs.data_ptr()),
|
||||
reinterpret_cast<typename ScheduleConfig::LayoutSFA*>(
|
||||
layout_sfa.data_ptr()),
|
||||
static_cast<const ElementAccumulator**>(b_scales_ptrs.data_ptr()),
|
||||
reinterpret_cast<typename ScheduleConfig::LayoutSFB*>(
|
||||
layout_sfb.data_ptr())};
|
||||
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = a_ptrs.get_device();
|
||||
hw_info.sm_count =
|
||||
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
|
||||
hw_info.device_id);
|
||||
|
||||
// Epilogue Arguments
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
{}, // epilogue.thread
|
||||
nullptr,
|
||||
static_cast<StrideC*>(stride_c.data_ptr()),
|
||||
static_cast<ElementD**>(out_ptrs.data_ptr()),
|
||||
static_cast<StrideC*>(stride_c.data_ptr())};
|
||||
|
||||
UnderlyingProblemShape* problem_sizes_as_shapes =
|
||||
static_cast<UnderlyingProblemShape*>(problem_sizes.data_ptr());
|
||||
|
||||
// Gemm Arguments
|
||||
typename GemmKernel::Arguments args{
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{num_experts, problem_sizes_as_shapes, nullptr},
|
||||
mainloop_args,
|
||||
epilogue_args,
|
||||
hw_info};
|
||||
|
||||
at::cuda::CUDAGuard device_guard{(char)a_ptrs.device().index()};
|
||||
const cudaStream_t stream =
|
||||
at::cuda::getCurrentCUDAStream(a_ptrs.get_device());
|
||||
|
||||
auto can_implement_status = gemm_op.can_implement(args);
|
||||
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
|
||||
"Failed to implement GEMM");
|
||||
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(a_ptrs.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
auto status = gemm_op.initialize(args, workspace.data_ptr(), stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
|
||||
|
||||
status = gemm_op.run(stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void blockwise_scaled_group_mm_dispatch_shape(
|
||||
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
|
||||
const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets) {
|
||||
struct MmaConfig {
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
|
||||
using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig<
|
||||
1, 128, 128, cute::UMMA::Major::K, cute::UMMA::Major::K>;
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using MmaTileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
};
|
||||
|
||||
int num_experts = (int)expert_offsets.size(0);
|
||||
|
||||
auto a_ptrs = torch::empty(
|
||||
{num_experts},
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
|
||||
auto b_ptrs = torch::empty(
|
||||
{num_experts},
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
|
||||
auto out_ptrs = torch::empty(
|
||||
{num_experts},
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
|
||||
auto a_scales_ptrs = torch::empty(
|
||||
{num_experts},
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
|
||||
auto b_scales_ptrs = torch::empty(
|
||||
{num_experts},
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
|
||||
|
||||
auto layout_sfa = torch::empty(
|
||||
{num_experts, 5},
|
||||
torch::TensorOptions().dtype(torch::kInt32).device(a.device()));
|
||||
auto layout_sfb = torch::empty(
|
||||
{num_experts, 5},
|
||||
torch::TensorOptions().dtype(torch::kInt32).device(a.device()));
|
||||
|
||||
auto stride_a = torch::full(
|
||||
{num_experts}, a.size(1),
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
|
||||
auto stride_b = torch::full(
|
||||
{num_experts}, a.size(1),
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
|
||||
auto stride_c = torch::full(
|
||||
{num_experts}, output.size(1),
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
|
||||
|
||||
torch::TensorOptions options_int =
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device());
|
||||
|
||||
run_get_ggemm_starts<typename MmaConfig::LayoutSFA,
|
||||
typename MmaConfig::LayoutSFB,
|
||||
typename MmaConfig::ScaleConfig>(
|
||||
expert_offsets, a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, a,
|
||||
b, output, scales_a, scales_b, layout_sfa, layout_sfb, problem_sizes);
|
||||
|
||||
run_blockwise_scaled_group_mm<OutType, MmaConfig,
|
||||
typename MmaConfig::LayoutC>(
|
||||
out_ptrs, a_ptrs, b_ptrs, a_scales_ptrs, b_scales_ptrs, stride_a,
|
||||
stride_b, stride_c, layout_sfa, layout_sfb, problem_sizes,
|
||||
expert_offsets);
|
||||
}
|
||||
|
||||
void cutlass_blockwise_scaled_grouped_mm(
|
||||
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
|
||||
const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets) {
|
||||
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
|
||||
TORCH_CHECK(problem_sizes.size(1) == 3,
|
||||
"problem_sizes must have shape (num_experts, 3)");
|
||||
TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0),
|
||||
"Number of experts in problem_sizes must match expert_offsets");
|
||||
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32,
|
||||
"problem_sizes must be int32");
|
||||
TORCH_CHECK(a.scalar_type() == torch::kFloat8_e4m3fn,
|
||||
"a must be kFloat8_e4m3fn");
|
||||
TORCH_CHECK(b.scalar_type() == torch::kFloat8_e4m3fn,
|
||||
"b must be kFloat8_e4m3fn");
|
||||
TORCH_CHECK(output.scalar_type() == torch::kBFloat16 ||
|
||||
output.scalar_type() == torch::kHalf,
|
||||
"output must be bfloat16 or half");
|
||||
TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32,
|
||||
"scales_a must be float32");
|
||||
TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32,
|
||||
"scales_b must be float32");
|
||||
TORCH_CHECK(expert_offsets.scalar_type() == torch::kInt32,
|
||||
"expert_offsets must be int32");
|
||||
|
||||
TORCH_CHECK(output.dim() == 2, "output must be 2D tensor");
|
||||
TORCH_CHECK(a.dim() == 2, "a must be 2D tensor");
|
||||
TORCH_CHECK(b.dim() == 3, "b must be 3D tensor");
|
||||
TORCH_CHECK(scales_a.dim() == 2, "scales_a must be 2D tensor");
|
||||
TORCH_CHECK(scales_b.dim() == 3, "scales_b must be 3D tensor");
|
||||
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
|
||||
TORCH_CHECK(problem_sizes.size(1) == 3,
|
||||
"problem_sizes must have shape (num_experts, 3)");
|
||||
TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0),
|
||||
"Number of experts in problem_sizes must match expert_offsets");
|
||||
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32,
|
||||
"problem_sizes must be int32");
|
||||
TORCH_CHECK(expert_offsets.dim() == 1, "expert_offsets must be 1D tensor");
|
||||
|
||||
#if defined(ENABLE_CUTLASS_MOE_SM100) && ENABLE_CUTLASS_MOE_SM100
|
||||
if (output.scalar_type() == torch::kBFloat16) {
|
||||
blockwise_scaled_group_mm_dispatch_shape<cutlass::bfloat16_t>(
|
||||
output, a, b, scales_a, scales_b, problem_sizes, expert_offsets);
|
||||
} else if (output.scalar_type() == torch::kFloat16) {
|
||||
blockwise_scaled_group_mm_dispatch_shape<cutlass::half_t>(
|
||||
output, a, b, scales_a, scales_b, problem_sizes, expert_offsets);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported output tensor type");
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("cutlass_blockwise_scaled_grouped_mm",
|
||||
&cutlass_blockwise_scaled_grouped_mm);
|
||||
}
|
||||
@ -7,7 +7,7 @@
|
||||
|
||||
constexpr uint64_t THREADS_PER_EXPERT = 512;
|
||||
|
||||
__global__ void compute_problem_sizes(const uint32_t* __restrict__ topk_ids,
|
||||
__global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids,
|
||||
int32_t* problem_sizes1,
|
||||
int32_t* problem_sizes2,
|
||||
int32_t* atomic_buffer,
|
||||
@ -62,7 +62,7 @@ __global__ void compute_expert_blockscale_offsets(
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void compute_arg_sorts(const uint32_t* __restrict__ topk_ids,
|
||||
__global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids,
|
||||
const int32_t* __restrict__ expert_offsets,
|
||||
int32_t* input_permutation,
|
||||
int32_t* output_permutation,
|
||||
@ -103,7 +103,7 @@ void get_cutlass_moe_mm_data_caller(
|
||||
|
||||
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
|
||||
compute_problem_sizes<<<num_experts, num_threads, 0, stream>>>(
|
||||
static_cast<const uint32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n, k);
|
||||
@ -120,7 +120,7 @@ void get_cutlass_moe_mm_data_caller(
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
|
||||
}
|
||||
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
|
||||
static_cast<const uint32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<const int32_t*>(expert_offsets.data_ptr()),
|
||||
static_cast<int32_t*>(input_permutation.data_ptr()),
|
||||
static_cast<int32_t*>(output_permutation.data_ptr()),
|
||||
|
||||
34
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu
Normal file
34
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu
Normal file
@ -0,0 +1,34 @@
|
||||
#include <cudaTypedefs.h>
|
||||
#include "c3x/scaled_mm_kernels.hpp"
|
||||
|
||||
#include "cuda_utils.h"
|
||||
|
||||
/*
|
||||
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
||||
NVIDIA GPUs with sm120 (Blackwell Geforce).
|
||||
*/
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
|
||||
|
||||
void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
int M = a.size(0), N = b.size(1), K = a.size(1);
|
||||
TORCH_CHECK(
|
||||
(a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
|
||||
(b_scales.numel() == 1 || b_scales.numel() == b.size(1)),
|
||||
"Currently, block scaled fp8 gemm is not implemented for Blackwell");
|
||||
|
||||
// Standard per-tensor/per-token/per-channel scaling
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn,
|
||||
"Currently, only fp8 gemm is implemented for Blackwell");
|
||||
vllm::cutlass_scaled_mm_sm120_fp8(c, a, b, a_scales, b_scales, bias);
|
||||
}
|
||||
|
||||
#endif
|
||||
@ -41,6 +41,14 @@ void cutlass_moe_mm_sm90(
|
||||
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
|
||||
void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
|
||||
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
@ -168,8 +176,15 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
||||
int32_t version_num = get_sm_version_num();
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
|
||||
if (version_num >= 120) {
|
||||
cutlass_scaled_mm_sm120(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
|
||||
if (version_num >= 100) {
|
||||
if (version_num >= 100 && version_num < 120) {
|
||||
cutlass_scaled_mm_sm100(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
|
||||
@ -335,8 +335,10 @@ void run_fp4_blockwise_scaled_group_mm(
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
|
||||
}
|
||||
|
||||
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
|
||||
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
|
||||
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
|
||||
#endif
|
||||
|
||||
#define CHECK_TYPE(x, st, m) \
|
||||
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
|
||||
|
||||
@ -1113,8 +1113,6 @@ __global__ void Marlin(
|
||||
if constexpr (has_zp && !is_zp_float) {
|
||||
if (is_new_zp) {
|
||||
if constexpr (group_blocks == -1) is_first_matmul_in_slice = false;
|
||||
FragB frag_zp_0;
|
||||
FragB frag_zp_1;
|
||||
int zp_quant_0, zp_quant_1;
|
||||
|
||||
if constexpr (w_type.size_bits() == 4) {
|
||||
|
||||
@ -38,7 +38,6 @@
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cute/atom/copy_traits_sm90_tma.hpp"
|
||||
#include "cute/algorithm/gemm.hpp"
|
||||
#include "cute/tensor_predicate.hpp"
|
||||
#include "cute/numeric/arithmetic_tuple.hpp"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cutlass/transform/collective/sm90_wgmma_transpose.hpp"
|
||||
|
||||
@ -27,6 +27,26 @@ __device__ inline void vectorize_with_alignment(
|
||||
constexpr int WIDTH = VEC_SIZE * sizeof(InT); // eg: 64 B
|
||||
uintptr_t addr = reinterpret_cast<uintptr_t>(in);
|
||||
|
||||
// fast path when the whole region is already aligned
|
||||
// Note: currently the output is guaranteed to be same as the input, so we
|
||||
// don't check it here, comments here just for future reference.
|
||||
bool can_vec = ((addr & (WIDTH - 1)) == 0) && ((len & (VEC_SIZE - 1)) == 0);
|
||||
if (can_vec) {
|
||||
int num_vec = len / VEC_SIZE;
|
||||
|
||||
using vin_t = vec_n_t<InT, VEC_SIZE>;
|
||||
using vout_t = vec_n_t<OutT, VEC_SIZE>;
|
||||
auto* v_in = reinterpret_cast<const vin_t*>(in);
|
||||
auto* v_out = reinterpret_cast<vout_t*>(out);
|
||||
|
||||
for (int i = tid; i < num_vec; i += stride) {
|
||||
vout_t tmp;
|
||||
vec_op(tmp, v_in[i]);
|
||||
v_out[i] = tmp;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
int misalignment_offset = addr & (WIDTH - 1); // addr % 64
|
||||
int alignment_bytes = WIDTH - misalignment_offset; // 64 - (addr % 64)
|
||||
int prefix_elems = alignment_bytes & (WIDTH - 1); // handle 64
|
||||
@ -72,4 +92,81 @@ __device__ __forceinline__ void vectorize_with_alignment(const InT* in,
|
||||
std::forward<ScaOp>(scalar_op));
|
||||
}
|
||||
|
||||
template <int VEC_SIZE, typename InT, typename ScaOp>
|
||||
struct DefaultReadVecOp {
|
||||
ScaOp scalar_op;
|
||||
|
||||
__device__ __forceinline__ void operator()(
|
||||
const vec_n_t<InT, VEC_SIZE>& src) const {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VEC_SIZE; ++i) {
|
||||
scalar_op(src.val[i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// read-only version: iterate over the input with alignment guarantees
|
||||
template <int VEC_SIZE, typename InT, typename VecOp, typename ScaOp>
|
||||
__device__ inline void vectorize_read_with_alignment(const InT* in, int len,
|
||||
int tid, int stride,
|
||||
VecOp&& vec_op,
|
||||
ScaOp&& scalar_op) {
|
||||
static_assert(VEC_SIZE > 0 && (VEC_SIZE & (VEC_SIZE - 1)) == 0,
|
||||
"VEC_SIZE must be a positive power-of-two");
|
||||
constexpr int WIDTH = VEC_SIZE * sizeof(InT);
|
||||
uintptr_t addr = reinterpret_cast<uintptr_t>(in);
|
||||
|
||||
// fast path when the whole region is already aligned
|
||||
bool can_vec = ((addr & (WIDTH - 1)) == 0) && ((len & (VEC_SIZE - 1)) == 0);
|
||||
if (can_vec) {
|
||||
int num_vec = len / VEC_SIZE;
|
||||
|
||||
using vin_t = vec_n_t<InT, VEC_SIZE>;
|
||||
auto* v_in = reinterpret_cast<const vin_t*>(in);
|
||||
|
||||
for (int i = tid; i < num_vec; i += stride) {
|
||||
vec_op(v_in[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
int misalignment_offset = addr & (WIDTH - 1);
|
||||
int alignment_bytes = WIDTH - misalignment_offset;
|
||||
int prefix_elems = alignment_bytes & (WIDTH - 1);
|
||||
prefix_elems /= sizeof(InT);
|
||||
prefix_elems = min(prefix_elems, len);
|
||||
|
||||
// 1. handle the possibly unaligned prefix with scalar access.
|
||||
for (int i = tid; i < prefix_elems; i += stride) {
|
||||
scalar_op(in[i]);
|
||||
}
|
||||
|
||||
in += prefix_elems;
|
||||
len -= prefix_elems;
|
||||
|
||||
int num_vec = len / VEC_SIZE;
|
||||
using vin_t = vec_n_t<InT, VEC_SIZE>;
|
||||
auto* v_in = reinterpret_cast<const vin_t*>(in);
|
||||
|
||||
// 2. vectorized traversal of the main aligned region.
|
||||
for (int i = tid; i < num_vec; i += stride) {
|
||||
vec_op(v_in[i]);
|
||||
}
|
||||
|
||||
// 3. handle remaining tail elements.
|
||||
int tail_start = num_vec * VEC_SIZE;
|
||||
for (int i = tid + tail_start; i < len; i += stride) {
|
||||
scalar_op(in[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// overload that requires only a scalar_op
|
||||
template <int VEC_SIZE, typename InT, typename ScaOp>
|
||||
__device__ __forceinline__ void vectorize_read_with_alignment(
|
||||
const InT* in, int len, int tid, int stride, ScaOp&& scalar_op) {
|
||||
using Vec = DefaultReadVecOp<VEC_SIZE, InT, std::decay_t<ScaOp>>;
|
||||
vectorize_read_with_alignment<VEC_SIZE>(in, len, tid, stride, Vec{scalar_op},
|
||||
std::forward<ScaOp>(scalar_op));
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
@ -59,6 +59,8 @@ void apply_repetition_penalties_(
|
||||
int vocab_size = logits.size(-1);
|
||||
int num_seqs = logits.size(0);
|
||||
|
||||
if (num_seqs == 0) return;
|
||||
|
||||
// Get number of SMs on the current device
|
||||
int sms = 0;
|
||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount,
|
||||
|
||||
@ -79,7 +79,8 @@ struct cutlass_sparse_3x_gemm {
|
||||
// These are the minimum alignments needed for the kernels to compile
|
||||
static constexpr int AlignmentAB =
|
||||
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||
static constexpr int AlignmentCD = 4;
|
||||
static constexpr int AlignmentCD =
|
||||
128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
|
||||
@ -393,6 +393,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
{stride_tag});
|
||||
ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm);
|
||||
|
||||
// cutlass blockwise scaledgroup GEMM
|
||||
ops.def(
|
||||
"cutlass_blockwise_scaled_grouped_mm(Tensor! output, Tensor a, Tensor b, "
|
||||
"Tensor scales_a, Tensor scales_b, "
|
||||
"Tensor problem_sizes, Tensor expert_offsets) -> ()",
|
||||
{stride_tag});
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
// cutlass nvfp4 block scaled group GEMM
|
||||
ops.def(
|
||||
"cutlass_fp4_group_mm(Tensor! out, Tensor a, Tensor b,"
|
||||
@ -586,28 +594,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"int pad_slot_id) -> ()");
|
||||
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
|
||||
|
||||
ops.def(
|
||||
"causal_conv1d_update(Tensor! x,"
|
||||
"Tensor! conv_state,"
|
||||
"Tensor! weight,"
|
||||
"Tensor? bias_,"
|
||||
"bool silu_activation,"
|
||||
"Tensor? cache_seqlens_,"
|
||||
"Tensor? conv_state_indices,"
|
||||
"int pad_slot_id) -> ()");
|
||||
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
|
||||
|
||||
ops.def(
|
||||
"causal_conv1d_fwd(Tensor! x, Tensor! weight,"
|
||||
"Tensor? bias_,"
|
||||
"Tensor!? conv_states,"
|
||||
"Tensor? query_start_loc,"
|
||||
"Tensor? cache_indices,"
|
||||
"Tensor? has_initial_state,"
|
||||
"bool silu_activation,"
|
||||
"int pad_slot_id) -> ()");
|
||||
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
|
||||
ops.def(
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
|
||||
# The vLLM Dockerfile is used to construct vLLM image that can be directly used
|
||||
# to run the OpenAI compatible server.
|
||||
|
||||
@ -62,12 +63,16 @@ ARG PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL=https://download.pytorch.org/whl/nightly
|
||||
ARG PIP_KEYRING_PROVIDER=disabled
|
||||
ARG UV_KEYRING_PROVIDER=${PIP_KEYRING_PROVIDER}
|
||||
|
||||
# Flag enables build-in KV-connector dependency libs into docker images
|
||||
ARG INSTALL_KV_CONNECTORS=false
|
||||
|
||||
#################### BASE BUILD IMAGE ####################
|
||||
# prepare basic build environment
|
||||
FROM ${BUILD_BASE_IMAGE} AS base
|
||||
ARG CUDA_VERSION
|
||||
ARG PYTHON_VERSION
|
||||
ARG TARGETPLATFORM
|
||||
ARG INSTALL_KV_CONNECTORS=false
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
ARG DEADSNAKES_MIRROR_URL
|
||||
@ -276,6 +281,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
FROM ${FINAL_BASE_IMAGE} AS vllm-base
|
||||
ARG CUDA_VERSION
|
||||
ARG PYTHON_VERSION
|
||||
ARG INSTALL_KV_CONNECTORS=false
|
||||
WORKDIR /vllm-workspace
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ARG TARGETPLATFORM
|
||||
@ -373,24 +379,44 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
|
||||
ARG FLASHINFER_CUDA128_INDEX_URL="https://download.pytorch.org/whl/cu128/flashinfer"
|
||||
ARG FLASHINFER_CUDA128_WHEEL="flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl"
|
||||
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
|
||||
ARG FLASHINFER_GIT_REF="v0.2.6.post1"
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
. /etc/environment && \
|
||||
if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
|
||||
# FlashInfer already has a wheel for PyTorch 2.7.0 and CUDA 12.8. This is enough for CI use
|
||||
if [[ "$CUDA_VERSION" == 12.8* ]]; then \
|
||||
uv pip install --system ${FLASHINFER_CUDA128_INDEX_URL}/${FLASHINFER_CUDA128_WHEEL} ; \
|
||||
else \
|
||||
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0a 10.0a 12.0' && \
|
||||
git clone ${FLASHINFER_GIT_REPO} --single-branch --branch ${FLASHINFER_GIT_REF} --recursive && \
|
||||
# Needed to build AOT kernels
|
||||
(cd flashinfer && \
|
||||
python3 -m flashinfer.aot && \
|
||||
uv pip install --system --no-build-isolation . \
|
||||
) && \
|
||||
rm -rf flashinfer; \
|
||||
fi \
|
||||
fi
|
||||
ARG FLASHINFER_GIT_REF="v0.2.8rc1"
|
||||
# Flag to control whether to use pre-built FlashInfer wheels (set to false to force build from source)
|
||||
# TODO: Currently disabled because the pre-built wheels are not available for FLASHINFER_GIT_REF
|
||||
ARG USE_FLASHINFER_PREBUILT_WHEEL=false
|
||||
RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
|
||||
. /etc/environment
|
||||
if [ "$TARGETPLATFORM" != "linux/arm64" ]; then
|
||||
# FlashInfer already has a wheel for PyTorch 2.7.0 and CUDA 12.8. This is enough for CI use
|
||||
if [[ "$CUDA_VERSION" == 12.8* ]] && [[ "$USE_FLASHINFER_PREBUILT_WHEEL" == "true" ]]; then
|
||||
uv pip install --system ${FLASHINFER_CUDA128_INDEX_URL}/${FLASHINFER_CUDA128_WHEEL}
|
||||
else
|
||||
# Exclude CUDA arches for older versions (11.x and 12.0-12.7)
|
||||
# TODO: Update this to allow setting TORCH_CUDA_ARCH_LIST as a build arg.
|
||||
if [[ "${CUDA_VERSION}" == 11.* ]]; then
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9"
|
||||
elif [[ "${CUDA_VERSION}" == 12.[0-7]* ]]; then
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a"
|
||||
else
|
||||
# CUDA 12.8+ supports 10.0a and 12.0
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0"
|
||||
fi
|
||||
echo "🏗️ Building FlashInfer for arches: ${FI_TORCH_CUDA_ARCH_LIST}"
|
||||
|
||||
git clone --depth 1 --recursive --shallow-submodules \
|
||||
--branch ${FLASHINFER_GIT_REF} \
|
||||
${FLASHINFER_GIT_REPO} flashinfer
|
||||
|
||||
# Needed to build AOT kernels
|
||||
pushd flashinfer
|
||||
python3 -m flashinfer.aot
|
||||
TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \
|
||||
uv pip install --system --no-build-isolation .
|
||||
popd
|
||||
|
||||
rm -rf flashinfer
|
||||
fi \
|
||||
fi
|
||||
BASH
|
||||
COPY examples examples
|
||||
COPY benchmarks benchmarks
|
||||
COPY ./vllm/collect_env.py .
|
||||
@ -464,6 +490,7 @@ RUN mv mkdocs.yaml test_docs/
|
||||
# base openai image with additional requirements, for any subsequent openai-style images
|
||||
FROM vllm-base AS vllm-openai-base
|
||||
ARG TARGETPLATFORM
|
||||
ARG INSTALL_KV_CONNECTORS=false
|
||||
|
||||
ARG PIP_INDEX_URL UV_INDEX_URL
|
||||
ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL
|
||||
@ -472,12 +499,17 @@ ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
|
||||
COPY requirements/kv_connectors.txt requirements/kv_connectors.txt
|
||||
|
||||
# install additional dependencies for openai api server
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
if [ "$INSTALL_KV_CONNECTORS" = "true" ]; then \
|
||||
uv pip install --system -r requirements/kv_connectors.txt; \
|
||||
fi; \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
|
||||
uv pip install --system accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.42.0' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \
|
||||
else \
|
||||
uv pip install --system accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.45.3' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \
|
||||
uv pip install --system accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.46.1' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \
|
||||
fi
|
||||
|
||||
ENV VLLM_USAGE_SOURCE production-docker-image
|
||||
|
||||
@ -8,6 +8,8 @@
|
||||
# Build arguments:
|
||||
# PYTHON_VERSION=3.12 (default)|3.11|3.10|3.9
|
||||
# VLLM_CPU_DISABLE_AVX512=false (default)|true
|
||||
# VLLM_CPU_AVX512BF16=false (default)|true
|
||||
# VLLM_CPU_AVX512VNNI=false (default)|true
|
||||
#
|
||||
|
||||
######################### BASE IMAGE #########################
|
||||
@ -25,7 +27,7 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
||||
apt-get update -y \
|
||||
&& apt-get install -y --no-install-recommends ccache git curl wget ca-certificates \
|
||||
gcc-12 g++-12 libtcmalloc-minimal4 libnuma-dev ffmpeg libsm6 libxext6 libgl1 \
|
||||
gcc-12 g++-12 libtcmalloc-minimal4 libnuma-dev ffmpeg libsm6 libxext6 libgl1 jq lsof \
|
||||
&& update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 \
|
||||
&& curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
@ -60,8 +62,14 @@ FROM base AS vllm-build
|
||||
|
||||
ARG GIT_REPO_CHECK=0
|
||||
# Support for building with non-AVX512 vLLM: docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" ...
|
||||
ARG VLLM_CPU_DISABLE_AVX512
|
||||
ARG VLLM_CPU_DISABLE_AVX512=0
|
||||
ENV VLLM_CPU_DISABLE_AVX512=${VLLM_CPU_DISABLE_AVX512}
|
||||
# Support for building with AVX512BF16 ISA: docker build --build-arg VLLM_CPU_AVX512BF16="true" ...
|
||||
ARG VLLM_CPU_AVX512BF16=0
|
||||
ENV VLLM_CPU_AVX512BF16=${VLLM_CPU_AVX512BF16}
|
||||
# Support for building with AVX512VNNI ISA: docker build --build-arg VLLM_CPU_AVX512VNNI="true" ...
|
||||
ARG VLLM_CPU_AVX512VNNI=0
|
||||
ENV VLLM_CPU_AVX512VNNI=${VLLM_CPU_AVX512VNNI}
|
||||
|
||||
WORKDIR /workspace/vllm
|
||||
|
||||
@ -134,6 +142,7 @@ ADD ./tests/ ./tests/
|
||||
ADD ./examples/ ./examples/
|
||||
ADD ./benchmarks/ ./benchmarks/
|
||||
ADD ./vllm/collect_env.py .
|
||||
ADD ./.buildkite/ ./.buildkite/
|
||||
|
||||
# install development dependencies (for testing)
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
|
||||
@ -47,7 +47,7 @@ FROM vllm-base AS vllm-openai
|
||||
|
||||
# install additional dependencies for openai api server
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install accelerate hf_transfer 'modelscope!=1.15.0'
|
||||
pip install accelerate hf_transfer pytest 'modelscope!=1.15.0'
|
||||
|
||||
ENV VLLM_USAGE_SOURCE production-docker-image \
|
||||
TRITON_XPU_PROFILE 1
|
||||
|
||||
@ -39,6 +39,7 @@ nav:
|
||||
- models/generative_models.md
|
||||
- models/pooling_models.md
|
||||
- models/extensions
|
||||
- Hardware Supported Models: models/hardware_supported_models
|
||||
- Features:
|
||||
- features/compatibility_matrix.md
|
||||
- features/*
|
||||
@ -54,6 +55,7 @@ nav:
|
||||
- contributing/model/registration.md
|
||||
- contributing/model/tests.md
|
||||
- contributing/model/multimodal.md
|
||||
- CI: contributing/ci
|
||||
- Design Documents:
|
||||
- V0: design
|
||||
- V1: design/v1
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
# Welcome to vLLM
|
||||
|
||||
<figure markdown="span">
|
||||
{ align="center" alt="vLLM" class="no-scaled-link" width="60%" }
|
||||
{ align="center" alt="vLLM Light" class="logo-light" width="60%" }
|
||||
{ align="center" alt="vLLM Dark" class="logo-dark" width="60%" }
|
||||
</figure>
|
||||
|
||||
<p style="text-align:center">
|
||||
@ -47,4 +48,4 @@ For more information, check out the following:
|
||||
- [vLLM announcing blog post](https://vllm.ai) (intro to PagedAttention)
|
||||
- [vLLM paper](https://arxiv.org/abs/2309.06180) (SOSP 2023)
|
||||
- [How continuous batching enables 23x throughput in LLM inference while reducing p50 latency](https://www.anyscale.com/blog/continuous-batching-llm-inference) by Cade Daniel et al.
|
||||
- [vLLM Meetups][meetups]
|
||||
- [vLLM Meetups](community/meetups.md)
|
||||
|
||||
@ -64,7 +64,7 @@ vLLM provides experimental support for multi-modal models through the [vllm.mult
|
||||
Multi-modal inputs can be passed alongside text and token prompts to [supported models][supported-mm-models]
|
||||
via the `multi_modal_data` field in [vllm.inputs.PromptType][].
|
||||
|
||||
Looking to add your own multi-modal model? Please follow the instructions listed [here][supports-multimodal].
|
||||
Looking to add your own multi-modal model? Please follow the instructions listed [here](../contributing/model/multimodal.md).
|
||||
|
||||
- [vllm.multimodal.MULTIMODAL_REGISTRY][]
|
||||
|
||||
|
||||
@ -16,7 +16,7 @@ vllm {chat,complete,serve,bench,collect-env,run-batch}
|
||||
|
||||
Start the vLLM OpenAI Compatible API server.
|
||||
|
||||
??? Examples
|
||||
??? console "Examples"
|
||||
|
||||
```bash
|
||||
# Start with a model
|
||||
|
||||
@ -1,6 +1,3 @@
|
||||
---
|
||||
title: Contact Us
|
||||
---
|
||||
[](){ #contactus }
|
||||
# Contact Us
|
||||
|
||||
--8<-- "README.md:contact-us"
|
||||
|
||||
@ -1,7 +1,4 @@
|
||||
---
|
||||
title: Meetups
|
||||
---
|
||||
[](){ #meetups }
|
||||
# Meetups
|
||||
|
||||
We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below:
|
||||
|
||||
|
||||
@ -33,7 +33,7 @@ Quantized models take less memory at the cost of lower precision.
|
||||
Statically quantized models can be downloaded from HF Hub (some popular ones are available at [Red Hat AI](https://huggingface.co/RedHatAI))
|
||||
and used directly without extra configuration.
|
||||
|
||||
Dynamic quantization is also supported via the `quantization` option -- see [here][quantization-index] for more details.
|
||||
Dynamic quantization is also supported via the `quantization` option -- see [here](../features/quantization/README.md) for more details.
|
||||
|
||||
## Context length and batch size
|
||||
|
||||
@ -57,7 +57,7 @@ By default, we optimize model inference using CUDA graphs which take up extra me
|
||||
|
||||
You can adjust `compilation_config` to achieve a better balance between inference speed and memory usage:
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
@ -129,7 +129,7 @@ reduce the size of the processed multi-modal inputs, which in turn saves memory.
|
||||
|
||||
Here are some examples:
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
|
||||
@ -1,18 +1,20 @@
|
||||
---
|
||||
title: Engine Arguments
|
||||
toc_depth: 3
|
||||
---
|
||||
[](){ #engine-args }
|
||||
|
||||
# Engine Arguments
|
||||
|
||||
Engine arguments control the behavior of the vLLM engine.
|
||||
|
||||
- For [offline inference][offline-inference], they are part of the arguments to [LLM][vllm.LLM] class.
|
||||
- For [online serving][openai-compatible-server], they are part of the arguments to `vllm serve`.
|
||||
- For [offline inference](../serving/offline_inference.md), they are part of the arguments to [LLM][vllm.LLM] class.
|
||||
- For [online serving](../serving/openai_compatible_server.md), they are part of the arguments to `vllm serve`.
|
||||
|
||||
You can look at [EngineArgs][vllm.engine.arg_utils.EngineArgs] and [AsyncEngineArgs][vllm.engine.arg_utils.AsyncEngineArgs] to see the available engine arguments.
|
||||
The engine argument classes, [EngineArgs][vllm.engine.arg_utils.EngineArgs] and [AsyncEngineArgs][vllm.engine.arg_utils.AsyncEngineArgs], are a combination of the configuration classes defined in [vllm.config][]. Therefore, if you are interested in developer documentation, we recommend looking at these configuration classes as they are the source of truth for types, defaults and docstrings.
|
||||
|
||||
However, these classes are a combination of the configuration classes defined in [vllm.config][]. Therefore, we would recommend you read about them there where they are best documented.
|
||||
## `EngineArgs`
|
||||
|
||||
For offline inference you will have access to these configuration classes and for online serving you can cross-reference the configs with `vllm serve --help`, which has its arguments grouped by config.
|
||||
--8<-- "docs/argparse/engine_args.md"
|
||||
|
||||
!!! note
|
||||
Additional arguments are available to the [AsyncLLMEngine][vllm.engine.async_llm_engine.AsyncLLMEngine] which is used for online serving. These can be found by running `vllm serve --help`
|
||||
## `AsyncEngineArgs`
|
||||
|
||||
--8<-- "docs/argparse/async_engine_args.md"
|
||||
|
||||
@ -7,7 +7,7 @@ vLLM uses the following environment variables to configure the system:
|
||||
|
||||
All environment variables used by vLLM are prefixed with `VLLM_`. **Special care should be taken for Kubernetes users**: please do not name the service as `vllm`, otherwise environment variables set by Kubernetes might conflict with vLLM's environment variables, because [Kubernetes sets environment variables for each service with the capitalized service name as the prefix](https://kubernetes.io/docs/concepts/services-networking/service/#environment-variables).
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
--8<-- "vllm/envs.py:env-vars-definition"
|
||||
|
||||
@ -20,4 +20,4 @@ model = LLM(
|
||||
)
|
||||
```
|
||||
|
||||
Our [list of supported models][supported-models] shows the model architectures that are recognized by vLLM.
|
||||
Our [list of supported models](../models/supported_models.md) shows the model architectures that are recognized by vLLM.
|
||||
|
||||
@ -1,7 +1,4 @@
|
||||
---
|
||||
title: Server Arguments
|
||||
---
|
||||
[](){ #serve-args }
|
||||
# Server Arguments
|
||||
|
||||
The `vllm serve` command is used to launch the OpenAI-compatible server.
|
||||
|
||||
@ -13,7 +10,7 @@ To see the available CLI arguments, run `vllm serve --help`!
|
||||
## Configuration file
|
||||
|
||||
You can load CLI arguments via a [YAML](https://yaml.org/) config file.
|
||||
The argument names must be the long form of those outlined [above][serve-args].
|
||||
The argument names must be the long form of those outlined [above](serve_args.md).
|
||||
|
||||
For example:
|
||||
|
||||
|
||||
@ -95,7 +95,7 @@ For additional features and advanced configurations, refer to the official [MkDo
|
||||
|
||||
## Testing
|
||||
|
||||
??? note "Commands"
|
||||
??? console "Commands"
|
||||
|
||||
```bash
|
||||
pip install -r requirements/dev.txt
|
||||
|
||||
@ -1,7 +1,4 @@
|
||||
---
|
||||
title: Benchmark Suites
|
||||
---
|
||||
[](){ #benchmarks }
|
||||
# Benchmark Suites
|
||||
|
||||
vLLM contains two sets of benchmarks:
|
||||
|
||||
|
||||
@ -6,9 +6,9 @@ the failure?
|
||||
- Check the dashboard of current CI test failures:
|
||||
👉 [CI Failures Dashboard](https://github.com/orgs/vllm-project/projects/20)
|
||||
|
||||
- If your failure **is already listed**, it's likely unrelated to your PR.
|
||||
Help fixing it is always welcome!
|
||||
- Leave comments with links to additional instances of the failure.
|
||||
- If your failure **is already listed**, it's likely unrelated to your PR.
|
||||
Help fixing it is always welcome!
|
||||
- Leave comments with links to additional instances of the failure.
|
||||
- React with a 👍 to signal how many are affected.
|
||||
|
||||
- If your failure **is not listed**, you should **file an issue**.
|
||||
@ -19,25 +19,25 @@ the failure?
|
||||
👉 [New CI Failure Report](https://github.com/vllm-project/vllm/issues/new?template=450-ci-failure.yml)
|
||||
|
||||
- **Use this title format:**
|
||||
|
||||
|
||||
```
|
||||
[CI Failure]: failing-test-job - regex/matching/failing:test
|
||||
```
|
||||
|
||||
- **For the environment field:**
|
||||
|
||||
|
||||
```
|
||||
Still failing on main as of commit abcdef123
|
||||
```
|
||||
|
||||
- **In the description, include failing tests:**
|
||||
|
||||
|
||||
```
|
||||
FAILED failing/test.py:failing_test1 - Failure description
|
||||
FAILED failing/test.py:failing_test2 - Failure description
|
||||
https://github.com/orgs/vllm-project/projects/20
|
||||
https://github.com/vllm-project/vllm/issues/new?template=400-bug-report.yml
|
||||
FAILED failing/test.py:failing_test3 - Failure description
|
||||
FAILED failing/test.py:failing_test1 - Failure description
|
||||
FAILED failing/test.py:failing_test2 - Failure description
|
||||
https://github.com/orgs/vllm-project/projects/20
|
||||
https://github.com/vllm-project/vllm/issues/new?template=400-bug-report.yml
|
||||
FAILED failing/test.py:failing_test3 - Failure description
|
||||
```
|
||||
|
||||
- **Attach logs** (collapsible section example):
|
||||
@ -45,17 +45,17 @@ the failure?
|
||||
<summary>Logs:</summary>
|
||||
|
||||
```text
|
||||
ERROR 05-20 03:26:38 [dump_input.py:68] Dumping input data
|
||||
ERROR 05-20 03:26:38 [dump_input.py:68] Dumping input data
|
||||
--- Logging error ---
|
||||
Traceback (most recent call last):
|
||||
File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 203, in execute_model
|
||||
return self.model_executor.execute_model(scheduler_output)
|
||||
return self.model_executor.execute_model(scheduler_output)
|
||||
...
|
||||
FAILED failing/test.py:failing_test1 - Failure description
|
||||
FAILED failing/test.py:failing_test2 - Failure description
|
||||
FAILED failing/test.py:failing_test3 - Failure description
|
||||
FAILED failing/test.py:failing_test1 - Failure description
|
||||
FAILED failing/test.py:failing_test2 - Failure description
|
||||
FAILED failing/test.py:failing_test3 - Failure description
|
||||
```
|
||||
|
||||
|
||||
</details>
|
||||
|
||||
## Logs Wrangling
|
||||
@ -78,7 +78,7 @@ tail -525 ci_build.log | wl-copy
|
||||
|
||||
## Investigating a CI Test Failure
|
||||
|
||||
1. Go to 👉 [Buildkite main branch](https://buildkite.com/vllm/ci/builds?branch=main)
|
||||
1. Go to 👉 [Buildkite main branch](https://buildkite.com/vllm/ci/builds?branch=main)
|
||||
2. Bisect to find the first build that shows the issue.
|
||||
3. Add your findings to the GitHub issue.
|
||||
4. If you find a strong candidate PR, mention it in the issue and ping contributors.
|
||||
@ -97,9 +97,9 @@ CI test failures may be flaky. Use a bash loop to run repeatedly:
|
||||
|
||||
If you submit a PR to fix a CI failure:
|
||||
|
||||
- Link the PR to the issue:
|
||||
- Link the PR to the issue:
|
||||
Add `Closes #12345` to the PR description.
|
||||
- Add the `ci-failure` label:
|
||||
- Add the `ci-failure` label:
|
||||
This helps track it in the [CI Failures GitHub Project](https://github.com/orgs/vllm-project/projects/20).
|
||||
|
||||
## Other Resources
|
||||
@ -1,15 +1,12 @@
|
||||
---
|
||||
title: Update PyTorch version on vLLM OSS CI/CD
|
||||
---
|
||||
# Update PyTorch version on vLLM OSS CI/CD
|
||||
|
||||
vLLM's current policy is to always use the latest PyTorch stable
|
||||
release in CI/CD. It is standard practice to submit a PR to update the
|
||||
PyTorch version as early as possible when a new [PyTorch stable
|
||||
release](https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-cadence) becomes available.
|
||||
This process is non-trivial due to the gap between PyTorch
|
||||
releases. Using [#16859](https://github.com/vllm-project/vllm/pull/16859) as
|
||||
an example, this document outlines common steps to achieve this update along with
|
||||
a list of potential issues and how to address them.
|
||||
releases. Using <gh-pr:16859> as an example, this document outlines common steps to achieve this
|
||||
update along with a list of potential issues and how to address them.
|
||||
|
||||
## Test PyTorch release candidates (RCs)
|
||||
|
||||
@ -19,11 +16,12 @@ by waiting for the next release or by implementing hacky workarounds in vLLM.
|
||||
The better solution is to test vLLM with PyTorch release candidates (RC) to ensure
|
||||
compatibility before each release.
|
||||
|
||||
PyTorch release candidates can be downloaded from PyTorch test index at https://download.pytorch.org/whl/test.
|
||||
For example, torch2.7.0+cu12.8 RC can be installed using the following command:
|
||||
PyTorch release candidates can be downloaded from [PyTorch test index](https://download.pytorch.org/whl/test).
|
||||
For example, `torch2.7.0+cu12.8` RC can be installed using the following command:
|
||||
|
||||
```
|
||||
uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu128
|
||||
```bash
|
||||
uv pip install torch torchvision torchaudio \
|
||||
--index-url https://download.pytorch.org/whl/test/cu128
|
||||
```
|
||||
|
||||
When the final RC is ready for testing, it will be announced to the community
|
||||
@ -31,13 +29,28 @@ on the [PyTorch dev-discuss forum](https://dev-discuss.pytorch.org/c/release-ann
|
||||
After this announcement, we can begin testing vLLM integration by drafting a pull request
|
||||
following this 3-step process:
|
||||
|
||||
1. Update requirements files in https://github.com/vllm-project/vllm/tree/main/requirements
|
||||
to point to the new releases for torch, torchvision, and torchaudio.
|
||||
2. Use `--extra-index-url https://download.pytorch.org/whl/test/<PLATFORM>` to
|
||||
get the final release candidates' wheels. Some common platforms are `cpu`, `cu128`,
|
||||
and `rocm6.2.4`.
|
||||
3. As vLLM uses uv, make sure that `unsafe-best-match` strategy is set either
|
||||
via `UV_INDEX_STRATEGY` env variable or via `--index-strategy unsafe-best-match`.
|
||||
1. Update [requirements files](https://github.com/vllm-project/vllm/tree/main/requirements)
|
||||
to point to the new releases for `torch`, `torchvision`, and `torchaudio`.
|
||||
|
||||
2. Use the following option to get the final release candidates' wheels. Some common platforms are `cpu`, `cu128`, and `rocm6.2.4`.
|
||||
|
||||
```bash
|
||||
--extra-index-url https://download.pytorch.org/whl/test/<PLATFORM>
|
||||
```
|
||||
|
||||
3. Since vLLM uses `uv`, ensure the following index strategy is applied:
|
||||
|
||||
- Via environment variable:
|
||||
|
||||
```bash
|
||||
export UV_INDEX_STRATEGY=unsafe-best-match
|
||||
```
|
||||
|
||||
- Or via CLI flag:
|
||||
|
||||
```bash
|
||||
--index-strategy unsafe-best-match
|
||||
```
|
||||
|
||||
If failures are found in the pull request, raise them as issues on vLLM and
|
||||
cc the PyTorch release team to initiate discussion on how to address them.
|
||||
@ -45,20 +58,25 @@ cc the PyTorch release team to initiate discussion on how to address them.
|
||||
## Update CUDA version
|
||||
|
||||
The PyTorch release matrix includes both stable and experimental [CUDA versions](https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix). Due to limitations, only the latest stable CUDA version (for example,
|
||||
torch2.7.0+cu12.6) is uploaded to PyPI. However, vLLM may require a different CUDA version,
|
||||
`torch2.7.0+cu12.6`) is uploaded to PyPI. However, vLLM may require a different CUDA version,
|
||||
such as 12.8 for Blackwell support.
|
||||
This complicates the process as we cannot use the out-of-the-box
|
||||
`pip install torch torchvision torchaudio` command. The solution is to use
|
||||
`--extra-index-url` in vLLM's Dockerfiles.
|
||||
|
||||
1. Use `--extra-index-url https://download.pytorch.org/whl/cu128` to install torch+cu128.
|
||||
2. Other important indexes at the moment include:
|
||||
1. CPU ‒ https://download.pytorch.org/whl/cpu
|
||||
2. ROCm ‒ https://download.pytorch.org/whl/rocm6.2.4 and https://download.pytorch.org/whl/rocm6.3
|
||||
3. XPU ‒ https://download.pytorch.org/whl/xpu
|
||||
3. Update .buildkite/release-pipeline.yaml and .buildkite/scripts/upload-wheels.sh to
|
||||
match the CUDA version from step 1. This makes sure that the release vLLM wheel is tested
|
||||
on CI.
|
||||
- Important indexes at the moment include:
|
||||
|
||||
| Platform | `--extra-index-url` |
|
||||
|----------|-----------------|
|
||||
| CUDA 12.8| [https://download.pytorch.org/whl/cu128](https://download.pytorch.org/whl/cu128)|
|
||||
| CPU | [https://download.pytorch.org/whl/cpu](https://download.pytorch.org/whl/cpu)|
|
||||
| ROCm 6.2 | [https://download.pytorch.org/whl/rocm6.2.4](https://download.pytorch.org/whl/rocm6.2.4) |
|
||||
| ROCm 6.3 | [https://download.pytorch.org/whl/rocm6.3](https://download.pytorch.org/whl/rocm6.3) |
|
||||
| XPU | [https://download.pytorch.org/whl/xpu](https://download.pytorch.org/whl/xpu) |
|
||||
|
||||
- Update the below files to match the CUDA version from step 1. This makes sure that the release vLLM wheel is tested on CI.
|
||||
- `.buildkite/release-pipeline.yaml`
|
||||
- `.buildkite/scripts/upload-wheels.sh`
|
||||
|
||||
## Address long vLLM build time
|
||||
|
||||
@ -68,8 +86,8 @@ and timeout. Additionally, since vLLM's fastcheck pipeline runs in read-only mod
|
||||
it doesn't populate the cache, so re-running it to warm up the cache
|
||||
is ineffective.
|
||||
|
||||
While ongoing efforts like [#17419](https://github.com/vllm-project/vllm/issues/17419)
|
||||
address the long build time at its source, the current workaround is to set VLLM_CI_BRANCH
|
||||
While ongoing efforts like [#17419](gh-issue:17419)
|
||||
address the long build time at its source, the current workaround is to set `VLLM_CI_BRANCH`
|
||||
to a custom branch provided by @khluu (`VLLM_CI_BRANCH=khluu/use_postmerge_q`)
|
||||
when manually triggering a build on Buildkite. This branch accomplishes two things:
|
||||
|
||||
@ -89,17 +107,18 @@ releases (which would take too much time), they can be built from
|
||||
source to unblock the update process.
|
||||
|
||||
### FlashInfer
|
||||
Here is how to build and install it from source with torch2.7.0+cu128 in vLLM [Dockerfile](https://github.com/vllm-project/vllm/blob/27bebcd89792d5c4b08af7a65095759526f2f9e1/docker/Dockerfile#L259-L271):
|
||||
Here is how to build and install it from source with `torch2.7.0+cu128` in vLLM [Dockerfile](https://github.com/vllm-project/vllm/blob/27bebcd89792d5c4b08af7a65095759526f2f9e1/docker/Dockerfile#L259-L271):
|
||||
|
||||
```bash
|
||||
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0 10.0+PTX'
|
||||
export FLASHINFER_ENABLE_SM90=1
|
||||
uv pip install --system --no-build-isolation "git+https://github.com/flashinfer-ai/flashinfer@v0.2.6.post1"
|
||||
uv pip install --system \
|
||||
--no-build-isolation "git+https://github.com/flashinfer-ai/flashinfer@v0.2.6.post1"
|
||||
```
|
||||
|
||||
One caveat is that building FlashInfer from source adds approximately 30
|
||||
minutes to the vLLM build time. Therefore, it's preferable to cache the wheel in a
|
||||
public location for immediate installation, such as https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl. For future releases, contact the PyTorch release
|
||||
public location for immediate installation, such as [this FlashInfer wheel link](https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl). For future releases, contact the PyTorch release
|
||||
team if you want to get the package published there.
|
||||
|
||||
### xFormers
|
||||
@ -107,13 +126,15 @@ Similar to FlashInfer, here is how to build and install xFormers from source:
|
||||
|
||||
```bash
|
||||
export TORCH_CUDA_ARCH_LIST='7.0 7.5 8.0 8.9 9.0 10.0+PTX'
|
||||
MAX_JOBS=16 uv pip install --system --no-build-isolation "git+https://github.com/facebookresearch/xformers@v0.0.30"
|
||||
MAX_JOBS=16 uv pip install --system \
|
||||
--no-build-isolation "git+https://github.com/facebookresearch/xformers@v0.0.30"
|
||||
```
|
||||
|
||||
### Mamba
|
||||
|
||||
```bash
|
||||
uv pip install --system --no-build-isolation "git+https://github.com/state-spaces/mamba@v2.2.4"
|
||||
uv pip install --system \
|
||||
--no-build-isolation "git+https://github.com/state-spaces/mamba@v2.2.4"
|
||||
```
|
||||
|
||||
### causal-conv1d
|
||||
@ -128,7 +149,6 @@ Rather than attempting to update all vLLM platforms in a single pull request, it
|
||||
to handle some platforms separately. The separation of requirements and Dockerfiles
|
||||
for different platforms in vLLM CI/CD allows us to selectively choose
|
||||
which platforms to update. For instance, updating XPU requires the corresponding
|
||||
release from https://github.com/intel/intel-extension-for-pytorch by Intel.
|
||||
While https://github.com/vllm-project/vllm/pull/16859 updated vLLM to PyTorch
|
||||
2.7.0 on CPU, CUDA, and ROCm, https://github.com/vllm-project/vllm/pull/17444
|
||||
completed the update for XPU.
|
||||
release from [Intel Extension for PyTorch](https://github.com/intel/intel-extension-for-pytorch) by Intel.
|
||||
While <gh-pr:16859> updated vLLM to PyTorch 2.7.0 on CPU, CUDA, and ROCm,
|
||||
<gh-pr:17444> completed the update for XPU.
|
||||
@ -37,14 +37,14 @@ multiple Y releases:
|
||||
- **Timeline**: A removal version is explicitly stated in the deprecation
|
||||
warning (e.g., "This will be removed in v0.10.0").
|
||||
- **Communication**: Deprecation is noted in the following, as applicable:
|
||||
- Help strings
|
||||
- Log output
|
||||
- API responses
|
||||
- `/metrics` output (for metrics features)
|
||||
- User-facing documentation
|
||||
- Release notes
|
||||
- GitHub Issue (RFC) for feedback
|
||||
- Documentation and use of the `@typing_extensions.deprecated` decorator for Python APIs
|
||||
- Help strings
|
||||
- Log output
|
||||
- API responses
|
||||
- `/metrics` output (for metrics features)
|
||||
- User-facing documentation
|
||||
- Release notes
|
||||
- GitHub Issue (RFC) for feedback
|
||||
- Documentation and use of the `@typing_extensions.deprecated` decorator for Python APIs
|
||||
|
||||
**2.Deprecated (Off By Default)**
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Dockerfile
|
||||
|
||||
We provide a <gh-file:docker/Dockerfile> to construct the image for running an OpenAI compatible server with vLLM.
|
||||
More information about deploying with Docker can be found [here][deployment-docker].
|
||||
More information about deploying with Docker can be found [here](../../deployment/docker.md).
|
||||
|
||||
Below is a visual representation of the multi-stage Dockerfile. The build graph contains the following nodes:
|
||||
|
||||
|
||||
@ -14,7 +14,7 @@ Before setting up the incremental build:
|
||||
VLLM_USE_PRECOMPILED=1 uv pip install -U -e . --torch-backend=auto
|
||||
```
|
||||
|
||||
2. **CUDA Toolkit:** Verify that the NVIDIA CUDA Toolkit is correctly installed and `nvcc` is accessible in your `PATH`. CMake relies on `nvcc` to compile CUDA code. You can typically find `nvcc` in `$CUDA_HOME/bin/nvcc` or by running `which nvcc`. If you encounter issues, refer to the [official CUDA Toolkit installation guides](https://developer.nvidia.com/cuda-toolkit-archive) and vLLM's main [GPU installation documentation](../getting_started/installation/gpu/cuda.inc.md#troubleshooting) for troubleshooting. The `CMAKE_CUDA_COMPILER` variable in your `CMakeUserPresets.json` should also point to your `nvcc` binary.
|
||||
2. **CUDA Toolkit:** Verify that the NVIDIA CUDA Toolkit is correctly installed and `nvcc` is accessible in your `PATH`. CMake relies on `nvcc` to compile CUDA code. You can typically find `nvcc` in `$CUDA_HOME/bin/nvcc` or by running `which nvcc`. If you encounter issues, refer to the [official CUDA Toolkit installation guides](https://developer.nvidia.com/cuda-toolkit-archive) and vLLM's main [GPU installation documentation](../getting_started/installation/gpu.md#troubleshooting) for troubleshooting. The `CMAKE_CUDA_COMPILER` variable in your `CMakeUserPresets.json` should also point to your `nvcc` binary.
|
||||
|
||||
3. **Build Tools:** It is highly recommended to install `ccache` for fast rebuilds by caching compilation results (e.g., `sudo apt install ccache` or `conda install ccache`). Also, ensure the core build dependencies like `cmake` and `ninja` are installed. These are installable through `requirements/build.txt` or your system's package manager.
|
||||
|
||||
@ -84,6 +84,7 @@ Below is an example of what the generated `CMakeUserPresets.json` might look lik
|
||||
```
|
||||
|
||||
**What do the various configurations mean?**
|
||||
|
||||
- `CMAKE_CUDA_COMPILER`: Path to your `nvcc` binary. The script attempts to find this automatically.
|
||||
- `CMAKE_C_COMPILER_LAUNCHER`, `CMAKE_CXX_COMPILER_LAUNCHER`, `CMAKE_CUDA_COMPILER_LAUNCHER`: Setting these to `ccache` (or `sccache`) significantly speeds up rebuilds by caching compilation results. Ensure `ccache` is installed (e.g., `sudo apt install ccache` or `conda install ccache`). The script sets these by default.
|
||||
- `VLLM_PYTHON_EXECUTABLE`: Path to the Python executable in your vLLM development environment. The script will prompt for this, defaulting to the current Python environment if suitable.
|
||||
@ -98,16 +99,16 @@ Once your `CMakeUserPresets.json` is configured:
|
||||
1. **Initialize the CMake build environment:**
|
||||
This step configures the build system according to your chosen preset (e.g., `release`) and creates the build directory at `binaryDir`
|
||||
|
||||
```console
|
||||
cmake --preset release
|
||||
```
|
||||
```console
|
||||
cmake --preset release
|
||||
```
|
||||
|
||||
2. **Build and install the vLLM components:**
|
||||
This command compiles the code and installs the resulting binaries into your vLLM source directory, making them available to your editable Python installation.
|
||||
|
||||
```console
|
||||
cmake --build --preset release --target install
|
||||
```
|
||||
```console
|
||||
cmake --build --preset release --target install
|
||||
```
|
||||
|
||||
3. **Make changes and repeat!**
|
||||
Now you start using your editable install of vLLM, testing and making changes as needed. If you need to build again to update based on changes, simply run the CMake command again to build only the affected files.
|
||||
|
||||
@ -1,12 +1,9 @@
|
||||
---
|
||||
title: Summary
|
||||
---
|
||||
[](){ #new-model }
|
||||
# Summary
|
||||
|
||||
!!! important
|
||||
Many decoder language models can now be automatically loaded using the [Transformers backend][transformers-backend] without having to implement them in vLLM. See if `vllm serve <model>` works first!
|
||||
|
||||
vLLM models are specialized [PyTorch](https://pytorch.org/) models that take advantage of various [features][compatibility-matrix] to optimize their performance.
|
||||
vLLM models are specialized [PyTorch](https://pytorch.org/) models that take advantage of various [features](../../features/compatibility_matrix.md) to optimize their performance.
|
||||
|
||||
The complexity of integrating a model into vLLM depends heavily on the model's architecture.
|
||||
The process is considerably straightforward if the model shares a similar architecture with an existing model in vLLM.
|
||||
|
||||
@ -1,7 +1,4 @@
|
||||
---
|
||||
title: Basic Model
|
||||
---
|
||||
[](){ #new-model-basic }
|
||||
# Basic Model
|
||||
|
||||
This guide walks you through the steps to implement a basic vLLM model.
|
||||
|
||||
@ -27,7 +24,7 @@ All vLLM modules within the model must include a `prefix` argument in their cons
|
||||
|
||||
The initialization code should look like this:
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
from torch import nn
|
||||
@ -108,7 +105,7 @@ This method should load the weights from the HuggingFace's checkpoint file and a
|
||||
|
||||
## 5. Register your model
|
||||
|
||||
See [this page][new-model-registration] for instructions on how to register your new model to be used by vLLM.
|
||||
See [this page](registration.md) for instructions on how to register your new model to be used by vLLM.
|
||||
|
||||
## Frequently Asked Questions
|
||||
|
||||
|
||||
@ -1,15 +1,28 @@
|
||||
---
|
||||
title: Multi-Modal Support
|
||||
---
|
||||
[](){ #supports-multimodal }
|
||||
# Multi-Modal Support
|
||||
|
||||
This document walks you through the steps to extend a basic model so that it accepts [multi-modal inputs][multimodal-inputs].
|
||||
This document walks you through the steps to extend a basic model so that it accepts [multi-modal inputs](../../features/multimodal_inputs.md).
|
||||
|
||||
## 1. Update the base vLLM model
|
||||
|
||||
It is assumed that you have already implemented the model in vLLM according to [these steps][new-model-basic].
|
||||
It is assumed that you have already implemented the model in vLLM according to [these steps](basic.md).
|
||||
Further update the model as follows:
|
||||
|
||||
- Implement [get_placeholder_str][vllm.model_executor.models.interfaces.SupportsMultiModal.get_placeholder_str] to define the placeholder string which is used to represent the multi-modal item in the text prompt. This should be consistent with the chat template of the model.
|
||||
|
||||
??? code
|
||||
|
||||
```python
|
||||
class YourModelForImage2Seq(nn.Module):
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
return "<image>"
|
||||
|
||||
raise ValueError("Only image modality is supported")
|
||||
```
|
||||
|
||||
- Reserve a keyword parameter in [forward][torch.nn.Module.forward] for each input tensor that corresponds to a multi-modal input, as shown in the following example:
|
||||
|
||||
```diff
|
||||
@ -25,7 +38,7 @@ Further update the model as follows:
|
||||
|
||||
- Implement [get_multimodal_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_multimodal_embeddings] that returns the embeddings from running the multimodal inputs through the multimodal tokenizer of the model. Below we provide a boilerplate of a typical implementation pattern, but feel free to adjust it to your own needs.
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
class YourModelForImage2Seq(nn.Module):
|
||||
@ -55,7 +68,7 @@ Further update the model as follows:
|
||||
|
||||
- Implement [get_input_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings] to merge `multimodal_embeddings` with text embeddings from the `input_ids`. If input processing for the model is implemented correctly (see sections below), then you can leverage the utility function we provide to easily merge the embeddings.
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
from .utils import merge_multimodal_embeddings
|
||||
@ -139,7 +152,7 @@ Assuming that the memory usage increases with the number of tokens, the dummy in
|
||||
|
||||
Looking at the code of HF's `LlavaForConditionalGeneration`:
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/llava/modeling_llava.py#L530-L544
|
||||
@ -163,7 +176,7 @@ Assuming that the memory usage increases with the number of tokens, the dummy in
|
||||
The number of placeholder feature tokens per image is `image_features.shape[1]`.
|
||||
`image_features` is calculated inside the `get_image_features` method:
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/llava/modeling_llava.py#L290-L300
|
||||
@ -201,7 +214,7 @@ Assuming that the memory usage increases with the number of tokens, the dummy in
|
||||
|
||||
To find the sequence length, we turn to the code of `CLIPVisionEmbeddings`:
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/clip/modeling_clip.py#L247-L257
|
||||
@ -228,7 +241,7 @@ Assuming that the memory usage increases with the number of tokens, the dummy in
|
||||
|
||||
Overall, the number of placeholder feature tokens for an image can be calculated as:
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
def get_num_image_tokens(
|
||||
@ -253,7 +266,7 @@ Assuming that the memory usage increases with the number of tokens, the dummy in
|
||||
Notice that the number of image tokens doesn't depend on the image width and height.
|
||||
We can simply use a dummy `image_size` to calculate the multimodal profiling data:
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
# NOTE: In actuality, this is usually implemented as part of the
|
||||
@ -298,7 +311,7 @@ Assuming that the memory usage increases with the number of tokens, the dummy in
|
||||
|
||||
Looking at the code of HF's `FuyuForCausalLM`:
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/modeling_fuyu.py#L311-L322
|
||||
@ -328,7 +341,7 @@ Assuming that the memory usage increases with the number of tokens, the dummy in
|
||||
In `FuyuImageProcessor.preprocess`, the images are resized and padded to the target `FuyuImageProcessor.size`,
|
||||
returning the dimensions after resizing (but before padding) as metadata.
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/processing_fuyu.py#L541-L544
|
||||
@ -366,7 +379,7 @@ Assuming that the memory usage increases with the number of tokens, the dummy in
|
||||
|
||||
In `FuyuImageProcessor.preprocess_with_tokenizer_info`, the images are split into patches based on this metadata:
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/processing_fuyu.py#L417-L425
|
||||
@ -404,7 +417,7 @@ Assuming that the memory usage increases with the number of tokens, the dummy in
|
||||
|
||||
The number of patches is in turn defined by `FuyuImageProcessor.get_num_patches`:
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/image_processing_fuyu.py#L552-L562
|
||||
@ -441,7 +454,7 @@ Assuming that the memory usage increases with the number of tokens, the dummy in
|
||||
|
||||
For the multimodal image profiling data, the logic is very similar to LLaVA:
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
def get_dummy_mm_data(
|
||||
@ -467,7 +480,7 @@ Afterwards, create a subclass of [BaseMultiModalProcessor][vllm.multimodal.proce
|
||||
to fill in the missing details about HF processing.
|
||||
|
||||
!!! info
|
||||
[Multi-Modal Data Processing][mm-processing]
|
||||
[Multi-Modal Data Processing](../../design/mm_processing.md)
|
||||
|
||||
### Multi-modal fields
|
||||
|
||||
@ -530,7 +543,7 @@ return a schema of the tensors outputted by the HF processor that are related to
|
||||
In order to support the use of [MultiModalFieldConfig.batched][] like in LLaVA,
|
||||
we remove the extra batch dimension by overriding [BaseMultiModalProcessor._call_hf_processor][]:
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
def _call_hf_processor(
|
||||
@ -538,11 +551,13 @@ return a schema of the tensors outputted by the HF processor that are related to
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
tok_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
tok_kwargs=tok_kwargs,
|
||||
)
|
||||
|
||||
image_patches = processed_outputs.get("image_patches")
|
||||
@ -566,6 +581,11 @@ return a schema of the tensors outputted by the HF processor that are related to
|
||||
Our [actual code](gh-file:vllm/model_executor/models/fuyu.py) has special handling
|
||||
for text-only inputs to prevent unnecessary warnings from HF processor.
|
||||
|
||||
!!! note
|
||||
The `_call_hf_processor` method specifies both `mm_kwargs` and `tok_kwargs` for
|
||||
processing. `mm_kwargs` is used to both initialize and call the huggingface
|
||||
processor, whereas `tok_kwargs` is only used to call the huggingface processor.
|
||||
|
||||
This lets us override [_get_mm_fields_config][vllm.multimodal.processing.BaseMultiModalProcessor._get_mm_fields_config] as follows:
|
||||
|
||||
```python
|
||||
@ -600,7 +620,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies
|
||||
It simply repeats each input `image_token` a number of times equal to the number of placeholder feature tokens (`num_image_tokens`).
|
||||
Based on this, we override [_get_prompt_updates][vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates] as follows:
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
def _get_prompt_updates(
|
||||
@ -645,7 +665,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies
|
||||
|
||||
We define a helper function to return `ncols` and `nrows` directly:
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
def get_image_feature_grid_size(
|
||||
@ -675,7 +695,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies
|
||||
|
||||
Based on this, we can initially define our replacement tokens as:
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
def get_replacement(item_idx: int):
|
||||
@ -695,7 +715,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies
|
||||
However, this is not entirely correct. After `FuyuImageProcessor.preprocess_with_tokenizer_info` is called,
|
||||
a BOS token (`<s>`) is also added to the promopt:
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/processing_fuyu.py#L417-L435
|
||||
@ -722,7 +742,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies
|
||||
To assign the vision embeddings to only the image tokens, instead of a string
|
||||
you can return an instance of [PromptUpdateDetails][vllm.multimodal.processing.PromptUpdateDetails]:
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
hf_config = self.info.get_hf_config()
|
||||
@ -749,7 +769,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies
|
||||
Finally, noticing that the HF processor removes the `|ENDOFTEXT|` token from the tokenized prompt,
|
||||
we can search for it to conduct the replacement at the start of the string:
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
def _get_prompt_updates(
|
||||
@ -796,7 +816,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies
|
||||
After you have defined [BaseProcessingInfo][vllm.multimodal.processing.BaseProcessingInfo] (Step 2),
|
||||
[BaseDummyInputsBuilder][vllm.multimodal.profiling.BaseDummyInputsBuilder] (Step 3),
|
||||
and [BaseMultiModalProcessor][vllm.multimodal.processing.BaseMultiModalProcessor] (Step 4),
|
||||
decorate the model class with {meth}`MULTIMODAL_REGISTRY.register_processor <vllm.multimodal.registry.MultiModalRegistry.register_processor>`
|
||||
decorate the model class with [MULTIMODAL_REGISTRY.register_processor][vllm.multimodal.processing.MultiModalRegistry.register_processor]
|
||||
to register them to the multi-modal registry:
|
||||
|
||||
```diff
|
||||
@ -823,7 +843,7 @@ Examples:
|
||||
|
||||
### Handling prompt updates unrelated to multi-modal data
|
||||
|
||||
[_get_prompt_updates][vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates] assumes that each application of prompt update corresponds to one multi-modal item. If the HF processor performs additional processing regardless of how many multi-modal items there are, you should override [_apply_hf_processor_tokens_only][vllm.multimodal.processing.BaseMultiModalProcessor._apply_hf_processor_tokens_only] so that the processed token inputs are consistent with the result of applying the HF processor on text inputs. This is because token inputs bypass the HF processor according to [our design][mm-processing].
|
||||
[_get_prompt_updates][vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates] assumes that each application of prompt update corresponds to one multi-modal item. If the HF processor performs additional processing regardless of how many multi-modal items there are, you should override [_apply_hf_processor_tokens_only][vllm.multimodal.processing.BaseMultiModalProcessor._apply_hf_processor_tokens_only] so that the processed token inputs are consistent with the result of applying the HF processor on text inputs. This is because token inputs bypass the HF processor according to [our design](../../design/mm_processing.md).
|
||||
|
||||
Examples:
|
||||
|
||||
|
||||
@ -1,10 +1,7 @@
|
||||
---
|
||||
title: Registering a Model
|
||||
---
|
||||
[](){ #new-model-registration }
|
||||
# Registering a Model
|
||||
|
||||
vLLM relies on a model registry to determine how to run each model.
|
||||
A list of pre-registered architectures can be found [here][supported-models].
|
||||
A list of pre-registered architectures can be found [here](../../models/supported_models.md).
|
||||
|
||||
If your model is not on this list, you must register it to vLLM.
|
||||
This page provides detailed instructions on how to do so.
|
||||
@ -14,16 +11,16 @@ This page provides detailed instructions on how to do so.
|
||||
To add a model directly to the vLLM library, start by forking our [GitHub repository](https://github.com/vllm-project/vllm) and then [build it from source][build-from-source].
|
||||
This gives you the ability to modify the codebase and test your model.
|
||||
|
||||
After you have implemented your model (see [tutorial][new-model-basic]), put it into the <gh-dir:vllm/model_executor/models> directory.
|
||||
After you have implemented your model (see [tutorial](basic.md)), put it into the <gh-dir:vllm/model_executor/models> directory.
|
||||
Then, add your model class to `_VLLM_MODELS` in <gh-file:vllm/model_executor/models/registry.py> so that it is automatically registered upon importing vLLM.
|
||||
Finally, update our [list of supported models][supported-models] to promote your model!
|
||||
Finally, update our [list of supported models](../../models/supported_models.md) to promote your model!
|
||||
|
||||
!!! important
|
||||
The list of models in each section should be maintained in alphabetical order.
|
||||
|
||||
## Out-of-tree models
|
||||
|
||||
You can load an external model [using a plugin][plugin-system] without modifying the vLLM codebase.
|
||||
You can load an external model [using a plugin](../../design/plugin_system.md) without modifying the vLLM codebase.
|
||||
|
||||
To register the model, use the following code:
|
||||
|
||||
@ -51,4 +48,4 @@ def register():
|
||||
|
||||
!!! important
|
||||
If your model is a multimodal model, ensure the model class implements the [SupportsMultiModal][vllm.model_executor.models.interfaces.SupportsMultiModal] interface.
|
||||
Read more about that [here][supports-multimodal].
|
||||
Read more about that [here](multimodal.md).
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user