Compare commits
185 Commits
pil_image
...
benchmark_
| Author | SHA1 | Date | |
|---|---|---|---|
| 221118dc85 | |||
| 66e63e86ec | |||
| 9214e60631 | |||
| f880d42582 | |||
| dcfe95234c | |||
| 48ac2bed5b | |||
| 3e0d435027 | |||
| 4ee4826ede | |||
| 60017dc841 | |||
| 55f1a468d9 | |||
| fd195b194e | |||
| fabe89bbc4 | |||
| e73b7dfd69 | |||
| 7fdfa01530 | |||
| aef94c6d07 | |||
| 0ceaebf87b | |||
| 1db4f47f81 | |||
| d3d91b6f71 | |||
| 87d871470d | |||
| a5f8c111c2 | |||
| e23564cb70 | |||
| 390ec88905 | |||
| 541817670c | |||
| 67da5720d4 | |||
| 5c04bb8b86 | |||
| 3d2779c29a | |||
| 6b31c84aff | |||
| b18201fe06 | |||
| f4937a51c1 | |||
| ee659e3b60 | |||
| 4e1c6a0264 | |||
| c7852a6d9b | |||
| 8795eb9975 | |||
| 0b34593017 | |||
| e3f3aee6f4 | |||
| 92540529c0 | |||
| fadb8d5c2d | |||
| 2aa5470ac5 | |||
| 51ff154639 | |||
| 566ec04c3d | |||
| 01c22335ba | |||
| 451da4bcbd | |||
| 07ad27121f | |||
| a9944aabfa | |||
| a8f5aec20a | |||
| de71fec81b | |||
| 70f8b96724 | |||
| dd2a94596a | |||
| 420caf7557 | |||
| 4f07a64075 | |||
| e6b8e65d2d | |||
| 26d0419309 | |||
| 83f74c698f | |||
| 2dff093574 | |||
| afe3236e90 | |||
| 65334ef3b9 | |||
| e60f550b38 | |||
| f25e0d1125 | |||
| 09f106a91e | |||
| 2142035b51 | |||
| 78aa341d12 | |||
| 7974736740 | |||
| 2fc9075b82 | |||
| d93c976a0d | |||
| 749f792553 | |||
| 856865008e | |||
| f9c069c85e | |||
| 418d2f8bfb | |||
| 964472b966 | |||
| 59dd311cf5 | |||
| d066e52013 | |||
| c8ea982d9b | |||
| dc372b9c8a | |||
| 9b5b39b650 | |||
| 9ccc6ded42 | |||
| d62a076e84 | |||
| 259127f8b8 | |||
| 612c2edb4f | |||
| 38fe728d60 | |||
| 82e7f9bb03 | |||
| 63dc3426e0 | |||
| 8f5dc41481 | |||
| 63ad622233 | |||
| e7ef61c1f0 | |||
| d4154c35a2 | |||
| 6685890d11 | |||
| 33011318c2 | |||
| 4f8b373225 | |||
| 7b2f28deba | |||
| 2d912fb66f | |||
| 12e6c0b41c | |||
| 9a2a6357de | |||
| 6266c57bae | |||
| 754b699cbe | |||
| 6e27c6d86b | |||
| d5af47a149 | |||
| 65f0f74b66 | |||
| 176a95c670 | |||
| f2ae883b67 | |||
| 40de1ef455 | |||
| 0189a65a2e | |||
| 55aa7af994 | |||
| 0b217da646 | |||
| 19324d660c | |||
| fc407a1425 | |||
| 009d9e7590 | |||
| b922c2ebd2 | |||
| 00b14e0f16 | |||
| 54e467e6f8 | |||
| 79a1d25bbd | |||
| 9944011b30 | |||
| 8c946cecca | |||
| ff334ca1cd | |||
| 6223dd8114 | |||
| 906f0598fc | |||
| cb528d0585 | |||
| 98fcba1575 | |||
| 23b3134eb5 | |||
| ea6ae8cb45 | |||
| 2ff297dce9 | |||
| 8dd0671bac | |||
| f0d610a8ae | |||
| e57e4d6e9e | |||
| ee5be834e7 | |||
| 48545728d8 | |||
| dc1a821768 | |||
| 61e0a506a3 | |||
| 1df491c522 | |||
| d8487ef557 | |||
| c06af9a959 | |||
| 60f7624334 | |||
| f6518b2b48 | |||
| d67085c2c8 | |||
| 307939f299 | |||
| 9d7ea9dbbf | |||
| acee8f48aa | |||
| f065de4e88 | |||
| dc9905368d | |||
| ebab1ac37c | |||
| 2b0db9b0e2 | |||
| 195adb47c0 | |||
| 302f3aca7e | |||
| e9c730c9bd | |||
| 289199feb6 | |||
| b9fd0d7a69 | |||
| 72a3f6b898 | |||
| 98ea35601c | |||
| d19110204c | |||
| 05a4324f8e | |||
| 7ea6cb28b2 | |||
| 9fbf2bfbd5 | |||
| 3a5ea75129 | |||
| 891b9d33de | |||
| 430783018c | |||
| 19a3c78d1f | |||
| ada50aa295 | |||
| 08bf784078 | |||
| d45fe333fb | |||
| 021c16c7ca | |||
| 7de18d541b | |||
| a810b5b088 | |||
| 009b3d5382 | |||
| e4b8713380 | |||
| 06c0922a69 | |||
| cd3edfc908 | |||
| 9cea90eab4 | |||
| d1110f5b5a | |||
| 8132365b74 | |||
| eea22a56ab | |||
| 9112155283 | |||
| 90d0a74b60 | |||
| d74e5f37bc | |||
| ca66a1674c | |||
| 950751a987 | |||
| 4c31218f80 | |||
| 68311891f5 | |||
| fc4441a4ee | |||
| 246e3e0a36 | |||
| 7042cc96b0 | |||
| 0c0fdae84f | |||
| 3b602cdea7 | |||
| 4b2ed7926a | |||
| 7e3571134f | |||
| ea2236bf95 | |||
| 7d4aedae7c |
@ -8,12 +8,12 @@ import zipfile
|
||||
# Note that we have 400 MiB quota, please use it wisely.
|
||||
# See https://github.com/pypi/support/issues/3792 .
|
||||
# Please also sync the value with the one in Dockerfile.
|
||||
VLLM_MAX_SIZE_MB = int(os.environ.get('VLLM_MAX_SIZE_MB', 400))
|
||||
VLLM_MAX_SIZE_MB = int(os.environ.get("VLLM_MAX_SIZE_MB", 400))
|
||||
|
||||
|
||||
def print_top_10_largest_files(zip_file):
|
||||
"""Print the top 10 largest files in the given zip file."""
|
||||
with zipfile.ZipFile(zip_file, 'r') as z:
|
||||
with zipfile.ZipFile(zip_file, "r") as z:
|
||||
file_sizes = [(f, z.getinfo(f).file_size) for f in z.namelist()]
|
||||
file_sizes.sort(key=lambda x: x[1], reverse=True)
|
||||
for f, size in file_sizes[:10]:
|
||||
@ -28,14 +28,18 @@ def check_wheel_size(directory):
|
||||
wheel_path = os.path.join(root, file_name)
|
||||
wheel_size_mb = os.path.getsize(wheel_path) / (1024 * 1024)
|
||||
if wheel_size_mb > VLLM_MAX_SIZE_MB:
|
||||
print(f"Not allowed: Wheel {wheel_path} is larger "
|
||||
f"({wheel_size_mb:.2f} MB) than the limit "
|
||||
f"({VLLM_MAX_SIZE_MB} MB).")
|
||||
print(
|
||||
f"Not allowed: Wheel {wheel_path} is larger "
|
||||
f"({wheel_size_mb:.2f} MB) than the limit "
|
||||
f"({VLLM_MAX_SIZE_MB} MB)."
|
||||
)
|
||||
print_top_10_largest_files(wheel_path)
|
||||
return 1
|
||||
else:
|
||||
print(f"Wheel {wheel_path} is within the allowed size "
|
||||
f"({wheel_size_mb:.2f} MB).")
|
||||
print(
|
||||
f"Wheel {wheel_path} is within the allowed size "
|
||||
f"({wheel_size_mb:.2f} MB)."
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
@ -45,4 +49,4 @@ if __name__ == "__main__":
|
||||
sys.exit(1)
|
||||
|
||||
directory = sys.argv[1]
|
||||
sys.exit(check_wheel_size(directory))
|
||||
sys.exit(check_wheel_size(directory))
|
||||
|
||||
@ -22,5 +22,5 @@ with open("index.html", "w") as f:
|
||||
print(f"Generated index.html for {args.wheel}")
|
||||
# cloudfront requires escaping the '+' character
|
||||
f.write(
|
||||
template.format(wheel=filename,
|
||||
wheel_html_escaped=filename.replace("+", "%2B")))
|
||||
template.format(wheel=filename, wheel_html_escaped=filename.replace("+", "%2B"))
|
||||
)
|
||||
|
||||
@ -8,11 +8,14 @@ def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
"--config-list-file",
|
||||
action="store",
|
||||
help="Path to the file listing model config YAMLs (one per line)")
|
||||
parser.addoption("--tp-size",
|
||||
action="store",
|
||||
default="1",
|
||||
help="Tensor parallel size to use for evaluation")
|
||||
help="Path to the file listing model config YAMLs (one per line)",
|
||||
)
|
||||
parser.addoption(
|
||||
"--tp-size",
|
||||
action="store",
|
||||
default="1",
|
||||
help="Tensor parallel size to use for evaluation",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@ -33,7 +36,8 @@ def pytest_generate_tests(metafunc):
|
||||
config_dir = config_list_file.parent
|
||||
with open(config_list_file, encoding="utf-8") as f:
|
||||
configs = [
|
||||
config_dir / line.strip() for line in f
|
||||
config_dir / line.strip()
|
||||
for line in f
|
||||
if line.strip() and not line.startswith("#")
|
||||
]
|
||||
metafunc.parametrize("config_filename", configs)
|
||||
|
||||
@ -16,19 +16,22 @@ RTOL = 0.08
|
||||
|
||||
|
||||
def launch_lm_eval(eval_config, tp_size):
|
||||
trust_remote_code = eval_config.get('trust_remote_code', False)
|
||||
model_args = f"pretrained={eval_config['model_name']}," \
|
||||
f"tensor_parallel_size={tp_size}," \
|
||||
f"enforce_eager=true," \
|
||||
f"add_bos_token=true," \
|
||||
f"trust_remote_code={trust_remote_code}"
|
||||
trust_remote_code = eval_config.get("trust_remote_code", False)
|
||||
model_args = (
|
||||
f"pretrained={eval_config['model_name']},"
|
||||
f"tensor_parallel_size={tp_size},"
|
||||
f"enforce_eager=true,"
|
||||
f"add_bos_token=true,"
|
||||
f"trust_remote_code={trust_remote_code}"
|
||||
)
|
||||
results = lm_eval.simple_evaluate(
|
||||
model="vllm",
|
||||
model_args=model_args,
|
||||
tasks=[task["name"] for task in eval_config["tasks"]],
|
||||
num_fewshot=eval_config["num_fewshot"],
|
||||
limit=eval_config["limit"],
|
||||
batch_size="auto")
|
||||
batch_size="auto",
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
@ -42,9 +45,10 @@ def test_lm_eval_correctness_param(config_filename, tp_size):
|
||||
for metric in task["metrics"]:
|
||||
ground_truth = metric["value"]
|
||||
measured_value = results["results"][task["name"]][metric["name"]]
|
||||
print(f'{task["name"]} | {metric["name"]}: '
|
||||
f'ground_truth={ground_truth} | measured={measured_value}')
|
||||
success = success and np.isclose(
|
||||
ground_truth, measured_value, rtol=RTOL)
|
||||
print(
|
||||
f"{task['name']} | {metric['name']}: "
|
||||
f"ground_truth={ground_truth} | measured={measured_value}"
|
||||
)
|
||||
success = success and np.isclose(ground_truth, measured_value, rtol=RTOL)
|
||||
|
||||
assert success
|
||||
|
||||
@ -65,18 +65,18 @@ def read_markdown(file):
|
||||
|
||||
|
||||
def results_to_json(latency, throughput, serving):
|
||||
return json.dumps({
|
||||
'latency': latency.to_dict(),
|
||||
'throughput': throughput.to_dict(),
|
||||
'serving': serving.to_dict()
|
||||
})
|
||||
return json.dumps(
|
||||
{
|
||||
"latency": latency.to_dict(),
|
||||
"throughput": throughput.to_dict(),
|
||||
"serving": serving.to_dict(),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# collect results
|
||||
for test_file in results_folder.glob("*.json"):
|
||||
|
||||
with open(test_file) as f:
|
||||
raw_result = json.loads(f.read())
|
||||
|
||||
@ -120,7 +120,8 @@ if __name__ == "__main__":
|
||||
for perc in [10, 25, 50, 75, 90, 99]:
|
||||
# Multiply 1000 to convert the time unit from s to ms
|
||||
raw_result.update(
|
||||
{f"P{perc}": 1000 * raw_result["percentiles"][str(perc)]})
|
||||
{f"P{perc}": 1000 * raw_result["percentiles"][str(perc)]}
|
||||
)
|
||||
raw_result["avg_latency"] = raw_result["avg_latency"] * 1000
|
||||
|
||||
# add the result to raw_result
|
||||
@ -153,26 +154,27 @@ if __name__ == "__main__":
|
||||
serving_results = pd.DataFrame.from_dict(serving_results)
|
||||
throughput_results = pd.DataFrame.from_dict(throughput_results)
|
||||
|
||||
raw_results_json = results_to_json(latency_results, throughput_results,
|
||||
serving_results)
|
||||
raw_results_json = results_to_json(
|
||||
latency_results, throughput_results, serving_results
|
||||
)
|
||||
|
||||
# remapping the key, for visualization purpose
|
||||
if not latency_results.empty:
|
||||
latency_results = latency_results[list(
|
||||
latency_column_mapping.keys())].rename(
|
||||
columns=latency_column_mapping)
|
||||
latency_results = latency_results[list(latency_column_mapping.keys())].rename(
|
||||
columns=latency_column_mapping
|
||||
)
|
||||
if not serving_results.empty:
|
||||
serving_results = serving_results[list(
|
||||
serving_column_mapping.keys())].rename(
|
||||
columns=serving_column_mapping)
|
||||
serving_results = serving_results[list(serving_column_mapping.keys())].rename(
|
||||
columns=serving_column_mapping
|
||||
)
|
||||
if not throughput_results.empty:
|
||||
throughput_results = throughput_results[list(
|
||||
throughput_results_column_mapping.keys())].rename(
|
||||
columns=throughput_results_column_mapping)
|
||||
throughput_results = throughput_results[
|
||||
list(throughput_results_column_mapping.keys())
|
||||
].rename(columns=throughput_results_column_mapping)
|
||||
|
||||
processed_results_json = results_to_json(latency_results,
|
||||
throughput_results,
|
||||
serving_results)
|
||||
processed_results_json = results_to_json(
|
||||
latency_results, throughput_results, serving_results
|
||||
)
|
||||
|
||||
for df in [latency_results, serving_results, throughput_results]:
|
||||
if df.empty:
|
||||
@ -184,38 +186,39 @@ if __name__ == "__main__":
|
||||
# The GPUs sometimes come in format of "GPUTYPE\nGPUTYPE\n...",
|
||||
# we want to turn it into "8xGPUTYPE"
|
||||
df["GPU"] = df["GPU"].apply(
|
||||
lambda x: f"{len(x.split('\n'))}x{x.split('\n')[0]}")
|
||||
lambda x: f"{len(x.split('\n'))}x{x.split('\n')[0]}"
|
||||
)
|
||||
|
||||
# get markdown tables
|
||||
latency_md_table = tabulate(latency_results,
|
||||
headers='keys',
|
||||
tablefmt='pipe',
|
||||
showindex=False)
|
||||
serving_md_table = tabulate(serving_results,
|
||||
headers='keys',
|
||||
tablefmt='pipe',
|
||||
showindex=False)
|
||||
throughput_md_table = tabulate(throughput_results,
|
||||
headers='keys',
|
||||
tablefmt='pipe',
|
||||
showindex=False)
|
||||
latency_md_table = tabulate(
|
||||
latency_results, headers="keys", tablefmt="pipe", showindex=False
|
||||
)
|
||||
serving_md_table = tabulate(
|
||||
serving_results, headers="keys", tablefmt="pipe", showindex=False
|
||||
)
|
||||
throughput_md_table = tabulate(
|
||||
throughput_results, headers="keys", tablefmt="pipe", showindex=False
|
||||
)
|
||||
|
||||
# document the result
|
||||
with open(results_folder / "benchmark_results.md", "w") as f:
|
||||
|
||||
results = read_markdown("../.buildkite/nightly-benchmarks/" +
|
||||
"performance-benchmarks-descriptions.md")
|
||||
results = read_markdown(
|
||||
"../.buildkite/nightly-benchmarks/"
|
||||
+ "performance-benchmarks-descriptions.md"
|
||||
)
|
||||
results = results.format(
|
||||
latency_tests_markdown_table=latency_md_table,
|
||||
throughput_tests_markdown_table=throughput_md_table,
|
||||
serving_tests_markdown_table=serving_md_table,
|
||||
benchmarking_results_in_json_string=processed_results_json)
|
||||
benchmarking_results_in_json_string=processed_results_json,
|
||||
)
|
||||
f.write(results)
|
||||
|
||||
# document benchmarking results in json
|
||||
with open(results_folder / "benchmark_results.json", "w") as f:
|
||||
|
||||
results = latency_results.to_dict(
|
||||
orient='records') + throughput_results.to_dict(
|
||||
orient='records') + serving_results.to_dict(orient='records')
|
||||
results = (
|
||||
latency_results.to_dict(orient="records")
|
||||
+ throughput_results.to_dict(orient="records")
|
||||
+ serving_results.to_dict(orient="records")
|
||||
)
|
||||
f.write(json.dumps(results))
|
||||
|
||||
@ -14,15 +14,12 @@ def main(model, cachedir):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Download and save Hugging Face tokenizer")
|
||||
parser.add_argument("--model",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Name of the model")
|
||||
parser.add_argument("--cachedir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Directory to save the tokenizer")
|
||||
description="Download and save Hugging Face tokenizer"
|
||||
)
|
||||
parser.add_argument("--model", type=str, required=True, help="Name of the model")
|
||||
parser.add_argument(
|
||||
"--cachedir", type=str, required=True, help="Directory to save the tokenizer"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args.model, args.cachedir)
|
||||
|
||||
@ -11,33 +11,33 @@ from tabulate import tabulate
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(
|
||||
description=
|
||||
'Parse command line arguments for summary-nightly-results script.')
|
||||
parser.add_argument('--results-folder',
|
||||
type=str,
|
||||
required=True,
|
||||
help='The folder where the results are stored.')
|
||||
parser.add_argument('--description',
|
||||
type=str,
|
||||
required=True,
|
||||
help='Description of the results.')
|
||||
description="Parse command line arguments for summary-nightly-results script."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--results-folder",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The folder where the results are stored.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--description", type=str, required=True, help="Description of the results."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def get_perf(df, method, model, metric):
|
||||
|
||||
means = []
|
||||
|
||||
for qps in [2, 4, 8, 16, "inf"]:
|
||||
target = df['Test name'].str.contains(model)
|
||||
target = target & df['Engine'].str.contains(method)
|
||||
target = target & df['Test name'].str.contains("qps_" + str(qps))
|
||||
target = df["Test name"].str.contains(model)
|
||||
target = target & df["Engine"].str.contains(method)
|
||||
target = target & df["Test name"].str.contains("qps_" + str(qps))
|
||||
filtered_df = df[target]
|
||||
|
||||
if filtered_df.empty:
|
||||
means.append(0.)
|
||||
means.append(0.0)
|
||||
else:
|
||||
means.append(filtered_df[metric].values[0])
|
||||
|
||||
@ -45,7 +45,6 @@ def get_perf(df, method, model, metric):
|
||||
|
||||
|
||||
def get_perf_w_std(df, method, model, metric):
|
||||
|
||||
if metric in ["TTFT", "ITL"]:
|
||||
mean = get_perf(df, method, model, "Mean " + metric + " (ms)")
|
||||
mean = mean.tolist()
|
||||
@ -60,7 +59,8 @@ def get_perf_w_std(df, method, model, metric):
|
||||
else:
|
||||
assert metric == "Tput"
|
||||
mean = get_perf(df, method, model, "Input Tput (tok/s)") + get_perf(
|
||||
df, method, model, "Output Tput (tok/s)")
|
||||
df, method, model, "Output Tput (tok/s)"
|
||||
)
|
||||
mean = mean.tolist()
|
||||
std = None
|
||||
|
||||
@ -80,18 +80,17 @@ def main(args):
|
||||
# generate markdown table
|
||||
df = pd.DataFrame.from_dict(results)
|
||||
|
||||
md_table = tabulate(df, headers='keys', tablefmt='pipe', showindex=False)
|
||||
md_table = tabulate(df, headers="keys", tablefmt="pipe", showindex=False)
|
||||
|
||||
with open(args.description) as f:
|
||||
description = f.read()
|
||||
|
||||
description = description.format(
|
||||
nightly_results_benchmarking_table=md_table)
|
||||
description = description.format(nightly_results_benchmarking_table=md_table)
|
||||
|
||||
with open("nightly_results.md", "w") as f:
|
||||
f.write(description)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
args = parse_arguments()
|
||||
main(args)
|
||||
|
||||
@ -34,10 +34,8 @@ serving_column_mapping = {
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# collect results
|
||||
for test_file in results_folder.glob("*.json"):
|
||||
|
||||
with open(test_file) as f:
|
||||
raw_result = json.loads(f.read())
|
||||
|
||||
@ -56,17 +54,16 @@ if __name__ == "__main__":
|
||||
serving_results = pd.DataFrame.from_dict(serving_results)
|
||||
|
||||
if not serving_results.empty:
|
||||
serving_results = serving_results[list(
|
||||
serving_column_mapping.keys())].rename(
|
||||
columns=serving_column_mapping)
|
||||
serving_results = serving_results[list(serving_column_mapping.keys())].rename(
|
||||
columns=serving_column_mapping
|
||||
)
|
||||
|
||||
serving_md_table_with_headers = tabulate(serving_results,
|
||||
headers='keys',
|
||||
tablefmt='pipe',
|
||||
showindex=False)
|
||||
serving_md_table_with_headers = tabulate(
|
||||
serving_results, headers="keys", tablefmt="pipe", showindex=False
|
||||
)
|
||||
# remove the first line of header
|
||||
serving_md_table_lines = serving_md_table_with_headers.split('\n')
|
||||
serving_md_table_without_header = '\n'.join(serving_md_table_lines[2:])
|
||||
serving_md_table_lines = serving_md_table_with_headers.split("\n")
|
||||
serving_md_table_without_header = "\n".join(serving_md_table_lines[2:])
|
||||
|
||||
prefix = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
prefix = prefix + "_" + os.environ.get("CURRENT_LLM_SERVING_ENGINE")
|
||||
@ -76,10 +73,9 @@ if __name__ == "__main__":
|
||||
# document results with header.
|
||||
# for those who wants to reproduce our benchmark.
|
||||
f.write(serving_md_table_with_headers)
|
||||
f.write('\n')
|
||||
f.write("\n")
|
||||
|
||||
# document benchmarking results in json
|
||||
with open(results_folder / f"{prefix}_nightly_results.json", "w") as f:
|
||||
|
||||
results = serving_results.to_dict(orient='records')
|
||||
results = serving_results.to_dict(orient="records")
|
||||
f.write(json.dumps(results))
|
||||
|
||||
51
.buildkite/pyproject.toml
Normal file
51
.buildkite/pyproject.toml
Normal file
@ -0,0 +1,51 @@
|
||||
# This local pyproject file is part of the migration from yapf to ruff format.
|
||||
# It uses the same core rules as the main pyproject.toml file, but with the
|
||||
# following differences:
|
||||
# - ruff line length is overridden to 88
|
||||
# - deprecated typing ignores (UP006, UP035) have been removed
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 88
|
||||
exclude = [
|
||||
# External file, leaving license intact
|
||||
"examples/other/fp8/quantizer/quantize.py",
|
||||
"vllm/vllm_flash_attn/flash_attn_interface.pyi"
|
||||
]
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"vllm/third_party/**" = ["ALL"]
|
||||
"vllm/version.py" = ["F401"]
|
||||
"vllm/_version.py" = ["ALL"]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
# pycodestyle
|
||||
"E",
|
||||
# Pyflakes
|
||||
"F",
|
||||
# pyupgrade
|
||||
"UP",
|
||||
# flake8-bugbear
|
||||
"B",
|
||||
# flake8-simplify
|
||||
"SIM",
|
||||
# isort
|
||||
"I",
|
||||
# flake8-logging-format
|
||||
"G",
|
||||
]
|
||||
ignore = [
|
||||
# star imports
|
||||
"F405", "F403",
|
||||
# lambda expression assignment
|
||||
"E731",
|
||||
# Loop control variable not used within loop body
|
||||
"B007",
|
||||
# f-string format
|
||||
"UP032",
|
||||
# Can remove once 3.10+ is the minimum Python version
|
||||
"UP007",
|
||||
]
|
||||
|
||||
[tool.ruff.format]
|
||||
docstring-code-format = true
|
||||
@ -3,6 +3,9 @@
|
||||
# This script runs test inside the corresponding ROCm docker container.
|
||||
set -o pipefail
|
||||
|
||||
# Export Python path
|
||||
export PYTHONPATH=".."
|
||||
|
||||
# Print ROCm version
|
||||
echo "--- Confirming Clean Initial State"
|
||||
while true; do
|
||||
@ -74,6 +77,23 @@ HF_MOUNT="/root/.cache/huggingface"
|
||||
|
||||
commands=$@
|
||||
echo "Commands:$commands"
|
||||
|
||||
if [[ $commands == *"pytest -v -s basic_correctness/test_basic_correctness.py"* ]]; then
|
||||
commands=${commands//"pytest -v -s basic_correctness/test_basic_correctness.py"/"VLLM_USE_TRITON_FLASH_ATTN=0 pytest -v -s basic_correctness/test_basic_correctness.py"}
|
||||
fi
|
||||
|
||||
if [[ $commands == *"pytest -v -s models/test_registry.py"* ]]; then
|
||||
commands=${commands//"pytest -v -s models/test_registry.py"/"pytest -v -s models/test_registry.py -k 'not BambaForCausalLM and not GritLM and not Mamba2ForCausalLM and not Zamba2ForCausalLM'"}
|
||||
fi
|
||||
|
||||
if [[ $commands == *"VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2'"* ]]; then
|
||||
commands=${commands//"VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2'"/"VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2 and not BambaForCausalLM and not Gemma2ForCausalLM and not Grok1ModelForCausalLM and not Zamba2ForCausalLM and not Gemma2Model and not GritLM'"}
|
||||
fi
|
||||
|
||||
if [[ $commands == *"pytest -v -s compile/test_basic_correctness.py"* ]]; then
|
||||
commands=${commands//"pytest -v -s compile/test_basic_correctness.py"/"VLLM_USE_TRITON_FLASH_ATTN=0 pytest -v -s compile/test_basic_correctness.py"}
|
||||
fi
|
||||
|
||||
#ignore certain kernels tests
|
||||
if [[ $commands == *" kernels/core"* ]]; then
|
||||
commands="${commands} \
|
||||
@ -161,6 +181,8 @@ fi
|
||||
|
||||
|
||||
PARALLEL_JOB_COUNT=8
|
||||
MYPYTHONPATH=".."
|
||||
|
||||
# check if the command contains shard flag, we will run all shards in parallel because the host have 8 GPUs.
|
||||
if [[ $commands == *"--shard-id="* ]]; then
|
||||
# assign job count as the number of shards used
|
||||
@ -181,6 +203,7 @@ if [[ $commands == *"--shard-id="* ]]; then
|
||||
-e AWS_SECRET_ACCESS_KEY \
|
||||
-v "${HF_CACHE}:${HF_MOUNT}" \
|
||||
-e "HF_HOME=${HF_MOUNT}" \
|
||||
-e "PYTHONPATH=${MYPYTHONPATH}" \
|
||||
--name "${container_name}_${GPU}" \
|
||||
"${image_name}" \
|
||||
/bin/bash -c "${commands_gpu}" \
|
||||
@ -211,6 +234,7 @@ else
|
||||
-e AWS_SECRET_ACCESS_KEY \
|
||||
-v "${HF_CACHE}:${HF_MOUNT}" \
|
||||
-e "HF_HOME=${HF_MOUNT}" \
|
||||
-e "PYTHONPATH=${MYPYTHONPATH}" \
|
||||
--name "${container_name}" \
|
||||
"${image_name}" \
|
||||
/bin/bash -c "${commands}"
|
||||
|
||||
@ -32,9 +32,12 @@ function cpu_tests() {
|
||||
set -e
|
||||
pip install pytest pytest-asyncio einops peft Pillow soundfile transformers_stream_generator matplotlib
|
||||
pip install sentence-transformers datamodel_code_generator
|
||||
pytest -v -s tests/models/embedding/language/test_cls_models.py::test_classification_models[float-jason9693/Qwen2.5-1.5B-apeach]
|
||||
pytest -v -s tests/models/embedding/language/test_embedding.py::test_models[half-BAAI/bge-base-en-v1.5]
|
||||
pytest -v -s tests/models/encoder_decoder/language -m cpu_model"
|
||||
pytest -v -s tests/models/language/generation/test_bart.py -m cpu_model
|
||||
pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-openai-community/gpt2]
|
||||
pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-facebook/opt-125m]
|
||||
pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-google/gemma-1.1-2b-it]
|
||||
pytest -v -s tests/models/language/pooling/test_classification.py::test_models[float-jason9693/Qwen2.5-1.5B-apeach]
|
||||
pytest -v -s tests/models/language/pooling/test_embedding.py::test_models[half-BAAI/bge-base-en-v1.5]"
|
||||
}
|
||||
|
||||
# All of CPU tests are expected to be finished less than 40 mins.
|
||||
|
||||
@ -26,27 +26,27 @@ docker run --privileged --net host --shm-size=16G -it \
|
||||
&& tpu-info \
|
||||
&& { \
|
||||
echo TEST_0: Running test_perf.py; \
|
||||
pytest -s -v /workspace/vllm/tests/tpu/test_perf.py; \
|
||||
python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_perf.py; \
|
||||
echo TEST_0_EXIT_CODE: \$?; \
|
||||
} & \
|
||||
&& { \
|
||||
{ \
|
||||
echo TEST_1: Running test_compilation.py; \
|
||||
pytest -s -v /workspace/vllm/tests/tpu/test_compilation.py; \
|
||||
python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_compilation.py; \
|
||||
echo TEST_1_EXIT_CODE: \$?; \
|
||||
} & \
|
||||
{ \
|
||||
echo TEST_2: Running test_basic.py; \
|
||||
pytest -s -v /workspace/vllm/tests/v1/tpu/test_basic.py; \
|
||||
python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_basic.py; \
|
||||
echo TEST_2_EXIT_CODE: \$?; \
|
||||
} & \
|
||||
{ \
|
||||
echo TEST_3: Running test_accuracy.py::test_lm_eval_accuracy_v1_engine; \
|
||||
pytest -s -v /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine; \
|
||||
python3 -m pytest -s -v /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine; \
|
||||
echo TEST_3_EXIT_CODE: \$?; \
|
||||
} & \
|
||||
{ \
|
||||
echo TEST_4: Running test_quantization_accuracy.py; \
|
||||
pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py; \
|
||||
python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py; \
|
||||
echo TEST_4_EXIT_CODE: \$?; \
|
||||
} & \
|
||||
{ \
|
||||
@ -56,43 +56,43 @@ docker run --privileged --net host --shm-size=16G -it \
|
||||
} & \
|
||||
{ \
|
||||
echo TEST_6: Running test_tpu_model_runner.py; \
|
||||
pytest -s -v /workspace/vllm/tests/tpu/worker/test_tpu_model_runner.py; \
|
||||
python3 -m pytest -s -v /workspace/vllm/tests/tpu/worker/test_tpu_model_runner.py; \
|
||||
echo TEST_6_EXIT_CODE: \$?; \
|
||||
} & \
|
||||
&& { \
|
||||
{ \
|
||||
echo TEST_7: Running test_sampler.py; \
|
||||
pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py; \
|
||||
python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py; \
|
||||
echo TEST_7_EXIT_CODE: \$?; \
|
||||
} & \
|
||||
&& { \
|
||||
{ \
|
||||
echo TEST_8: Running test_topk_topp_sampler.py; \
|
||||
pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py; \
|
||||
python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py; \
|
||||
echo TEST_8_EXIT_CODE: \$?; \
|
||||
} & \
|
||||
&& { \
|
||||
{ \
|
||||
echo TEST_9: Running test_multimodal.py; \
|
||||
pytest -s -v /workspace/vllm/tests/v1/tpu/test_multimodal.py; \
|
||||
python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_multimodal.py; \
|
||||
echo TEST_9_EXIT_CODE: \$?; \
|
||||
} & \
|
||||
&& { \
|
||||
{ \
|
||||
echo TEST_10: Running test_pallas.py; \
|
||||
pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py; \
|
||||
python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py; \
|
||||
echo TEST_10_EXIT_CODE: \$?; \
|
||||
} & \
|
||||
&& { \
|
||||
{ \
|
||||
echo TEST_11: Running test_struct_output_generate.py; \
|
||||
pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py; \
|
||||
python3 -m pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py; \
|
||||
echo TEST_11_EXIT_CODE: \$?; \
|
||||
} & \
|
||||
&& { \
|
||||
{ \
|
||||
echo TEST_12: Running test_moe_pallas.py; \
|
||||
pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py; \
|
||||
python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py; \
|
||||
echo TEST_12_EXIT_CODE: \$?; \
|
||||
} & \
|
||||
# Disable the TPU LoRA tests until the feature is activated
|
||||
# && { \
|
||||
# & { \
|
||||
# echo TEST_13: Running test_moe_pallas.py; \
|
||||
# pytest -s -v /workspace/vllm/tests/tpu/lora/; \
|
||||
# python3 -m pytest -s -v /workspace/vllm/tests/tpu/lora/; \
|
||||
# echo TEST_13_EXIT_CODE: \$?; \
|
||||
# } & \
|
||||
wait \
|
||||
|
||||
@ -75,3 +75,4 @@ else
|
||||
fi
|
||||
|
||||
aws s3 cp "$wheel" "s3://vllm-wheels/$version/"
|
||||
aws s3 cp index.html "s3://vllm-wheels/$version/vllm/index.html"
|
||||
|
||||
@ -32,6 +32,7 @@ steps:
|
||||
##### fast check tests #####
|
||||
|
||||
- label: Documentation Build # 2min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/test_docs/docs"
|
||||
fast_check: true
|
||||
no_gpu: True
|
||||
@ -42,6 +43,7 @@ steps:
|
||||
- grep \"sig sig-object py\" build/html/api/vllm/vllm.sampling_params.html
|
||||
|
||||
- label: Async Engine, Inputs, Utils, Worker Test # 24min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/mq_llm_engine
|
||||
@ -62,6 +64,7 @@ steps:
|
||||
- pytest -v -s worker # Worker
|
||||
|
||||
- label: Python-only Installation Test
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- tests/standalone_tests/python_only_compile.sh
|
||||
- setup.py
|
||||
@ -69,7 +72,7 @@ steps:
|
||||
- bash standalone_tests/python_only_compile.sh
|
||||
|
||||
- label: Basic Correctness Test # 30min
|
||||
#mirror_hardwares: [amd]
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
fast_check: true
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
@ -86,6 +89,7 @@ steps:
|
||||
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
|
||||
|
||||
- label: Chunked Prefill Test
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/basic_correctness/test_chunked_prefill
|
||||
@ -94,7 +98,7 @@ steps:
|
||||
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
|
||||
|
||||
- label: Core Test # 10min
|
||||
mirror_hardwares: [amd]
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
fast_check: true
|
||||
source_file_dependencies:
|
||||
- vllm/core
|
||||
@ -104,10 +108,10 @@ steps:
|
||||
- pytest -v -s core
|
||||
|
||||
- label: Entrypoints Test # 40min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
fast_check: true
|
||||
torch_nightly: true
|
||||
#mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/entrypoints/llm
|
||||
@ -126,6 +130,7 @@ steps:
|
||||
- VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
|
||||
|
||||
- label: Distributed Tests (4 GPUs) # 10min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 4
|
||||
source_file_dependencies:
|
||||
@ -143,6 +148,8 @@ steps:
|
||||
# test with tp=2 and external_dp=2
|
||||
- VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
|
||||
- torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
|
||||
# test with tp=2 and pp=2
|
||||
- PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
|
||||
# test with internal dp
|
||||
- python3 ../examples/offline_inference/data_parallel.py
|
||||
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
||||
@ -153,12 +160,12 @@ steps:
|
||||
# TODO: create a dedicated test section for multi-GPU example tests
|
||||
# when we have multiple distributed example tests
|
||||
- pushd ../examples/offline_inference
|
||||
- python3 rlhf.py
|
||||
- RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
|
||||
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py
|
||||
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
|
||||
- popd
|
||||
|
||||
- label: Metrics, Tracing Test # 10min
|
||||
mirror_hardwares: [amd]
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
num_gpus: 2
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -172,7 +179,7 @@ steps:
|
||||
##### 1 GPU test #####
|
||||
|
||||
- label: Regression Test # 5min
|
||||
#mirror_hardwares: [amd]
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/test_regression
|
||||
@ -182,7 +189,7 @@ steps:
|
||||
working_dir: "/vllm-workspace/tests" # optional
|
||||
|
||||
- label: Engine Test # 10min
|
||||
mirror_hardwares: [amd]
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/engine
|
||||
@ -196,7 +203,7 @@ steps:
|
||||
- pytest -v -s tokenization
|
||||
|
||||
- label: V1 Test
|
||||
#mirror_hardwares: [amd]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/v1
|
||||
@ -209,8 +216,8 @@ steps:
|
||||
- pytest -v -s v1/worker
|
||||
- pytest -v -s v1/structured_output
|
||||
- pytest -v -s v1/spec_decode
|
||||
- pytest -v -s v1/kv_connector/unit
|
||||
- pytest -v -s v1/test_serial_utils.py
|
||||
- pytest -v -s v1/test_stats.py
|
||||
- pytest -v -s v1/test_utils.py
|
||||
- pytest -v -s v1/test_oracle.py
|
||||
# TODO: accuracy does not match, whether setting
|
||||
@ -221,8 +228,8 @@ steps:
|
||||
- pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine
|
||||
|
||||
- label: Examples Test # 25min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/examples"
|
||||
#mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- vllm/entrypoints
|
||||
- examples/
|
||||
@ -246,7 +253,7 @@ steps:
|
||||
- VLLM_USE_V1=0 python3 offline_inference/profiling.py --model facebook/opt-125m run_num_steps --num-steps 2
|
||||
|
||||
- label: Prefix Caching Test # 9min
|
||||
mirror_hardwares: [amd]
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/prefix_caching
|
||||
@ -254,6 +261,7 @@ steps:
|
||||
- pytest -v -s prefix_caching
|
||||
|
||||
- label: Samplers Test # 36min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- vllm/model_executor/layers
|
||||
- vllm/sampling_metadata.py
|
||||
@ -264,7 +272,7 @@ steps:
|
||||
- VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers
|
||||
|
||||
- label: LogitsProcessor Test # 5min
|
||||
mirror_hardwares: [amd]
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
source_file_dependencies:
|
||||
- vllm/model_executor/layers
|
||||
- vllm/model_executor/guided_decoding
|
||||
@ -275,6 +283,7 @@ steps:
|
||||
- pytest -v -s model_executor/test_guided_processors.py
|
||||
|
||||
- label: Speculative decoding tests # 40min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- vllm/spec_decode
|
||||
- tests/spec_decode
|
||||
@ -285,7 +294,7 @@ steps:
|
||||
- pytest -v -s spec_decode/e2e/test_eagle_correctness.py
|
||||
|
||||
- label: LoRA Test %N # 15min each
|
||||
#mirror_hardwares: [amd]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- vllm/lora
|
||||
- tests/lora
|
||||
@ -293,6 +302,7 @@ steps:
|
||||
parallelism: 4
|
||||
|
||||
- label: PyTorch Compilation Unit Tests
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -300,9 +310,11 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s compile/test_pass_manager.py
|
||||
- pytest -v -s compile/test_fusion.py
|
||||
- pytest -v -s compile/test_silu_mul_quant_fusion.py
|
||||
- pytest -v -s compile/test_sequence_parallelism.py
|
||||
|
||||
- label: PyTorch Fullgraph Smoke Test # 9min
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -314,6 +326,7 @@ steps:
|
||||
- pytest -v -s compile/piecewise/test_toy_llama.py
|
||||
|
||||
- label: PyTorch Fullgraph Test # 18min
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -322,7 +335,7 @@ steps:
|
||||
- pytest -v -s compile/test_full_graph.py
|
||||
|
||||
- label: Kernels Core Operation Test
|
||||
mirror_hardwares: [amd]
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
source_file_dependencies:
|
||||
- csrc/
|
||||
- tests/kernels/core
|
||||
@ -330,7 +343,7 @@ steps:
|
||||
- pytest -v -s kernels/core
|
||||
|
||||
- label: Kernels Attention Test %N
|
||||
mirror_hardwares: [amd]
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
source_file_dependencies:
|
||||
- csrc/attention/
|
||||
- vllm/attention
|
||||
@ -341,7 +354,7 @@ steps:
|
||||
parallelism: 2
|
||||
|
||||
- label: Kernels Quantization Test %N
|
||||
mirror_hardwares: [amd]
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
source_file_dependencies:
|
||||
- csrc/quantization/
|
||||
- vllm/model_executor/layers/quantization
|
||||
@ -351,7 +364,7 @@ steps:
|
||||
parallelism: 2
|
||||
|
||||
- label: Kernels MoE Test
|
||||
#mirror_hardwares: [amd]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- csrc/moe/
|
||||
- tests/kernels/moe
|
||||
@ -360,7 +373,7 @@ steps:
|
||||
- pytest -v -s kernels/moe
|
||||
|
||||
- label: Kernels Mamba Test
|
||||
#mirror_hardwares: [amd]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- csrc/mamba/
|
||||
- tests/kernels/mamba
|
||||
@ -368,7 +381,7 @@ steps:
|
||||
- pytest -v -s kernels/mamba
|
||||
|
||||
- label: Tensorizer Test # 11min
|
||||
# mirror_hardwares: [amd]
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
soft_fail: true
|
||||
source_file_dependencies:
|
||||
- vllm/model_executor/model_loader
|
||||
@ -379,14 +392,15 @@ steps:
|
||||
- pytest -v -s tensorizer_loader
|
||||
|
||||
- label: Benchmarks # 9min
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
working_dir: "/vllm-workspace/.buildkite"
|
||||
mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- benchmarks/
|
||||
commands:
|
||||
- bash scripts/run-benchmarks.sh
|
||||
|
||||
- label: Benchmarks CLI Test # 10min
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/benchmarks/
|
||||
@ -394,6 +408,7 @@ steps:
|
||||
- pytest -v -s benchmarks/
|
||||
|
||||
- label: Quantization Test
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- csrc/
|
||||
- vllm/model_executor/layers/quantization
|
||||
@ -402,6 +417,7 @@ steps:
|
||||
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization
|
||||
|
||||
- label: LM Eval Small Models # 53min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
|
||||
source_file_dependencies:
|
||||
- csrc/
|
||||
@ -411,6 +427,7 @@ steps:
|
||||
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-small.txt --tp-size=1
|
||||
|
||||
- label: OpenAI API correctness
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- csrc/
|
||||
- vllm/entrypoints/openai/
|
||||
@ -419,6 +436,7 @@ steps:
|
||||
- pytest -s entrypoints/openai/correctness/
|
||||
|
||||
- label: Encoder Decoder tests # 5min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/encoder_decoder
|
||||
@ -426,8 +444,8 @@ steps:
|
||||
- pytest -v -s encoder_decoder
|
||||
|
||||
- label: OpenAI-Compatible Tool Use # 20 min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
fast_check: false
|
||||
#mirror_hardwares: [ amd ]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/tool_use
|
||||
@ -439,6 +457,7 @@ steps:
|
||||
##### models test #####
|
||||
|
||||
- label: Basic Models Test # 24min
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -454,16 +473,19 @@ steps:
|
||||
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'plamo2'
|
||||
|
||||
- label: Language Models Test (Standard)
|
||||
#mirror_hardwares: [amd]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/models/language
|
||||
commands:
|
||||
# Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
|
||||
- pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8'
|
||||
- pip freeze | grep -E 'torch'
|
||||
- pytest -v -s models/language -m core_model
|
||||
|
||||
- label: Language Models Test (Extended)
|
||||
mirror_hardwares: [amdexperimental]
|
||||
optional: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -474,17 +496,20 @@ steps:
|
||||
- pytest -v -s models/language -m 'not core_model'
|
||||
|
||||
- label: Multi-Modal Models Test (Standard)
|
||||
#mirror_hardwares: [amd]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/models/multimodal
|
||||
commands:
|
||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||
- pip freeze | grep -E 'torch'
|
||||
- pytest -v -s models/multimodal/processing
|
||||
- pytest -v -s --ignore models/multimodal/generation/test_whisper.py models/multimodal -m core_model
|
||||
- cd .. && pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work
|
||||
|
||||
- label: Multi-Modal Models Test (Extended) 1
|
||||
mirror_hardwares: [amdexperimental]
|
||||
optional: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -494,6 +519,7 @@ steps:
|
||||
- pytest -v -s --ignore models/multimodal/generation/test_common.py --ignore models/multimodal/processing models/multimodal -m 'not core_model'
|
||||
|
||||
- label: Multi-Modal Models Test (Extended) 2
|
||||
mirror_hardwares: [amdexperimental]
|
||||
optional: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -503,6 +529,7 @@ steps:
|
||||
- pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=0) and not core_model'
|
||||
|
||||
- label: Multi-Modal Models Test (Extended) 3
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
optional: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -512,7 +539,7 @@ steps:
|
||||
- pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=1) and not core_model'
|
||||
|
||||
- label: Quantized Models Test
|
||||
#mirror_hardwares: [amd]
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
source_file_dependencies:
|
||||
- vllm/model_executor/layers/quantization
|
||||
- tests/models/quantization
|
||||
@ -521,7 +548,7 @@ steps:
|
||||
|
||||
# This test is used only in PR development phase to test individual models and should never run on main
|
||||
- label: Custom Models Test
|
||||
mirror_hardwares: [amd]
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
optional: true
|
||||
commands:
|
||||
- echo 'Testing custom models...'
|
||||
@ -533,7 +560,7 @@ steps:
|
||||
##### multi gpus test #####
|
||||
|
||||
- label: Distributed Comm Ops Test # 7min
|
||||
mirror_hardwares: [amd]
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 2
|
||||
source_file_dependencies:
|
||||
@ -544,6 +571,7 @@ steps:
|
||||
- pytest -v -s distributed/test_shm_broadcast.py
|
||||
|
||||
- label: 2 Node Tests (4 GPUs in total) # 16min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 2
|
||||
num_nodes: 2
|
||||
@ -562,7 +590,7 @@ steps:
|
||||
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
|
||||
|
||||
- label: Distributed Tests (2 GPUs) # 40min
|
||||
#mirror_hardwares: [amd]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 2
|
||||
source_file_dependencies:
|
||||
@ -599,13 +627,14 @@ steps:
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown
|
||||
|
||||
- label: Plugin Tests (2 GPUs) # 40min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 2
|
||||
source_file_dependencies:
|
||||
- vllm/plugins/
|
||||
- tests/plugins/
|
||||
commands:
|
||||
# begin platform plugin tests, all the code in-between runs on dummy platform
|
||||
# begin platform plugin and general plugin tests, all the code in-between runs on dummy platform
|
||||
- pip install -e ./plugins/vllm_add_dummy_platform
|
||||
- pytest -v -s plugins_tests/test_platform_plugins.py
|
||||
- pip uninstall vllm_add_dummy_platform -y
|
||||
@ -616,8 +645,10 @@ steps:
|
||||
- pytest -v -s distributed/test_distributed_oot.py
|
||||
- pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process
|
||||
- pytest -v -s models/test_oot_registration.py # it needs a clean process
|
||||
- pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins
|
||||
|
||||
- label: Multi-step Tests (4 GPUs) # 36min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 4
|
||||
source_file_dependencies:
|
||||
@ -638,6 +669,7 @@ steps:
|
||||
- pytest -v -s multi_step/test_correctness_llm.py
|
||||
|
||||
- label: Pipeline Parallelism Test # 45min
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 4
|
||||
source_file_dependencies:
|
||||
@ -651,6 +683,7 @@ steps:
|
||||
- pytest -v -s distributed/test_pipeline_parallel.py
|
||||
|
||||
- label: LoRA TP Test (Distributed)
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
num_gpus: 4
|
||||
source_file_dependencies:
|
||||
- vllm/lora
|
||||
@ -666,6 +699,7 @@ steps:
|
||||
|
||||
|
||||
- label: Weight Loading Multiple GPU Test # 33min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 2
|
||||
source_file_dependencies:
|
||||
@ -675,6 +709,7 @@ steps:
|
||||
- bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt
|
||||
|
||||
- label: Weight Loading Multiple GPU Test - Large Models # optional
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 2
|
||||
gpu: a100
|
||||
|
||||
11
.github/mergify.yml
vendored
11
.github/mergify.yml
vendored
@ -163,6 +163,17 @@ pull_request_rules:
|
||||
|
||||
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
|
||||
|
||||
- name: assign reviewer for tensorizer changes
|
||||
conditions:
|
||||
- files~=^vllm/model_executor/model_loader/tensorizer.py
|
||||
- files~=^vllm/model_executor/model_loader/tensorizer_loader.py
|
||||
- files~=^tests/entrypoints/openai/test_tensorizer_entrypoint.py
|
||||
- files~=^tests/tensorizer_loader/
|
||||
actions:
|
||||
assign:
|
||||
users:
|
||||
- "sangstar"
|
||||
|
||||
- name: remove 'needs-rebase' label when conflict is resolved
|
||||
conditions:
|
||||
- -conflict
|
||||
|
||||
2
.github/workflows/add_label_automerge.yml
vendored
2
.github/workflows/add_label_automerge.yml
vendored
@ -1,4 +1,6 @@
|
||||
name: Add label on auto-merge enabled
|
||||
permissions:
|
||||
pull-requests: write
|
||||
on:
|
||||
pull_request_target:
|
||||
types:
|
||||
|
||||
3
.github/workflows/lint-and-deploy.yaml
vendored
3
.github/workflows/lint-and-deploy.yaml
vendored
@ -2,6 +2,9 @@ name: Lint and Deploy Charts
|
||||
|
||||
on: pull_request
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
lint-and-deploy:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
3
.github/workflows/pre-commit.yml
vendored
3
.github/workflows/pre-commit.yml
vendored
@ -5,6 +5,9 @@ on:
|
||||
push:
|
||||
branches: [main]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
2
.github/workflows/reminder_comment.yml
vendored
2
.github/workflows/reminder_comment.yml
vendored
@ -1,4 +1,6 @@
|
||||
name: PR Reminder Comment Bot
|
||||
permissions:
|
||||
pull-requests: write
|
||||
on:
|
||||
pull_request_target:
|
||||
types: [opened]
|
||||
|
||||
@ -16,6 +16,8 @@ repos:
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--output-format, github, --fix]
|
||||
- id: ruff-format
|
||||
files: ^(.buildkite|benchmarks)/.*
|
||||
- repo: https://github.com/codespell-project/codespell
|
||||
rev: v2.4.1
|
||||
hooks:
|
||||
|
||||
@ -230,6 +230,7 @@ set(VLLM_EXT_SRC
|
||||
"csrc/attention/paged_attention_v1.cu"
|
||||
"csrc/attention/paged_attention_v2.cu"
|
||||
"csrc/attention/merge_attn_states.cu"
|
||||
"csrc/attention/vertical_slash_index.cu"
|
||||
"csrc/pos_encoding_kernels.cu"
|
||||
"csrc/activation_kernels.cu"
|
||||
"csrc/layernorm_kernels.cu"
|
||||
@ -288,6 +289,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
|
||||
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
|
||||
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
|
||||
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu"
|
||||
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
|
||||
"csrc/cutlass_extensions/common.cpp"
|
||||
"csrc/attention/mla/cutlass_mla_entry.cu")
|
||||
@ -299,7 +301,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# Only build Marlin kernels if we are building for at least some compatible archs.
|
||||
# Keep building Marlin for 9.0 as there are some group sizes and shapes that
|
||||
# are not supported by Machete yet.
|
||||
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
|
||||
# 9.0 for latest bf16 atomicAdd PTX
|
||||
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;9.0+PTX" "${CUDA_ARCHS}")
|
||||
if (MARLIN_ARCHS)
|
||||
|
||||
#
|
||||
@ -443,8 +446,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
#
|
||||
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
|
||||
# kernels for the remaining archs that are not already built for 3x.
|
||||
# (Build 8.9 for FP8)
|
||||
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
|
||||
"7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
|
||||
"7.5;8.0;8.9+PTX" "${CUDA_ARCHS}")
|
||||
# subtract out the archs that are already built for 3x
|
||||
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
|
||||
if (SCALED_MM_2X_ARCHS)
|
||||
@ -495,7 +499,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND FP4_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
|
||||
"csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu")
|
||||
"csrc/quantization/fp4/nvfp4_experts_quant.cu"
|
||||
"csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu"
|
||||
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${FP4_ARCHS}")
|
||||
@ -533,7 +539,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works
|
||||
# on Hopper). get_cutlass_moe_mm_data should only be compiled if it's possible
|
||||
# to compile MoE kernels that use its output.
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
|
||||
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu"
|
||||
"csrc/quantization/cutlass_w8a8/moe/moe_data.cu")
|
||||
@ -671,7 +677,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
CUDA_ARCHS "${CUDA_ARCHS}")
|
||||
|
||||
list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}")
|
||||
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
|
||||
# 9.0 for latest bf16 atomicAdd PTX
|
||||
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;9.0+PTX" "${CUDA_ARCHS}")
|
||||
if (MARLIN_MOE_ARCHS)
|
||||
|
||||
#
|
||||
|
||||
@ -74,7 +74,7 @@ vLLM is flexible and easy to use with:
|
||||
- OpenAI-compatible API server
|
||||
- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, TPU, and AWS Neuron.
|
||||
- Prefix caching support
|
||||
- Multi-lora support
|
||||
- Multi-LoRA support
|
||||
|
||||
vLLM seamlessly supports most popular open-source models on HuggingFace, including:
|
||||
- Transformer-like LLMs (e.g., Llama)
|
||||
|
||||
@ -12,8 +12,7 @@ from typing import Optional, Union
|
||||
import aiohttp
|
||||
import huggingface_hub.constants
|
||||
from tqdm.asyncio import tqdm
|
||||
from transformers import (AutoTokenizer, PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast)
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
# NOTE(simon): do not import vLLM here so the benchmark script
|
||||
# can run without vLLM installed.
|
||||
@ -43,8 +42,7 @@ class RequestFuncOutput:
|
||||
latency: float = 0.0
|
||||
output_tokens: int = 0
|
||||
ttft: float = 0.0 # Time to first token
|
||||
itl: list[float] = field(
|
||||
default_factory=list) # list of inter-token latencies
|
||||
itl: list[float] = field(default_factory=list) # list of inter-token latencies
|
||||
tpot: float = 0.0 # avg next-token latencies
|
||||
prompt_len: int = 0
|
||||
error: str = ""
|
||||
@ -57,8 +55,9 @@ async def async_request_tgi(
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith("generate_stream")
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
async with aiohttp.ClientSession(
|
||||
trust_env=True, timeout=AIOHTTP_TIMEOUT
|
||||
) as session:
|
||||
params = {
|
||||
"max_new_tokens": request_func_input.output_len,
|
||||
"do_sample": True,
|
||||
@ -105,8 +104,7 @@ async def async_request_tgi(
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp -
|
||||
most_recent_timestamp)
|
||||
output.itl.append(timestamp - most_recent_timestamp)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
|
||||
@ -133,8 +131,9 @@ async def async_request_trt_llm(
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith("generate_stream")
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
async with aiohttp.ClientSession(
|
||||
trust_env=True, timeout=AIOHTTP_TIMEOUT
|
||||
) as session:
|
||||
payload = {
|
||||
"accumulate_tokens": True,
|
||||
"text_input": request_func_input.prompt,
|
||||
@ -159,8 +158,7 @@ async def async_request_trt_llm(
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
||||
"data:")
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix("data:")
|
||||
|
||||
data = json.loads(chunk)
|
||||
output.generated_text += data["text_output"]
|
||||
@ -172,8 +170,7 @@ async def async_request_trt_llm(
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp -
|
||||
most_recent_timestamp)
|
||||
output.itl.append(timestamp - most_recent_timestamp)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
|
||||
@ -197,9 +194,9 @@ async def async_request_deepspeed_mii(
|
||||
request_func_input: RequestFuncInput,
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
|
||||
async with aiohttp.ClientSession(
|
||||
trust_env=True, timeout=AIOHTTP_TIMEOUT
|
||||
) as session:
|
||||
payload = {
|
||||
"model": request_func_input.model,
|
||||
"prompt": request_func_input.prompt,
|
||||
@ -217,19 +214,21 @@ async def async_request_deepspeed_mii(
|
||||
|
||||
st = time.perf_counter()
|
||||
try:
|
||||
async with session.post(url=request_func_input.api_url,
|
||||
json=payload) as response:
|
||||
async with session.post(
|
||||
url=request_func_input.api_url, json=payload
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
parsed_resp = await response.json()
|
||||
output.latency = time.perf_counter() - st
|
||||
if "choices" in parsed_resp:
|
||||
output.generated_text = parsed_resp["choices"][0][
|
||||
"text"]
|
||||
output.generated_text = parsed_resp["choices"][0]["text"]
|
||||
elif "text" in parsed_resp:
|
||||
output.generated_text = parsed_resp["text"][0]
|
||||
else:
|
||||
output.error = ("Unexpected response format: "
|
||||
"neither 'choices' nor 'text' found")
|
||||
output.error = (
|
||||
"Unexpected response format: "
|
||||
"neither 'choices' nor 'text' found"
|
||||
)
|
||||
output.success = False
|
||||
output.success = True
|
||||
else:
|
||||
@ -250,15 +249,17 @@ async def async_request_openai_completions(
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith(
|
||||
("completions", "profile")
|
||||
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
|
||||
assert api_url.endswith(("completions", "profile")), (
|
||||
"OpenAI Completions API URL must end with 'completions' or 'profile'."
|
||||
)
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
async with aiohttp.ClientSession(
|
||||
trust_env=True, timeout=AIOHTTP_TIMEOUT
|
||||
) as session:
|
||||
payload = {
|
||||
"model": request_func_input.model_name \
|
||||
if request_func_input.model_name else request_func_input.model,
|
||||
"model": request_func_input.model_name
|
||||
if request_func_input.model_name
|
||||
else request_func_input.model,
|
||||
"prompt": request_func_input.prompt,
|
||||
"temperature": 0.0,
|
||||
"repetition_penalty": 1.0,
|
||||
@ -273,9 +274,7 @@ async def async_request_openai_completions(
|
||||
payload["ignore_eos"] = request_func_input.ignore_eos
|
||||
if request_func_input.extra_body:
|
||||
payload.update(request_func_input.extra_body)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
|
||||
}
|
||||
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
@ -284,8 +283,9 @@ async def async_request_openai_completions(
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload,
|
||||
headers=headers) as response:
|
||||
async with session.post(
|
||||
url=api_url, json=payload, headers=headers
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
first_chunk_received = False
|
||||
async for chunk_bytes in response.content:
|
||||
@ -293,8 +293,7 @@ async def async_request_openai_completions(
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
||||
"data: ")
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
|
||||
if chunk != "[DONE]":
|
||||
data = json.loads(chunk)
|
||||
|
||||
@ -314,21 +313,20 @@ async def async_request_openai_completions(
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp -
|
||||
most_recent_timestamp)
|
||||
output.itl.append(timestamp - most_recent_timestamp)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
generated_text += text or ""
|
||||
elif usage := data.get("usage"):
|
||||
output.output_tokens = usage.get(
|
||||
"completion_tokens")
|
||||
output.output_tokens = usage.get("completion_tokens")
|
||||
if first_chunk_received:
|
||||
output.success = True
|
||||
else:
|
||||
output.success = False
|
||||
output.error = (
|
||||
"Never received a valid chunk to calculate TTFT."
|
||||
"This response will be marked as failed!")
|
||||
"This response will be marked as failed!"
|
||||
)
|
||||
output.generated_text = generated_text
|
||||
output.latency = most_recent_timestamp - st
|
||||
else:
|
||||
@ -349,23 +347,22 @@ async def async_request_openai_chat_completions(
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith(
|
||||
("chat/completions", "profile")
|
||||
), "OpenAI Chat Completions API URL must end with 'chat/completions'."
|
||||
assert api_url.endswith(("chat/completions", "profile")), (
|
||||
"OpenAI Chat Completions API URL must end with 'chat/completions'."
|
||||
)
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
async with aiohttp.ClientSession(
|
||||
trust_env=True, timeout=AIOHTTP_TIMEOUT
|
||||
) as session:
|
||||
content = [{"type": "text", "text": request_func_input.prompt}]
|
||||
if request_func_input.multi_modal_content:
|
||||
content.append(request_func_input.multi_modal_content)
|
||||
payload = {
|
||||
"model": request_func_input.model_name \
|
||||
if request_func_input.model_name else request_func_input.model,
|
||||
"model": request_func_input.model_name
|
||||
if request_func_input.model_name
|
||||
else request_func_input.model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": content
|
||||
},
|
||||
{"role": "user", "content": content},
|
||||
],
|
||||
"temperature": 0.0,
|
||||
"max_completion_tokens": request_func_input.output_len,
|
||||
@ -391,16 +388,16 @@ async def async_request_openai_chat_completions(
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload,
|
||||
headers=headers) as response:
|
||||
async with session.post(
|
||||
url=api_url, json=payload, headers=headers
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
async for chunk_bytes in response.content:
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
||||
"data: ")
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
|
||||
if chunk != "[DONE]":
|
||||
timestamp = time.perf_counter()
|
||||
data = json.loads(chunk)
|
||||
@ -414,13 +411,11 @@ async def async_request_openai_chat_completions(
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp -
|
||||
most_recent_timestamp)
|
||||
output.itl.append(timestamp - most_recent_timestamp)
|
||||
|
||||
generated_text += content or ""
|
||||
elif usage := data.get("usage"):
|
||||
output.output_tokens = usage.get(
|
||||
"completion_tokens")
|
||||
output.output_tokens = usage.get("completion_tokens")
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
|
||||
@ -446,25 +441,28 @@ async def async_request_openai_audio(
|
||||
) -> RequestFuncOutput:
|
||||
# Lazy import without PlaceholderModule to avoid vllm dep.
|
||||
import soundfile
|
||||
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith(
|
||||
("transcriptions", "translations"
|
||||
)), "OpenAI Chat Completions API URL must end with 'transcriptions' "
|
||||
assert api_url.endswith(("transcriptions", "translations")), (
|
||||
"OpenAI Chat Completions API URL must end with 'transcriptions' "
|
||||
)
|
||||
"or `translations`."
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
async with aiohttp.ClientSession(
|
||||
trust_env=True, timeout=AIOHTTP_TIMEOUT
|
||||
) as session:
|
||||
content = [{"type": "text", "text": request_func_input.prompt}]
|
||||
payload = {
|
||||
"model": request_func_input.model_name \
|
||||
if request_func_input.model_name else request_func_input.model,
|
||||
"model": request_func_input.model_name
|
||||
if request_func_input.model_name
|
||||
else request_func_input.model,
|
||||
"temperature": 0.0,
|
||||
"max_completion_tokens": request_func_input.output_len,
|
||||
"stream": True,
|
||||
"language": "en",
|
||||
# Flattened due to multipart/form-data
|
||||
"stream_include_usage": True,
|
||||
"stream_continuous_usage_stats": True
|
||||
"stream_continuous_usage_stats": True,
|
||||
}
|
||||
if request_func_input.extra_body:
|
||||
payload.update(request_func_input.extra_body)
|
||||
@ -479,9 +477,9 @@ async def async_request_openai_audio(
|
||||
buffer.seek(0)
|
||||
return buffer
|
||||
|
||||
with to_bytes(*request_func_input.multi_modal_content['audio']) as f:
|
||||
with to_bytes(*request_func_input.multi_modal_content["audio"]) as f:
|
||||
form = aiohttp.FormData()
|
||||
form.add_field('file', f, content_type='audio/wav')
|
||||
form.add_field("file", f, content_type="audio/wav")
|
||||
for key, value in payload.items():
|
||||
form.add_field(key, str(value))
|
||||
|
||||
@ -493,24 +491,22 @@ async def async_request_openai_audio(
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url,
|
||||
data=form,
|
||||
headers=headers) as response:
|
||||
async with session.post(
|
||||
url=api_url, data=form, headers=headers
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
async for chunk_bytes in response.content:
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
||||
"data: ")
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
|
||||
if chunk != "[DONE]":
|
||||
timestamp = time.perf_counter()
|
||||
data = json.loads(chunk)
|
||||
|
||||
if choices := data.get("choices"):
|
||||
content = choices[0]["delta"].get(
|
||||
"content")
|
||||
content = choices[0]["delta"].get("content")
|
||||
# First token
|
||||
if ttft == 0.0:
|
||||
ttft = timestamp - st
|
||||
@ -519,12 +515,14 @@ async def async_request_openai_audio(
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(
|
||||
timestamp - most_recent_timestamp)
|
||||
timestamp - most_recent_timestamp
|
||||
)
|
||||
|
||||
generated_text += content or ""
|
||||
elif usage := data.get("usage"):
|
||||
output.output_tokens = usage.get(
|
||||
"completion_tokens")
|
||||
"completion_tokens"
|
||||
)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
|
||||
@ -545,7 +543,7 @@ async def async_request_openai_audio(
|
||||
|
||||
|
||||
def get_model(pretrained_model_name_or_path: str) -> str:
|
||||
if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true':
|
||||
if os.getenv("VLLM_USE_MODELSCOPE", "False").lower() == "true":
|
||||
from modelscope import snapshot_download
|
||||
|
||||
from vllm.model_executor.model_loader.weight_utils import get_lock
|
||||
@ -556,7 +554,8 @@ def get_model(pretrained_model_name_or_path: str) -> str:
|
||||
model_path = snapshot_download(
|
||||
model_id=pretrained_model_name_or_path,
|
||||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||
ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])
|
||||
ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"],
|
||||
)
|
||||
|
||||
return model_path
|
||||
return pretrained_model_name_or_path
|
||||
@ -569,23 +568,23 @@ def get_tokenizer(
|
||||
**kwargs,
|
||||
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||
if pretrained_model_name_or_path is not None and not os.path.exists(
|
||||
pretrained_model_name_or_path):
|
||||
pretrained_model_name_or_path = get_model(
|
||||
pretrained_model_name_or_path)
|
||||
pretrained_model_name_or_path
|
||||
):
|
||||
pretrained_model_name_or_path = get_model(pretrained_model_name_or_path)
|
||||
if tokenizer_mode == "slow":
|
||||
if kwargs.get("use_fast", False):
|
||||
raise ValueError(
|
||||
"Cannot use the fast tokenizer in slow tokenizer mode.")
|
||||
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
|
||||
kwargs["use_fast"] = False
|
||||
if tokenizer_mode == "mistral":
|
||||
try:
|
||||
from vllm.transformers_utils.tokenizer import MistralTokenizer
|
||||
except ImportError as e:
|
||||
raise ImportError("MistralTokenizer requires vllm package.\n"
|
||||
"Please install it with `pip install vllm` "
|
||||
"to use mistral tokenizer mode.") from e
|
||||
return MistralTokenizer.from_pretrained(
|
||||
str(pretrained_model_name_or_path))
|
||||
raise ImportError(
|
||||
"MistralTokenizer requires vllm package.\n"
|
||||
"Please install it with `pip install vllm` "
|
||||
"to use mistral tokenizer mode."
|
||||
) from e
|
||||
return MistralTokenizer.from_pretrained(str(pretrained_model_name_or_path))
|
||||
else:
|
||||
return AutoTokenizer.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
@ -608,7 +607,7 @@ ASYNC_REQUEST_FUNCS = {
|
||||
}
|
||||
|
||||
OPENAI_COMPATIBLE_BACKENDS = [
|
||||
k for k, v in ASYNC_REQUEST_FUNCS.items()
|
||||
if v in (async_request_openai_completions,
|
||||
async_request_openai_chat_completions)
|
||||
k
|
||||
for k, v in ASYNC_REQUEST_FUNCS.items()
|
||||
if v in (async_request_openai_completions, async_request_openai_chat_completions)
|
||||
]
|
||||
|
||||
@ -82,14 +82,12 @@ class BenchmarkDataset(ABC):
|
||||
self.dataset_path = dataset_path
|
||||
# Set the random seed, ensuring that a None value is replaced with the
|
||||
# default seed.
|
||||
self.random_seed = (random_seed
|
||||
if random_seed is not None else self.DEFAULT_SEED)
|
||||
self.random_seed = random_seed if random_seed is not None else self.DEFAULT_SEED
|
||||
self.data = None
|
||||
|
||||
def apply_multimodal_chat_transformation(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_content: Optional[MultiModalDataDict] = None) -> list[dict]:
|
||||
self, prompt: str, mm_content: Optional[MultiModalDataDict] = None
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Transform a prompt and optional multimodal content into a chat format.
|
||||
This method is used for chat models that expect a specific conversation
|
||||
@ -111,8 +109,7 @@ class BenchmarkDataset(ABC):
|
||||
NotImplementedError: If a subclass does not implement this method.
|
||||
"""
|
||||
# TODO (jenniferzhao): add support for downloading data
|
||||
raise NotImplementedError(
|
||||
"load_data must be implemented in subclasses.")
|
||||
raise NotImplementedError("load_data must be implemented in subclasses.")
|
||||
|
||||
def get_random_lora_request(
|
||||
self,
|
||||
@ -158,8 +155,9 @@ class BenchmarkDataset(ABC):
|
||||
return lora_request, lora_tokenizer_cache[lora_id] or tokenizer
|
||||
|
||||
@abstractmethod
|
||||
def sample(self, tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int) -> list[SampleRequest]:
|
||||
def sample(
|
||||
self, tokenizer: PreTrainedTokenizerBase, num_requests: int
|
||||
) -> list[SampleRequest]:
|
||||
"""
|
||||
Abstract method to generate sample requests from the dataset.
|
||||
|
||||
@ -177,8 +175,9 @@ class BenchmarkDataset(ABC):
|
||||
"""
|
||||
raise NotImplementedError("sample must be implemented in subclasses.")
|
||||
|
||||
def maybe_oversample_requests(self, requests: list[SampleRequest],
|
||||
num_requests: int) -> None:
|
||||
def maybe_oversample_requests(
|
||||
self, requests: list[SampleRequest], num_requests: int
|
||||
) -> None:
|
||||
"""
|
||||
Oversamples the list of requests if its size is less than the desired
|
||||
number.
|
||||
@ -189,11 +188,9 @@ class BenchmarkDataset(ABC):
|
||||
"""
|
||||
if len(requests) < num_requests:
|
||||
random.seed(self.random_seed)
|
||||
additional = random.choices(requests,
|
||||
k=num_requests - len(requests))
|
||||
additional = random.choices(requests, k=num_requests - len(requests))
|
||||
requests.extend(additional)
|
||||
logger.info("Oversampled requests to reach %d total samples.",
|
||||
num_requests)
|
||||
logger.info("Oversampled requests to reach %d total samples.", num_requests)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@ -218,14 +215,14 @@ def is_valid_sequence(
|
||||
"""
|
||||
# Check for invalid conditions
|
||||
prompt_too_short = prompt_len < min_len
|
||||
output_too_short = (not skip_min_output_len_check) and (output_len
|
||||
< min_len)
|
||||
output_too_short = (not skip_min_output_len_check) and (output_len < min_len)
|
||||
prompt_too_long = prompt_len > max_prompt_len
|
||||
combined_too_long = (prompt_len + output_len) > max_total_len
|
||||
|
||||
# Return True if none of the invalid conditions are met
|
||||
return not (prompt_too_short or output_too_short or prompt_too_long
|
||||
or combined_too_long)
|
||||
return not (
|
||||
prompt_too_short or output_too_short or prompt_too_long or combined_too_long
|
||||
)
|
||||
|
||||
|
||||
@cache
|
||||
@ -257,28 +254,28 @@ def process_image(image: Any) -> Mapping[str, Any]:
|
||||
Raises:
|
||||
ValueError: If the input is not a supported type.
|
||||
"""
|
||||
if isinstance(image, dict) and 'bytes' in image:
|
||||
image = Image.open(BytesIO(image['bytes']))
|
||||
if isinstance(image, dict) and "bytes" in image:
|
||||
image = Image.open(BytesIO(image["bytes"]))
|
||||
if isinstance(image, Image.Image):
|
||||
image = image.convert("RGB")
|
||||
with io.BytesIO() as image_data:
|
||||
image.save(image_data, format="JPEG")
|
||||
image_base64 = base64.b64encode(
|
||||
image_data.getvalue()).decode("utf-8")
|
||||
image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8")
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{image_base64}"
|
||||
},
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{image_base64}"},
|
||||
}
|
||||
|
||||
if isinstance(image, str):
|
||||
image_url = (image if image.startswith(
|
||||
("http://", "file://")) else f"file://{image}")
|
||||
image_url = (
|
||||
image if image.startswith(("http://", "file://")) else f"file://{image}"
|
||||
)
|
||||
return {"type": "image_url", "image_url": {"url": image_url}}
|
||||
|
||||
raise ValueError(f"Invalid image input {image}. Must be a PIL.Image.Image"
|
||||
" or str or dictionary with raw image bytes.")
|
||||
raise ValueError(
|
||||
f"Invalid image input {image}. Must be a PIL.Image.Image"
|
||||
" or str or dictionary with raw image bytes."
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@ -318,8 +315,11 @@ class RandomDataset(BenchmarkDataset):
|
||||
num_special_tokens = tokenizer.num_special_tokens_to_add()
|
||||
real_input_len = input_len - num_special_tokens
|
||||
|
||||
prefix_token_ids = (np.random.randint(
|
||||
0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else [])
|
||||
prefix_token_ids = (
|
||||
np.random.randint(0, vocab_size, size=prefix_len).tolist()
|
||||
if prefix_len > 0
|
||||
else []
|
||||
)
|
||||
|
||||
# New sampling logic: [X * (1 - b), X * (1 + b)]
|
||||
input_low = int(real_input_len * (1 - range_ratio))
|
||||
@ -329,21 +329,17 @@ class RandomDataset(BenchmarkDataset):
|
||||
|
||||
# Add logging for debugging
|
||||
logger.info("Sampling input_len from [%s, %s]", input_low, input_high)
|
||||
logger.info("Sampling output_len from [%s, %s]", output_low,
|
||||
output_high)
|
||||
logger.info("Sampling output_len from [%s, %s]", output_low, output_high)
|
||||
|
||||
input_lens = np.random.randint(input_low,
|
||||
input_high + 1,
|
||||
size=num_requests)
|
||||
output_lens = np.random.randint(output_low,
|
||||
output_high + 1,
|
||||
size=num_requests)
|
||||
input_lens = np.random.randint(input_low, input_high + 1, size=num_requests)
|
||||
output_lens = np.random.randint(output_low, output_high + 1, size=num_requests)
|
||||
offsets = np.random.randint(0, vocab_size, size=num_requests)
|
||||
|
||||
requests = []
|
||||
for i in range(num_requests):
|
||||
inner_seq = ((offsets[i] + i + np.arange(input_lens[i])) %
|
||||
vocab_size).tolist()
|
||||
inner_seq = (
|
||||
(offsets[i] + i + np.arange(input_lens[i])) % vocab_size
|
||||
).tolist()
|
||||
token_sequence = prefix_token_ids + inner_seq
|
||||
prompt = tokenizer.decode(token_sequence)
|
||||
# After decoding the prompt we have to encode and decode it again.
|
||||
@ -354,8 +350,9 @@ class RandomDataset(BenchmarkDataset):
|
||||
# [1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
|
||||
# To avoid uncontrolled change of the prompt length,
|
||||
# the encoded sequence is truncated before being decode again.
|
||||
re_encoded_sequence = tokenizer.encode(
|
||||
prompt, add_special_tokens=False)[:input_lens[i]]
|
||||
re_encoded_sequence = tokenizer.encode(prompt, add_special_tokens=False)[
|
||||
: input_lens[i]
|
||||
]
|
||||
prompt = tokenizer.decode(re_encoded_sequence)
|
||||
total_input_len = prefix_len + int(input_lens[i])
|
||||
requests.append(
|
||||
@ -363,7 +360,8 @@ class RandomDataset(BenchmarkDataset):
|
||||
prompt=prompt,
|
||||
prompt_len=total_input_len,
|
||||
expected_output_len=int(output_lens[i]),
|
||||
))
|
||||
)
|
||||
)
|
||||
return requests
|
||||
|
||||
|
||||
@ -390,7 +388,8 @@ class ShareGPTDataset(BenchmarkDataset):
|
||||
self.data = json.load(f)
|
||||
# Filter entries with at least two conversation turns.
|
||||
self.data = [
|
||||
entry for entry in self.data
|
||||
entry
|
||||
for entry in self.data
|
||||
if "conversations" in entry and len(entry["conversations"]) >= 2
|
||||
]
|
||||
random.seed(self.random_seed)
|
||||
@ -416,27 +415,28 @@ class ShareGPTDataset(BenchmarkDataset):
|
||||
)
|
||||
|
||||
lora_request, tokenizer = self.get_random_lora_request(
|
||||
tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path)
|
||||
tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path
|
||||
)
|
||||
prompt_ids = tokenizer(prompt).input_ids
|
||||
completion_ids = tokenizer(completion).input_ids
|
||||
prompt_len = len(prompt_ids)
|
||||
new_output_len = (len(completion_ids)
|
||||
if output_len is None else output_len)
|
||||
if not is_valid_sequence(prompt_len,
|
||||
new_output_len,
|
||||
skip_min_output_len_check=output_len
|
||||
is not None):
|
||||
new_output_len = len(completion_ids) if output_len is None else output_len
|
||||
if not is_valid_sequence(
|
||||
prompt_len,
|
||||
new_output_len,
|
||||
skip_min_output_len_check=output_len is not None,
|
||||
):
|
||||
continue
|
||||
if enable_multimodal_chat:
|
||||
prompt = self.apply_multimodal_chat_transformation(
|
||||
prompt, None)
|
||||
prompt = self.apply_multimodal_chat_transformation(prompt, None)
|
||||
samples.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=new_output_len,
|
||||
lora_request=lora_request,
|
||||
))
|
||||
)
|
||||
)
|
||||
self.maybe_oversample_requests(samples, num_requests)
|
||||
return samples
|
||||
|
||||
@ -482,20 +482,20 @@ class SonnetDataset(BenchmarkDataset):
|
||||
) -> list:
|
||||
# Calculate average token length for a poem line.
|
||||
tokenized_lines = [tokenizer(line).input_ids for line in self.data]
|
||||
avg_len = sum(len(tokens)
|
||||
for tokens in tokenized_lines) / len(tokenized_lines)
|
||||
avg_len = sum(len(tokens) for tokens in tokenized_lines) / len(tokenized_lines)
|
||||
|
||||
# Build the base prompt.
|
||||
base_prompt = "Pick as many lines as you can from these poem lines:\n"
|
||||
base_msg = [{"role": "user", "content": base_prompt}]
|
||||
base_fmt = tokenizer.apply_chat_template(base_msg,
|
||||
add_generation_prompt=True,
|
||||
tokenize=False)
|
||||
base_fmt = tokenizer.apply_chat_template(
|
||||
base_msg, add_generation_prompt=True, tokenize=False
|
||||
)
|
||||
base_offset = len(tokenizer(base_fmt).input_ids)
|
||||
if input_len <= base_offset:
|
||||
raise ValueError(
|
||||
f"'input_len' must be higher than the base prompt length "
|
||||
f"({base_offset}).")
|
||||
f"({base_offset})."
|
||||
)
|
||||
|
||||
# Determine how many poem lines to use.
|
||||
num_input_lines = round((input_len - base_offset) / avg_len)
|
||||
@ -504,21 +504,23 @@ class SonnetDataset(BenchmarkDataset):
|
||||
|
||||
samples = []
|
||||
while len(samples) < num_requests:
|
||||
extra_lines = random.choices(self.data,
|
||||
k=num_input_lines - num_prefix_lines)
|
||||
extra_lines = random.choices(
|
||||
self.data, k=num_input_lines - num_prefix_lines
|
||||
)
|
||||
prompt = f"{base_prompt}{''.join(prefix_lines + extra_lines)}"
|
||||
msg = [{"role": "user", "content": prompt}]
|
||||
prompt_formatted = tokenizer.apply_chat_template(
|
||||
msg, add_generation_prompt=True, tokenize=False)
|
||||
msg, add_generation_prompt=True, tokenize=False
|
||||
)
|
||||
prompt_len = len(tokenizer(prompt_formatted).input_ids)
|
||||
if prompt_len <= input_len:
|
||||
samples.append(
|
||||
SampleRequest(
|
||||
prompt=prompt_formatted
|
||||
if return_prompt_formatted else prompt,
|
||||
prompt=prompt_formatted if return_prompt_formatted else prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
))
|
||||
)
|
||||
)
|
||||
return samples
|
||||
|
||||
|
||||
@ -538,7 +540,9 @@ class BurstGPTDataset(BenchmarkDataset):
|
||||
super().__init__(**kwargs)
|
||||
self.load_data()
|
||||
|
||||
def load_data(self, ):
|
||||
def load_data(
|
||||
self,
|
||||
):
|
||||
if self.dataset_path is None:
|
||||
raise ValueError("dataset_path must be provided for loading data.")
|
||||
|
||||
@ -552,8 +556,7 @@ class BurstGPTDataset(BenchmarkDataset):
|
||||
|
||||
def _sample_loaded_data(self, num_requests: int) -> list:
|
||||
if num_requests <= len(self.data):
|
||||
data = self.data.sample(n=num_requests,
|
||||
random_state=self.random_seed)
|
||||
data = self.data.sample(n=num_requests, random_state=self.random_seed)
|
||||
else:
|
||||
data = self.data.sample(
|
||||
n=num_requests,
|
||||
@ -577,7 +580,8 @@ class BurstGPTDataset(BenchmarkDataset):
|
||||
input_len = int(data[i][2])
|
||||
output_len = int(data[i][3])
|
||||
lora_req, tokenizer = self.get_random_lora_request(
|
||||
tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path)
|
||||
tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path
|
||||
)
|
||||
vocab_size = tokenizer.vocab_size
|
||||
# Generate a synthetic prompt: a list of token IDs computed as (i +
|
||||
# j) modulo vocab_size.
|
||||
@ -589,7 +593,8 @@ class BurstGPTDataset(BenchmarkDataset):
|
||||
prompt_len=input_len,
|
||||
expected_output_len=output_len,
|
||||
lora_request=lora_req,
|
||||
))
|
||||
)
|
||||
)
|
||||
return samples
|
||||
|
||||
|
||||
@ -632,20 +637,23 @@ class HuggingFaceDataset(BenchmarkDataset):
|
||||
|
||||
class ConversationDataset(HuggingFaceDataset):
|
||||
"""Dataset for conversation data with multimodal support."""
|
||||
|
||||
SUPPORTED_DATASET_PATHS = {
|
||||
'lmms-lab/LLaVA-OneVision-Data', 'Aeala/ShareGPT_Vicuna_unfiltered'
|
||||
"lmms-lab/LLaVA-OneVision-Data",
|
||||
"Aeala/ShareGPT_Vicuna_unfiltered",
|
||||
}
|
||||
IS_MULTIMODAL = True
|
||||
|
||||
def sample(self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
**kwargs) -> list:
|
||||
def sample(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
**kwargs,
|
||||
) -> list:
|
||||
# Filter examples with at least 2 conversations
|
||||
filtered_data = self.data.filter(
|
||||
lambda x: len(x["conversations"]) >= 2)
|
||||
filtered_data = self.data.filter(lambda x: len(x["conversations"]) >= 2)
|
||||
sampled_requests = []
|
||||
dynamic_output = output_len is None
|
||||
|
||||
@ -661,24 +669,22 @@ class ConversationDataset(HuggingFaceDataset):
|
||||
completion_len = len(completion_ids)
|
||||
output_len = completion_len if dynamic_output else output_len
|
||||
assert isinstance(output_len, int) and output_len > 0
|
||||
if dynamic_output and not is_valid_sequence(
|
||||
prompt_len, completion_len):
|
||||
if dynamic_output and not is_valid_sequence(prompt_len, completion_len):
|
||||
continue
|
||||
mm_content = process_image(
|
||||
item["image"]) if "image" in item else None
|
||||
mm_content = process_image(item["image"]) if "image" in item else None
|
||||
if enable_multimodal_chat:
|
||||
# Note: when chat is enabled the request prompt_len is no longer
|
||||
# accurate and we will be using request output to count the
|
||||
# actual prompt len and output len
|
||||
prompt = self.apply_multimodal_chat_transformation(
|
||||
prompt, mm_content)
|
||||
prompt = self.apply_multimodal_chat_transformation(prompt, mm_content)
|
||||
sampled_requests.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
multi_modal_data=mm_content,
|
||||
))
|
||||
)
|
||||
)
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
return sampled_requests
|
||||
|
||||
@ -695,10 +701,8 @@ class VisionArenaDataset(HuggingFaceDataset):
|
||||
|
||||
DEFAULT_OUTPUT_LEN = 128
|
||||
SUPPORTED_DATASET_PATHS = {
|
||||
"lmarena-ai/VisionArena-Chat":
|
||||
lambda x: x["conversation"][0][0]["content"],
|
||||
"lmarena-ai/vision-arena-bench-v0.1":
|
||||
lambda x: x["turns"][0][0]["content"]
|
||||
"lmarena-ai/VisionArena-Chat": lambda x: x["conversation"][0][0]["content"],
|
||||
"lmarena-ai/vision-arena-bench-v0.1": lambda x: x["turns"][0][0]["content"],
|
||||
}
|
||||
IS_MULTIMODAL = True
|
||||
|
||||
@ -710,16 +714,14 @@ class VisionArenaDataset(HuggingFaceDataset):
|
||||
enable_multimodal_chat: bool = False,
|
||||
**kwargs,
|
||||
) -> list:
|
||||
output_len = (output_len
|
||||
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
|
||||
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
|
||||
sampled_requests = []
|
||||
for item in self.data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path)
|
||||
if parser_fn is None:
|
||||
raise ValueError(
|
||||
f"Unsupported dataset path: {self.dataset_path}")
|
||||
raise ValueError(f"Unsupported dataset path: {self.dataset_path}")
|
||||
prompt = parser_fn(item)
|
||||
mm_content = process_image(item["images"][0])
|
||||
prompt_len = len(tokenizer(prompt).input_ids)
|
||||
@ -727,15 +729,15 @@ class VisionArenaDataset(HuggingFaceDataset):
|
||||
# Note: when chat is enabled the request prompt_len is no longer
|
||||
# accurate and we will be using request output to count the
|
||||
# actual prompt len
|
||||
prompt = self.apply_multimodal_chat_transformation(
|
||||
prompt, mm_content)
|
||||
prompt = self.apply_multimodal_chat_transformation(prompt, mm_content)
|
||||
sampled_requests.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
multi_modal_data=mm_content,
|
||||
))
|
||||
)
|
||||
)
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
return sampled_requests
|
||||
|
||||
@ -760,14 +762,15 @@ class InstructCoderDataset(HuggingFaceDataset):
|
||||
"likaixin/InstructCoder",
|
||||
}
|
||||
|
||||
def sample(self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
**kwargs) -> list:
|
||||
output_len = (output_len
|
||||
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
|
||||
def sample(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
**kwargs,
|
||||
) -> list:
|
||||
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
|
||||
sampled_requests = []
|
||||
for item in self.data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
@ -779,7 +782,8 @@ class InstructCoderDataset(HuggingFaceDataset):
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
))
|
||||
)
|
||||
)
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
return sampled_requests
|
||||
|
||||
@ -794,38 +798,38 @@ class MTBenchDataset(HuggingFaceDataset):
|
||||
MT-Bench Dataset.
|
||||
https://huggingface.co/datasets/philschmid/mt-bench
|
||||
|
||||
We create a single turn dataset for MT-Bench.
|
||||
We create a single turn dataset for MT-Bench.
|
||||
This is similar to Spec decoding benchmark setup in vLLM
|
||||
https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18
|
||||
""" # noqa: E501
|
||||
""" # noqa: E501
|
||||
|
||||
DEFAULT_OUTPUT_LEN = 256 # avg len used in SD bench in vLLM
|
||||
SUPPORTED_DATASET_PATHS = {
|
||||
"philschmid/mt-bench",
|
||||
}
|
||||
|
||||
def sample(self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
**kwargs) -> list:
|
||||
output_len = (output_len
|
||||
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
|
||||
def sample(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
**kwargs,
|
||||
) -> list:
|
||||
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
|
||||
sampled_requests = []
|
||||
|
||||
for item in self.data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
prompt = item['turns'][0]
|
||||
prompt = item["turns"][0]
|
||||
|
||||
# apply template
|
||||
prompt = tokenizer.apply_chat_template([{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}],
|
||||
add_generation_prompt=True,
|
||||
tokenize=False)
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt}],
|
||||
add_generation_prompt=True,
|
||||
tokenize=False,
|
||||
)
|
||||
|
||||
prompt_len = len(tokenizer(prompt).input_ids)
|
||||
sampled_requests.append(
|
||||
@ -833,7 +837,8 @@ class MTBenchDataset(HuggingFaceDataset):
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
))
|
||||
)
|
||||
)
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
return sampled_requests
|
||||
|
||||
@ -847,23 +852,27 @@ class AIMODataset(HuggingFaceDataset):
|
||||
"""
|
||||
Dataset class for processing a AIMO dataset with reasoning questions.
|
||||
"""
|
||||
|
||||
SUPPORTED_DATASET_PATHS = {
|
||||
"AI-MO/aimo-validation-aime", "AI-MO/NuminaMath-1.5",
|
||||
"AI-MO/NuminaMath-CoT"
|
||||
"AI-MO/aimo-validation-aime",
|
||||
"AI-MO/NuminaMath-1.5",
|
||||
"AI-MO/NuminaMath-CoT",
|
||||
}
|
||||
|
||||
def sample(self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
**kwargs) -> list:
|
||||
def sample(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> list:
|
||||
sampled_requests = []
|
||||
dynamic_output = output_len is None
|
||||
|
||||
for item in self.data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
prompt, completion = item['problem'], item["solution"]
|
||||
prompt, completion = item["problem"], item["solution"]
|
||||
|
||||
prompt_ids = tokenizer(prompt).input_ids
|
||||
completion_ids = tokenizer(completion).input_ids
|
||||
@ -871,10 +880,9 @@ class AIMODataset(HuggingFaceDataset):
|
||||
completion_len = len(completion_ids)
|
||||
output_len = completion_len if dynamic_output else output_len
|
||||
assert isinstance(output_len, int) and output_len > 0
|
||||
if dynamic_output and not is_valid_sequence(prompt_len,
|
||||
completion_len,
|
||||
max_prompt_len=2048,
|
||||
max_total_len=32000):
|
||||
if dynamic_output and not is_valid_sequence(
|
||||
prompt_len, completion_len, max_prompt_len=2048, max_total_len=32000
|
||||
):
|
||||
continue
|
||||
sampled_requests.append(
|
||||
SampleRequest(
|
||||
@ -882,7 +890,8 @@ class AIMODataset(HuggingFaceDataset):
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
multi_modal_data=None,
|
||||
))
|
||||
)
|
||||
)
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
return sampled_requests
|
||||
|
||||
@ -905,25 +914,25 @@ You are a code completion assistant and your task is to analyze user edits and t
|
||||
|
||||
### Response:
|
||||
|
||||
""" # noqa: E501
|
||||
""" # noqa: E501
|
||||
|
||||
|
||||
def _format_zeta_prompt(
|
||||
sample: dict,
|
||||
original_start_marker: str = "<|editable_region_start|>") -> dict:
|
||||
sample: dict, original_start_marker: str = "<|editable_region_start|>"
|
||||
) -> dict:
|
||||
"""Format the zeta prompt for the Next Edit Prediction (NEP) dataset.
|
||||
|
||||
This function formats examples from the NEP dataset
|
||||
into prompts and expected outputs. It could be
|
||||
|
||||
This function formats examples from the NEP dataset
|
||||
into prompts and expected outputs. It could be
|
||||
further extended to support more NEP datasets.
|
||||
|
||||
|
||||
Args:
|
||||
sample: The dataset sample containing events,
|
||||
sample: The dataset sample containing events,
|
||||
inputs, and outputs.
|
||||
original_start_marker: The marker indicating the
|
||||
start of the editable region. Defaults to
|
||||
original_start_marker: The marker indicating the
|
||||
start of the editable region. Defaults to
|
||||
"<|editable_region_start|>".
|
||||
|
||||
|
||||
Returns:
|
||||
A dictionary with the formatted prompts and expected outputs.
|
||||
"""
|
||||
@ -953,10 +962,8 @@ class NextEditPredictionDataset(HuggingFaceDataset):
|
||||
"zed-industries/zeta": _format_zeta_prompt,
|
||||
}
|
||||
|
||||
def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int,
|
||||
**kwargs):
|
||||
formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(
|
||||
self.dataset_path)
|
||||
def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int, **kwargs):
|
||||
formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(self.dataset_path)
|
||||
if formatting_prompt_func is None:
|
||||
raise ValueError(f"Unsupported dataset path: {self.dataset_path}")
|
||||
samples = []
|
||||
@ -967,8 +974,10 @@ class NextEditPredictionDataset(HuggingFaceDataset):
|
||||
prompt=sample["prompt"],
|
||||
prompt_len=len(tokenizer(sample["prompt"]).input_ids),
|
||||
expected_output_len=len(
|
||||
tokenizer(sample["expected_output"]).input_ids),
|
||||
))
|
||||
tokenizer(sample["expected_output"]).input_ids
|
||||
),
|
||||
)
|
||||
)
|
||||
if len(samples) >= num_requests:
|
||||
break
|
||||
self.maybe_oversample_requests(samples, num_requests)
|
||||
@ -997,18 +1006,22 @@ class ASRDataset(HuggingFaceDataset):
|
||||
| AMI | Meetings | Spontaneous | ihm, sdm |
|
||||
+----------------+----------------------------------------+--------------------------+-----------------------------+
|
||||
|
||||
""" # noqa: E501
|
||||
""" # noqa: E501
|
||||
|
||||
SUPPORTED_DATASET_PATHS = {
|
||||
"openslr/librispeech_asr", "facebook/voxpopuli", "LIUM/tedlium",
|
||||
"edinburghcstr/ami", "speechcolab/gigaspeech", "kensho/spgispeech"
|
||||
"openslr/librispeech_asr",
|
||||
"facebook/voxpopuli",
|
||||
"LIUM/tedlium",
|
||||
"edinburghcstr/ami",
|
||||
"speechcolab/gigaspeech",
|
||||
"kensho/spgispeech",
|
||||
}
|
||||
|
||||
DEFAULT_OUTPUT_LEN = 128
|
||||
IS_MULTIMODAL = True
|
||||
|
||||
# TODO Whisper-specific. Abstract interface when more models are supported.
|
||||
TRANSCRIPTION_PREAMBLE = "<|startoftranscript|><|en|><|transcribe|>"\
|
||||
"<|notimestamps|>"
|
||||
TRANSCRIPTION_PREAMBLE = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>"
|
||||
skip_long_audios: bool = True
|
||||
|
||||
def sample(
|
||||
@ -1019,8 +1032,8 @@ class ASRDataset(HuggingFaceDataset):
|
||||
**kwargs,
|
||||
) -> list:
|
||||
import librosa
|
||||
output_len = (output_len
|
||||
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
|
||||
|
||||
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
|
||||
prompt = ASRDataset.TRANSCRIPTION_PREAMBLE
|
||||
prompt_len = len(tokenizer(prompt).input_ids)
|
||||
sampled_requests = []
|
||||
@ -1043,10 +1056,14 @@ class ASRDataset(HuggingFaceDataset):
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
multi_modal_data=mm_content,
|
||||
))
|
||||
)
|
||||
)
|
||||
if skipped:
|
||||
logger.warning("%d samples discarded from dataset due to" \
|
||||
" their length being greater than" \
|
||||
" what Whisper supports.", skipped)
|
||||
logger.warning(
|
||||
"%d samples discarded from dataset due to"
|
||||
" their length being greater than"
|
||||
" what Whisper supports.",
|
||||
skipped,
|
||||
)
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
return sampled_requests
|
||||
|
||||
@ -11,9 +11,9 @@ from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||
from tqdm import tqdm
|
||||
|
||||
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.inputs import PromptType
|
||||
@ -21,13 +21,14 @@ from vllm.sampling_params import BeamSearchParams
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def save_to_pytorch_benchmark_format(args: argparse.Namespace,
|
||||
results: dict[str, Any]) -> None:
|
||||
def save_to_pytorch_benchmark_format(
|
||||
args: argparse.Namespace, results: dict[str, Any]
|
||||
) -> None:
|
||||
pt_records = convert_to_pytorch_benchmark_format(
|
||||
args=args,
|
||||
metrics={"latency": results["latencies"]},
|
||||
extra_info={k: results[k]
|
||||
for k in ["avg_latency", "percentiles"]})
|
||||
extra_info={k: results[k] for k in ["avg_latency", "percentiles"]},
|
||||
)
|
||||
if pt_records:
|
||||
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
|
||||
write_to_json(pt_file, pt_records)
|
||||
@ -42,9 +43,11 @@ def main(args: argparse.Namespace):
|
||||
# the engine will automatically process the request in multiple batches.
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
assert llm.llm_engine.model_config.max_model_len >= (
|
||||
args.input_len +
|
||||
args.output_len), ("Please ensure that max_model_len is greater than"
|
||||
" the sum of input_len and output_len.")
|
||||
args.input_len + args.output_len
|
||||
), (
|
||||
"Please ensure that max_model_len is greater than"
|
||||
" the sum of input_len and output_len."
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
n=args.n,
|
||||
@ -55,18 +58,16 @@ def main(args: argparse.Namespace):
|
||||
detokenize=not args.disable_detokenize,
|
||||
)
|
||||
print(sampling_params)
|
||||
dummy_prompt_token_ids = np.random.randint(10000,
|
||||
size=(args.batch_size,
|
||||
args.input_len))
|
||||
dummy_prompts: list[PromptType] = [{
|
||||
"prompt_token_ids": batch
|
||||
} for batch in dummy_prompt_token_ids.tolist()]
|
||||
dummy_prompt_token_ids = np.random.randint(
|
||||
10000, size=(args.batch_size, args.input_len)
|
||||
)
|
||||
dummy_prompts: list[PromptType] = [
|
||||
{"prompt_token_ids": batch} for batch in dummy_prompt_token_ids.tolist()
|
||||
]
|
||||
|
||||
def llm_generate():
|
||||
if not args.use_beam_search:
|
||||
llm.generate(dummy_prompts,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=False)
|
||||
llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False)
|
||||
else:
|
||||
llm.beam_search(
|
||||
dummy_prompts,
|
||||
@ -80,12 +81,13 @@ def main(args: argparse.Namespace):
|
||||
def run_to_completion(profile_dir: Optional[str] = None):
|
||||
if profile_dir:
|
||||
with torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
str(profile_dir)),
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
str(profile_dir)
|
||||
),
|
||||
) as p:
|
||||
llm_generate()
|
||||
print(p.key_averages().table(sort_by="self_cuda_time_total"))
|
||||
@ -103,8 +105,9 @@ def main(args: argparse.Namespace):
|
||||
if args.profile:
|
||||
profile_dir = args.profile_result_dir
|
||||
if not profile_dir:
|
||||
profile_dir = (Path(".") / "vllm_benchmark_result" /
|
||||
f"latency_result_{time.time()}")
|
||||
profile_dir = (
|
||||
Path(".") / "vllm_benchmark_result" / f"latency_result_{time.time()}"
|
||||
)
|
||||
print(f"Profiling (results will be saved to '{profile_dir}')...")
|
||||
run_to_completion(profile_dir=profile_dir)
|
||||
return
|
||||
@ -135,7 +138,8 @@ def main(args: argparse.Namespace):
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark the latency of processing a single batch of "
|
||||
"requests till completion.")
|
||||
"requests till completion."
|
||||
)
|
||||
parser.add_argument("--input-len", type=int, default=32)
|
||||
parser.add_argument("--output-len", type=int, default=128)
|
||||
parser.add_argument("--batch-size", type=int, default=8)
|
||||
@ -152,10 +156,9 @@ if __name__ == "__main__":
|
||||
default=10,
|
||||
help="Number of iterations to run for warmup.",
|
||||
)
|
||||
parser.add_argument("--num-iters",
|
||||
type=int,
|
||||
default=30,
|
||||
help="Number of iterations to run.")
|
||||
parser.add_argument(
|
||||
"--num-iters", type=int, default=30, help="Number of iterations to run."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
action="store_true",
|
||||
@ -165,8 +168,10 @@ if __name__ == "__main__":
|
||||
"--profile-result-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help=("path to save the pytorch profiler output. Can be visualized "
|
||||
"with ui.perfetto.dev or Tensorboard."),
|
||||
help=(
|
||||
"path to save the pytorch profiler output. Can be visualized "
|
||||
"with ui.perfetto.dev or Tensorboard."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-json",
|
||||
@ -177,8 +182,10 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--disable-detokenize",
|
||||
action="store_true",
|
||||
help=("Do not detokenize responses (i.e. do not include "
|
||||
"detokenization time in the latency measurement)"),
|
||||
help=(
|
||||
"Do not detokenize responses (i.e. do not include "
|
||||
"detokenization time in the latency measurement)"
|
||||
),
|
||||
)
|
||||
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
|
||||
@ -76,7 +76,7 @@ def repeat_prompts(prompts, repeat_count, mode: str):
|
||||
- 'random': Shuffle the prompts randomly after repetition.
|
||||
- 'tile': Repeat the entire prompt list in sequence.
|
||||
Example: [1, 2, 3] -> [1, 2, 3, 1, 2, 3].
|
||||
- 'interleave': Repeat each prompt consecutively before moving to
|
||||
- 'interleave': Repeat each prompt consecutively before moving to
|
||||
the next. Example: [1, 2, 3] -> [1, 1, 2, 2, 3, 3].
|
||||
|
||||
Returns:
|
||||
@ -86,20 +86,21 @@ def repeat_prompts(prompts, repeat_count, mode: str):
|
||||
ValueError: If an invalid mode is provided.
|
||||
"""
|
||||
print("Repeat mode: ", mode)
|
||||
if mode == 'random':
|
||||
if mode == "random":
|
||||
repeated_prompts = prompts * repeat_count
|
||||
random.shuffle(repeated_prompts)
|
||||
return repeated_prompts
|
||||
elif mode == 'tile':
|
||||
elif mode == "tile":
|
||||
return prompts * repeat_count
|
||||
elif mode == 'interleave':
|
||||
elif mode == "interleave":
|
||||
repeated_prompts = []
|
||||
for prompt in prompts:
|
||||
repeated_prompts.extend([prompt] * repeat_count)
|
||||
return repeated_prompts
|
||||
else:
|
||||
raise ValueError(f"Invalid mode: {mode}, only support "
|
||||
"'random', 'tile', 'interleave'")
|
||||
raise ValueError(
|
||||
f"Invalid mode: {mode}, only support 'random', 'tile', 'interleave'"
|
||||
)
|
||||
|
||||
|
||||
def main(args):
|
||||
@ -109,16 +110,16 @@ def main(args):
|
||||
# we append the document id at the beginning to avoid any of the document
|
||||
# being the prefix of other documents
|
||||
prompts = [
|
||||
str(i) + ' '.join(['hi'] * args.document_length)
|
||||
str(i) + " ".join(["hi"] * args.document_length)
|
||||
for i in range(args.num_documents)
|
||||
]
|
||||
|
||||
prompts = repeat_prompts(prompts, args.repeat_count, mode=args.repeat_mode)
|
||||
|
||||
warmup_prompts = [
|
||||
"This is warm up request " + str(i) + \
|
||||
' '.join(['hi'] * args.document_length)
|
||||
for i in range(args.num_documents)]
|
||||
"This is warm up request " + str(i) + " ".join(["hi"] * args.document_length)
|
||||
for i in range(args.num_documents)
|
||||
]
|
||||
|
||||
# Create the LLM engine
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
@ -142,42 +143,52 @@ def main(args):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description=
|
||||
'Benchmark the performance with or without automatic prefix caching.')
|
||||
description="Benchmark the performance with or "
|
||||
"without automatic prefix caching."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--document-length',
|
||||
"--document-length",
|
||||
type=int,
|
||||
# Roughly the number of tokens for a system paper,
|
||||
# excluding images
|
||||
default=20000,
|
||||
help='Range of input lengths for sampling prompts,'
|
||||
'specified as "min:max" (e.g., "128:256").')
|
||||
help="Range of input lengths for sampling prompts, "
|
||||
'specified as "min:max" (e.g., "128:256").',
|
||||
)
|
||||
|
||||
parser.add_argument('--num-documents',
|
||||
type=int,
|
||||
default=8,
|
||||
help='Range of input lengths for sampling prompts,'
|
||||
'specified as "min:max" (e.g., "128:256").')
|
||||
parser.add_argument(
|
||||
"--num-documents",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Range of input lengths for sampling prompts, "
|
||||
'specified as "min:max" (e.g., "128:256").',
|
||||
)
|
||||
|
||||
parser.add_argument('--output-len', type=int, default=10)
|
||||
parser.add_argument("--output-len", type=int, default=10)
|
||||
|
||||
parser.add_argument('--repeat-count',
|
||||
type=int,
|
||||
default=2,
|
||||
help='Number of times to repeat each prompt')
|
||||
parser.add_argument(
|
||||
"--repeat-count",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Number of times to repeat each prompt",
|
||||
)
|
||||
|
||||
parser.add_argument("--repeat-mode",
|
||||
type=str,
|
||||
default='random',
|
||||
help='The mode to repeat prompts. The supported '
|
||||
'modes are "random", "tile", and "interleave". '
|
||||
'See repeat_prompts() in the source code for details.')
|
||||
parser.add_argument(
|
||||
"--repeat-mode",
|
||||
type=str,
|
||||
default="random",
|
||||
help="The mode to repeat prompts. The supported "
|
||||
'modes are "random", "tile", and "interleave". '
|
||||
"See repeat_prompts() in the source code for details.",
|
||||
)
|
||||
|
||||
parser.add_argument("--shuffle-seed",
|
||||
type=int,
|
||||
default=0,
|
||||
help='Random seed when the repeat mode is "random"')
|
||||
parser.add_argument(
|
||||
"--shuffle-seed",
|
||||
type=int,
|
||||
default=0,
|
||||
help='Random seed when the repeat mode is "random"',
|
||||
)
|
||||
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -63,8 +63,7 @@ class Request:
|
||||
output_len: int
|
||||
|
||||
|
||||
def sample_tokens(tokenizer: PreTrainedTokenizerBase,
|
||||
length: int) -> list[int]:
|
||||
def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> list[int]:
|
||||
vocab = tokenizer.get_vocab()
|
||||
all_special_ids = set(tokenizer.all_special_ids)
|
||||
|
||||
@ -91,8 +90,10 @@ def sample_requests_from_dataset(
|
||||
# Filter out the conversations with less than 2 turns.
|
||||
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
||||
# Only keep the first two turns of each conversation.
|
||||
dataset = [(data["conversations"][0]["value"],
|
||||
data["conversations"][1]["value"]) for data in dataset]
|
||||
dataset = [
|
||||
(data["conversations"][0]["value"], data["conversations"][1]["value"])
|
||||
for data in dataset
|
||||
]
|
||||
|
||||
# Shuffle the dataset.
|
||||
random.shuffle(dataset)
|
||||
@ -113,8 +114,9 @@ def sample_requests_from_dataset(
|
||||
completion = dataset[i][1]
|
||||
completion_token_ids = tokenizer(completion).input_ids
|
||||
prompt_len = len(prompt_token_ids)
|
||||
output_len = (len(completion_token_ids)
|
||||
if fixed_output_len is None else fixed_output_len)
|
||||
output_len = (
|
||||
len(completion_token_ids) if fixed_output_len is None else fixed_output_len
|
||||
)
|
||||
if min_len <= prompt_len <= max_len:
|
||||
filtered_requests.append(Request(prompt, prompt_len, output_len))
|
||||
|
||||
@ -128,27 +130,27 @@ def sample_requests_from_random(
|
||||
fixed_output_len: Optional[int],
|
||||
prefix_len: int,
|
||||
) -> list[Request]:
|
||||
|
||||
requests = []
|
||||
prefix_token_ids = sample_tokens(tokenizer, prefix_len)
|
||||
min_len, max_len = input_length_range
|
||||
|
||||
for i in range(num_requests):
|
||||
unique_part_token_ids = sample_tokens(
|
||||
tokenizer,
|
||||
random.randint(min_len - prefix_len, max_len - prefix_len))
|
||||
tokenizer, random.randint(min_len - prefix_len, max_len - prefix_len)
|
||||
)
|
||||
prompt_token_ids = prefix_token_ids + unique_part_token_ids
|
||||
prompt = tokenizer.decode(prompt_token_ids)
|
||||
prompt_len = len(prompt_token_ids)
|
||||
assert (min_len <= prompt_len <= max_len
|
||||
), f"prompt_len {prompt_len} out of range {min_len}:{max_len}"
|
||||
assert min_len <= prompt_len <= max_len, (
|
||||
f"prompt_len {prompt_len} out of range {min_len}:{max_len}"
|
||||
)
|
||||
requests.append(Request(prompt, prompt_len, fixed_output_len))
|
||||
return requests
|
||||
|
||||
|
||||
def repeat_and_sort_requests(requests: list[Request],
|
||||
repeat_count: int,
|
||||
sort: bool = False) -> list[str]:
|
||||
def repeat_and_sort_requests(
|
||||
requests: list[Request], repeat_count: int, sort: bool = False
|
||||
) -> list[str]:
|
||||
repeated_requests = requests * repeat_count
|
||||
if sort:
|
||||
repeated_requests.sort(key=lambda x: x[1])
|
||||
@ -159,14 +161,14 @@ def repeat_and_sort_requests(requests: list[Request],
|
||||
|
||||
def main(args):
|
||||
tokenizer = get_tokenizer(args.model, trust_remote_code=True)
|
||||
input_length_range = tuple(map(int, args.input_length_range.split(':')))
|
||||
input_length_range = tuple(map(int, args.input_length_range.split(":")))
|
||||
random.seed(args.seed)
|
||||
if args.dataset_path is not None:
|
||||
if args.prefix_len > 0:
|
||||
raise ValueError("prefix-len is not supported when "
|
||||
"dataset-path is provided.")
|
||||
print(f"Start to sample {args.num_prompts} prompts "
|
||||
f"from {args.dataset_path}")
|
||||
raise ValueError(
|
||||
"prefix-len is not supported when dataset-path is provided."
|
||||
)
|
||||
print(f"Start to sample {args.num_prompts} prompts from {args.dataset_path}")
|
||||
filtered_requests = sample_requests_from_dataset(
|
||||
dataset_path=args.dataset_path,
|
||||
num_requests=args.num_prompts,
|
||||
@ -196,14 +198,16 @@ def main(args):
|
||||
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
|
||||
sampling_params = SamplingParams(temperature=0,
|
||||
max_tokens=args.output_len,
|
||||
detokenize=not args.disable_detokenize)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
max_tokens=args.output_len,
|
||||
detokenize=not args.disable_detokenize,
|
||||
)
|
||||
|
||||
print("Testing filtered requests")
|
||||
prompts = repeat_and_sort_requests(filtered_requests,
|
||||
repeat_count=args.repeat_count,
|
||||
sort=args.sort)
|
||||
prompts = repeat_and_sort_requests(
|
||||
filtered_requests, repeat_count=args.repeat_count, sort=args.sort
|
||||
)
|
||||
|
||||
print("------start generating------")
|
||||
test_prefix(
|
||||
@ -215,29 +219,35 @@ def main(args):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description=
|
||||
'Benchmark the performance with or without automatic prefix caching.')
|
||||
parser.add_argument("--dataset-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the dataset.")
|
||||
parser.add_argument('--output-len', type=int, default=10)
|
||||
parser.add_argument('--num-prompts',
|
||||
type=int,
|
||||
required=True,
|
||||
help="Number of the prompts sampled from dataset")
|
||||
parser.add_argument('--repeat-count',
|
||||
type=int,
|
||||
default=1,
|
||||
help='Number of times to repeat each prompt')
|
||||
parser.add_argument('--sort',
|
||||
action='store_true',
|
||||
help='Sort prompts by input length')
|
||||
parser.add_argument('--input-length-range',
|
||||
type=str,
|
||||
required=True,
|
||||
help='Range of input lengths for sampling prompts,'
|
||||
'specified as "min:max" (e.g., "128:256").')
|
||||
description="Benchmark the performance with or without "
|
||||
"automatic prefix caching."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-path", type=str, default=None, help="Path to the dataset."
|
||||
)
|
||||
parser.add_argument("--output-len", type=int, default=10)
|
||||
parser.add_argument(
|
||||
"--num-prompts",
|
||||
type=int,
|
||||
required=True,
|
||||
help="Number of the prompts sampled from dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repeat-count",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of times to repeat each prompt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sort", action="store_true", help="Sort prompts by input length"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-length-range",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Range of input lengths for sampling prompts,"
|
||||
'specified as "min:max" (e.g., "128:256").',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefix-len",
|
||||
type=int,
|
||||
@ -248,10 +258,12 @@ if __name__ == "__main__":
|
||||
"when dataset-path is not provided.",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--disable-detokenize',
|
||||
action='store_true',
|
||||
help=("Do not detokenize responses (i.e. do not include "
|
||||
"detokenization time in the latency measurement)"),
|
||||
"--disable-detokenize",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Do not detokenize responses (i.e. do not include "
|
||||
"detokenization time in the latency measurement)"
|
||||
),
|
||||
)
|
||||
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Benchmark offline prioritization."""
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
@ -13,7 +14,7 @@ from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
#Select a equi-probable random priority
|
||||
# Select a equi-probable random priority
|
||||
def get_random_flag():
|
||||
return 0 if random.random() < 0.5 else 1
|
||||
|
||||
@ -33,8 +34,10 @@ def sample_requests(
|
||||
# Filter out the conversations with less than 2 turns.
|
||||
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
||||
# Only keep the first two turns of each conversation.
|
||||
dataset = [(data["conversations"][0]["value"],
|
||||
data["conversations"][1]["value"]) for data in dataset]
|
||||
dataset = [
|
||||
(data["conversations"][0]["value"], data["conversations"][1]["value"])
|
||||
for data in dataset
|
||||
]
|
||||
|
||||
# Shuffle the dataset.
|
||||
random.shuffle(dataset)
|
||||
@ -51,8 +54,9 @@ def sample_requests(
|
||||
completion = dataset[i][1]
|
||||
completion_token_ids = tokenizer(completion).input_ids
|
||||
prompt_len = len(prompt_token_ids)
|
||||
output_len = len(completion_token_ids
|
||||
) if fixed_output_len is None else fixed_output_len
|
||||
output_len = (
|
||||
len(completion_token_ids) if fixed_output_len is None else fixed_output_len
|
||||
)
|
||||
if prompt_len < 4 or output_len < 4:
|
||||
# Prune too short sequences.
|
||||
continue
|
||||
@ -74,13 +78,16 @@ def run_vllm(
|
||||
disable_detokenize: bool = False,
|
||||
) -> float:
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
|
||||
assert all(
|
||||
llm.llm_engine.model_config.max_model_len >= (request[1] + request[2])
|
||||
for request in requests), (
|
||||
"Please ensure that max_model_len is greater than the sum of"
|
||||
" input_len and output_len for all requests.")
|
||||
for request in requests
|
||||
), (
|
||||
"Please ensure that max_model_len is greater than the sum of"
|
||||
" input_len and output_len for all requests."
|
||||
)
|
||||
|
||||
# Add the requests to the engine.
|
||||
prompts = []
|
||||
@ -97,7 +104,8 @@ def run_vllm(
|
||||
ignore_eos=True,
|
||||
max_tokens=output_len,
|
||||
detokenize=not disable_detokenize,
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
start = time.perf_counter()
|
||||
llm.generate(prompts, sampling_params, priority=priority, use_tqdm=True)
|
||||
@ -111,26 +119,33 @@ def main(args: argparse.Namespace):
|
||||
|
||||
# Sample the requests.
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer, trust_remote_code=args.trust_remote_code)
|
||||
args.tokenizer, trust_remote_code=args.trust_remote_code
|
||||
)
|
||||
if args.dataset is None:
|
||||
# Synthesize a prompt with the given input length.
|
||||
prompt = "hi" * (args.input_len - 1)
|
||||
requests = [(prompt, args.input_len, args.output_len,
|
||||
get_random_flag()) for _ in range(args.num_prompts)]
|
||||
requests = [
|
||||
(prompt, args.input_len, args.output_len, get_random_flag())
|
||||
for _ in range(args.num_prompts)
|
||||
]
|
||||
else:
|
||||
requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
|
||||
args.output_len)
|
||||
requests = sample_requests(
|
||||
args.dataset, args.num_prompts, tokenizer, args.output_len
|
||||
)
|
||||
|
||||
if args.backend == "vllm":
|
||||
elapsed_time = run_vllm(requests, args.n,
|
||||
EngineArgs.from_cli_args(args),
|
||||
args.disable_detokenize)
|
||||
elapsed_time = run_vllm(
|
||||
requests, args.n, EngineArgs.from_cli_args(args), args.disable_detokenize
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown backend: {args.backend}")
|
||||
total_num_tokens = sum(prompt_len + output_len
|
||||
for _, prompt_len, output_len, priority in requests)
|
||||
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
||||
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
|
||||
total_num_tokens = sum(
|
||||
prompt_len + output_len for _, prompt_len, output_len, priority in requests
|
||||
)
|
||||
print(
|
||||
f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
||||
f"{total_num_tokens / elapsed_time:.2f} tokens/s"
|
||||
)
|
||||
|
||||
# Output JSON results if specified
|
||||
if args.output_json:
|
||||
@ -147,41 +162,44 @@ def main(args: argparse.Namespace):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(description="Benchmark the throughput.")
|
||||
parser.add_argument("--backend",
|
||||
type=str,
|
||||
choices=["vllm", "hf", "mii"],
|
||||
default="vllm")
|
||||
parser.add_argument("--dataset",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the dataset.")
|
||||
parser.add_argument("--input-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Input prompt length for each request")
|
||||
parser.add_argument("--output-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Output length for each request. Overrides the "
|
||||
"output length from the dataset.")
|
||||
parser.add_argument("--n",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of generated sequences per prompt.")
|
||||
parser.add_argument("--num-prompts",
|
||||
type=int,
|
||||
default=200,
|
||||
help="Number of prompts to process.")
|
||||
parser.add_argument(
|
||||
'--output-json',
|
||||
"--backend", type=str, choices=["vllm", "hf", "mii"], default="vllm"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset", type=str, default=None, help="Path to the dataset."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Input prompt length for each request",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Output length for each request. Overrides the "
|
||||
"output length from the dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n", type=int, default=1, help="Number of generated sequences per prompt."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-prompts", type=int, default=200, help="Number of prompts to process."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-json",
|
||||
type=str,
|
||||
default=None,
|
||||
help='Path to save the throughput results in JSON format.')
|
||||
help="Path to save the throughput results in JSON format.",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--disable-detokenize',
|
||||
action='store_true',
|
||||
help=("Do not detokenize responses (i.e. do not include "
|
||||
"detokenization time in the latency measurement)"),
|
||||
"--disable-detokenize",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Do not detokenize responses (i.e. do not include "
|
||||
"detokenization time in the latency measurement)"
|
||||
),
|
||||
)
|
||||
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
|
||||
@ -20,6 +20,7 @@ On the client side, run:
|
||||
--endpoint /generate_stream
|
||||
to the end of the command above.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import gc
|
||||
@ -34,12 +35,16 @@ from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
from backend_request_func import (ASYNC_REQUEST_FUNCS,
|
||||
OPENAI_COMPATIBLE_BACKENDS, RequestFuncInput,
|
||||
RequestFuncOutput)
|
||||
from tqdm.asyncio import tqdm
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from backend_request_func import (
|
||||
ASYNC_REQUEST_FUNCS,
|
||||
OPENAI_COMPATIBLE_BACKENDS,
|
||||
RequestFuncInput,
|
||||
RequestFuncOutput,
|
||||
)
|
||||
|
||||
try:
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
except ImportError:
|
||||
@ -50,12 +55,21 @@ try:
|
||||
except ImportError:
|
||||
from argparse import ArgumentParser as FlexibleArgumentParser
|
||||
|
||||
from benchmark_dataset import (AIMODataset, ASRDataset, BurstGPTDataset,
|
||||
ConversationDataset, HuggingFaceDataset,
|
||||
InstructCoderDataset, MTBenchDataset,
|
||||
NextEditPredictionDataset, RandomDataset,
|
||||
SampleRequest, ShareGPTDataset, SonnetDataset,
|
||||
VisionArenaDataset)
|
||||
from benchmark_dataset import (
|
||||
AIMODataset,
|
||||
ASRDataset,
|
||||
BurstGPTDataset,
|
||||
ConversationDataset,
|
||||
HuggingFaceDataset,
|
||||
InstructCoderDataset,
|
||||
MTBenchDataset,
|
||||
NextEditPredictionDataset,
|
||||
RandomDataset,
|
||||
SampleRequest,
|
||||
ShareGPTDataset,
|
||||
SonnetDataset,
|
||||
VisionArenaDataset,
|
||||
)
|
||||
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||
|
||||
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
|
||||
@ -118,7 +132,8 @@ async def get_request(
|
||||
|
||||
# Calculate scale parameter theta to maintain the desired request_rate.
|
||||
assert burstiness > 0, (
|
||||
f"A positive burstiness factor is expected, but given {burstiness}.")
|
||||
f"A positive burstiness factor is expected, but given {burstiness}."
|
||||
)
|
||||
theta = 1.0 / (request_rate * burstiness)
|
||||
|
||||
for request in input_requests:
|
||||
@ -164,8 +179,10 @@ def calculate_metrics(
|
||||
# bundled together
|
||||
# Note : this may inflate the output token count slightly
|
||||
output_len = len(
|
||||
tokenizer(outputs[i].generated_text,
|
||||
add_special_tokens=False).input_ids)
|
||||
tokenizer(
|
||||
outputs[i].generated_text, add_special_tokens=False
|
||||
).input_ids
|
||||
)
|
||||
actual_output_lens.append(output_len)
|
||||
total_input += input_requests[i].prompt_len
|
||||
tpot = 0
|
||||
@ -188,16 +205,19 @@ def calculate_metrics(
|
||||
|
||||
if "ttft" in goodput_config_dict:
|
||||
valid_metrics.append(ttfts)
|
||||
slo_values.append(goodput_config_dict["ttft"] /
|
||||
MILLISECONDS_TO_SECONDS_CONVERSION)
|
||||
slo_values.append(
|
||||
goodput_config_dict["ttft"] / MILLISECONDS_TO_SECONDS_CONVERSION
|
||||
)
|
||||
if "tpot" in goodput_config_dict:
|
||||
valid_metrics.append(all_tpots)
|
||||
slo_values.append(goodput_config_dict["tpot"] /
|
||||
MILLISECONDS_TO_SECONDS_CONVERSION)
|
||||
slo_values.append(
|
||||
goodput_config_dict["tpot"] / MILLISECONDS_TO_SECONDS_CONVERSION
|
||||
)
|
||||
if "e2el" in goodput_config_dict:
|
||||
valid_metrics.append(e2els)
|
||||
slo_values.append(goodput_config_dict["e2el"] /
|
||||
MILLISECONDS_TO_SECONDS_CONVERSION)
|
||||
slo_values.append(
|
||||
goodput_config_dict["e2el"] / MILLISECONDS_TO_SECONDS_CONVERSION
|
||||
)
|
||||
|
||||
for req_metric in zip(*valid_metrics):
|
||||
is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)])
|
||||
@ -208,7 +228,8 @@ def calculate_metrics(
|
||||
warnings.warn(
|
||||
"All requests failed. This is likely due to a misconfiguration "
|
||||
"on the benchmark arguments.",
|
||||
stacklevel=2)
|
||||
stacklevel=2,
|
||||
)
|
||||
metrics = BenchmarkMetrics(
|
||||
completed=completed,
|
||||
total_input=total_input,
|
||||
@ -217,27 +238,31 @@ def calculate_metrics(
|
||||
request_goodput=good_completed / dur_s,
|
||||
output_throughput=sum(actual_output_lens) / dur_s,
|
||||
total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s,
|
||||
mean_ttft_ms=np.mean(ttfts or 0) *
|
||||
1000, # ttfts is empty if streaming is not supported by backend
|
||||
mean_ttft_ms=np.mean(ttfts or 0)
|
||||
* 1000, # ttfts is empty if streaming is not supported by backend
|
||||
std_ttft_ms=np.std(ttfts or 0) * 1000,
|
||||
median_ttft_ms=np.median(ttfts or 0) * 1000,
|
||||
percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000)
|
||||
for p in selected_percentiles],
|
||||
percentiles_ttft_ms=[
|
||||
(p, np.percentile(ttfts or 0, p) * 1000) for p in selected_percentiles
|
||||
],
|
||||
mean_tpot_ms=np.mean(tpots or 0) * 1000,
|
||||
std_tpot_ms=np.std(tpots or 0) * 1000,
|
||||
median_tpot_ms=np.median(tpots or 0) * 1000,
|
||||
percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000)
|
||||
for p in selected_percentiles],
|
||||
percentiles_tpot_ms=[
|
||||
(p, np.percentile(tpots or 0, p) * 1000) for p in selected_percentiles
|
||||
],
|
||||
mean_itl_ms=np.mean(itls or 0) * 1000,
|
||||
std_itl_ms=np.std(itls or 0) * 1000,
|
||||
median_itl_ms=np.median(itls or 0) * 1000,
|
||||
percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000)
|
||||
for p in selected_percentiles],
|
||||
percentiles_itl_ms=[
|
||||
(p, np.percentile(itls or 0, p) * 1000) for p in selected_percentiles
|
||||
],
|
||||
mean_e2el_ms=np.mean(e2els or 0) * 1000,
|
||||
std_e2el_ms=np.std(e2els or 0) * 1000,
|
||||
median_e2el_ms=np.median(e2els or 0) * 1000,
|
||||
percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000)
|
||||
for p in selected_percentiles],
|
||||
percentiles_e2el_ms=[
|
||||
(p, np.percentile(e2els or 0, p) * 1000) for p in selected_percentiles
|
||||
],
|
||||
)
|
||||
|
||||
return metrics, actual_output_lens
|
||||
@ -250,7 +275,7 @@ async def benchmark(
|
||||
model_id: str,
|
||||
model_name: str,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
input_requests: list[SampleRequest],
|
||||
requests: list[SampleRequest],
|
||||
logprobs: Optional[int],
|
||||
request_rate: float,
|
||||
burstiness: float,
|
||||
@ -270,10 +295,14 @@ async def benchmark(
|
||||
raise ValueError(f"Unknown backend: {backend}")
|
||||
|
||||
print("Starting initial single prompt test run...")
|
||||
test_prompt, test_prompt_len, test_output_len, test_mm_content = \
|
||||
input_requests[0].prompt, input_requests[0].prompt_len, \
|
||||
input_requests[0].expected_output_len, \
|
||||
input_requests[0].multi_modal_data
|
||||
last_idx = len(requests) - 1
|
||||
test_prompt, test_prompt_len, test_output_len, test_mm_content = (
|
||||
requests[last_idx].prompt,
|
||||
requests[last_idx].prompt_len,
|
||||
requests[last_idx].expected_output_len,
|
||||
requests[last_idx].multi_modal_data,
|
||||
)
|
||||
input_requests = requests[:last_idx]
|
||||
|
||||
assert test_mm_content is None or isinstance(test_mm_content, dict)
|
||||
test_input = RequestFuncInput(
|
||||
@ -293,36 +322,36 @@ async def benchmark(
|
||||
if not test_output.success:
|
||||
raise ValueError(
|
||||
"Initial test run failed - Please make sure benchmark arguments "
|
||||
f"are correctly specified. Error: {test_output.error}")
|
||||
f"are correctly specified. Error: {test_output.error}"
|
||||
)
|
||||
else:
|
||||
print("Initial test run completed. Starting main benchmark run...")
|
||||
|
||||
if lora_modules:
|
||||
# For each input request, choose a LoRA module at random.
|
||||
lora_modules = iter(
|
||||
[random.choice(lora_modules) \
|
||||
for _ in range(len(input_requests))])
|
||||
[random.choice(lora_modules) for _ in range(len(input_requests))]
|
||||
)
|
||||
|
||||
if profile:
|
||||
print("Starting profiler...")
|
||||
profile_input = RequestFuncInput(model=model_id,
|
||||
model_name=model_name,
|
||||
prompt=test_prompt,
|
||||
api_url=base_url + "/start_profile",
|
||||
prompt_len=test_prompt_len,
|
||||
output_len=test_output_len,
|
||||
logprobs=logprobs,
|
||||
multi_modal_content=test_mm_content,
|
||||
ignore_eos=ignore_eos,
|
||||
extra_body=extra_body)
|
||||
profile_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
model_name=model_name,
|
||||
prompt=test_prompt,
|
||||
api_url=base_url + "/start_profile",
|
||||
prompt_len=test_prompt_len,
|
||||
output_len=test_output_len,
|
||||
logprobs=logprobs,
|
||||
multi_modal_content=test_mm_content,
|
||||
ignore_eos=ignore_eos,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
profile_output = await request_func(request_func_input=profile_input)
|
||||
if profile_output.success:
|
||||
print("Profiler started")
|
||||
|
||||
if burstiness == 1.0:
|
||||
distribution = "Poisson process"
|
||||
else:
|
||||
distribution = "Gamma distribution"
|
||||
distribution = "Poisson process" if burstiness == 1.0 else "Gamma distribution"
|
||||
|
||||
print(f"Traffic request rate: {request_rate}")
|
||||
print(f"Burstiness factor: {burstiness} ({distribution})")
|
||||
@ -334,42 +363,45 @@ async def benchmark(
|
||||
# and it will simplify the code in limited_request_func.
|
||||
# semaphore = (asyncio.Semaphore(max_concurrency)
|
||||
# if max_concurrency else contextlib.nullcontext())
|
||||
semaphore = (asyncio.Semaphore(max_concurrency)
|
||||
if max_concurrency else None)
|
||||
semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
|
||||
|
||||
async def limited_request_func(request_func_input, pbar):
|
||||
if semaphore is None:
|
||||
return await request_func(request_func_input=request_func_input,
|
||||
pbar=pbar)
|
||||
return await request_func(request_func_input=request_func_input, pbar=pbar)
|
||||
async with semaphore:
|
||||
return await request_func(request_func_input=request_func_input,
|
||||
pbar=pbar)
|
||||
return await request_func(request_func_input=request_func_input, pbar=pbar)
|
||||
|
||||
benchmark_start_time = time.perf_counter()
|
||||
tasks: list[asyncio.Task] = []
|
||||
async for request in get_request(input_requests, request_rate, burstiness):
|
||||
prompt, prompt_len, output_len, mm_content = request.prompt, \
|
||||
request.prompt_len, request.expected_output_len, \
|
||||
request.multi_modal_data
|
||||
prompt, prompt_len, output_len, mm_content = (
|
||||
request.prompt,
|
||||
request.prompt_len,
|
||||
request.expected_output_len,
|
||||
request.multi_modal_data,
|
||||
)
|
||||
req_model_id, req_model_name = model_id, model_name
|
||||
if lora_modules:
|
||||
req_lora_module = next(lora_modules)
|
||||
req_model_id, req_model_name = req_lora_module, req_lora_module
|
||||
|
||||
request_func_input = RequestFuncInput(model=req_model_id,
|
||||
model_name=req_model_name,
|
||||
prompt=prompt,
|
||||
api_url=api_url,
|
||||
prompt_len=prompt_len,
|
||||
output_len=output_len,
|
||||
logprobs=logprobs,
|
||||
multi_modal_content=mm_content,
|
||||
ignore_eos=ignore_eos,
|
||||
extra_body=extra_body)
|
||||
request_func_input = RequestFuncInput(
|
||||
model=req_model_id,
|
||||
model_name=req_model_name,
|
||||
prompt=prompt,
|
||||
api_url=api_url,
|
||||
prompt_len=prompt_len,
|
||||
output_len=output_len,
|
||||
logprobs=logprobs,
|
||||
multi_modal_content=mm_content,
|
||||
ignore_eos=ignore_eos,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
limited_request_func(request_func_input=request_func_input,
|
||||
pbar=pbar)))
|
||||
limited_request_func(request_func_input=request_func_input, pbar=pbar)
|
||||
)
|
||||
)
|
||||
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
|
||||
|
||||
if profile:
|
||||
@ -401,22 +433,32 @@ async def benchmark(
|
||||
goodput_config_dict=goodput_config_dict,
|
||||
)
|
||||
|
||||
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
|
||||
print("{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
|
||||
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
|
||||
print("{:<40} {:<10.2f}".format("Benchmark duration (s):",
|
||||
benchmark_duration))
|
||||
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
|
||||
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
|
||||
print("{:<40} {:<10}".format("Total generated tokens:",
|
||||
metrics.total_output))
|
||||
print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
|
||||
metrics.request_throughput))
|
||||
print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
"Request throughput (req/s):", metrics.request_throughput
|
||||
)
|
||||
)
|
||||
if goodput_config_dict:
|
||||
print("{:<40} {:<10.2f}".format("Request goodput (req/s):",
|
||||
metrics.request_goodput))
|
||||
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
|
||||
metrics.output_throughput))
|
||||
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
|
||||
metrics.total_token_throughput))
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
"Request goodput (req/s):", metrics.request_goodput
|
||||
)
|
||||
)
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
"Output token throughput (tok/s):", metrics.output_throughput
|
||||
)
|
||||
)
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
"Total Token throughput (tok/s):", metrics.total_token_throughput
|
||||
)
|
||||
)
|
||||
|
||||
result = {
|
||||
"duration": benchmark_duration,
|
||||
@ -424,8 +466,7 @@ async def benchmark(
|
||||
"total_input_tokens": metrics.total_input,
|
||||
"total_output_tokens": metrics.total_output,
|
||||
"request_throughput": metrics.request_throughput,
|
||||
"request_goodput:":
|
||||
metrics.request_goodput if goodput_config_dict else None,
|
||||
"request_goodput:": metrics.request_goodput if goodput_config_dict else None,
|
||||
"output_throughput": metrics.output_throughput,
|
||||
"total_token_throughput": metrics.total_token_throughput,
|
||||
"input_lens": [output.prompt_len for output in outputs],
|
||||
@ -448,29 +489,35 @@ async def benchmark(
|
||||
# metric.
|
||||
if metric_attribute_name not in selected_percentile_metrics:
|
||||
return
|
||||
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
|
||||
print("{:<40} {:<10.2f}".format(
|
||||
f"Mean {metric_name} (ms):",
|
||||
getattr(metrics, f"mean_{metric_attribute_name}_ms")))
|
||||
print("{:<40} {:<10.2f}".format(
|
||||
f"Median {metric_name} (ms):",
|
||||
getattr(metrics, f"median_{metric_attribute_name}_ms")))
|
||||
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
f"Mean {metric_name} (ms):",
|
||||
getattr(metrics, f"mean_{metric_attribute_name}_ms"),
|
||||
)
|
||||
)
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
f"Median {metric_name} (ms):",
|
||||
getattr(metrics, f"median_{metric_attribute_name}_ms"),
|
||||
)
|
||||
)
|
||||
result[f"mean_{metric_attribute_name}_ms"] = getattr(
|
||||
metrics, f"mean_{metric_attribute_name}_ms")
|
||||
metrics, f"mean_{metric_attribute_name}_ms"
|
||||
)
|
||||
result[f"median_{metric_attribute_name}_ms"] = getattr(
|
||||
metrics, f"median_{metric_attribute_name}_ms")
|
||||
metrics, f"median_{metric_attribute_name}_ms"
|
||||
)
|
||||
result[f"std_{metric_attribute_name}_ms"] = getattr(
|
||||
metrics, f"std_{metric_attribute_name}_ms")
|
||||
for p, value in getattr(metrics,
|
||||
f"percentiles_{metric_attribute_name}_ms"):
|
||||
metrics, f"std_{metric_attribute_name}_ms"
|
||||
)
|
||||
for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}_ms"):
|
||||
p_word = str(int(p)) if int(p) == p else str(p)
|
||||
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):",
|
||||
value))
|
||||
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value))
|
||||
result[f"p{p_word}_{metric_attribute_name}_ms"] = value
|
||||
|
||||
process_one_metric("ttft", "TTFT", "Time to First Token")
|
||||
process_one_metric("tpot", "TPOT",
|
||||
"Time per Output Token (excl. 1st token)")
|
||||
process_one_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)")
|
||||
process_one_metric("itl", "ITL", "Inter-token Latency")
|
||||
process_one_metric("e2el", "E2EL", "End-to-end Latency")
|
||||
|
||||
@ -490,12 +537,14 @@ def check_goodput_args(args):
|
||||
raise ValueError(
|
||||
f"Invalid metric name found, {slo_name}: {slo_val}. "
|
||||
"The service level objective name should be one of "
|
||||
f"{str(VALID_NAMES)}. ")
|
||||
f"{str(VALID_NAMES)}. "
|
||||
)
|
||||
if slo_val < 0:
|
||||
raise ValueError(
|
||||
f"Invalid value found, {slo_name}: {slo_val}. "
|
||||
"The service level objective value should be "
|
||||
"non-negative.")
|
||||
"non-negative."
|
||||
)
|
||||
return goodput_config_dict
|
||||
|
||||
|
||||
@ -508,31 +557,42 @@ def parse_goodput(slo_pairs):
|
||||
except ValueError as err:
|
||||
raise argparse.ArgumentTypeError(
|
||||
"Invalid format found for service level objectives. "
|
||||
"Specify service level objectives for goodput as \"KEY:VALUE\" "
|
||||
'Specify service level objectives for goodput as "KEY:VALUE" '
|
||||
"pairs, where the key is a metric name, and the value is a "
|
||||
"number in milliseconds.") from err
|
||||
"number in milliseconds."
|
||||
) from err
|
||||
return goodput_config_dict
|
||||
|
||||
|
||||
def save_to_pytorch_benchmark_format(args: argparse.Namespace,
|
||||
results: dict[str, Any],
|
||||
file_name: str) -> None:
|
||||
def save_to_pytorch_benchmark_format(
|
||||
args: argparse.Namespace, results: dict[str, Any], file_name: str
|
||||
) -> None:
|
||||
metrics = [
|
||||
"median_ttft_ms", "mean_ttft_ms", "std_ttft_ms", "p99_ttft_ms",
|
||||
"mean_tpot_ms", "median_tpot_ms", "std_tpot_ms", "p99_tpot_ms",
|
||||
"median_itl_ms", "mean_itl_ms", "std_itl_ms", "p99_itl_ms"
|
||||
"median_ttft_ms",
|
||||
"mean_ttft_ms",
|
||||
"std_ttft_ms",
|
||||
"p99_ttft_ms",
|
||||
"mean_tpot_ms",
|
||||
"median_tpot_ms",
|
||||
"std_tpot_ms",
|
||||
"p99_tpot_ms",
|
||||
"median_itl_ms",
|
||||
"mean_itl_ms",
|
||||
"std_itl_ms",
|
||||
"p99_itl_ms",
|
||||
]
|
||||
# These raw data might be useful, but they are rather big. They can be added
|
||||
# later if needed
|
||||
ignored_metrics = ["ttfts", "itls", "generated_texts", "errors"]
|
||||
pt_records = convert_to_pytorch_benchmark_format(
|
||||
args=args,
|
||||
metrics={k: [results[k]]
|
||||
for k in metrics},
|
||||
metrics={k: [results[k]] for k in metrics},
|
||||
extra_info={
|
||||
k: results[k]
|
||||
for k in results if k not in metrics and k not in ignored_metrics
|
||||
})
|
||||
for k in results
|
||||
if k not in metrics and k not in ignored_metrics
|
||||
},
|
||||
)
|
||||
if pt_records:
|
||||
# Don't use json suffix here as we don't want CI to pick it up
|
||||
pt_file = f"{os.path.splitext(file_name)[0]}.pytorch.json"
|
||||
@ -557,34 +617,45 @@ def main(args: argparse.Namespace):
|
||||
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
|
||||
base_url = f"http://{args.host}:{args.port}"
|
||||
|
||||
tokenizer = get_tokenizer(tokenizer_id,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
trust_remote_code=args.trust_remote_code)
|
||||
# Create one more request (for a test request)
|
||||
total_prompts = args.num_prompts + 1
|
||||
|
||||
tokenizer = get_tokenizer(
|
||||
tokenizer_id,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
)
|
||||
|
||||
if args.dataset_name is None:
|
||||
raise ValueError(
|
||||
"Please specify '--dataset-name' and the corresponding "
|
||||
"'--dataset-path' if required.")
|
||||
"'--dataset-path' if required."
|
||||
)
|
||||
|
||||
if args.dataset_name == "sonnet":
|
||||
dataset = SonnetDataset(dataset_path=args.dataset_path)
|
||||
# For the "sonnet" dataset, formatting depends on the backend.
|
||||
if args.backend == "openai-chat":
|
||||
input_requests = dataset.sample(num_requests=args.num_prompts,
|
||||
input_len=args.sonnet_input_len,
|
||||
output_len=args.sonnet_output_len,
|
||||
prefix_len=args.sonnet_prefix_len,
|
||||
tokenizer=tokenizer,
|
||||
return_prompt_formatted=False)
|
||||
input_requests = dataset.sample(
|
||||
num_requests=total_prompts,
|
||||
input_len=args.sonnet_input_len,
|
||||
output_len=args.sonnet_output_len,
|
||||
prefix_len=args.sonnet_prefix_len,
|
||||
tokenizer=tokenizer,
|
||||
return_prompt_formatted=False,
|
||||
)
|
||||
else:
|
||||
assert tokenizer.chat_template or tokenizer.default_chat_template, (
|
||||
"Tokenizer/model must have chat template for sonnet dataset.")
|
||||
input_requests = dataset.sample(num_requests=args.num_prompts,
|
||||
input_len=args.sonnet_input_len,
|
||||
output_len=args.sonnet_output_len,
|
||||
prefix_len=args.sonnet_prefix_len,
|
||||
tokenizer=tokenizer,
|
||||
return_prompt_formatted=True)
|
||||
"Tokenizer/model must have chat template for sonnet dataset."
|
||||
)
|
||||
input_requests = dataset.sample(
|
||||
num_requests=total_prompts,
|
||||
input_len=args.sonnet_input_len,
|
||||
output_len=args.sonnet_output_len,
|
||||
prefix_len=args.sonnet_prefix_len,
|
||||
tokenizer=tokenizer,
|
||||
return_prompt_formatted=True,
|
||||
)
|
||||
|
||||
elif args.dataset_name == "hf":
|
||||
# all following datasets are implemented from the
|
||||
@ -611,30 +682,37 @@ def main(args: argparse.Namespace):
|
||||
dataset_class = ASRDataset
|
||||
args.hf_split = "train"
|
||||
else:
|
||||
supported_datasets = set([
|
||||
dataset_name for cls in HuggingFaceDataset.__subclasses__()
|
||||
for dataset_name in cls.SUPPORTED_DATASET_PATHS
|
||||
])
|
||||
supported_datasets = set(
|
||||
[
|
||||
dataset_name
|
||||
for cls in HuggingFaceDataset.__subclasses__()
|
||||
for dataset_name in cls.SUPPORTED_DATASET_PATHS
|
||||
]
|
||||
)
|
||||
raise ValueError(
|
||||
f"Unsupported dataset path: {args.dataset_path}. "
|
||||
"Huggingface dataset only supports dataset_path"
|
||||
f" from one of following: {supported_datasets}. "
|
||||
"Please consider contributing if you would "
|
||||
"like to add support for additional dataset formats.")
|
||||
"like to add support for additional dataset formats."
|
||||
)
|
||||
|
||||
if (dataset_class.IS_MULTIMODAL and backend not in \
|
||||
["openai-chat", "openai-audio"]):
|
||||
if dataset_class.IS_MULTIMODAL and backend not in [
|
||||
"openai-chat",
|
||||
"openai-audio",
|
||||
]:
|
||||
# multi-modal benchmark is only available on OpenAI Chat backend.
|
||||
raise ValueError(
|
||||
"Multi-modal content is only supported on 'openai-chat' and " \
|
||||
"'openai-audio' backend.")
|
||||
"Multi-modal content is only supported on 'openai-chat' and "
|
||||
"'openai-audio' backend."
|
||||
)
|
||||
input_requests = dataset_class(
|
||||
dataset_path=args.dataset_path,
|
||||
dataset_subset=args.hf_subset,
|
||||
dataset_split=args.hf_split,
|
||||
random_seed=args.seed,
|
||||
).sample(
|
||||
num_requests=args.num_prompts,
|
||||
num_requests=total_prompts,
|
||||
tokenizer=tokenizer,
|
||||
output_len=args.hf_output_len,
|
||||
)
|
||||
@ -642,26 +720,24 @@ def main(args: argparse.Namespace):
|
||||
else:
|
||||
# For datasets that follow a similar structure, use a mapping.
|
||||
dataset_mapping = {
|
||||
"sharegpt":
|
||||
lambda: ShareGPTDataset(random_seed=args.seed,
|
||||
dataset_path=args.dataset_path).sample(
|
||||
tokenizer=tokenizer,
|
||||
num_requests=args.num_prompts,
|
||||
output_len=args.sharegpt_output_len,
|
||||
),
|
||||
"burstgpt":
|
||||
lambda: BurstGPTDataset(random_seed=args.seed,
|
||||
dataset_path=args.dataset_path).
|
||||
sample(tokenizer=tokenizer, num_requests=args.num_prompts),
|
||||
"random":
|
||||
lambda: RandomDataset(dataset_path=args.dataset_path).sample(
|
||||
"sharegpt": lambda: ShareGPTDataset(
|
||||
random_seed=args.seed, dataset_path=args.dataset_path
|
||||
).sample(
|
||||
tokenizer=tokenizer,
|
||||
num_requests=args.num_prompts,
|
||||
num_requests=total_prompts,
|
||||
output_len=args.sharegpt_output_len,
|
||||
),
|
||||
"burstgpt": lambda: BurstGPTDataset(
|
||||
random_seed=args.seed, dataset_path=args.dataset_path
|
||||
).sample(tokenizer=tokenizer, num_requests=total_prompts),
|
||||
"random": lambda: RandomDataset(dataset_path=args.dataset_path).sample(
|
||||
tokenizer=tokenizer,
|
||||
num_requests=total_prompts,
|
||||
prefix_len=args.random_prefix_len,
|
||||
input_len=args.random_input_len,
|
||||
output_len=args.random_output_len,
|
||||
range_ratio=args.random_range_ratio,
|
||||
)
|
||||
),
|
||||
}
|
||||
|
||||
try:
|
||||
@ -677,15 +753,16 @@ def main(args: argparse.Namespace):
|
||||
"top_p": args.top_p,
|
||||
"top_k": args.top_k,
|
||||
"min_p": args.min_p,
|
||||
"temperature": args.temperature
|
||||
}.items() if v is not None
|
||||
"temperature": args.temperature,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
# Sampling parameters are only supported by openai-compatible backend.
|
||||
if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS:
|
||||
raise ValueError(
|
||||
"Sampling parameters are only supported by openai-compatible "
|
||||
"backends.")
|
||||
"Sampling parameters are only supported by openai-compatible backends."
|
||||
)
|
||||
|
||||
if "temperature" not in sampling_params:
|
||||
sampling_params["temperature"] = 0.0 # Default to greedy decoding.
|
||||
@ -702,22 +779,21 @@ def main(args: argparse.Namespace):
|
||||
model_id=model_id,
|
||||
model_name=model_name,
|
||||
tokenizer=tokenizer,
|
||||
input_requests=input_requests,
|
||||
requests=input_requests,
|
||||
logprobs=args.logprobs,
|
||||
request_rate=args.request_rate,
|
||||
burstiness=args.burstiness,
|
||||
disable_tqdm=args.disable_tqdm,
|
||||
profile=args.profile,
|
||||
selected_percentile_metrics=args.percentile_metrics.split(","),
|
||||
selected_percentiles=[
|
||||
float(p) for p in args.metric_percentiles.split(",")
|
||||
],
|
||||
selected_percentiles=[float(p) for p in args.metric_percentiles.split(",")],
|
||||
ignore_eos=args.ignore_eos,
|
||||
goodput_config_dict=goodput_config_dict,
|
||||
max_concurrency=args.max_concurrency,
|
||||
lora_modules=args.lora_modules,
|
||||
extra_body=sampling_params,
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
# Save config and results to json
|
||||
if args.save_result or args.append_result:
|
||||
@ -742,8 +818,9 @@ def main(args: argparse.Namespace):
|
||||
"Invalid metadata format. Please use KEY=VALUE format."
|
||||
)
|
||||
# Traffic
|
||||
result_json["request_rate"] = (args.request_rate if args.request_rate
|
||||
< float("inf") else "inf")
|
||||
result_json["request_rate"] = (
|
||||
args.request_rate if args.request_rate < float("inf") else "inf"
|
||||
)
|
||||
result_json["burstiness"] = args.burstiness
|
||||
result_json["max_concurrency"] = args.max_concurrency
|
||||
|
||||
@ -753,24 +830,31 @@ def main(args: argparse.Namespace):
|
||||
if not args.save_detailed:
|
||||
# Remove fields with too many data points
|
||||
for field in [
|
||||
"input_lens", "output_lens", "ttfts", "itls",
|
||||
"generated_texts", "errors"
|
||||
"input_lens",
|
||||
"output_lens",
|
||||
"ttfts",
|
||||
"itls",
|
||||
"generated_texts",
|
||||
"errors",
|
||||
]:
|
||||
if field in result_json:
|
||||
del result_json[field]
|
||||
|
||||
# Save to file
|
||||
base_model_id = model_id.split("/")[-1]
|
||||
max_concurrency_str = (f"-concurrency{args.max_concurrency}"
|
||||
if args.max_concurrency is not None else "")
|
||||
file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" #noqa
|
||||
max_concurrency_str = (
|
||||
f"-concurrency{args.max_concurrency}"
|
||||
if args.max_concurrency is not None
|
||||
else ""
|
||||
)
|
||||
file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
|
||||
if args.result_filename:
|
||||
file_name = args.result_filename
|
||||
if args.result_dir:
|
||||
file_name = os.path.join(args.result_dir, file_name)
|
||||
with open(file_name,
|
||||
mode="a+" if args.append_result else "w",
|
||||
encoding='utf-8') as outfile:
|
||||
with open(
|
||||
file_name, mode="a+" if args.append_result else "w", encoding="utf-8"
|
||||
) as outfile:
|
||||
# Append a newline.
|
||||
if args.append_result and outfile.tell() != 0:
|
||||
outfile.write("\n")
|
||||
@ -780,7 +864,8 @@ def main(args: argparse.Namespace):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark the online serving throughput.")
|
||||
description="Benchmark the online serving throughput."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
type=str,
|
||||
@ -809,11 +894,13 @@ if __name__ == "__main__":
|
||||
choices=["sharegpt", "burstgpt", "sonnet", "random", "hf"],
|
||||
help="Name of the dataset to benchmark on.",
|
||||
)
|
||||
parser.add_argument("--dataset-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the sharegpt/sonnet dataset. "
|
||||
"Or the huggingface dataset ID if using HF dataset.")
|
||||
parser.add_argument(
|
||||
"--dataset-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the sharegpt/sonnet dataset. "
|
||||
"Or the huggingface dataset ID if using HF dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-concurrency",
|
||||
type=int,
|
||||
@ -825,7 +912,8 @@ if __name__ == "__main__":
|
||||
"initiated, this argument will control how many are actually allowed "
|
||||
"to execute at a time. This means that when used in combination, the "
|
||||
"actual request rate may be lower than specified with --request-rate, "
|
||||
"if the server is not processing requests fast enough to keep up.")
|
||||
"if the server is not processing requests fast enough to keep up.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
@ -836,8 +924,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--tokenizer",
|
||||
type=str,
|
||||
help=
|
||||
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
|
||||
help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
|
||||
)
|
||||
parser.add_argument("--use-beam-search", action="store_true")
|
||||
parser.add_argument(
|
||||
@ -850,11 +937,13 @@ if __name__ == "__main__":
|
||||
"--logprobs",
|
||||
type=int,
|
||||
default=None,
|
||||
help=("Number of logprobs-per-token to compute & return as part of "
|
||||
"the request. If unspecified, then either (1) if beam search "
|
||||
"is disabled, no logprobs are computed & a single dummy "
|
||||
"logprob is returned for each token; or (2) if beam search "
|
||||
"is enabled 1 logprob per token is computed"),
|
||||
help=(
|
||||
"Number of logprobs-per-token to compute & return as part of "
|
||||
"the request. If unspecified, then either (1) if beam search "
|
||||
"is disabled, no logprobs are computed & a single dummy "
|
||||
"logprob is returned for each token; or (2) if beam search "
|
||||
"is enabled 1 logprob per token is computed"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--request-rate",
|
||||
@ -938,35 +1027,38 @@ if __name__ == "__main__":
|
||||
"--ignore-eos",
|
||||
action="store_true",
|
||||
help="Set ignore_eos flag when sending the benchmark request."
|
||||
"Warning: ignore_eos is not supported in deepspeed_mii and tgi.")
|
||||
"Warning: ignore_eos is not supported in deepspeed_mii and tgi.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--percentile-metrics",
|
||||
type=str,
|
||||
default="ttft,tpot,itl",
|
||||
help="Comma-separated list of selected metrics to report percentils. "
|
||||
"This argument specifies the metrics to report percentiles. "
|
||||
"Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". "
|
||||
"Default value is \"ttft,tpot,itl\".")
|
||||
'Allowed metric names are "ttft", "tpot", "itl", "e2el". '
|
||||
'Default value is "ttft,tpot,itl".',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metric-percentiles",
|
||||
type=str,
|
||||
default="99",
|
||||
help="Comma-separated list of percentiles for selected metrics. "
|
||||
"To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". "
|
||||
"Default value is \"99\". "
|
||||
"Use \"--percentile-metrics\" to select metrics.",
|
||||
'To report 25-th, 50-th, and 75-th percentiles, use "25,50,75". '
|
||||
'Default value is "99". '
|
||||
'Use "--percentile-metrics" to select metrics.',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--goodput",
|
||||
nargs="+",
|
||||
required=False,
|
||||
help="Specify service level objectives for goodput as \"KEY:VALUE\" "
|
||||
help='Specify service level objectives for goodput as "KEY:VALUE" '
|
||||
"pairs, where the key is a metric name, and the value is in "
|
||||
"milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, "
|
||||
'milliseconds. Multiple "KEY:VALUE" pairs can be provided, '
|
||||
"separated by spaces. Allowed request level metric names are "
|
||||
"\"ttft\", \"tpot\", \"e2el\". For more context on the definition of "
|
||||
'"ttft", "tpot", "e2el". For more context on the definition of '
|
||||
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
|
||||
"and the blog: https://hao-ai-lab.github.io/blogs/distserve")
|
||||
"and the blog: https://hao-ai-lab.github.io/blogs/distserve",
|
||||
)
|
||||
|
||||
# group for dataset specific arguments
|
||||
sonnet_group = parser.add_argument_group("sonnet dataset options")
|
||||
@ -974,22 +1066,19 @@ if __name__ == "__main__":
|
||||
"--sonnet-input-len",
|
||||
type=int,
|
||||
default=550,
|
||||
help=
|
||||
"Number of input tokens per request, used only for sonnet dataset.",
|
||||
help="Number of input tokens per request, used only for sonnet dataset.",
|
||||
)
|
||||
sonnet_group.add_argument(
|
||||
"--sonnet-output-len",
|
||||
type=int,
|
||||
default=150,
|
||||
help=
|
||||
"Number of output tokens per request, used only for sonnet dataset.",
|
||||
help="Number of output tokens per request, used only for sonnet dataset.",
|
||||
)
|
||||
sonnet_group.add_argument(
|
||||
"--sonnet-prefix-len",
|
||||
type=int,
|
||||
default=200,
|
||||
help=
|
||||
"Number of prefix tokens per request, used only for sonnet dataset.",
|
||||
help="Number of prefix tokens per request, used only for sonnet dataset.",
|
||||
)
|
||||
|
||||
sharegpt_group = parser.add_argument_group("sharegpt dataset options")
|
||||
@ -998,22 +1087,21 @@ if __name__ == "__main__":
|
||||
type=int,
|
||||
default=None,
|
||||
help="Output length for each request. Overrides the output length "
|
||||
"from the ShareGPT dataset.")
|
||||
"from the ShareGPT dataset.",
|
||||
)
|
||||
|
||||
random_group = parser.add_argument_group("random dataset options")
|
||||
random_group.add_argument(
|
||||
"--random-input-len",
|
||||
type=int,
|
||||
default=1024,
|
||||
help=
|
||||
"Number of input tokens per request, used only for random sampling.",
|
||||
help="Number of input tokens per request, used only for random sampling.",
|
||||
)
|
||||
random_group.add_argument(
|
||||
"--random-output-len",
|
||||
type=int,
|
||||
default=128,
|
||||
help=
|
||||
"Number of output tokens per request, used only for random sampling.",
|
||||
help="Number of output tokens per request, used only for random sampling.",
|
||||
)
|
||||
random_group.add_argument(
|
||||
"--random-range-ratio",
|
||||
@ -1028,23 +1116,23 @@ if __name__ == "__main__":
|
||||
"--random-prefix-len",
|
||||
type=int,
|
||||
default=0,
|
||||
help=("Number of fixed prefix tokens before the random context "
|
||||
"in a request. "
|
||||
"The total input length is the sum of `random-prefix-len` and "
|
||||
"a random "
|
||||
"context length sampled from [input_len * (1 - range_ratio), "
|
||||
"input_len * (1 + range_ratio)]."),
|
||||
help=(
|
||||
"Number of fixed prefix tokens before the random context "
|
||||
"in a request. "
|
||||
"The total input length is the sum of `random-prefix-len` and "
|
||||
"a random "
|
||||
"context length sampled from [input_len * (1 - range_ratio), "
|
||||
"input_len * (1 + range_ratio)]."
|
||||
),
|
||||
)
|
||||
|
||||
hf_group = parser.add_argument_group("hf dataset options")
|
||||
hf_group.add_argument("--hf-subset",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Subset of the HF dataset.")
|
||||
hf_group.add_argument("--hf-split",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Split of the HF dataset.")
|
||||
hf_group.add_argument(
|
||||
"--hf-subset", type=str, default=None, help="Subset of the HF dataset."
|
||||
)
|
||||
hf_group.add_argument(
|
||||
"--hf-split", type=str, default=None, help="Split of the HF dataset."
|
||||
)
|
||||
hf_group.add_argument(
|
||||
"--hf-output-len",
|
||||
type=int,
|
||||
@ -1058,52 +1146,58 @@ if __name__ == "__main__":
|
||||
"--top-p",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Top-p sampling parameter. Only has effect on openai-compatible "
|
||||
"backends.")
|
||||
help="Top-p sampling parameter. Only has effect on openai-compatible backends.",
|
||||
)
|
||||
sampling_group.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Top-k sampling parameter. Only has effect on openai-compatible "
|
||||
"backends.")
|
||||
help="Top-k sampling parameter. Only has effect on openai-compatible backends.",
|
||||
)
|
||||
sampling_group.add_argument(
|
||||
"--min-p",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Min-p sampling parameter. Only has effect on openai-compatible "
|
||||
"backends.")
|
||||
help="Min-p sampling parameter. Only has effect on openai-compatible backends.",
|
||||
)
|
||||
sampling_group.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Temperature sampling parameter. Only has effect on "
|
||||
"openai-compatible backends. If not specified, default to greedy "
|
||||
"decoding (i.e. temperature==0.0).")
|
||||
"decoding (i.e. temperature==0.0).",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--tokenizer-mode',
|
||||
"--tokenizer-mode",
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=['auto', 'slow', 'mistral', 'custom'],
|
||||
choices=["auto", "slow", "mistral", "custom"],
|
||||
help='The tokenizer mode.\n\n* "auto" will use the '
|
||||
'fast tokenizer if available.\n* "slow" will '
|
||||
'always use the slow tokenizer. \n* '
|
||||
"always use the slow tokenizer. \n* "
|
||||
'"mistral" will always use the `mistral_common` tokenizer. \n*'
|
||||
'"custom" will use --tokenizer to select the preregistered tokenizer.')
|
||||
'"custom" will use --tokenizer to select the preregistered tokenizer.',
|
||||
)
|
||||
|
||||
parser.add_argument("--served-model-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The model name used in the API. "
|
||||
"If not specified, the model name will be the "
|
||||
"same as the ``--model`` argument. ")
|
||||
parser.add_argument(
|
||||
"--served-model-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The model name used in the API. "
|
||||
"If not specified, the model name will be the "
|
||||
"same as the ``--model`` argument. ",
|
||||
)
|
||||
|
||||
parser.add_argument("--lora-modules",
|
||||
nargs='+',
|
||||
default=None,
|
||||
help="A subset of LoRA module names passed in when "
|
||||
"launching the server. For each request, the "
|
||||
"script chooses a LoRA module at random.")
|
||||
parser.add_argument(
|
||||
"--lora-modules",
|
||||
nargs="+",
|
||||
default=None,
|
||||
help="A subset of LoRA module names passed in when "
|
||||
"launching the server. For each request, the "
|
||||
"script chooses a LoRA module at random.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
@ -19,6 +19,7 @@ On the client side, run:
|
||||
--endpoint /generate_stream
|
||||
to the end of the command above.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import copy
|
||||
@ -36,11 +37,15 @@ from typing import Optional
|
||||
import datasets
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
|
||||
RequestFuncOutput)
|
||||
from tqdm.asyncio import tqdm
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from backend_request_func import (
|
||||
ASYNC_REQUEST_FUNCS,
|
||||
RequestFuncInput,
|
||||
RequestFuncOutput,
|
||||
)
|
||||
|
||||
try:
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
except ImportError:
|
||||
@ -52,7 +57,8 @@ except ImportError:
|
||||
from argparse import ArgumentParser as FlexibleArgumentParser
|
||||
|
||||
from vllm.v1.structured_output.backend_xgrammar import (
|
||||
has_xgrammar_unsupported_json_features)
|
||||
has_xgrammar_unsupported_json_features,
|
||||
)
|
||||
|
||||
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
|
||||
|
||||
@ -98,6 +104,7 @@ class SampleRequest:
|
||||
prompt_len: The length of the prompt in tokens.
|
||||
expected_output_len: The expected length of the output in tokens.
|
||||
"""
|
||||
|
||||
prompt: str
|
||||
prompt_len: int
|
||||
expected_output_len: int
|
||||
@ -106,32 +113,28 @@ class SampleRequest:
|
||||
completion: str = None
|
||||
|
||||
|
||||
def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
||||
args: argparse.Namespace) -> list[SampleRequest]:
|
||||
if args.dataset == 'json' or args.dataset == 'json-unique':
|
||||
def sample_requests(
|
||||
tokenizer: PreTrainedTokenizerBase, args: argparse.Namespace
|
||||
) -> list[SampleRequest]:
|
||||
if args.dataset == "json" or args.dataset == "json-unique":
|
||||
if args.json_schema_path is None:
|
||||
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||
args.json_schema_path = os.path.join(dir_path,
|
||||
"structured_schemas",
|
||||
"structured_schema_1.json")
|
||||
args.json_schema_path = os.path.join(
|
||||
dir_path, "structured_schemas", "structured_schema_1.json"
|
||||
)
|
||||
json_schemas = []
|
||||
with open(args.json_schema_path) as f:
|
||||
schema = json.load(f)
|
||||
|
||||
if args.dataset == 'json-unique':
|
||||
json_schemas = [
|
||||
copy.deepcopy(schema) for _ in range(args.num_prompts)
|
||||
]
|
||||
if args.dataset == "json-unique":
|
||||
json_schemas = [copy.deepcopy(schema) for _ in range(args.num_prompts)]
|
||||
for i in range(len(json_schemas)):
|
||||
if "properties" not in json_schemas[i]:
|
||||
json_schemas[i]["properties"] = {}
|
||||
json_schemas[i]["properties"][
|
||||
f"__optional_field_{uuid.uuid4()}"] = {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"An unique optional field to avoid cached schemas"
|
||||
}
|
||||
json_schemas[i]["properties"][f"__optional_field_{uuid.uuid4()}"] = {
|
||||
"type": "string",
|
||||
"description": "An unique optional field to avoid cached schemas",
|
||||
}
|
||||
else:
|
||||
json_schemas = [schema] * args.num_prompts
|
||||
|
||||
@ -142,11 +145,13 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
||||
return json_schemas[index % len(json_schemas)]
|
||||
|
||||
requests = [
|
||||
SampleRequest(prompt=gen_prompt(i),
|
||||
prompt_len=len(tokenizer(gen_prompt(i)).input_ids),
|
||||
expected_output_len=args.output_len,
|
||||
schema=get_schema(i),
|
||||
structure_type=args.structure_type)
|
||||
SampleRequest(
|
||||
prompt=gen_prompt(i),
|
||||
prompt_len=len(tokenizer(gen_prompt(i)).input_ids),
|
||||
expected_output_len=args.output_len,
|
||||
schema=get_schema(i),
|
||||
structure_type=args.structure_type,
|
||||
)
|
||||
for i in range(args.num_prompts)
|
||||
]
|
||||
|
||||
@ -170,11 +175,13 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
||||
input_len = len(tokenizer(prompt).input_ids)
|
||||
print(f"Input length of the prompt: {input_len} tokens")
|
||||
requests = [
|
||||
SampleRequest(prompt=prompt,
|
||||
prompt_len=input_len,
|
||||
expected_output_len=args.output_len,
|
||||
schema=schema,
|
||||
structure_type=args.structure_type)
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=input_len,
|
||||
expected_output_len=args.output_len,
|
||||
schema=schema,
|
||||
structure_type=args.structure_type,
|
||||
)
|
||||
for _ in range(args.num_prompts)
|
||||
]
|
||||
|
||||
@ -188,11 +195,13 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
||||
input_len = len(tokenizer(prompt).input_ids)
|
||||
print(f"Input length of the prompt: {input_len} tokens")
|
||||
requests = [
|
||||
SampleRequest(prompt=prompt,
|
||||
prompt_len=input_len,
|
||||
expected_output_len=args.output_len,
|
||||
schema=regex,
|
||||
structure_type=args.structure_type)
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=input_len,
|
||||
expected_output_len=args.output_len,
|
||||
schema=regex,
|
||||
structure_type=args.structure_type,
|
||||
)
|
||||
for _ in range(args.num_prompts)
|
||||
]
|
||||
|
||||
@ -203,48 +212,55 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
||||
input_len = len(tokenizer(prompt).input_ids)
|
||||
print(f"Input length of the prompt: {input_len} tokens")
|
||||
requests = [
|
||||
SampleRequest(prompt=prompt,
|
||||
prompt_len=input_len,
|
||||
expected_output_len=args.output_len,
|
||||
schema=choice,
|
||||
structure_type=args.structure_type)
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=input_len,
|
||||
expected_output_len=args.output_len,
|
||||
schema=choice,
|
||||
structure_type=args.structure_type,
|
||||
)
|
||||
for _ in range(args.num_prompts)
|
||||
]
|
||||
|
||||
elif args.dataset == "xgrammar_bench":
|
||||
requests: list[SampleRequest] = []
|
||||
dataset = datasets.load_dataset("NousResearch/json-mode-eval",
|
||||
split="train")
|
||||
dataset = datasets.load_dataset("NousResearch/json-mode-eval", split="train")
|
||||
full_dataset_len = len(dataset)
|
||||
|
||||
def _filter_func(item):
|
||||
import json
|
||||
|
||||
schema = json.loads(item["schema"])
|
||||
return not has_xgrammar_unsupported_json_features(schema)
|
||||
|
||||
dataset = dataset.filter(_filter_func)
|
||||
num_filtered_out = full_dataset_len - len(dataset)
|
||||
print(f"dataset has {len(dataset)} entries after filtering "
|
||||
f"out {num_filtered_out} entries with unsupported features")
|
||||
print(
|
||||
f"dataset has {len(dataset)} entries after filtering "
|
||||
f"out {num_filtered_out} entries with unsupported features"
|
||||
)
|
||||
len_dataset = len(dataset)
|
||||
for data_point_idx in range(args.num_prompts):
|
||||
idx = data_point_idx
|
||||
while idx >= len_dataset:
|
||||
idx -= len_dataset
|
||||
schema = dataset["schema"][idx]
|
||||
prompt = tokenizer.apply_chat_template(dataset["prompt"][idx],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
dataset["prompt"][idx], tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
input_len = len(tokenizer(prompt).input_ids)
|
||||
completion = dataset["completion"][idx]
|
||||
|
||||
requests.append(
|
||||
SampleRequest(prompt=prompt,
|
||||
prompt_len=input_len,
|
||||
expected_output_len=args.output_len,
|
||||
schema=schema,
|
||||
structure_type=args.structure_type,
|
||||
completion=completion))
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=input_len,
|
||||
expected_output_len=args.output_len,
|
||||
schema=schema,
|
||||
structure_type=args.structure_type,
|
||||
completion=completion,
|
||||
)
|
||||
)
|
||||
|
||||
return requests
|
||||
|
||||
@ -276,7 +292,8 @@ async def get_request(
|
||||
|
||||
# Calculate scale parameter theta to maintain the desired request_rate.
|
||||
assert burstiness > 0, (
|
||||
f"A positive burstiness factor is expected, but given {burstiness}.")
|
||||
f"A positive burstiness factor is expected, but given {burstiness}."
|
||||
)
|
||||
theta = 1.0 / (request_rate * burstiness)
|
||||
|
||||
for i, request in enumerate(input_requests):
|
||||
@ -318,8 +335,8 @@ def calculate_metrics(
|
||||
# multiple output tokens may be bundled together
|
||||
# Note : this may inflate the output token count slightly
|
||||
output_len = len(
|
||||
tokenizer(outputs[i].generated_text,
|
||||
add_special_tokens=False).input_ids)
|
||||
tokenizer(outputs[i].generated_text, add_special_tokens=False).input_ids
|
||||
)
|
||||
actual_output_lens.append(output_len)
|
||||
total_input += input_requests[i].prompt_len
|
||||
tpot = 0
|
||||
@ -343,16 +360,19 @@ def calculate_metrics(
|
||||
|
||||
if "ttft" in goodput_config_dict:
|
||||
valid_metrics.append(ttfts)
|
||||
slo_values.append(goodput_config_dict["ttft"] /
|
||||
MILLISECONDS_TO_SECONDS_CONVERSION)
|
||||
slo_values.append(
|
||||
goodput_config_dict["ttft"] / MILLISECONDS_TO_SECONDS_CONVERSION
|
||||
)
|
||||
if "tpot" in goodput_config_dict:
|
||||
valid_metrics.append(all_tpots)
|
||||
slo_values.append(goodput_config_dict["tpot"] /
|
||||
MILLISECONDS_TO_SECONDS_CONVERSION)
|
||||
slo_values.append(
|
||||
goodput_config_dict["tpot"] / MILLISECONDS_TO_SECONDS_CONVERSION
|
||||
)
|
||||
if "e2el" in goodput_config_dict:
|
||||
valid_metrics.append(e2els)
|
||||
slo_values.append(goodput_config_dict["e2el"] /
|
||||
MILLISECONDS_TO_SECONDS_CONVERSION)
|
||||
slo_values.append(
|
||||
goodput_config_dict["e2el"] / MILLISECONDS_TO_SECONDS_CONVERSION
|
||||
)
|
||||
|
||||
for req_metric in zip(*valid_metrics):
|
||||
is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)])
|
||||
@ -363,7 +383,8 @@ def calculate_metrics(
|
||||
warnings.warn(
|
||||
"All requests failed. This is likely due to a misconfiguration "
|
||||
"on the benchmark arguments.",
|
||||
stacklevel=2)
|
||||
stacklevel=2,
|
||||
)
|
||||
metrics = BenchmarkMetrics(
|
||||
completed=completed,
|
||||
total_input=total_input,
|
||||
@ -372,27 +393,31 @@ def calculate_metrics(
|
||||
request_goodput=good_completed / dur_s,
|
||||
output_throughput=sum(actual_output_lens) / dur_s,
|
||||
total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s,
|
||||
mean_ttft_ms=np.mean(ttfts or 0) *
|
||||
1000, # ttfts is empty if streaming is not supported by backend
|
||||
mean_ttft_ms=np.mean(ttfts or 0)
|
||||
* 1000, # ttfts is empty if streaming is not supported by backend
|
||||
std_ttft_ms=np.std(ttfts or 0) * 1000,
|
||||
median_ttft_ms=np.median(ttfts or 0) * 1000,
|
||||
percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000)
|
||||
for p in selected_percentiles],
|
||||
percentiles_ttft_ms=[
|
||||
(p, np.percentile(ttfts or 0, p) * 1000) for p in selected_percentiles
|
||||
],
|
||||
mean_tpot_ms=np.mean(tpots or 0) * 1000,
|
||||
std_tpot_ms=np.std(tpots or 0) * 1000,
|
||||
median_tpot_ms=np.median(tpots or 0) * 1000,
|
||||
percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000)
|
||||
for p in selected_percentiles],
|
||||
percentiles_tpot_ms=[
|
||||
(p, np.percentile(tpots or 0, p) * 1000) for p in selected_percentiles
|
||||
],
|
||||
mean_itl_ms=np.mean(itls or 0) * 1000,
|
||||
std_itl_ms=np.std(itls or 0) * 1000,
|
||||
median_itl_ms=np.median(itls or 0) * 1000,
|
||||
percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000)
|
||||
for p in selected_percentiles],
|
||||
percentiles_itl_ms=[
|
||||
(p, np.percentile(itls or 0, p) * 1000) for p in selected_percentiles
|
||||
],
|
||||
mean_e2el_ms=np.mean(e2els or 0) * 1000,
|
||||
std_e2el_ms=np.std(e2els or 0) * 1000,
|
||||
median_e2el_ms=np.median(e2els or 0) * 1000,
|
||||
percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000)
|
||||
for p in selected_percentiles],
|
||||
percentiles_e2el_ms=[
|
||||
(p, np.percentile(e2els or 0, p) * 1000) for p in selected_percentiles
|
||||
],
|
||||
)
|
||||
|
||||
return metrics, actual_output_lens
|
||||
@ -429,12 +454,13 @@ async def benchmark(
|
||||
|
||||
print("Starting initial single prompt test run...")
|
||||
structured_output_req_idx = random.sample(
|
||||
range(len(input_requests)),
|
||||
int(len(input_requests) * structured_output_ratio))
|
||||
range(len(input_requests)), int(len(input_requests) * structured_output_ratio)
|
||||
)
|
||||
|
||||
test_request = input_requests[0]
|
||||
test_req_extra_body = (prepare_extra_body(test_request)
|
||||
if 0 in structured_output_req_idx else None)
|
||||
test_req_extra_body = (
|
||||
prepare_extra_body(test_request) if 0 in structured_output_req_idx else None
|
||||
)
|
||||
test_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
prompt=test_request.prompt,
|
||||
@ -448,7 +474,8 @@ async def benchmark(
|
||||
if not test_output.success:
|
||||
raise ValueError(
|
||||
"Initial test run failed - Please make sure benchmark arguments "
|
||||
f"are correctly specified. Error: {test_output.error}")
|
||||
f"are correctly specified. Error: {test_output.error}"
|
||||
)
|
||||
else:
|
||||
print("Initial test run completed. Starting main benchmark run...")
|
||||
|
||||
@ -467,10 +494,7 @@ async def benchmark(
|
||||
if profile_output.success:
|
||||
print("Profiler started")
|
||||
|
||||
if burstiness == 1.0:
|
||||
distribution = "Poisson process"
|
||||
else:
|
||||
distribution = "Gamma distribution"
|
||||
distribution = "Poisson process" if burstiness == 1.0 else "Gamma distribution"
|
||||
|
||||
print(f"Traffic request rate: {request_rate}")
|
||||
print(f"Burstiness factor: {burstiness} ({distribution})")
|
||||
@ -482,24 +506,21 @@ async def benchmark(
|
||||
# and it will simplify the code in limited_request_func.
|
||||
# semaphore = (asyncio.Semaphore(max_concurrency)
|
||||
# if max_concurrency else contextlib.nullcontext())
|
||||
semaphore = (asyncio.Semaphore(max_concurrency)
|
||||
if max_concurrency else None)
|
||||
semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
|
||||
|
||||
async def limited_request_func(request_func_input, pbar):
|
||||
if semaphore is None:
|
||||
return await request_func(request_func_input=request_func_input,
|
||||
pbar=pbar)
|
||||
return await request_func(request_func_input=request_func_input, pbar=pbar)
|
||||
async with semaphore:
|
||||
return await request_func(request_func_input=request_func_input,
|
||||
pbar=pbar)
|
||||
return await request_func(request_func_input=request_func_input, pbar=pbar)
|
||||
|
||||
benchmark_start_time = time.perf_counter()
|
||||
tasks: list[asyncio.Task] = []
|
||||
expected: list[str] = []
|
||||
async for i, request in get_request(input_requests, request_rate,
|
||||
burstiness):
|
||||
extra_body = prepare_extra_body(
|
||||
request) if i in structured_output_req_idx else None
|
||||
async for i, request in get_request(input_requests, request_rate, burstiness):
|
||||
extra_body = (
|
||||
prepare_extra_body(request) if i in structured_output_req_idx else None
|
||||
)
|
||||
request_func_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
prompt=request.prompt,
|
||||
@ -512,8 +533,9 @@ async def benchmark(
|
||||
expected.append(request.completion)
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
limited_request_func(request_func_input=request_func_input,
|
||||
pbar=pbar)))
|
||||
limited_request_func(request_func_input=request_func_input, pbar=pbar)
|
||||
)
|
||||
)
|
||||
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
|
||||
|
||||
if profile:
|
||||
@ -545,54 +567,58 @@ async def benchmark(
|
||||
goodput_config_dict=goodput_config_dict,
|
||||
)
|
||||
|
||||
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
|
||||
print("{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
|
||||
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
|
||||
print("{:<40} {:<10.2f}".format("Benchmark duration (s):",
|
||||
benchmark_duration))
|
||||
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
|
||||
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
|
||||
print("{:<40} {:<10}".format("Total generated tokens:",
|
||||
metrics.total_output))
|
||||
print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
|
||||
metrics.request_throughput))
|
||||
print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
"Request throughput (req/s):", metrics.request_throughput
|
||||
)
|
||||
)
|
||||
if goodput_config_dict:
|
||||
print("{:<40} {:<10.2f}".format("Request goodput (req/s):",
|
||||
metrics.request_goodput))
|
||||
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
|
||||
metrics.output_throughput))
|
||||
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
|
||||
metrics.total_token_throughput))
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
"Request goodput (req/s):", metrics.request_goodput
|
||||
)
|
||||
)
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
"Output token throughput (tok/s):", metrics.output_throughput
|
||||
)
|
||||
)
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
"Total Token throughput (tok/s):", metrics.total_token_throughput
|
||||
)
|
||||
)
|
||||
|
||||
result = {
|
||||
"duration":
|
||||
benchmark_duration,
|
||||
"completed":
|
||||
metrics.completed,
|
||||
"total_input_tokens":
|
||||
metrics.total_input,
|
||||
"total_output_tokens":
|
||||
metrics.total_output,
|
||||
"request_throughput":
|
||||
metrics.request_throughput,
|
||||
"output_throughput":
|
||||
metrics.output_throughput,
|
||||
"total_token_throughput":
|
||||
metrics.total_token_throughput,
|
||||
"ttft_description":
|
||||
pd.Series([output.ttft for output in outputs]).describe().to_dict(),
|
||||
"tpot_description":
|
||||
pd.Series([output.tpot for output in outputs]).describe().to_dict(),
|
||||
"duration": benchmark_duration,
|
||||
"completed": metrics.completed,
|
||||
"total_input_tokens": metrics.total_input,
|
||||
"total_output_tokens": metrics.total_output,
|
||||
"request_throughput": metrics.request_throughput,
|
||||
"output_throughput": metrics.output_throughput,
|
||||
"total_token_throughput": metrics.total_token_throughput,
|
||||
"ttft_description": pd.Series([output.ttft for output in outputs])
|
||||
.describe()
|
||||
.to_dict(),
|
||||
"tpot_description": pd.Series([output.tpot for output in outputs])
|
||||
.describe()
|
||||
.to_dict(),
|
||||
"input_lens": [output.prompt_len for output in outputs],
|
||||
"output_lens":
|
||||
actual_output_lens,
|
||||
"output_lens": actual_output_lens,
|
||||
"ttfts": [output.ttft for output in outputs],
|
||||
"itls": [output.itl for output in outputs],
|
||||
"errors": [output.error for output in outputs],
|
||||
}
|
||||
|
||||
ret = [{
|
||||
'generated': output.generated_text,
|
||||
'expected': gt
|
||||
} for output, gt in zip(outputs, expected)]
|
||||
ret = [
|
||||
{"generated": output.generated_text, "expected": gt}
|
||||
for output, gt in zip(outputs, expected)
|
||||
]
|
||||
|
||||
def process_one_metric(
|
||||
# E.g., "ttft"
|
||||
@ -606,29 +632,35 @@ async def benchmark(
|
||||
# metric.
|
||||
if metric_attribute_name not in selected_percentile_metrics:
|
||||
return
|
||||
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
|
||||
print("{:<40} {:<10.2f}".format(
|
||||
f"Mean {metric_name} (ms):",
|
||||
getattr(metrics, f"mean_{metric_attribute_name}_ms")))
|
||||
print("{:<40} {:<10.2f}".format(
|
||||
f"Median {metric_name} (ms):",
|
||||
getattr(metrics, f"median_{metric_attribute_name}_ms")))
|
||||
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
f"Mean {metric_name} (ms):",
|
||||
getattr(metrics, f"mean_{metric_attribute_name}_ms"),
|
||||
)
|
||||
)
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
f"Median {metric_name} (ms):",
|
||||
getattr(metrics, f"median_{metric_attribute_name}_ms"),
|
||||
)
|
||||
)
|
||||
result[f"mean_{metric_attribute_name}_ms"] = getattr(
|
||||
metrics, f"mean_{metric_attribute_name}_ms")
|
||||
metrics, f"mean_{metric_attribute_name}_ms"
|
||||
)
|
||||
result[f"median_{metric_attribute_name}_ms"] = getattr(
|
||||
metrics, f"median_{metric_attribute_name}_ms")
|
||||
metrics, f"median_{metric_attribute_name}_ms"
|
||||
)
|
||||
result[f"std_{metric_attribute_name}_ms"] = getattr(
|
||||
metrics, f"std_{metric_attribute_name}_ms")
|
||||
for p, value in getattr(metrics,
|
||||
f"percentiles_{metric_attribute_name}_ms"):
|
||||
metrics, f"std_{metric_attribute_name}_ms"
|
||||
)
|
||||
for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}_ms"):
|
||||
p_word = str(int(p)) if int(p) == p else str(p)
|
||||
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):",
|
||||
value))
|
||||
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value))
|
||||
result[f"p{p_word}_{metric_attribute_name}_ms"] = value
|
||||
|
||||
process_one_metric("ttft", "TTFT", "Time to First Token")
|
||||
process_one_metric("tpot", "TPOT",
|
||||
"Time per Output Token (excl. 1st token)")
|
||||
process_one_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)")
|
||||
process_one_metric("itl", "ITL", "Inter-token Latency")
|
||||
process_one_metric("e2el", "E2EL", "End-to-end Latency")
|
||||
|
||||
@ -638,13 +670,13 @@ async def benchmark(
|
||||
|
||||
|
||||
def evaluate(ret, args):
|
||||
|
||||
def _eval_correctness_json(expected, actual):
|
||||
# extract json string from string using regex
|
||||
import re
|
||||
actual = actual.replace('\n', '').replace(' ', '').strip()
|
||||
|
||||
actual = actual.replace("\n", "").replace(" ", "").strip()
|
||||
try:
|
||||
actual = re.search(r'\{.*\}', actual).group()
|
||||
actual = re.search(r"\{.*\}", actual).group()
|
||||
actual = json.loads(actual)
|
||||
except Exception:
|
||||
return False
|
||||
@ -656,28 +688,32 @@ def evaluate(ret, args):
|
||||
|
||||
def _eval_correctness_regex(expected, actual):
|
||||
import re
|
||||
|
||||
return re.match(args.regex, actual) is not None
|
||||
|
||||
def _eval_correctness(expected, actual):
|
||||
if args.structure_type == 'guided_json':
|
||||
if args.structure_type == "guided_json":
|
||||
return _eval_correctness_json(expected, actual)
|
||||
elif args.structure_type == 'guided_regex':
|
||||
elif args.structure_type == "guided_regex":
|
||||
return _eval_correctness_regex(expected, actual)
|
||||
elif args.structure_type == 'guided_choice':
|
||||
elif args.structure_type == "guided_choice":
|
||||
return _eval_correctness_choice(expected, actual)
|
||||
else:
|
||||
return None
|
||||
|
||||
scores = []
|
||||
for res in ret:
|
||||
score = _eval_correctness(res['expected'], res['generated'])
|
||||
res['correctness'] = score
|
||||
score = _eval_correctness(res["expected"], res["generated"])
|
||||
res["correctness"] = score
|
||||
scores.append(score)
|
||||
|
||||
not_none_scores = [score for score in scores if score is not None]
|
||||
|
||||
return (sum(not_none_scores) / len(not_none_scores) *
|
||||
100) if len(not_none_scores) > 0 else None
|
||||
return (
|
||||
(sum(not_none_scores) / len(not_none_scores) * 100)
|
||||
if len(not_none_scores) > 0
|
||||
else None
|
||||
)
|
||||
|
||||
|
||||
def parse_goodput(slo_pairs):
|
||||
@ -689,9 +725,10 @@ def parse_goodput(slo_pairs):
|
||||
except ValueError as err:
|
||||
raise argparse.ArgumentTypeError(
|
||||
"Invalid format found for service level objectives. "
|
||||
"Specify service level objectives for goodput as \"KEY:VALUE\" "
|
||||
'Specify service level objectives for goodput as "KEY:VALUE" '
|
||||
"pairs, where the key is a metric name, and the value is a "
|
||||
"number in milliseconds.") from err
|
||||
"number in milliseconds."
|
||||
) from err
|
||||
return goodput_config_dict
|
||||
|
||||
|
||||
@ -705,12 +742,14 @@ def check_goodput_args(args):
|
||||
raise ValueError(
|
||||
f"Invalid metric name found, {slo_name}: {slo_val}. "
|
||||
"The service level objective name should be one of "
|
||||
f"{str(VALID_NAMES)}. ")
|
||||
f"{str(VALID_NAMES)}. "
|
||||
)
|
||||
if slo_val < 0:
|
||||
raise ValueError(
|
||||
f"Invalid value found, {slo_name}: {slo_val}. "
|
||||
"The service level objective value should be "
|
||||
"non-negative.")
|
||||
"non-negative."
|
||||
)
|
||||
return goodput_config_dict
|
||||
|
||||
|
||||
@ -736,19 +775,19 @@ def main(args: argparse.Namespace):
|
||||
tokenizer_mode=args.tokenizer_mode,
|
||||
)
|
||||
|
||||
if args.dataset == 'grammar':
|
||||
args.structure_type = 'guided_grammar'
|
||||
elif args.dataset == 'regex':
|
||||
args.structure_type = 'guided_regex'
|
||||
elif args.dataset == 'choice':
|
||||
args.structure_type = 'guided_choice'
|
||||
if args.dataset == "grammar":
|
||||
args.structure_type = "guided_grammar"
|
||||
elif args.dataset == "regex":
|
||||
args.structure_type = "guided_regex"
|
||||
elif args.dataset == "choice":
|
||||
args.structure_type = "guided_choice"
|
||||
else:
|
||||
args.structure_type = 'guided_json'
|
||||
args.structure_type = "guided_json"
|
||||
|
||||
if args.no_structured_output:
|
||||
args.structured_output_ratio = 0
|
||||
if args.save_results:
|
||||
result_file_name = f'{args.structured_output_ratio}guided'
|
||||
result_file_name = f"{args.structured_output_ratio}guided"
|
||||
result_file_name += f"_{backend}"
|
||||
result_file_name += f"_{args.request_rate}qps"
|
||||
result_file_name += f"_{args.model.split('/')[-1]}"
|
||||
@ -776,36 +815,29 @@ def main(args: argparse.Namespace):
|
||||
disable_tqdm=args.disable_tqdm,
|
||||
profile=args.profile,
|
||||
selected_percentile_metrics=args.percentile_metrics.split(","),
|
||||
selected_percentiles=[
|
||||
float(p) for p in args.metric_percentiles.split(",")
|
||||
],
|
||||
selected_percentiles=[float(p) for p in args.metric_percentiles.split(",")],
|
||||
ignore_eos=args.ignore_eos,
|
||||
max_concurrency=args.max_concurrency,
|
||||
structured_output_ratio=args.structured_output_ratio,
|
||||
goodput_config_dict=goodput_config_dict,
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
# Save config and results to json
|
||||
score = evaluate(ret, args)
|
||||
print("correct_rate(%)", score, '\n')
|
||||
print("correct_rate(%)", score, "\n")
|
||||
if args.save_results:
|
||||
results = {
|
||||
"backend":
|
||||
backend,
|
||||
"model_id":
|
||||
model_id,
|
||||
"tokenizer_id":
|
||||
tokenizer_id,
|
||||
"num_prompts":
|
||||
args.num_prompts,
|
||||
"request_rate":
|
||||
args.request_rate if args.request_rate < float("inf") else "inf",
|
||||
"burstiness":
|
||||
args.burstiness,
|
||||
"max_concurrency":
|
||||
args.max_concurrency,
|
||||
"correct_rate(%)":
|
||||
score
|
||||
"backend": backend,
|
||||
"model_id": model_id,
|
||||
"tokenizer_id": tokenizer_id,
|
||||
"num_prompts": args.num_prompts,
|
||||
"request_rate": args.request_rate
|
||||
if args.request_rate < float("inf")
|
||||
else "inf",
|
||||
"burstiness": args.burstiness,
|
||||
"max_concurrency": args.max_concurrency,
|
||||
"correct_rate(%)": score,
|
||||
}
|
||||
results = {"outputs": ret, **results, **benchmark_result}
|
||||
|
||||
@ -814,13 +846,14 @@ def main(args: argparse.Namespace):
|
||||
result_file_name = args.result_filename
|
||||
if args.result_dir:
|
||||
result_file_name = os.path.join(args.result_dir, result_file_name)
|
||||
with open(result_file_name, "w", encoding='utf-8') as outfile:
|
||||
with open(result_file_name, "w", encoding="utf-8") as outfile:
|
||||
json.dump(results, outfile, indent=4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark the online serving throughput.")
|
||||
description="Benchmark the online serving throughput."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
type=str,
|
||||
@ -842,16 +875,14 @@ if __name__ == "__main__":
|
||||
default="/v1/completions",
|
||||
help="API endpoint.",
|
||||
)
|
||||
parser.add_argument("--dataset",
|
||||
default='json',
|
||||
choices=[
|
||||
'json', 'json-unique', 'grammar', 'regex',
|
||||
'choice', 'xgrammar_bench'
|
||||
])
|
||||
parser.add_argument("--json-schema-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to json schema.")
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
default="json",
|
||||
choices=["json", "json-unique", "grammar", "regex", "choice", "xgrammar_bench"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--json-schema-path", type=str, default=None, help="Path to json schema."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-concurrency",
|
||||
type=int,
|
||||
@ -863,7 +894,8 @@ if __name__ == "__main__":
|
||||
"initiated, this argument will control how many are actually allowed "
|
||||
"to execute at a time. This means that when used in combination, the "
|
||||
"actual request rate may be lower than specified with --request-rate, "
|
||||
"if the server is not processing requests fast enough to keep up.")
|
||||
"if the server is not processing requests fast enough to keep up.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
@ -873,15 +905,13 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--tokenizer",
|
||||
type=str,
|
||||
help=
|
||||
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
|
||||
help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer-mode",
|
||||
type=str,
|
||||
default="auto",
|
||||
help=
|
||||
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
|
||||
help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-prompts",
|
||||
@ -958,44 +988,51 @@ if __name__ == "__main__":
|
||||
"--ignore-eos",
|
||||
action="store_true",
|
||||
help="Set ignore_eos flag when sending the benchmark request."
|
||||
"Warning: ignore_eos is not supported in deepspeed_mii and tgi.")
|
||||
"Warning: ignore_eos is not supported in deepspeed_mii and tgi.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--percentile-metrics",
|
||||
type=str,
|
||||
default="ttft,tpot,itl",
|
||||
help="Comma-separated list of selected metrics to report percentils. "
|
||||
"This argument specifies the metrics to report percentiles. "
|
||||
"Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". "
|
||||
"Default value is \"ttft,tpot,itl\".")
|
||||
'Allowed metric names are "ttft", "tpot", "itl", "e2el". '
|
||||
'Default value is "ttft,tpot,itl".',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metric-percentiles",
|
||||
type=str,
|
||||
default="99",
|
||||
help="Comma-separated list of percentiles for selected metrics. "
|
||||
"To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". "
|
||||
"Default value is \"99\". "
|
||||
"Use \"--percentile-metrics\" to select metrics.",
|
||||
'To report 25-th, 50-th, and 75-th percentiles, use "25,50,75". '
|
||||
'Default value is "99". '
|
||||
'Use "--percentile-metrics" to select metrics.',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--goodput",
|
||||
nargs="+",
|
||||
required=False,
|
||||
help="Specify service level objectives for goodput as \"KEY:VALUE\" "
|
||||
help='Specify service level objectives for goodput as "KEY:VALUE" '
|
||||
"pairs, where the key is a metric name, and the value is in "
|
||||
"milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, "
|
||||
'milliseconds. Multiple "KEY:VALUE" pairs can be provided, '
|
||||
"separated by spaces. Allowed request level metric names are "
|
||||
"\"ttft\", \"tpot\", \"e2el\". For more context on the definition of "
|
||||
'"ttft", "tpot", "e2el". For more context on the definition of '
|
||||
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
|
||||
"and the blog: https://hao-ai-lab.github.io/blogs/distserve")
|
||||
"and the blog: https://hao-ai-lab.github.io/blogs/distserve",
|
||||
)
|
||||
|
||||
parser.add_argument("--no-structured-output",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Whether to disable JSON decoding or not.")
|
||||
parser.add_argument("--structured-output-ratio",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Ratio of Structured Outputs requests")
|
||||
parser.add_argument(
|
||||
"--no-structured-output",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to disable JSON decoding or not.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--structured-output-ratio",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Ratio of Structured Outputs requests",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Benchmark offline inference throughput."""
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
@ -11,18 +12,25 @@ from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import uvloop
|
||||
from benchmark_dataset import (AIMODataset, BurstGPTDataset,
|
||||
ConversationDataset, InstructCoderDataset,
|
||||
RandomDataset, SampleRequest, ShareGPTDataset,
|
||||
SonnetDataset, VisionArenaDataset)
|
||||
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||
from tqdm import tqdm
|
||||
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
||||
PreTrainedTokenizerBase)
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase
|
||||
|
||||
from benchmark_dataset import (
|
||||
AIMODataset,
|
||||
BurstGPTDataset,
|
||||
ConversationDataset,
|
||||
InstructCoderDataset,
|
||||
RandomDataset,
|
||||
SampleRequest,
|
||||
ShareGPTDataset,
|
||||
SonnetDataset,
|
||||
VisionArenaDataset,
|
||||
)
|
||||
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||
from vllm.entrypoints.openai.api_server import (
|
||||
build_async_engine_client_from_engine_args)
|
||||
build_async_engine_client_from_engine_args,
|
||||
)
|
||||
from vllm.inputs import TextPrompt, TokensPrompt
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import RequestOutput
|
||||
@ -37,23 +45,30 @@ def run_vllm(
|
||||
disable_detokenize: bool = False,
|
||||
) -> tuple[float, Optional[list[RequestOutput]]]:
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
assert all(
|
||||
llm.llm_engine.model_config.max_model_len >= (
|
||||
request.prompt_len + request.expected_output_len)
|
||||
for request in requests), (
|
||||
"Please ensure that max_model_len is greater than the sum of"
|
||||
" prompt_len and expected_output_len for all requests.")
|
||||
llm.llm_engine.model_config.max_model_len
|
||||
>= (request.prompt_len + request.expected_output_len)
|
||||
for request in requests
|
||||
), (
|
||||
"Please ensure that max_model_len is greater than the sum of"
|
||||
" prompt_len and expected_output_len for all requests."
|
||||
)
|
||||
# Add the requests to the engine.
|
||||
prompts: list[Union[TextPrompt, TokensPrompt]] = []
|
||||
sampling_params: list[SamplingParams] = []
|
||||
for request in requests:
|
||||
prompts.append(
|
||||
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"],
|
||||
multi_modal_data=request.multi_modal_data)
|
||||
if "prompt_token_ids" in request.prompt else \
|
||||
TextPrompt(prompt=request.prompt,
|
||||
multi_modal_data=request.multi_modal_data))
|
||||
TokensPrompt(
|
||||
prompt_token_ids=request.prompt["prompt_token_ids"],
|
||||
multi_modal_data=request.multi_modal_data,
|
||||
)
|
||||
if "prompt_token_ids" in request.prompt
|
||||
else TextPrompt(
|
||||
prompt=request.prompt, multi_modal_data=request.multi_modal_data
|
||||
)
|
||||
)
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
@ -62,7 +77,8 @@ def run_vllm(
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
detokenize=not disable_detokenize,
|
||||
))
|
||||
)
|
||||
)
|
||||
lora_requests: Optional[list[LoRARequest]] = None
|
||||
if engine_args.enable_lora:
|
||||
lora_requests = [request.lora_request for request in requests]
|
||||
@ -72,10 +88,9 @@ def run_vllm(
|
||||
outputs = None
|
||||
if not use_beam_search:
|
||||
start = time.perf_counter()
|
||||
outputs = llm.generate(prompts,
|
||||
sampling_params,
|
||||
lora_request=lora_requests,
|
||||
use_tqdm=True)
|
||||
outputs = llm.generate(
|
||||
prompts, sampling_params, lora_request=lora_requests, use_tqdm=True
|
||||
)
|
||||
end = time.perf_counter()
|
||||
else:
|
||||
assert lora_requests is None, "BeamSearch API does not support LoRA"
|
||||
@ -91,30 +106,35 @@ def run_vllm(
|
||||
beam_width=n,
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
))
|
||||
),
|
||||
)
|
||||
end = time.perf_counter()
|
||||
return end - start, outputs
|
||||
|
||||
|
||||
def run_vllm_chat(
|
||||
requests: list[SampleRequest],
|
||||
n: int,
|
||||
engine_args: EngineArgs,
|
||||
disable_detokenize: bool = False) -> tuple[float, list[RequestOutput]]:
|
||||
requests: list[SampleRequest],
|
||||
n: int,
|
||||
engine_args: EngineArgs,
|
||||
disable_detokenize: bool = False,
|
||||
) -> tuple[float, list[RequestOutput]]:
|
||||
"""
|
||||
Run vLLM chat benchmark. This function is recommended ONLY for benchmarking
|
||||
multimodal models as it properly handles multimodal inputs and chat
|
||||
formatting. For non-multimodal models, use run_vllm() instead.
|
||||
"""
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
|
||||
assert all(
|
||||
llm.llm_engine.model_config.max_model_len >= (
|
||||
request.prompt_len + request.expected_output_len)
|
||||
for request in requests), (
|
||||
"Please ensure that max_model_len is greater than the sum of "
|
||||
"prompt_len and expected_output_len for all requests.")
|
||||
llm.llm_engine.model_config.max_model_len
|
||||
>= (request.prompt_len + request.expected_output_len)
|
||||
for request in requests
|
||||
), (
|
||||
"Please ensure that max_model_len is greater than the sum of "
|
||||
"prompt_len and expected_output_len for all requests."
|
||||
)
|
||||
|
||||
prompts = []
|
||||
sampling_params: list[SamplingParams] = []
|
||||
@ -128,7 +148,8 @@ def run_vllm_chat(
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
detokenize=not disable_detokenize,
|
||||
))
|
||||
)
|
||||
)
|
||||
start = time.perf_counter()
|
||||
outputs = llm.chat(prompts, sampling_params, use_tqdm=True)
|
||||
end = time.perf_counter()
|
||||
@ -145,13 +166,17 @@ async def run_vllm_async(
|
||||
from vllm import SamplingParams
|
||||
|
||||
async with build_async_engine_client_from_engine_args(
|
||||
engine_args, disable_frontend_multiprocessing) as llm:
|
||||
engine_args, disable_frontend_multiprocessing
|
||||
) as llm:
|
||||
model_config = await llm.get_model_config()
|
||||
assert all(
|
||||
llm.model_config.max_model_len >= (request.prompt_len +
|
||||
request.expected_output_len)
|
||||
for request in requests), (
|
||||
"Please ensure that max_model_len is greater than the sum of"
|
||||
" prompt_len and expected_output_len for all requests.")
|
||||
model_config.max_model_len
|
||||
>= (request.prompt_len + request.expected_output_len)
|
||||
for request in requests
|
||||
), (
|
||||
"Please ensure that max_model_len is greater than the sum of"
|
||||
" prompt_len and expected_output_len for all requests."
|
||||
)
|
||||
|
||||
# Add the requests to the engine.
|
||||
prompts: list[Union[TextPrompt, TokensPrompt]] = []
|
||||
@ -159,11 +184,15 @@ async def run_vllm_async(
|
||||
lora_requests: list[Optional[LoRARequest]] = []
|
||||
for request in requests:
|
||||
prompts.append(
|
||||
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"],
|
||||
multi_modal_data=request.multi_modal_data)
|
||||
if "prompt_token_ids" in request.prompt else \
|
||||
TextPrompt(prompt=request.prompt,
|
||||
multi_modal_data=request.multi_modal_data))
|
||||
TokensPrompt(
|
||||
prompt_token_ids=request.prompt["prompt_token_ids"],
|
||||
multi_modal_data=request.multi_modal_data,
|
||||
)
|
||||
if "prompt_token_ids" in request.prompt
|
||||
else TextPrompt(
|
||||
prompt=request.prompt, multi_modal_data=request.multi_modal_data
|
||||
)
|
||||
)
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
@ -172,17 +201,16 @@ async def run_vllm_async(
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
detokenize=not disable_detokenize,
|
||||
))
|
||||
)
|
||||
)
|
||||
lora_requests.append(request.lora_request)
|
||||
|
||||
generators = []
|
||||
start = time.perf_counter()
|
||||
for i, (prompt, sp,
|
||||
lr) in enumerate(zip(prompts, sampling_params, lora_requests)):
|
||||
generator = llm.generate(prompt,
|
||||
sp,
|
||||
lora_request=lr,
|
||||
request_id=f"test{i}")
|
||||
for i, (prompt, sp, lr) in enumerate(
|
||||
zip(prompts, sampling_params, lora_requests)
|
||||
):
|
||||
generator = llm.generate(prompt, sp, lora_request=lr, request_id=f"test{i}")
|
||||
generators.append(generator)
|
||||
all_gens = merge_async_iterators(*generators)
|
||||
async for i, res in all_gens:
|
||||
@ -201,7 +229,8 @@ def run_hf(
|
||||
disable_detokenize: bool = False,
|
||||
) -> float:
|
||||
llm = AutoModelForCausalLM.from_pretrained(
|
||||
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
|
||||
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code
|
||||
)
|
||||
if llm.config.model_type == "llama":
|
||||
# To enable padding in the HF backend.
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
@ -224,14 +253,15 @@ def run_hf(
|
||||
# Check if we can add more requests to the batch.
|
||||
next_prompt_len = requests[i + 1].prompt_len
|
||||
next_output_len = requests[i + 1].expected_output_len
|
||||
if (max(max_prompt_len, next_prompt_len) +
|
||||
max(max_output_len, next_output_len)) <= 2048:
|
||||
if (
|
||||
max(max_prompt_len, next_prompt_len)
|
||||
+ max(max_output_len, next_output_len)
|
||||
) <= 2048:
|
||||
# We can add more requests to the batch.
|
||||
continue
|
||||
|
||||
# Generate the sequences.
|
||||
input_ids = tokenizer(batch, return_tensors="pt",
|
||||
padding=True).input_ids
|
||||
input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids
|
||||
llm_outputs = llm.generate(
|
||||
input_ids=input_ids.cuda(),
|
||||
do_sample=True,
|
||||
@ -261,6 +291,7 @@ def run_mii(
|
||||
output_len: int,
|
||||
) -> float:
|
||||
from mii import client, serve
|
||||
|
||||
llm = serve(model, tensor_parallel=tensor_parallel_size)
|
||||
prompts = [request.prompt for request in requests]
|
||||
|
||||
@ -272,8 +303,9 @@ def run_mii(
|
||||
return end - start
|
||||
|
||||
|
||||
def save_to_pytorch_benchmark_format(args: argparse.Namespace,
|
||||
results: dict[str, Any]) -> None:
|
||||
def save_to_pytorch_benchmark_format(
|
||||
args: argparse.Namespace, results: dict[str, Any]
|
||||
) -> None:
|
||||
pt_records = convert_to_pytorch_benchmark_format(
|
||||
args=args,
|
||||
metrics={
|
||||
@ -281,9 +313,9 @@ def save_to_pytorch_benchmark_format(args: argparse.Namespace,
|
||||
"tokens_per_second": [results["tokens_per_second"]],
|
||||
},
|
||||
extra_info={
|
||||
k: results[k]
|
||||
for k in ["elapsed_time", "num_requests", "total_num_tokens"]
|
||||
})
|
||||
k: results[k] for k in ["elapsed_time", "num_requests", "total_num_tokens"]
|
||||
},
|
||||
)
|
||||
if pt_records:
|
||||
# Don't use json suffix here as we don't want CI to pick it up
|
||||
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
|
||||
@ -315,7 +347,8 @@ def get_requests(args, tokenizer):
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
elif args.dataset_name == "sonnet":
|
||||
assert tokenizer.chat_template or tokenizer.default_chat_template, (
|
||||
"Tokenizer/model must have chat template for sonnet dataset.")
|
||||
"Tokenizer/model must have chat template for sonnet dataset."
|
||||
)
|
||||
dataset_cls = SonnetDataset
|
||||
sample_kwargs["prefix_len"] = args.prefix_len
|
||||
sample_kwargs["return_prompt_formatted"] = True
|
||||
@ -324,21 +357,21 @@ def get_requests(args, tokenizer):
|
||||
elif args.dataset_name == "hf":
|
||||
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = VisionArenaDataset
|
||||
common_kwargs['dataset_subset'] = None
|
||||
common_kwargs['dataset_split'] = "train"
|
||||
common_kwargs["dataset_subset"] = None
|
||||
common_kwargs["dataset_split"] = "train"
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = InstructCoderDataset
|
||||
common_kwargs['dataset_split'] = "train"
|
||||
common_kwargs["dataset_split"] = "train"
|
||||
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = ConversationDataset
|
||||
common_kwargs['dataset_subset'] = args.hf_subset
|
||||
common_kwargs['dataset_split'] = args.hf_split
|
||||
common_kwargs["dataset_subset"] = args.hf_subset
|
||||
common_kwargs["dataset_split"] = args.hf_split
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = AIMODataset
|
||||
common_kwargs['dataset_subset'] = None
|
||||
common_kwargs['dataset_split'] = "train"
|
||||
common_kwargs["dataset_subset"] = None
|
||||
common_kwargs["dataset_split"] = "train"
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset name: {args.dataset_name}")
|
||||
# Remove None values
|
||||
@ -353,10 +386,10 @@ def main(args: argparse.Namespace):
|
||||
random.seed(args.seed)
|
||||
# Sample the requests.
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer, trust_remote_code=args.trust_remote_code)
|
||||
args.tokenizer, trust_remote_code=args.trust_remote_code
|
||||
)
|
||||
requests = get_requests(args, tokenizer)
|
||||
is_multi_modal = any(request.multi_modal_data is not None
|
||||
for request in requests)
|
||||
is_multi_modal = any(request.multi_modal_data is not None for request in requests)
|
||||
request_outputs: Optional[list[RequestOutput]] = None
|
||||
if args.backend == "vllm":
|
||||
if args.async_engine:
|
||||
@ -367,23 +400,34 @@ def main(args: argparse.Namespace):
|
||||
AsyncEngineArgs.from_cli_args(args),
|
||||
args.disable_frontend_multiprocessing,
|
||||
args.disable_detokenize,
|
||||
))
|
||||
)
|
||||
)
|
||||
else:
|
||||
elapsed_time, request_outputs = run_vllm(
|
||||
requests, args.n, EngineArgs.from_cli_args(args),
|
||||
args.disable_detokenize)
|
||||
requests,
|
||||
args.n,
|
||||
EngineArgs.from_cli_args(args),
|
||||
args.disable_detokenize,
|
||||
)
|
||||
elif args.backend == "hf":
|
||||
assert args.tensor_parallel_size == 1
|
||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||
args.hf_max_batch_size, args.trust_remote_code,
|
||||
args.disable_detokenize)
|
||||
elapsed_time = run_hf(
|
||||
requests,
|
||||
args.model,
|
||||
tokenizer,
|
||||
args.n,
|
||||
args.hf_max_batch_size,
|
||||
args.trust_remote_code,
|
||||
args.disable_detokenize,
|
||||
)
|
||||
elif args.backend == "mii":
|
||||
elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
|
||||
args.output_len)
|
||||
elapsed_time = run_mii(
|
||||
requests, args.model, args.tensor_parallel_size, args.output_len
|
||||
)
|
||||
elif args.backend == "vllm-chat":
|
||||
elapsed_time, request_outputs = run_vllm_chat(
|
||||
requests, args.n, EngineArgs.from_cli_args(args),
|
||||
args.disable_detokenize)
|
||||
requests, args.n, EngineArgs.from_cli_args(args), args.disable_detokenize
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown backend: {args.backend}")
|
||||
|
||||
@ -395,28 +439,31 @@ def main(args: argparse.Namespace):
|
||||
for ro in request_outputs:
|
||||
if not isinstance(ro, RequestOutput):
|
||||
continue
|
||||
total_prompt_tokens += len(
|
||||
ro.prompt_token_ids) if ro.prompt_token_ids else 0
|
||||
total_output_tokens += sum(
|
||||
len(o.token_ids) for o in ro.outputs if o)
|
||||
total_prompt_tokens += (
|
||||
len(ro.prompt_token_ids) if ro.prompt_token_ids else 0
|
||||
)
|
||||
total_output_tokens += sum(len(o.token_ids) for o in ro.outputs if o)
|
||||
total_num_tokens = total_prompt_tokens + total_output_tokens
|
||||
else:
|
||||
total_num_tokens = sum(r.prompt_len + r.expected_output_len
|
||||
for r in requests)
|
||||
total_num_tokens = sum(r.prompt_len + r.expected_output_len for r in requests)
|
||||
total_output_tokens = sum(r.expected_output_len for r in requests)
|
||||
total_prompt_tokens = total_num_tokens - total_output_tokens
|
||||
|
||||
if is_multi_modal and args.backend != "vllm-chat":
|
||||
print("\033[91mWARNING\033[0m: Multi-modal request with "
|
||||
f"{args.backend} backend detected. The "
|
||||
"following metrics are not accurate because image tokens are not"
|
||||
" counted. See vllm-project/vllm/issues/9778 for details.")
|
||||
print(
|
||||
"\033[91mWARNING\033[0m: Multi-modal request with "
|
||||
f"{args.backend} backend detected. The "
|
||||
"following metrics are not accurate because image tokens are not"
|
||||
" counted. See vllm-project/vllm/issues/9778 for details."
|
||||
)
|
||||
# TODO(vllm-project/vllm/issues/9778): Count multi-modal token length.
|
||||
# vllm-chat backend counts the image tokens now
|
||||
|
||||
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
||||
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
|
||||
f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
|
||||
print(
|
||||
f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
||||
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
|
||||
f"{total_output_tokens / elapsed_time:.2f} output tokens/s"
|
||||
)
|
||||
print(f"Total num prompt tokens: {total_prompt_tokens}")
|
||||
print(f"Total num output tokens: {total_output_tokens}")
|
||||
|
||||
@ -444,7 +491,8 @@ def validate_args(args):
|
||||
warnings.warn(
|
||||
"The '--dataset' argument will be deprecated in the next release. "
|
||||
"Please use '--dataset-name' and '--dataset-path' instead.",
|
||||
stacklevel=2)
|
||||
stacklevel=2,
|
||||
)
|
||||
args.dataset_path = args.dataset
|
||||
|
||||
if not getattr(args, "tokenizer", None):
|
||||
@ -457,9 +505,8 @@ def validate_args(args):
|
||||
|
||||
# === Dataset Configuration ===
|
||||
if not args.dataset and not args.dataset_path:
|
||||
print(
|
||||
"When dataset path is not set, it will default to random dataset")
|
||||
args.dataset_name = 'random'
|
||||
print("When dataset path is not set, it will default to random dataset")
|
||||
args.dataset_name = "random"
|
||||
if args.input_len is None:
|
||||
raise ValueError("input_len must be provided for a random dataset")
|
||||
|
||||
@ -467,41 +514,55 @@ def validate_args(args):
|
||||
# --hf-subset and --hf-split: only used
|
||||
# when dataset_name is 'hf'
|
||||
if args.dataset_name != "hf" and (
|
||||
getattr(args, "hf_subset", None) is not None
|
||||
or getattr(args, "hf_split", None) is not None):
|
||||
warnings.warn("--hf-subset and --hf-split will be ignored \
|
||||
getattr(args, "hf_subset", None) is not None
|
||||
or getattr(args, "hf_split", None) is not None
|
||||
):
|
||||
warnings.warn(
|
||||
"--hf-subset and --hf-split will be ignored \
|
||||
since --dataset-name is not 'hf'.",
|
||||
stacklevel=2)
|
||||
stacklevel=2,
|
||||
)
|
||||
elif args.dataset_name == "hf":
|
||||
if args.dataset_path in (
|
||||
VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys()
|
||||
| ConversationDataset.SUPPORTED_DATASET_PATHS):
|
||||
assert args.backend == "vllm-chat", f"{args.dataset_path} needs to use vllm-chat as the backend." #noqa: E501
|
||||
elif args.dataset_path in (InstructCoderDataset.SUPPORTED_DATASET_PATHS
|
||||
| AIMODataset.SUPPORTED_DATASET_PATHS):
|
||||
assert args.backend == "vllm", f"{args.dataset_path} needs to use vllm as the backend." #noqa: E501
|
||||
VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys()
|
||||
| ConversationDataset.SUPPORTED_DATASET_PATHS
|
||||
):
|
||||
assert args.backend == "vllm-chat", (
|
||||
f"{args.dataset_path} needs to use vllm-chat as the backend."
|
||||
) # noqa: E501
|
||||
elif args.dataset_path in (
|
||||
InstructCoderDataset.SUPPORTED_DATASET_PATHS
|
||||
| AIMODataset.SUPPORTED_DATASET_PATHS
|
||||
):
|
||||
assert args.backend == "vllm", (
|
||||
f"{args.dataset_path} needs to use vllm as the backend."
|
||||
) # noqa: E501
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{args.dataset_path} is not supported by hf dataset.")
|
||||
raise ValueError(f"{args.dataset_path} is not supported by hf dataset.")
|
||||
|
||||
# --random-range-ratio: only used when dataset_name is 'random'
|
||||
if args.dataset_name != 'random' and args.random_range_ratio is not None:
|
||||
warnings.warn("--random-range-ratio will be ignored since \
|
||||
if args.dataset_name != "random" and args.random_range_ratio is not None:
|
||||
warnings.warn(
|
||||
"--random-range-ratio will be ignored since \
|
||||
--dataset-name is not 'random'.",
|
||||
stacklevel=2)
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# --prefix-len: only used when dataset_name is 'random', 'sonnet', or not
|
||||
# set.
|
||||
if args.dataset_name not in {"random", "sonnet", None
|
||||
} and args.prefix_len is not None:
|
||||
warnings.warn("--prefix-len will be ignored since --dataset-name\
|
||||
if (
|
||||
args.dataset_name not in {"random", "sonnet", None}
|
||||
and args.prefix_len is not None
|
||||
):
|
||||
warnings.warn(
|
||||
"--prefix-len will be ignored since --dataset-name\
|
||||
is not 'random', 'sonnet', or not set.",
|
||||
stacklevel=2)
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# === LoRA Settings ===
|
||||
if getattr(args, "enable_lora", False) and args.backend != "vllm":
|
||||
raise ValueError(
|
||||
"LoRA benchmarking is only supported for vLLM backend")
|
||||
raise ValueError("LoRA benchmarking is only supported for vLLM backend")
|
||||
if getattr(args, "enable_lora", False) and args.lora_path is None:
|
||||
raise ValueError("LoRA path must be provided when enable_lora is True")
|
||||
|
||||
@ -511,8 +572,10 @@ def validate_args(args):
|
||||
if args.backend != "hf" and args.hf_max_batch_size is not None:
|
||||
raise ValueError("HF max batch size is only for HF backend.")
|
||||
|
||||
if args.backend in {"hf", "mii"} and getattr(args, "quantization",
|
||||
None) is not None:
|
||||
if (
|
||||
args.backend in {"hf", "mii"}
|
||||
and getattr(args, "quantization", None) is not None
|
||||
):
|
||||
raise ValueError("Quantization is only for vLLM backend.")
|
||||
|
||||
if args.backend == "mii" and args.dtype != "auto":
|
||||
@ -520,29 +583,32 @@ def validate_args(args):
|
||||
if args.backend == "mii" and args.n != 1:
|
||||
raise ValueError("n must be 1 for MII backend.")
|
||||
if args.backend == "mii" and args.tokenizer != args.model:
|
||||
raise ValueError(
|
||||
"Tokenizer must be the same as the model for MII backend.")
|
||||
raise ValueError("Tokenizer must be the same as the model for MII backend.")
|
||||
|
||||
# --data-parallel is not supported currently.
|
||||
# https://github.com/vllm-project/vllm/issues/16222
|
||||
if args.data_parallel_size > 1:
|
||||
raise ValueError(
|
||||
"Data parallel is not supported in offline benchmark, \
|
||||
please use benchmark serving instead")
|
||||
please use benchmark serving instead"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(description="Benchmark the throughput.")
|
||||
parser.add_argument("--backend",
|
||||
type=str,
|
||||
choices=["vllm", "hf", "mii", "vllm-chat"],
|
||||
default="vllm")
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
type=str,
|
||||
choices=["vllm", "hf", "mii", "vllm-chat"],
|
||||
default="vllm",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"],
|
||||
help="Name of the dataset to benchmark on.",
|
||||
default="sharegpt")
|
||||
default="sharegpt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
@ -550,57 +616,70 @@ if __name__ == "__main__":
|
||||
help="Path to the ShareGPT dataset, will be deprecated in\
|
||||
the next release. The dataset is expected to "
|
||||
"be a json in form of list[dict[..., conversations: "
|
||||
"list[dict[..., value: <prompt_or_response>]]]]")
|
||||
parser.add_argument("--dataset-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the dataset")
|
||||
parser.add_argument("--input-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Input prompt length for each request")
|
||||
parser.add_argument("--output-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Output length for each request. Overrides the "
|
||||
"output length from the dataset.")
|
||||
parser.add_argument("--n",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of generated sequences per prompt.")
|
||||
parser.add_argument("--num-prompts",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of prompts to process.")
|
||||
parser.add_argument("--hf-max-batch-size",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum batch size for HF backend.")
|
||||
"list[dict[..., value: <prompt_or_response>]]]]",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output-json',
|
||||
"--dataset-path", type=str, default=None, help="Path to the dataset"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Input prompt length for each request",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Output length for each request. Overrides the "
|
||||
"output length from the dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n", type=int, default=1, help="Number of generated sequences per prompt."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-prompts", type=int, default=1000, help="Number of prompts to process."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-max-batch-size",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum batch size for HF backend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-json",
|
||||
type=str,
|
||||
default=None,
|
||||
help='Path to save the throughput results in JSON format.')
|
||||
parser.add_argument("--async-engine",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Use vLLM async engine rather than LLM class.")
|
||||
parser.add_argument("--disable-frontend-multiprocessing",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Disable decoupled async engine frontend.")
|
||||
help="Path to save the throughput results in JSON format.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--async-engine",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use vLLM async engine rather than LLM class.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-frontend-multiprocessing",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Disable decoupled async engine frontend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-detokenize",
|
||||
action="store_true",
|
||||
help=("Do not detokenize the response (i.e. do not include "
|
||||
"detokenization time in the measurement)"))
|
||||
help=(
|
||||
"Do not detokenize the response (i.e. do not include "
|
||||
"detokenization time in the measurement)"
|
||||
),
|
||||
)
|
||||
# LoRA
|
||||
parser.add_argument(
|
||||
"--lora-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the lora adapters to use. This can be an absolute path, "
|
||||
"a relative path, or a Hugging Face model identifier.")
|
||||
help="Path to the LoRA adapters to use. This can be an absolute path, "
|
||||
"a relative path, or a Hugging Face model identifier.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefix-len",
|
||||
type=int,
|
||||
@ -614,7 +693,8 @@ if __name__ == "__main__":
|
||||
f"prefix_len (default: {SonnetDataset.DEFAULT_PREFIX_LEN}) "
|
||||
"controls how much of the input is fixed lines versus "
|
||||
"random lines, but the total input length remains approximately "
|
||||
"input_len tokens.")
|
||||
"input_len tokens.",
|
||||
)
|
||||
# random dataset
|
||||
parser.add_argument(
|
||||
"--random-range-ratio",
|
||||
@ -628,14 +708,12 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
# hf dtaset
|
||||
parser.add_argument("--hf-subset",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Subset of the HF dataset.")
|
||||
parser.add_argument("--hf-split",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Split of the HF dataset.")
|
||||
parser.add_argument(
|
||||
"--hf-subset", type=str, default=None, help="Subset of the HF dataset."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-split", type=str, default=None, help="Split of the HF dataset."
|
||||
)
|
||||
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -7,9 +7,9 @@ import os
|
||||
from typing import Any
|
||||
|
||||
|
||||
def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
|
||||
metrics: dict[str, list],
|
||||
extra_info: dict[str, Any]) -> list:
|
||||
def convert_to_pytorch_benchmark_format(
|
||||
args: argparse.Namespace, metrics: dict[str, list], extra_info: dict[str, Any]
|
||||
) -> list:
|
||||
"""
|
||||
Save the benchmark results in the format used by PyTorch OSS benchmark with
|
||||
on metric per record
|
||||
@ -37,12 +37,12 @@ def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
|
||||
},
|
||||
}
|
||||
|
||||
tp = record["benchmark"]["extra_info"]["args"].get(
|
||||
"tensor_parallel_size")
|
||||
tp = record["benchmark"]["extra_info"]["args"].get("tensor_parallel_size")
|
||||
# Save tensor_parallel_size parameter if it's part of the metadata
|
||||
if not tp and "tensor_parallel_size" in extra_info:
|
||||
record["benchmark"]["extra_info"]["args"][
|
||||
"tensor_parallel_size"] = extra_info["tensor_parallel_size"]
|
||||
record["benchmark"]["extra_info"]["args"]["tensor_parallel_size"] = (
|
||||
extra_info["tensor_parallel_size"]
|
||||
)
|
||||
|
||||
records.append(record)
|
||||
|
||||
@ -50,7 +50,6 @@ def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
|
||||
|
||||
|
||||
class InfEncoder(json.JSONEncoder):
|
||||
|
||||
def clear_inf(self, o: Any):
|
||||
if isinstance(o, dict):
|
||||
return {k: self.clear_inf(v) for k, v in o.items()}
|
||||
|
||||
@ -23,8 +23,9 @@ DEFAULT_TP_SIZES = [1]
|
||||
|
||||
|
||||
# bench
|
||||
def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
|
||||
**kwargs) -> TMeasurement:
|
||||
def bench_fn(
|
||||
label: str, sub_label: str, description: str, fn: Callable, *args, **kwargs
|
||||
) -> TMeasurement:
|
||||
min_run_time = 1
|
||||
|
||||
globals = {
|
||||
@ -41,16 +42,18 @@ def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
|
||||
|
||||
def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
||||
sub_label: str) -> Iterable[TMeasurement]:
|
||||
def bench_int8(
|
||||
dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str
|
||||
) -> Iterable[TMeasurement]:
|
||||
assert dtype == torch.int8
|
||||
b_compressed, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k)
|
||||
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
|
||||
bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b,
|
||||
torch.bfloat16)
|
||||
out = ops.cutlass_scaled_sparse_mm(
|
||||
a, b_compressed, e, scale_a, scale_b, torch.bfloat16
|
||||
)
|
||||
out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)
|
||||
|
||||
if not torch.allclose(out, out_ref):
|
||||
@ -63,54 +66,107 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
||||
timers = []
|
||||
# pytorch impl - bfloat16
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
|
||||
torch.mm, a.to(dtype=torch.bfloat16),
|
||||
b.to(dtype=torch.bfloat16)))
|
||||
bench_fn(
|
||||
label,
|
||||
sub_label,
|
||||
"pytorch_bf16_bf16_bf16_matmul-no-scales",
|
||||
torch.mm,
|
||||
a.to(dtype=torch.bfloat16),
|
||||
b.to(dtype=torch.bfloat16),
|
||||
)
|
||||
)
|
||||
|
||||
# pytorch impl - float16
|
||||
timers.append(
|
||||
bench_fn(label, sub_label,
|
||||
"pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm,
|
||||
a.to(dtype=torch.float16), b.to(dtype=torch.float16)))
|
||||
bench_fn(
|
||||
label,
|
||||
sub_label,
|
||||
"pytorch_fp16_fp16_fp16_matmul-no-scales",
|
||||
torch.mm,
|
||||
a.to(dtype=torch.float16),
|
||||
b.to(dtype=torch.float16),
|
||||
)
|
||||
)
|
||||
|
||||
# cutlass impl
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm",
|
||||
ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
|
||||
torch.bfloat16))
|
||||
bench_fn(
|
||||
label,
|
||||
sub_label,
|
||||
"cutlass_i8_i8_bf16_scaled_mm",
|
||||
ops.cutlass_scaled_mm,
|
||||
a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
torch.bfloat16,
|
||||
)
|
||||
)
|
||||
|
||||
# cutlass with bias
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias",
|
||||
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16,
|
||||
bias))
|
||||
bench_fn(
|
||||
label,
|
||||
sub_label,
|
||||
"cutlass_i8_i8_bf16_scaled_mm_bias",
|
||||
ops.cutlass_scaled_mm,
|
||||
a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
torch.bfloat16,
|
||||
bias,
|
||||
)
|
||||
)
|
||||
|
||||
# cutlass sparse impl
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm",
|
||||
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
|
||||
scale_b, torch.bfloat16))
|
||||
bench_fn(
|
||||
label,
|
||||
sub_label,
|
||||
"cutlass_i8_i8_bf16_scaled_sparse_mm",
|
||||
ops.cutlass_scaled_sparse_mm,
|
||||
a,
|
||||
b_compressed,
|
||||
e,
|
||||
scale_a,
|
||||
scale_b,
|
||||
torch.bfloat16,
|
||||
)
|
||||
)
|
||||
|
||||
# cutlass sparse with bias
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm_bias",
|
||||
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
|
||||
scale_b, torch.bfloat16, bias))
|
||||
bench_fn(
|
||||
label,
|
||||
sub_label,
|
||||
"cutlass_i8_i8_bf16_scaled_sparse_mm_bias",
|
||||
ops.cutlass_scaled_sparse_mm,
|
||||
a,
|
||||
b_compressed,
|
||||
e,
|
||||
scale_a,
|
||||
scale_b,
|
||||
torch.bfloat16,
|
||||
bias,
|
||||
)
|
||||
)
|
||||
|
||||
return timers
|
||||
|
||||
|
||||
def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
||||
sub_label: str) -> Iterable[TMeasurement]:
|
||||
def bench_fp8(
|
||||
dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str
|
||||
) -> Iterable[TMeasurement]:
|
||||
assert dtype == torch.float8_e4m3fn
|
||||
b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n,
|
||||
k)
|
||||
b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k)
|
||||
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
|
||||
bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b,
|
||||
torch.bfloat16)
|
||||
out = ops.cutlass_scaled_sparse_mm(
|
||||
a, b_compressed, e, scale_a, scale_b, torch.bfloat16
|
||||
)
|
||||
out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)
|
||||
|
||||
if not torch.allclose(out, out_ref):
|
||||
@ -124,97 +180,165 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
||||
|
||||
# pytorch impl w. bf16
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
|
||||
torch.mm, a.to(dtype=torch.bfloat16, device="cuda"),
|
||||
b.to(dtype=torch.bfloat16, device="cuda")))
|
||||
bench_fn(
|
||||
label,
|
||||
sub_label,
|
||||
"pytorch_bf16_bf16_bf16_matmul-no-scales",
|
||||
torch.mm,
|
||||
a.to(dtype=torch.bfloat16, device="cuda"),
|
||||
b.to(dtype=torch.bfloat16, device="cuda"),
|
||||
)
|
||||
)
|
||||
|
||||
# pytorch impl: bf16 output, without fp8 fast accum
|
||||
timers.append(
|
||||
bench_fn(label,
|
||||
sub_label,
|
||||
"pytorch_fp8_fp8_bf16_scaled_mm",
|
||||
torch._scaled_mm,
|
||||
a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=torch.bfloat16))
|
||||
bench_fn(
|
||||
label,
|
||||
sub_label,
|
||||
"pytorch_fp8_fp8_bf16_scaled_mm",
|
||||
torch._scaled_mm,
|
||||
a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
)
|
||||
|
||||
# pytorch impl: bf16 output, with fp8 fast accum
|
||||
timers.append(
|
||||
bench_fn(label,
|
||||
sub_label,
|
||||
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum",
|
||||
torch._scaled_mm,
|
||||
a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=torch.bfloat16,
|
||||
use_fast_accum=True))
|
||||
bench_fn(
|
||||
label,
|
||||
sub_label,
|
||||
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum",
|
||||
torch._scaled_mm,
|
||||
a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=torch.bfloat16,
|
||||
use_fast_accum=True,
|
||||
)
|
||||
)
|
||||
|
||||
# pytorch impl: fp16 output, without fp8 fast accum
|
||||
timers.append(
|
||||
bench_fn(label,
|
||||
sub_label,
|
||||
"pytorch_fp8_fp8_fp16_scaled_mm",
|
||||
torch._scaled_mm,
|
||||
a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=torch.float16))
|
||||
bench_fn(
|
||||
label,
|
||||
sub_label,
|
||||
"pytorch_fp8_fp8_fp16_scaled_mm",
|
||||
torch._scaled_mm,
|
||||
a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=torch.float16,
|
||||
)
|
||||
)
|
||||
|
||||
# pytorch impl: fp16 output, with fp8 fast accum
|
||||
timers.append(
|
||||
bench_fn(label,
|
||||
sub_label,
|
||||
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum",
|
||||
torch._scaled_mm,
|
||||
a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=torch.float16,
|
||||
use_fast_accum=True))
|
||||
bench_fn(
|
||||
label,
|
||||
sub_label,
|
||||
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum",
|
||||
torch._scaled_mm,
|
||||
a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=torch.float16,
|
||||
use_fast_accum=True,
|
||||
)
|
||||
)
|
||||
|
||||
# cutlass impl: bf16 output
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm",
|
||||
ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
|
||||
torch.bfloat16))
|
||||
bench_fn(
|
||||
label,
|
||||
sub_label,
|
||||
"cutlass_fp8_fp8_bf16_scaled_mm",
|
||||
ops.cutlass_scaled_mm,
|
||||
a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
torch.bfloat16,
|
||||
)
|
||||
)
|
||||
|
||||
# cutlass impl: bf16 output
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_sparse_mm",
|
||||
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
|
||||
scale_b, torch.bfloat16))
|
||||
bench_fn(
|
||||
label,
|
||||
sub_label,
|
||||
"cutlass_fp8_fp8_bf16_scaled_sparse_mm",
|
||||
ops.cutlass_scaled_sparse_mm,
|
||||
a,
|
||||
b_compressed,
|
||||
e,
|
||||
scale_a,
|
||||
scale_b,
|
||||
torch.bfloat16,
|
||||
)
|
||||
)
|
||||
|
||||
# cutlass impl: fp16 output
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_sparse_mm",
|
||||
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
|
||||
scale_b, torch.float16))
|
||||
bench_fn(
|
||||
label,
|
||||
sub_label,
|
||||
"cutlass_fp8_fp8_fp16_scaled_sparse_mm",
|
||||
ops.cutlass_scaled_sparse_mm,
|
||||
a,
|
||||
b_compressed,
|
||||
e,
|
||||
scale_a,
|
||||
scale_b,
|
||||
torch.float16,
|
||||
)
|
||||
)
|
||||
|
||||
# cutlass impl: bf16 output, with bias
|
||||
timers.append(
|
||||
bench_fn(label, sub_label,
|
||||
"cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias",
|
||||
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
|
||||
scale_b, torch.bfloat16, bias))
|
||||
bench_fn(
|
||||
label,
|
||||
sub_label,
|
||||
"cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias",
|
||||
ops.cutlass_scaled_sparse_mm,
|
||||
a,
|
||||
b_compressed,
|
||||
e,
|
||||
scale_a,
|
||||
scale_b,
|
||||
torch.bfloat16,
|
||||
bias,
|
||||
)
|
||||
)
|
||||
|
||||
# cutlass impl: fp16 output, with bias
|
||||
timers.append(
|
||||
bench_fn(label, sub_label,
|
||||
"cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias",
|
||||
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
|
||||
scale_b, torch.float16, bias.to(dtype=torch.float16)))
|
||||
bench_fn(
|
||||
label,
|
||||
sub_label,
|
||||
"cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias",
|
||||
ops.cutlass_scaled_sparse_mm,
|
||||
a,
|
||||
b_compressed,
|
||||
e,
|
||||
scale_a,
|
||||
scale_b,
|
||||
torch.float16,
|
||||
bias.to(dtype=torch.float16),
|
||||
)
|
||||
)
|
||||
|
||||
return timers
|
||||
|
||||
|
||||
def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
||||
sub_label: str) -> Iterable[TMeasurement]:
|
||||
def bench(
|
||||
dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str
|
||||
) -> Iterable[TMeasurement]:
|
||||
if dtype == torch.int8:
|
||||
return bench_int8(dtype, m, k, n, label, sub_label)
|
||||
if dtype == torch.float8_e4m3fn:
|
||||
@ -228,12 +352,12 @@ def print_timers(timers: Iterable[TMeasurement]):
|
||||
compare.print()
|
||||
|
||||
|
||||
def run(dtype: torch.dtype,
|
||||
MKNs: Iterable[tuple[int, int, int]]) -> Iterable[TMeasurement]:
|
||||
def run(
|
||||
dtype: torch.dtype, MKNs: Iterable[tuple[int, int, int]]
|
||||
) -> Iterable[TMeasurement]:
|
||||
results = []
|
||||
for m, k, n in MKNs:
|
||||
timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm",
|
||||
f"MKN=({m}x{k}x{n})")
|
||||
timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", f"MKN=({m}x{k}x{n})")
|
||||
print_timers(timers)
|
||||
results.extend(timers)
|
||||
|
||||
@ -241,10 +365,12 @@ def run(dtype: torch.dtype,
|
||||
|
||||
|
||||
# output makers
|
||||
def make_output(data: Iterable[TMeasurement],
|
||||
MKNs: Iterable[tuple[int, int, int]],
|
||||
base_description: str,
|
||||
timestamp=None):
|
||||
def make_output(
|
||||
data: Iterable[TMeasurement],
|
||||
MKNs: Iterable[tuple[int, int, int]],
|
||||
base_description: str,
|
||||
timestamp=None,
|
||||
):
|
||||
print(f"== All Results {base_description} ====")
|
||||
print_timers(data)
|
||||
|
||||
@ -258,8 +384,7 @@ def make_output(data: Iterable[TMeasurement],
|
||||
|
||||
|
||||
def run_square_bench(args):
|
||||
dim_sizes = list(
|
||||
range(args.dim_start, args.dim_end + 1, args.dim_increment))
|
||||
dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment))
|
||||
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
|
||||
data = run(args.dtype, MKNs)
|
||||
|
||||
@ -319,7 +444,7 @@ def run_model_bench(args):
|
||||
pkl.dump(all_data, f)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
|
||||
def to_torch_dtype(dt):
|
||||
if dt == "int8":
|
||||
@ -344,12 +469,15 @@ Benchmark Cutlass GEMM.
|
||||
Output:
|
||||
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
|
||||
""", # noqa: E501
|
||||
formatter_class=argparse.RawTextHelpFormatter)
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
|
||||
parser.add_argument("--dtype",
|
||||
type=to_torch_dtype,
|
||||
required=True,
|
||||
help="Available options are ['int8', 'fp8']")
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=to_torch_dtype,
|
||||
required=True,
|
||||
help="Available options are ['int8', 'fp8']",
|
||||
)
|
||||
subparsers = parser.add_subparsers(dest="cmd")
|
||||
|
||||
square_parser = subparsers.add_parser("square_bench")
|
||||
@ -368,19 +496,19 @@ Benchmark Cutlass GEMM.
|
||||
range_parser.set_defaults(func=run_range_bench)
|
||||
|
||||
model_parser = subparsers.add_parser("model_bench")
|
||||
model_parser.add_argument("--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=DEFAULT_MODELS,
|
||||
choices=WEIGHT_SHAPES.keys())
|
||||
model_parser.add_argument("--tp-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=DEFAULT_TP_SIZES)
|
||||
model_parser.add_argument("--batch-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=DEFAULT_BATCH_SIZES)
|
||||
model_parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=DEFAULT_MODELS,
|
||||
choices=WEIGHT_SHAPES.keys(),
|
||||
)
|
||||
model_parser.add_argument(
|
||||
"--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES
|
||||
)
|
||||
model_parser.add_argument(
|
||||
"--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
|
||||
)
|
||||
model_parser.set_defaults(func=run_model_bench)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -10,8 +10,9 @@ import vllm._custom_ops as ops
|
||||
|
||||
def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
return torch.round(tensor.clamp(
|
||||
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
|
||||
return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to(
|
||||
dtype=torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
|
||||
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
|
||||
@ -26,10 +27,11 @@ def to_fp16(tensor: torch.Tensor) -> torch.Tensor:
|
||||
return tensor.to(dtype=torch.float16)
|
||||
|
||||
|
||||
def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
|
||||
k: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
a = torch.randn((m, k), device='cuda') * 5
|
||||
b = torch.randn((n, k), device='cuda').t() * 5
|
||||
def make_rand_tensors(
|
||||
dtype: torch.dtype, m: int, n: int, k: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
a = torch.randn((m, k), device="cuda") * 5
|
||||
b = torch.randn((n, k), device="cuda").t() * 5
|
||||
|
||||
if dtype == torch.int8:
|
||||
return to_int8(a), to_int8(b)
|
||||
@ -49,9 +51,7 @@ def prune_to_2_4(tensor):
|
||||
|
||||
# Create binary mask
|
||||
mask = torch.zeros_like(reshaped)
|
||||
mask.scatter_(dim=1,
|
||||
index=indices,
|
||||
src=torch.ones_like(indices, dtype=mask.dtype))
|
||||
mask.scatter_(dim=1, index=indices, src=torch.ones_like(indices, dtype=mask.dtype))
|
||||
|
||||
# Apply mask and reshape back
|
||||
pruned = reshaped * mask
|
||||
@ -62,10 +62,11 @@ def prune_to_2_4(tensor):
|
||||
return pruned.reshape(original_shape)
|
||||
|
||||
|
||||
def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int,
|
||||
k: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
a = torch.randn((m, k), device='cuda') * 5
|
||||
b = torch.randn((n, k), device='cuda').t() * 5
|
||||
def make_rand_sparse_tensors(
|
||||
dtype: torch.dtype, m: int, n: int, k: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
a = torch.randn((m, k), device="cuda") * 5
|
||||
b = torch.randn((n, k), device="cuda").t() * 5
|
||||
|
||||
b = prune_to_2_4(b.t()).t()
|
||||
|
||||
@ -86,9 +87,9 @@ def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int,
|
||||
return b_compressed, e, a, b
|
||||
|
||||
|
||||
def make_n_rand_sparse_tensors(num_tensors: int, dtype: torch.dtype,
|
||||
m: int, n: int, k: int) -> \
|
||||
tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]:
|
||||
def make_n_rand_sparse_tensors(
|
||||
num_tensors: int, dtype: torch.dtype, m: int, n: int, k: int
|
||||
) -> tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]:
|
||||
ABs = []
|
||||
for _ in range(num_tensors):
|
||||
b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k)
|
||||
|
||||
@ -16,7 +16,8 @@ from weight_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
w8a8_block_fp8_matmul)
|
||||
w8a8_block_fp8_matmul,
|
||||
)
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
|
||||
@ -25,8 +26,9 @@ DEFAULT_TP_SIZES = [1]
|
||||
|
||||
|
||||
# bench
|
||||
def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
|
||||
**kwargs) -> TMeasurement:
|
||||
def bench_fn(
|
||||
label: str, sub_label: str, description: str, fn: Callable, *args, **kwargs
|
||||
) -> TMeasurement:
|
||||
min_run_time = 1
|
||||
|
||||
globals = {
|
||||
@ -44,45 +46,48 @@ def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
|
||||
|
||||
|
||||
def bench_int8(
|
||||
dtype: torch.dtype,
|
||||
m: int,
|
||||
k: int,
|
||||
n: int,
|
||||
label: str,
|
||||
sub_label: str,
|
||||
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]:
|
||||
dtype: torch.dtype,
|
||||
m: int,
|
||||
k: int,
|
||||
n: int,
|
||||
label: str,
|
||||
sub_label: str,
|
||||
bench_kernels: Optional[list[str]] = None,
|
||||
) -> Iterable[TMeasurement]:
|
||||
"""Benchmark INT8-based kernels."""
|
||||
assert dtype == torch.int8
|
||||
a, b = make_rand_tensors(torch.int8, m, n, k)
|
||||
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
|
||||
azp = torch.zeros((m, ), device="cuda", dtype=torch.int32)
|
||||
azp_adj = torch.zeros((n, ), device="cuda", dtype=torch.int32)
|
||||
bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
|
||||
azp = torch.zeros((m,), device="cuda", dtype=torch.int32)
|
||||
azp_adj = torch.zeros((n,), device="cuda", dtype=torch.int32)
|
||||
|
||||
bench_fns = {
|
||||
"pytorch_bf16_bf16_bf16_matmul-no-scales":
|
||||
lambda: torch.mm(a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
|
||||
),
|
||||
"pytorch_fp16_fp16_fp16_matmul-no-scales":
|
||||
lambda: torch.mm(a.to(dtype=torch.float16), b.to(dtype=torch.float16)),
|
||||
"cutlass_i8_i8_bf16_scaled_mm":
|
||||
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16),
|
||||
"cutlass_i8_i8_bf16_scaled_mm_bias":
|
||||
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16,
|
||||
bias),
|
||||
"cutlass_i8_i8_bf16_scaled_mm_azp":
|
||||
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
|
||||
bfloat16, azp_adj),
|
||||
"cutlass_i8_i8_bf16_scaled_mm_azp_bias":
|
||||
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
|
||||
bfloat16, azp_adj, None, bias),
|
||||
"cutlass_i8_i8_bf16_scaled_mm_azp_pt":
|
||||
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
|
||||
bfloat16, azp_adj, azp),
|
||||
"cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias":
|
||||
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
|
||||
bfloat16, azp_adj, azp, bias),
|
||||
"pytorch_bf16_bf16_bf16_matmul-no-scales": lambda: torch.mm(
|
||||
a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
|
||||
),
|
||||
"pytorch_fp16_fp16_fp16_matmul-no-scales": lambda: torch.mm(
|
||||
a.to(dtype=torch.float16), b.to(dtype=torch.float16)
|
||||
),
|
||||
"cutlass_i8_i8_bf16_scaled_mm": lambda: ops.cutlass_scaled_mm(
|
||||
a, b, scale_a, scale_b, torch.bfloat16
|
||||
),
|
||||
"cutlass_i8_i8_bf16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
|
||||
a, b, scale_a, scale_b, torch.bfloat16, bias
|
||||
),
|
||||
"cutlass_i8_i8_bf16_scaled_mm_azp": lambda: ops.cutlass_scaled_mm_azp(
|
||||
a, b, scale_a, scale_b, torch.bfloat16, azp_adj
|
||||
),
|
||||
"cutlass_i8_i8_bf16_scaled_mm_azp_bias": lambda: ops.cutlass_scaled_mm_azp(
|
||||
a, b, scale_a, scale_b, torch.bfloat16, azp_adj, None, bias
|
||||
),
|
||||
"cutlass_i8_i8_bf16_scaled_mm_azp_pt": lambda: ops.cutlass_scaled_mm_azp(
|
||||
a, b, scale_a, scale_b, torch.bfloat16, azp_adj, azp
|
||||
),
|
||||
"cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias": lambda: ops.cutlass_scaled_mm_azp(
|
||||
a, b, scale_a, scale_b, torch.bfloat16, azp_adj, azp, bias
|
||||
),
|
||||
}
|
||||
|
||||
timers = []
|
||||
@ -96,73 +101,73 @@ def bench_int8(
|
||||
|
||||
|
||||
def bench_fp8(
|
||||
dtype: torch.dtype,
|
||||
m: int,
|
||||
k: int,
|
||||
n: int,
|
||||
label: str,
|
||||
sub_label: str,
|
||||
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]:
|
||||
dtype: torch.dtype,
|
||||
m: int,
|
||||
k: int,
|
||||
n: int,
|
||||
label: str,
|
||||
sub_label: str,
|
||||
bench_kernels: Optional[list[str]] = None,
|
||||
) -> Iterable[TMeasurement]:
|
||||
"""Benchmark FP8-based kernels."""
|
||||
assert dtype == torch.float8_e4m3fn
|
||||
a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k)
|
||||
a_cont = a.contiguous()
|
||||
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
block_scale_a = torch.rand((m, k // 128),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
block_scale_b = torch.rand((k // 128, n // 128),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
|
||||
def ceil_div(x: int, y: int) -> int:
|
||||
return (x + y - 1) // y
|
||||
|
||||
block_scale_a = torch.rand(
|
||||
(m, ceil_div(k, 128)), device="cuda", dtype=torch.float32
|
||||
)
|
||||
block_scale_b = torch.rand(
|
||||
ceil_div(k, 128), ceil_div(n, 128), device="cuda", dtype=torch.float32
|
||||
)
|
||||
block_scale_a_M_major = block_scale_a.t().contiguous().t()
|
||||
block_scale_b_K_major = block_scale_b.t().contiguous().t()
|
||||
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
|
||||
bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
print(m, k, n)
|
||||
|
||||
bench_fns = {
|
||||
"pytorch_bf16_bf16_bf16_matmul-no-scales":
|
||||
lambda: torch.mm(a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
|
||||
),
|
||||
"pytorch_fp16_fp16_fp16_matmul-no-scales":
|
||||
lambda: torch.mm(a.to(dtype=torch.float16), b.to(dtype=torch.float16)),
|
||||
"pytorch_fp8_fp8_fp16_scaled_mm":
|
||||
lambda: torch._scaled_mm(
|
||||
a, b, scale_a, scale_b, out_dtype=torch.float16),
|
||||
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum":
|
||||
lambda: torch._scaled_mm(a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=torch.float16,
|
||||
use_fast_accum=True),
|
||||
"pytorch_fp8_fp8_bf16_scaled_mm":
|
||||
lambda: torch._scaled_mm(
|
||||
a, b, scale_a, scale_b, out_dtype=torch.bfloat16),
|
||||
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum":
|
||||
lambda: torch._scaled_mm(a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=torch.bfloat16,
|
||||
use_fast_accum=True),
|
||||
"cutlass_fp8_fp8_bf16_scaled_mm":
|
||||
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16),
|
||||
"cutlass_fp8_fp8_fp16_scaled_mm":
|
||||
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.float16),
|
||||
"cutlass_fp8_fp8_bf16_scaled_mm_bias":
|
||||
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16,
|
||||
bias),
|
||||
"cutlass_fp8_fp8_fp16_scaled_mm_bias":
|
||||
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.float16,
|
||||
bias.to(dtype=torch.float16)),
|
||||
"triton_fp8_fp8_fp16_scaled_mm_blockwise":
|
||||
lambda: w8a8_block_fp8_matmul(a_cont, b.t(), block_scale_a,
|
||||
block_scale_b.t(), (128, 128)),
|
||||
"cutlass_fp8_fp8_fp16_scaled_mm_blockwise":
|
||||
lambda: ops.cutlass_scaled_mm(a, b, block_scale_a_M_major,
|
||||
block_scale_b_K_major, torch.float16),
|
||||
"pytorch_bf16_bf16_bf16_matmul-no-scales": lambda: torch.mm(
|
||||
a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
|
||||
),
|
||||
"pytorch_fp16_fp16_fp16_matmul-no-scales": lambda: torch.mm(
|
||||
a.to(dtype=torch.float16), b.to(dtype=torch.float16)
|
||||
),
|
||||
"pytorch_fp8_fp8_fp16_scaled_mm": lambda: torch._scaled_mm(
|
||||
a, b, scale_a, scale_b, out_dtype=torch.float16
|
||||
),
|
||||
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum": lambda: torch._scaled_mm(
|
||||
a, b, scale_a, scale_b, out_dtype=torch.float16, use_fast_accum=True
|
||||
),
|
||||
"pytorch_fp8_fp8_bf16_scaled_mm": lambda: torch._scaled_mm(
|
||||
a, b, scale_a, scale_b, out_dtype=torch.bfloat16
|
||||
),
|
||||
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum": lambda: torch._scaled_mm(
|
||||
a, b, scale_a, scale_b, out_dtype=torch.bfloat16, use_fast_accum=True
|
||||
),
|
||||
"cutlass_fp8_fp8_bf16_scaled_mm": lambda: ops.cutlass_scaled_mm(
|
||||
a, b, scale_a, scale_b, torch.bfloat16
|
||||
),
|
||||
"cutlass_fp8_fp8_fp16_scaled_mm": lambda: ops.cutlass_scaled_mm(
|
||||
a, b, scale_a, scale_b, torch.float16
|
||||
),
|
||||
"cutlass_fp8_fp8_bf16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
|
||||
a, b, scale_a, scale_b, torch.bfloat16, bias
|
||||
),
|
||||
"cutlass_fp8_fp8_fp16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
|
||||
a, b, scale_a, scale_b, torch.float16, bias.to(dtype=torch.float16)
|
||||
),
|
||||
"triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_block_fp8_matmul(
|
||||
a_cont, b.t(), block_scale_a, block_scale_b.t(), (128, 128)
|
||||
),
|
||||
"cutlass_fp8_fp8_fp16_scaled_mm_blockwise": lambda: ops.cutlass_scaled_mm(
|
||||
a, b, block_scale_a_M_major, block_scale_b_K_major, torch.float16
|
||||
),
|
||||
}
|
||||
|
||||
timers = []
|
||||
@ -175,13 +180,15 @@ def bench_fp8(
|
||||
return timers
|
||||
|
||||
|
||||
def bench(dtype: torch.dtype,
|
||||
m: int,
|
||||
k: int,
|
||||
n: int,
|
||||
label: str,
|
||||
sub_label: str,
|
||||
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]:
|
||||
def bench(
|
||||
dtype: torch.dtype,
|
||||
m: int,
|
||||
k: int,
|
||||
n: int,
|
||||
label: str,
|
||||
sub_label: str,
|
||||
bench_kernels: Optional[list[str]] = None,
|
||||
) -> Iterable[TMeasurement]:
|
||||
if dtype == torch.int8:
|
||||
return bench_int8(dtype, m, k, n, label, sub_label, bench_kernels)
|
||||
if dtype == torch.float8_e4m3fn:
|
||||
@ -195,27 +202,33 @@ def print_timers(timers: Iterable[TMeasurement]):
|
||||
compare.print()
|
||||
|
||||
|
||||
def run(dtype: torch.dtype,
|
||||
MKNs: Iterable[tuple[int, int, int]],
|
||||
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]:
|
||||
def run(
|
||||
dtype: torch.dtype,
|
||||
MKNs: Iterable[tuple[int, int, int]],
|
||||
bench_kernels: Optional[list[str]] = None,
|
||||
) -> Iterable[TMeasurement]:
|
||||
results = []
|
||||
for m, k, n in MKNs:
|
||||
timers = bench(dtype,
|
||||
m,
|
||||
k,
|
||||
n,
|
||||
f"scaled-{dtype}-gemm",
|
||||
f"MKN=({m}x{k}x{n})",
|
||||
bench_kernels=bench_kernels)
|
||||
timers = bench(
|
||||
dtype,
|
||||
m,
|
||||
k,
|
||||
n,
|
||||
f"scaled-{dtype}-gemm",
|
||||
f"MKN=({m}x{k}x{n})",
|
||||
bench_kernels=bench_kernels,
|
||||
)
|
||||
print_timers(timers)
|
||||
results.extend(timers)
|
||||
return results
|
||||
|
||||
|
||||
def make_output(data: Iterable[TMeasurement],
|
||||
MKNs: Iterable[tuple[int, int, int]],
|
||||
base_description: str,
|
||||
timestamp=None):
|
||||
def make_output(
|
||||
data: Iterable[TMeasurement],
|
||||
MKNs: Iterable[tuple[int, int, int]],
|
||||
base_description: str,
|
||||
timestamp=None,
|
||||
):
|
||||
print(f"== All Results {base_description} ====")
|
||||
print_timers(data)
|
||||
|
||||
@ -226,8 +239,7 @@ def make_output(data: Iterable[TMeasurement],
|
||||
|
||||
|
||||
def run_square_bench(args):
|
||||
dim_sizes = list(
|
||||
range(args.dim_start, args.dim_end + 1, args.dim_increment))
|
||||
dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment))
|
||||
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
|
||||
data = run(args.dtype, MKNs, bench_kernels=args.kernels)
|
||||
make_output(data, MKNs, f"square_bench-{args.dtype}")
|
||||
@ -285,7 +297,7 @@ def run_model_bench(args):
|
||||
pkl.dump(all_data, f)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
|
||||
def to_torch_dtype(dt):
|
||||
if dt == "int8":
|
||||
@ -310,19 +322,21 @@ Benchmark Cutlass GEMM.
|
||||
Output:
|
||||
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
|
||||
""", # noqa: E501
|
||||
formatter_class=argparse.RawTextHelpFormatter)
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
|
||||
parser.add_argument("--dtype",
|
||||
type=to_torch_dtype,
|
||||
required=True,
|
||||
help="Available options are ['int8', 'fp8']")
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=to_torch_dtype,
|
||||
required=True,
|
||||
help="Available options are ['int8', 'fp8']",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kernels",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=None,
|
||||
help=
|
||||
"Exact names of the kernels to benchmark. If not set, runs all kernels."
|
||||
help="Exact names of the kernels to benchmark. If not set, runs all kernels.",
|
||||
)
|
||||
|
||||
subparsers = parser.add_subparsers(dest="cmd")
|
||||
@ -343,19 +357,19 @@ Benchmark Cutlass GEMM.
|
||||
range_parser.set_defaults(func=run_range_bench)
|
||||
|
||||
model_parser = subparsers.add_parser("model_bench")
|
||||
model_parser.add_argument("--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=DEFAULT_MODELS,
|
||||
choices=WEIGHT_SHAPES.keys())
|
||||
model_parser.add_argument("--tp-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=DEFAULT_TP_SIZES)
|
||||
model_parser.add_argument("--batch-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=DEFAULT_BATCH_SIZES)
|
||||
model_parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=DEFAULT_MODELS,
|
||||
choices=WEIGHT_SHAPES.keys(),
|
||||
)
|
||||
model_parser.add_argument(
|
||||
"--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES
|
||||
)
|
||||
model_parser.add_argument(
|
||||
"--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
|
||||
)
|
||||
model_parser.set_defaults(func=run_model_bench)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -42,4 +42,4 @@ WEIGHT_SHAPES = {
|
||||
([8192, 57344], 1),
|
||||
([28672, 8192], 0),
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
@ -12,39 +12,37 @@ app = Quart(__name__)
|
||||
|
||||
async def forward_request(url, data):
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
|
||||
}
|
||||
async with session.post(url=url, json=data,
|
||||
headers=headers) as response:
|
||||
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
||||
async with session.post(url=url, json=data, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
# if response.headers.get('Transfer-Encoding') == 'chunked':
|
||||
if True:
|
||||
async for chunk_bytes in response.content.iter_chunked(
|
||||
1024):
|
||||
async for chunk_bytes in response.content.iter_chunked(1024):
|
||||
yield chunk_bytes
|
||||
else:
|
||||
content = await response.read()
|
||||
yield content
|
||||
|
||||
|
||||
@app.route('/v1/completions', methods=['POST'])
|
||||
@app.route("/v1/completions", methods=["POST"])
|
||||
async def handle_request():
|
||||
try:
|
||||
original_request_data = await request.get_json()
|
||||
|
||||
prefill_request = original_request_data.copy()
|
||||
# change max_tokens = 1 to let it only do prefill
|
||||
prefill_request['max_tokens'] = 1
|
||||
prefill_request["max_tokens"] = 1
|
||||
|
||||
# finish prefill
|
||||
async for _ in forward_request('http://localhost:8100/v1/completions',
|
||||
prefill_request):
|
||||
async for _ in forward_request(
|
||||
"http://localhost:8100/v1/completions", prefill_request
|
||||
):
|
||||
continue
|
||||
|
||||
# return decode
|
||||
generator = forward_request('http://localhost:8200/v1/completions',
|
||||
original_request_data)
|
||||
generator = forward_request(
|
||||
"http://localhost:8200/v1/completions", original_request_data
|
||||
)
|
||||
response = await make_response(generator)
|
||||
response.timeout = None
|
||||
|
||||
@ -53,11 +51,12 @@ async def handle_request():
|
||||
except Exception as e:
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
exc_info = sys.exc_info()
|
||||
print("Error occurred in disagg prefill proxy server")
|
||||
print(e)
|
||||
print("".join(traceback.format_exception(*exc_info)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
app.run(port=8000)
|
||||
|
||||
@ -8,7 +8,6 @@ from aiohttp import web
|
||||
|
||||
|
||||
class RoundRobinProxy:
|
||||
|
||||
def __init__(self, target_ports):
|
||||
self.target_ports = target_ports
|
||||
self.port_cycle = itertools.cycle(self.target_ports)
|
||||
@ -21,14 +20,15 @@ class RoundRobinProxy:
|
||||
try:
|
||||
# Forward the request
|
||||
async with session.request(
|
||||
method=request.method,
|
||||
url=target_url,
|
||||
headers=request.headers,
|
||||
data=request.content,
|
||||
method=request.method,
|
||||
url=target_url,
|
||||
headers=request.headers,
|
||||
data=request.content,
|
||||
) as response:
|
||||
# Start sending the response
|
||||
resp = web.StreamResponse(status=response.status,
|
||||
headers=response.headers)
|
||||
resp = web.StreamResponse(
|
||||
status=response.status, headers=response.headers
|
||||
)
|
||||
await resp.prepare(request)
|
||||
|
||||
# Stream the response content
|
||||
@ -45,11 +45,11 @@ class RoundRobinProxy:
|
||||
async def main():
|
||||
proxy = RoundRobinProxy([8100, 8200])
|
||||
app = web.Application()
|
||||
app.router.add_route('*', '/{path:.*}', proxy.handle_request)
|
||||
app.router.add_route("*", "/{path:.*}", proxy.handle_request)
|
||||
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, 'localhost', 8000)
|
||||
site = web.TCPSite(runner, "localhost", 8000)
|
||||
await site.start()
|
||||
|
||||
print("Proxy server started on http://localhost:8000")
|
||||
@ -58,5 +58,5 @@ async def main():
|
||||
await asyncio.Event().wait()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
@ -6,43 +6,41 @@ import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
data = []
|
||||
for name in ['disagg_prefill', 'chunked_prefill']:
|
||||
for name in ["disagg_prefill", "chunked_prefill"]:
|
||||
for qps in [2, 4, 6, 8]:
|
||||
with open(f"results/{name}-qps-{qps}.json") as f:
|
||||
x = json.load(f)
|
||||
x['name'] = name
|
||||
x['qps'] = qps
|
||||
x["name"] = name
|
||||
x["qps"] = qps
|
||||
data.append(x)
|
||||
|
||||
df = pd.DataFrame.from_dict(data)
|
||||
dis_df = df[df['name'] == 'disagg_prefill']
|
||||
chu_df = df[df['name'] == 'chunked_prefill']
|
||||
dis_df = df[df["name"] == "disagg_prefill"]
|
||||
chu_df = df[df["name"] == "chunked_prefill"]
|
||||
|
||||
plt.style.use('bmh')
|
||||
plt.rcParams['font.size'] = 20
|
||||
plt.style.use("bmh")
|
||||
plt.rcParams["font.size"] = 20
|
||||
|
||||
for key in [
|
||||
'mean_ttft_ms', 'median_ttft_ms', 'p99_ttft_ms', 'mean_itl_ms',
|
||||
'median_itl_ms', 'p99_itl_ms'
|
||||
"mean_ttft_ms",
|
||||
"median_ttft_ms",
|
||||
"p99_ttft_ms",
|
||||
"mean_itl_ms",
|
||||
"median_itl_ms",
|
||||
"p99_itl_ms",
|
||||
]:
|
||||
|
||||
fig, ax = plt.subplots(figsize=(11, 7))
|
||||
plt.plot(dis_df['qps'],
|
||||
dis_df[key],
|
||||
label='disagg_prefill',
|
||||
marker='o',
|
||||
linewidth=4)
|
||||
plt.plot(chu_df['qps'],
|
||||
chu_df[key],
|
||||
label='chunked_prefill',
|
||||
marker='o',
|
||||
linewidth=4)
|
||||
plt.plot(
|
||||
dis_df["qps"], dis_df[key], label="disagg_prefill", marker="o", linewidth=4
|
||||
)
|
||||
plt.plot(
|
||||
chu_df["qps"], chu_df[key], label="chunked_prefill", marker="o", linewidth=4
|
||||
)
|
||||
ax.legend()
|
||||
|
||||
ax.set_xlabel('QPS')
|
||||
ax.set_xlabel("QPS")
|
||||
ax.set_ylabel(key)
|
||||
ax.set_ylim(bottom=0)
|
||||
fig.savefig(f'results/{key}.png')
|
||||
fig.savefig(f"results/{key}.png")
|
||||
plt.close(fig)
|
||||
|
||||
@ -24,10 +24,12 @@ class bench_params_t:
|
||||
dtype: torch.dtype
|
||||
|
||||
def description(self):
|
||||
return (f'N {self.num_tokens} '
|
||||
f'x D {self.hidden_size} '
|
||||
f'x R {self.add_residual} '
|
||||
f'x DT {self.dtype}')
|
||||
return (
|
||||
f"N {self.num_tokens} "
|
||||
f"x D {self.hidden_size} "
|
||||
f"x R {self.add_residual} "
|
||||
f"x DT {self.dtype}"
|
||||
)
|
||||
|
||||
|
||||
def get_bench_params() -> list[bench_params_t]:
|
||||
@ -38,15 +40,19 @@ def get_bench_params() -> list[bench_params_t]:
|
||||
DTYPES = [torch.bfloat16, torch.float]
|
||||
|
||||
combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES)
|
||||
bench_params = list(map(lambda x: \
|
||||
bench_params_t(x[0], x[1], x[2], x[3]), combinations))
|
||||
bench_params = list(
|
||||
map(lambda x: bench_params_t(x[0], x[1], x[2], x[3]), combinations)
|
||||
)
|
||||
return bench_params
|
||||
|
||||
|
||||
# Reference impls
|
||||
def unfused_int8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
quant_dtype: torch.dtype):
|
||||
def unfused_int8_impl(
|
||||
rms_norm_layer: RMSNorm,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
quant_dtype: torch.dtype,
|
||||
):
|
||||
# Norm
|
||||
torch_out = None
|
||||
if residual is None:
|
||||
@ -58,9 +64,12 @@ def unfused_int8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor,
|
||||
torch_out, _, _ = ops.scaled_int8_quant(torch_out)
|
||||
|
||||
|
||||
def unfused_fp8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
quant_dtype: torch.dtype):
|
||||
def unfused_fp8_impl(
|
||||
rms_norm_layer: RMSNorm,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
quant_dtype: torch.dtype,
|
||||
):
|
||||
# Norm
|
||||
torch_out = None
|
||||
if residual is None:
|
||||
@ -73,22 +82,27 @@ def unfused_fp8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor,
|
||||
|
||||
|
||||
def fused_impl(
|
||||
rms_norm_layer: RMSNorm, # this stores the weights
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
quant_dtype: torch.dtype):
|
||||
out, _ = ops.rms_norm_dynamic_per_token_quant(x,
|
||||
rms_norm_layer.weight,
|
||||
1e-6,
|
||||
quant_dtype,
|
||||
residual=residual)
|
||||
rms_norm_layer: RMSNorm, # this stores the weights
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
quant_dtype: torch.dtype,
|
||||
):
|
||||
out, _ = ops.rms_norm_dynamic_per_token_quant(
|
||||
x, rms_norm_layer.weight, 1e-6, quant_dtype, residual=residual
|
||||
)
|
||||
|
||||
|
||||
# Bench functions
|
||||
def bench_fn(rms_norm_layer: RMSNorm, x: torch.Tensor, residual: torch.Tensor,
|
||||
quant_dtype: torch.dtype, label: str, sub_label: str,
|
||||
fn: Callable, description: str) -> TMeasurement:
|
||||
|
||||
def bench_fn(
|
||||
rms_norm_layer: RMSNorm,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
quant_dtype: torch.dtype,
|
||||
label: str,
|
||||
sub_label: str,
|
||||
fn: Callable,
|
||||
description: str,
|
||||
) -> TMeasurement:
|
||||
min_run_time = 1
|
||||
|
||||
globals = {
|
||||
@ -106,43 +120,81 @@ def bench_fn(rms_norm_layer: RMSNorm, x: torch.Tensor, residual: torch.Tensor,
|
||||
description=description,
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
|
||||
def bench(params: bench_params_t, label: str, sub_label: str) \
|
||||
-> Iterable[TMeasurement]:
|
||||
|
||||
def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasurement]:
|
||||
# Make inputs
|
||||
layer = RMSNorm(params.hidden_size, 1e-6).to(dtype=params.dtype)
|
||||
# Make weights
|
||||
layer.weight.data.normal_(mean=1.0, std=0.1)
|
||||
# Make inputs
|
||||
scale = 1 / params.hidden_size
|
||||
x = torch.randn(params.num_tokens,
|
||||
params.hidden_size,
|
||||
dtype=params.dtype,
|
||||
device='cuda') * scale
|
||||
residual = (torch.randn_like(x) * scale).to(device='cuda') \
|
||||
if params.add_residual else None
|
||||
x = (
|
||||
torch.randn(
|
||||
params.num_tokens, params.hidden_size, dtype=params.dtype, device="cuda"
|
||||
)
|
||||
* scale
|
||||
)
|
||||
residual = (
|
||||
(torch.randn_like(x) * scale).to(device="cuda") if params.add_residual else None
|
||||
)
|
||||
|
||||
timers = []
|
||||
|
||||
# unfused int8 impl.
|
||||
timers.append(
|
||||
bench_fn(layer, x, residual, torch.int8, label, sub_label,
|
||||
unfused_int8_impl, "unfused_int8_impl"))
|
||||
bench_fn(
|
||||
layer,
|
||||
x,
|
||||
residual,
|
||||
torch.int8,
|
||||
label,
|
||||
sub_label,
|
||||
unfused_int8_impl,
|
||||
"unfused_int8_impl",
|
||||
)
|
||||
)
|
||||
|
||||
# unfused fp8 impl.
|
||||
timers.append(
|
||||
bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label,
|
||||
unfused_fp8_impl, "unfused_fp8_impl"))
|
||||
bench_fn(
|
||||
layer,
|
||||
x,
|
||||
residual,
|
||||
torch.float8_e4m3fn,
|
||||
label,
|
||||
sub_label,
|
||||
unfused_fp8_impl,
|
||||
"unfused_fp8_impl",
|
||||
)
|
||||
)
|
||||
|
||||
# fused int8 impl.
|
||||
timers.append(
|
||||
bench_fn(layer, x, residual, torch.int8, label, sub_label, fused_impl,
|
||||
"fused_int8_impl"))
|
||||
bench_fn(
|
||||
layer,
|
||||
x,
|
||||
residual,
|
||||
torch.int8,
|
||||
label,
|
||||
sub_label,
|
||||
fused_impl,
|
||||
"fused_int8_impl",
|
||||
)
|
||||
)
|
||||
|
||||
# fused fp8 impl.
|
||||
timers.append(
|
||||
bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label,
|
||||
fused_impl, "fused_fp8_impl"))
|
||||
bench_fn(
|
||||
layer,
|
||||
x,
|
||||
residual,
|
||||
torch.float8_e4m3fn,
|
||||
label,
|
||||
sub_label,
|
||||
fused_impl,
|
||||
"fused_fp8_impl",
|
||||
)
|
||||
)
|
||||
|
||||
print_timers(timers)
|
||||
|
||||
@ -157,13 +209,12 @@ def print_timers(timers: Iterable[TMeasurement]):
|
||||
|
||||
|
||||
def main():
|
||||
torch.set_default_device('cuda')
|
||||
torch.set_default_device("cuda")
|
||||
bench_params = get_bench_params()
|
||||
|
||||
timers = []
|
||||
for bp in tqdm(bench_params):
|
||||
timers.extend(
|
||||
bench(bp, "rms-norm-dynamic-per-token-quant", bp.description()))
|
||||
timers.extend(bench(bp, "rms-norm-dynamic-per-token-quant", bp.description()))
|
||||
print_timers(timers)
|
||||
|
||||
# pickle all the results
|
||||
@ -172,5 +223,5 @@ def main():
|
||||
pkl.dump(timers, f)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -9,32 +9,39 @@ import torch.nn.functional as F
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.aqlm import (
|
||||
dequantize_weight, generic_dequantize_gemm, get_int_dtype,
|
||||
optimized_dequantize_gemm)
|
||||
dequantize_weight,
|
||||
generic_dequantize_gemm,
|
||||
get_int_dtype,
|
||||
optimized_dequantize_gemm,
|
||||
)
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
|
||||
|
||||
def torch_mult(
|
||||
input: torch.Tensor, # [..., in_features]
|
||||
weights: torch.Tensor,
|
||||
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
|
||||
# [..., in_features]
|
||||
input: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
# [num_out_groups, 1, 1, 1]
|
||||
scales: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
output = F.linear(input, weights)
|
||||
return output
|
||||
|
||||
|
||||
def dequant_out_scale(
|
||||
input: torch.Tensor, # [..., in_features]
|
||||
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
|
||||
codebooks: torch.
|
||||
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
|
||||
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
|
||||
# [..., in_features]
|
||||
input: torch.Tensor,
|
||||
# [num_out_groups, num_in_groups, num_codebooks]
|
||||
codes: torch.IntTensor,
|
||||
# [num_codebooks, codebook_size, out_group_size, in_group_size]
|
||||
codebooks: torch.Tensor,
|
||||
# [num_out_groups, 1, 1, 1]
|
||||
scales: torch.Tensor,
|
||||
output_partition_sizes: torch.IntTensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
|
||||
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
|
||||
|
||||
if bias is None:
|
||||
@ -46,40 +53,42 @@ def dequant_out_scale(
|
||||
flattened_output *= b_scales
|
||||
return flattened_output.view(orig_shape)
|
||||
else:
|
||||
b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
|
||||
-1, weights.shape[1])
|
||||
b_scales = scales.view(scales.shape[:-3] + (-1,)).expand(-1, weights.shape[1])
|
||||
weights *= b_scales
|
||||
return F.linear(input, weights, bias)
|
||||
|
||||
|
||||
def dequant_weight_scale(
|
||||
input: torch.Tensor, # [..., in_features]
|
||||
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
|
||||
codebooks: torch.
|
||||
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
|
||||
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
|
||||
# [..., in_features]
|
||||
input: torch.Tensor,
|
||||
# [num_out_groups, num_in_groups, num_codebooks]
|
||||
codes: torch.IntTensor,
|
||||
# [num_codebooks, codebook_size, out_group_size, in_group_size]
|
||||
codebooks: torch.Tensor,
|
||||
# [num_out_groups, 1, 1, 1]
|
||||
scales: torch.Tensor,
|
||||
output_partition_sizes: torch.IntTensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
|
||||
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
|
||||
|
||||
b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
|
||||
-1, weights.shape[1])
|
||||
b_scales = scales.view(scales.shape[:-3] + (-1,)).expand(-1, weights.shape[1])
|
||||
weights *= b_scales
|
||||
return F.linear(input, weights, bias)
|
||||
|
||||
|
||||
def dequant_no_scale(
|
||||
input: torch.Tensor, # [..., in_features]
|
||||
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
|
||||
codebooks: torch.
|
||||
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
|
||||
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
|
||||
# [..., in_features]
|
||||
input: torch.Tensor,
|
||||
# [num_out_groups, num_in_groups, num_codebooks]
|
||||
codes: torch.IntTensor,
|
||||
# [num_codebooks, codebook_size, out_group_size, in_group_size]
|
||||
codebooks: torch.Tensor,
|
||||
# [num_out_groups, 1, 1, 1]
|
||||
scales: torch.Tensor,
|
||||
output_partition_sizes: torch.IntTensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
|
||||
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
|
||||
|
||||
return F.linear(input, weights, bias)
|
||||
@ -89,23 +98,26 @@ def dequant_no_scale(
|
||||
# the generic pytorch version.
|
||||
# Just visual comparison.
|
||||
def dequant_test(k: int, parts: torch.Tensor, nbooks: int, bits: int) -> None:
|
||||
|
||||
n = int(parts.sum().item())
|
||||
|
||||
device = torch.device('cuda:0')
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
code_range = (1 << bits) // 2
|
||||
ingroups = 8
|
||||
|
||||
codes = torch.randint(-code_range,
|
||||
code_range,
|
||||
size=(n, k // ingroups, nbooks),
|
||||
dtype=get_int_dtype(bits),
|
||||
device=device)
|
||||
codes = torch.randint(
|
||||
-code_range,
|
||||
code_range,
|
||||
size=(n, k // ingroups, nbooks),
|
||||
dtype=get_int_dtype(bits),
|
||||
device=device,
|
||||
)
|
||||
|
||||
codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
|
||||
dtype=torch.float16,
|
||||
device=device)
|
||||
codebooks = torch.randn(
|
||||
size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
|
||||
dtype=torch.float16,
|
||||
device=device,
|
||||
)
|
||||
|
||||
count = 0
|
||||
for index in range(16):
|
||||
@ -138,24 +150,25 @@ def dequant_test(k: int, parts: torch.Tensor, nbooks: int, bits: int) -> None:
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
parser = FlexibleArgumentParser(description="Benchmark aqlm performance.")
|
||||
|
||||
# Add arguments
|
||||
parser.add_argument("--nbooks",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of codebooks (default: 1)")
|
||||
parser.add_argument("--bits",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Number of bits per code element (default: 16)")
|
||||
parser.add_argument(
|
||||
"--nbooks", type=int, default=1, help="Number of codebooks (default: 1)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bits",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Number of bits per code element (default: 16)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Run the decompression/dequant tester rather than benchmarking "
|
||||
"(default: False)")
|
||||
"(default: False)",
|
||||
)
|
||||
|
||||
# Parse the arguments
|
||||
args = parser.parse_args()
|
||||
@ -165,7 +178,7 @@ def main():
|
||||
bits = args.bits
|
||||
|
||||
if args.test:
|
||||
dequant_test(4096, torch.tensor((4096, )), nbooks, bits)
|
||||
dequant_test(4096, torch.tensor((4096,)), nbooks, bits)
|
||||
return
|
||||
|
||||
# Otherwise, benchmark.
|
||||
@ -184,31 +197,54 @@ def main():
|
||||
with open(filename, "w") as f:
|
||||
sys.stdout = f
|
||||
|
||||
print('m | k | n | n parts', end='')
|
||||
print("m | k | n | n parts", end="")
|
||||
for method in methods:
|
||||
print(f" | {method.__name__.replace('_', ' ')} (µs)", end='')
|
||||
print('')
|
||||
print(f" | {method.__name__.replace('_', ' ')} (µs)", end="")
|
||||
print("")
|
||||
|
||||
# These are reasonable prefill sizes.
|
||||
ksandpartions = ((4096, (4096, 4096, 4096)), (4096, (4096, )),
|
||||
(4096, (11008, 11008)), (11008, (4096, )))
|
||||
ksandpartions = (
|
||||
(4096, (4096, 4096, 4096)),
|
||||
(4096, (4096,)),
|
||||
(4096, (11008, 11008)),
|
||||
(11008, (4096,)),
|
||||
)
|
||||
|
||||
# reasonable ranges for m.
|
||||
for m in [
|
||||
1, 2, 4, 8, 10, 12, 14, 16, 24, 32, 48, 52, 56, 64, 96, 112,
|
||||
128, 256, 512, 1024, 1536, 2048, 3072, 4096
|
||||
1,
|
||||
2,
|
||||
4,
|
||||
8,
|
||||
10,
|
||||
12,
|
||||
14,
|
||||
16,
|
||||
24,
|
||||
32,
|
||||
48,
|
||||
52,
|
||||
56,
|
||||
64,
|
||||
96,
|
||||
112,
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
1536,
|
||||
2048,
|
||||
3072,
|
||||
4096,
|
||||
]:
|
||||
print(f'{m}', file=sys.__stdout__)
|
||||
print(f"{m}", file=sys.__stdout__)
|
||||
for ksp in ksandpartions:
|
||||
run_grid(m, ksp[0], torch.tensor(ksp[1]), nbooks, bits,
|
||||
methods)
|
||||
run_grid(m, ksp[0], torch.tensor(ksp[1]), nbooks, bits, methods)
|
||||
|
||||
sys.stdout = sys.__stdout__
|
||||
|
||||
|
||||
def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int,
|
||||
methods):
|
||||
|
||||
def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, methods):
|
||||
# I didn't see visible improvements from increasing these, but feel free :)
|
||||
num_warmup_trials = 1
|
||||
num_trials = 1
|
||||
@ -229,7 +265,7 @@ def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int,
|
||||
)
|
||||
|
||||
n = parts.sum().item()
|
||||
print(f'{m} | {k} | {n} | {parts.tolist()}', end='')
|
||||
print(f"{m} | {k} | {n} | {parts.tolist()}", end="")
|
||||
|
||||
for method in methods:
|
||||
best_time_us = 1e20
|
||||
@ -249,32 +285,36 @@ def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int,
|
||||
if kernel_dur_us < best_time_us:
|
||||
best_time_us = kernel_dur_us
|
||||
|
||||
print(f' | {kernel_dur_us:.0f}', end='')
|
||||
print(f" | {kernel_dur_us:.0f}", end="")
|
||||
|
||||
print('')
|
||||
print("")
|
||||
|
||||
|
||||
def run_timing(num_calls: int, m: int, k: int, parts: torch.Tensor,
|
||||
nbooks: int, bits: int, method) -> float:
|
||||
|
||||
def run_timing(
|
||||
num_calls: int, m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, method
|
||||
) -> float:
|
||||
n = int(parts.sum().item())
|
||||
|
||||
device = torch.device('cuda:0')
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
input = torch.randn((1, m, k), dtype=torch.float16, device=device)
|
||||
|
||||
code_range = (1 << bits) // 2
|
||||
ingroups = 8
|
||||
|
||||
codes = torch.randint(-code_range,
|
||||
code_range,
|
||||
size=(n, k // ingroups, nbooks),
|
||||
dtype=get_int_dtype(bits),
|
||||
device=device)
|
||||
codes = torch.randint(
|
||||
-code_range,
|
||||
code_range,
|
||||
size=(n, k // ingroups, nbooks),
|
||||
dtype=get_int_dtype(bits),
|
||||
device=device,
|
||||
)
|
||||
|
||||
codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
|
||||
dtype=torch.float16,
|
||||
device=device)
|
||||
codebooks = torch.randn(
|
||||
size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
|
||||
dtype=torch.float16,
|
||||
device=device,
|
||||
)
|
||||
|
||||
scales = torch.randn(size=(n, 1, 1, 1), dtype=torch.float16, device=device)
|
||||
|
||||
|
||||
@ -3,27 +3,33 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
|
||||
MINIMUM_BITBLAS_VERSION)
|
||||
MINIMUM_BITBLAS_VERSION,
|
||||
)
|
||||
|
||||
try:
|
||||
import bitblas
|
||||
|
||||
if bitblas.__version__ < MINIMUM_BITBLAS_VERSION:
|
||||
raise ImportError("bitblas version is wrong. Please "
|
||||
f"install bitblas>={MINIMUM_BITBLAS_VERSION}")
|
||||
raise ImportError(
|
||||
"bitblas version is wrong. Please "
|
||||
f"install bitblas>={MINIMUM_BITBLAS_VERSION}"
|
||||
)
|
||||
except ImportError as e:
|
||||
bitblas_import_exception = e
|
||||
raise ValueError("Trying to use the bitblas backend, but could not import"
|
||||
f"with the following error: {bitblas_import_exception}. "
|
||||
"Please install bitblas through the following command: "
|
||||
f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`"
|
||||
) from bitblas_import_exception
|
||||
raise ValueError(
|
||||
"Trying to use the bitblas backend, but could not import"
|
||||
f"with the following error: {bitblas_import_exception}. "
|
||||
"Please install bitblas through the following command: "
|
||||
f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`"
|
||||
) from bitblas_import_exception
|
||||
|
||||
from bitblas import Matmul, MatmulConfig, auto_detect_nvidia_target
|
||||
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark BitBLAS int4 on a specific target.")
|
||||
description="Benchmark BitBLAS int4 on a specific target."
|
||||
)
|
||||
|
||||
# Add arguments to the parser
|
||||
parser.add_argument(
|
||||
@ -32,10 +38,9 @@ parser.add_argument(
|
||||
default=auto_detect_nvidia_target(),
|
||||
help="Specify the target device for benchmarking.",
|
||||
)
|
||||
parser.add_argument("--group_size",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Group size for grouped quantization.")
|
||||
parser.add_argument(
|
||||
"--group_size", type=int, default=None, help="Group size for grouped quantization."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--A_dtype",
|
||||
type=str,
|
||||
@ -82,17 +87,17 @@ parser.add_argument(
|
||||
choices=["nt", "nn"],
|
||||
help="Matrix layout, 'nt' for non-transpose A and transpose W.",
|
||||
)
|
||||
parser.add_argument("--with_bias",
|
||||
action="store_true",
|
||||
help="Include bias in the benchmark.")
|
||||
parser.add_argument(
|
||||
"--with_bias", action="store_true", help="Include bias in the benchmark."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--with_scaling",
|
||||
action="store_true",
|
||||
help="Include scaling factor in the quantization.",
|
||||
)
|
||||
parser.add_argument("--with_zeros",
|
||||
action="store_true",
|
||||
help="Include zeros in the quantization.")
|
||||
parser.add_argument(
|
||||
"--with_zeros", action="store_true", help="Include zeros in the quantization."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--zeros_mode",
|
||||
type=str,
|
||||
@ -170,8 +175,7 @@ shapes = [
|
||||
]
|
||||
|
||||
# Build test shapes with all the shared arguments
|
||||
test_shapes = [(MatmulConfig, Matmul, (*shape, *shared_args))
|
||||
for shape in shapes]
|
||||
test_shapes = [(MatmulConfig, Matmul, (*shape, *shared_args)) for shape in shapes]
|
||||
|
||||
benchmark_sets = []
|
||||
benchmark_sets.extend(test_shapes)
|
||||
@ -206,12 +210,12 @@ for config_key, values in benchmark_results.items():
|
||||
func_name = args_split[0]
|
||||
input_args_str = "-".join(args_split[1:])
|
||||
col_widths[0] = max(col_widths[0], len(func_name) + 2, len(headers[0]) + 2)
|
||||
col_widths[1] = max(col_widths[1],
|
||||
len(input_args_str) + 2,
|
||||
len(headers[1]) + 2)
|
||||
col_widths[2] = max(col_widths[2],
|
||||
len(f"{values['BitBLAS_top20_latency']:.3f} ms") + 2,
|
||||
len(headers[2]) + 2)
|
||||
col_widths[1] = max(col_widths[1], len(input_args_str) + 2, len(headers[1]) + 2)
|
||||
col_widths[2] = max(
|
||||
col_widths[2],
|
||||
len(f"{values['BitBLAS_top20_latency']:.3f} ms") + 2,
|
||||
len(headers[2]) + 2,
|
||||
)
|
||||
# break only if you want to measure widths from a single example;
|
||||
# otherwise, let it loop over all items.
|
||||
|
||||
@ -232,5 +236,6 @@ for config_key, values in benchmark_results.items():
|
||||
f"{values['BitBLAS_top20_latency']:.3f} ms",
|
||||
]
|
||||
row_str = "".join(
|
||||
[str(cell).ljust(col_widths[idx]) for idx, cell in enumerate(row)])
|
||||
[str(cell).ljust(col_widths[idx]) for idx, cell in enumerate(row)]
|
||||
)
|
||||
print(row_str)
|
||||
|
||||
489
benchmarks/kernels/benchmark_cutlass_fp4_moe.py
Normal file
489
benchmarks/kernels/benchmark_cutlass_fp4_moe.py
Normal file
@ -0,0 +1,489 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
Benchmark the performance of the cutlass_moe_fp4 kernel vs the triton_moe
|
||||
kernel. The cutlass_moe_fp4 kernel takes in fp4 quantized weights and 16-bit
|
||||
activations. The triton_moe kernel takes in fp8 weights(tensor scaled to fp8)
|
||||
and 16-bit activations.
|
||||
"""
|
||||
|
||||
import nvtx
|
||||
import torch
|
||||
import torch.utils.benchmark as benchmark
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
WEIGHT_SHAPES_MOE = {
|
||||
"nvidia/DeepSeek-R1-FP4": [
|
||||
[256, 8, 2048, 7168],
|
||||
],
|
||||
}
|
||||
|
||||
DEFAULT_MODELS = [
|
||||
"nvidia/DeepSeek-R1-FP4",
|
||||
]
|
||||
|
||||
DEFAULT_BATCH_SIZES = [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
||||
DEFAULT_TP_SIZES = [1]
|
||||
|
||||
PER_ACT_TOKEN_OPTS = [False]
|
||||
PER_OUT_CH_OPTS = [False]
|
||||
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
||||
|
||||
def to_fp8(tensor: torch.Tensor):
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to(
|
||||
dtype=torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
|
||||
def bench_run(
|
||||
results: list[benchmark.Measurement],
|
||||
model: str,
|
||||
num_experts: int,
|
||||
topk: int,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
mkn: tuple[int, int, int],
|
||||
):
|
||||
label = "NVFP4 Blockscaled CUTLASS MOE vs FP8 Tensor Scaled Triton"
|
||||
|
||||
sub_label = (
|
||||
"{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, MKN=({})".format(
|
||||
model, num_experts, topk, per_act_token, per_out_ch, mkn
|
||||
)
|
||||
)
|
||||
|
||||
print(f"Testing: {sub_label}")
|
||||
|
||||
(m, k, n) = mkn
|
||||
|
||||
dtype = torch.half
|
||||
device = "cuda"
|
||||
a = torch.randn((m, k), device=device, dtype=dtype) / 10
|
||||
w1 = torch.randn((num_experts, 2 * n, k), device=device, dtype=dtype) / 10
|
||||
w2 = torch.randn((num_experts, k, n), device=device, dtype=dtype) / 10
|
||||
|
||||
_, a_fp8_scale = ops.scaled_fp8_quant(a)
|
||||
|
||||
w1_fp8q = torch.empty(
|
||||
(num_experts, 2 * n, k), device=device, dtype=torch.float8_e4m3fn
|
||||
)
|
||||
w2_fp8q = torch.empty((num_experts, k, n), device=device, dtype=torch.float8_e4m3fn)
|
||||
w1_fp8scale = torch.empty((num_experts, 1, 1), device=device, dtype=torch.float32)
|
||||
w2_fp8scale = torch.empty((num_experts, 1, 1), device=device, dtype=torch.float32)
|
||||
|
||||
for expert in range(num_experts):
|
||||
w1_fp8q[expert], w1_fp8scale[expert] = ops.scaled_fp8_quant(w1[expert])
|
||||
w2_fp8q[expert], w2_fp8scale[expert] = ops.scaled_fp8_quant(w2[expert])
|
||||
|
||||
w1_fp8q_notransp = w1_fp8q.clone()
|
||||
w2_fp8q_notransp = w2_fp8q.clone()
|
||||
w1_fp8q = w1_fp8q.transpose(1, 2)
|
||||
w2_fp8q = w2_fp8q.transpose(1, 2)
|
||||
|
||||
score = torch.randn((m, num_experts), device=device, dtype=dtype)
|
||||
|
||||
topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False)
|
||||
|
||||
quant_blocksize = 16
|
||||
w1_blockscale = torch.empty(
|
||||
(num_experts, 2 * n, k // quant_blocksize),
|
||||
device=device,
|
||||
dtype=torch.float8_e4m3fn,
|
||||
)
|
||||
w2_blockscale = torch.empty(
|
||||
(num_experts, k, n // quant_blocksize), device=device, dtype=torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
# n_b_scales = 2 * n if per_out_ch else 1
|
||||
# k_b_scales = k if per_out_ch else 1
|
||||
w1_fp4 = torch.empty((num_experts, 2 * n, k // 2), device=device, dtype=torch.uint8)
|
||||
w2_fp4 = torch.empty((num_experts, k, n // 2), device=device, dtype=torch.uint8)
|
||||
|
||||
w1_gs = torch.empty((num_experts,), device=device, dtype=torch.float32)
|
||||
w2_gs = torch.empty((num_experts,), device=device, dtype=torch.float32)
|
||||
a1_gs = torch.ones((num_experts,), device=device, dtype=torch.float32)
|
||||
a2_gs = torch.ones((num_experts,), device=device, dtype=torch.float32)
|
||||
|
||||
for expert in range(num_experts):
|
||||
w1_e = w1[expert]
|
||||
w2_e = w2[expert]
|
||||
w1_amax = torch.abs(w1_e).max().to(torch.float32)
|
||||
w2_amax = torch.abs(w2_e).max().to(torch.float32)
|
||||
w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax
|
||||
w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax
|
||||
|
||||
w1_fp4[expert], w1_blockscale[expert] = ops.scaled_fp4_quant(
|
||||
w1_e, w1_gs[expert]
|
||||
)
|
||||
|
||||
w2_fp4[expert], w2_blockscale[expert] = ops.scaled_fp4_quant(
|
||||
w2_e, w2_gs[expert]
|
||||
)
|
||||
|
||||
def run_triton_moe(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
a_fp8_scale: torch.Tensor,
|
||||
num_repeats: int,
|
||||
):
|
||||
for _ in range(num_repeats):
|
||||
fused_experts(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a_fp8_scale,
|
||||
)
|
||||
|
||||
def run_cutlass_moe_fp4(
|
||||
a: torch.Tensor,
|
||||
w1_fp4: torch.Tensor,
|
||||
w2_fp4: torch.Tensor,
|
||||
w1_blockscale: torch.Tensor,
|
||||
w2_blockscale: torch.Tensor,
|
||||
w1_gs: torch.Tensor,
|
||||
w2_gs: torch.Tensor,
|
||||
a1_gs: torch.Tensor,
|
||||
a2_gs: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
device: torch.device,
|
||||
num_repeats: int,
|
||||
):
|
||||
for _ in range(num_repeats):
|
||||
with nvtx.annotate("cutlass_moe_fp4", color="green"):
|
||||
cutlass_moe_fp4(
|
||||
a=a,
|
||||
a1_gscale=a1_gs,
|
||||
a2_gscale=a2_gs,
|
||||
w1_fp4=w1_fp4,
|
||||
w1_blockscale=w1_blockscale,
|
||||
w1_alphas=w1_gs,
|
||||
w2_fp4=w2_fp4,
|
||||
w2_blockscale=w2_blockscale,
|
||||
w2_alphas=w2_gs,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
m=m,
|
||||
n=n,
|
||||
k=k,
|
||||
e=num_experts,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def run_cutlass_from_graph(
|
||||
a: torch.Tensor,
|
||||
a1_gscale: torch.Tensor,
|
||||
w1_fp4: torch.Tensor,
|
||||
w1_blockscale: torch.Tensor,
|
||||
w1_alphas: torch.Tensor,
|
||||
a2_gscale: torch.Tensor,
|
||||
w2_fp4: torch.Tensor,
|
||||
w2_blockscale: torch.Tensor,
|
||||
w2_alphas: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
device: torch.device,
|
||||
):
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
||||
):
|
||||
return cutlass_moe_fp4(
|
||||
a=a,
|
||||
a1_gscale=a1_gs,
|
||||
w1_fp4=w1_fp4,
|
||||
w1_blockscale=w1_blockscale,
|
||||
w1_alphas=w1_alphas,
|
||||
a2_gscale=a2_gs,
|
||||
w2_fp4=w2_fp4,
|
||||
w2_blockscale=w2_blockscale,
|
||||
w2_alphas=w2_alphas,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
m=m,
|
||||
n=n,
|
||||
k=k,
|
||||
e=num_experts,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def run_triton_from_graph(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
a_fp8_scale: torch.Tensor,
|
||||
):
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
||||
):
|
||||
return fused_experts(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a_fp8_scale,
|
||||
)
|
||||
|
||||
def replay_graph(graph, num_repeats):
|
||||
for _ in range(num_repeats):
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
cutlass_stream = torch.cuda.Stream()
|
||||
cutlass_graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(cutlass_graph, stream=cutlass_stream):
|
||||
run_cutlass_from_graph(
|
||||
a=a,
|
||||
a1_gscale=a1_gs,
|
||||
w1_fp4=w1_fp4,
|
||||
w1_blockscale=w1_blockscale,
|
||||
w1_alphas=w1_gs,
|
||||
a2_gscale=a2_gs,
|
||||
w2_fp4=w2_fp4,
|
||||
w2_blockscale=w2_blockscale,
|
||||
w2_alphas=w2_gs,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
m=m,
|
||||
n=n,
|
||||
k=k,
|
||||
e=num_experts,
|
||||
device=device,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
triton_stream = torch.cuda.Stream()
|
||||
triton_graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(triton_graph, stream=triton_stream):
|
||||
run_triton_from_graph(
|
||||
a,
|
||||
w1_fp8q_notransp,
|
||||
w2_fp8q_notransp,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
w1_fp8scale,
|
||||
w2_fp8scale,
|
||||
a_fp8_scale,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
min_run_time = 5
|
||||
num_warmup = 5
|
||||
num_runs = 25
|
||||
|
||||
globals = {
|
||||
# Baseline params
|
||||
"w1": w1,
|
||||
"w2": w2,
|
||||
"score": score,
|
||||
"topk": topk,
|
||||
"w1_fp8q_notransp": w1_fp8q_notransp,
|
||||
"w2_fp8q_notransp": w2_fp8q_notransp,
|
||||
"w1_fp8scale": w1_fp8scale,
|
||||
"w2_fp8scale": w2_fp8scale,
|
||||
"a_fp8_scale": a_fp8_scale,
|
||||
# Cutlass params
|
||||
"a": a,
|
||||
"a1_gscale": a1_gs,
|
||||
"w1_fp4": w1_fp4,
|
||||
"w1_blockscale": w1_blockscale,
|
||||
"w1_alphas": w1_gs,
|
||||
"a2_gscale": a2_gs,
|
||||
"w2_fp4": w2_fp4,
|
||||
"w2_blockscale": w2_blockscale,
|
||||
"w2_alphas": w2_gs,
|
||||
"topk_weights": topk_weights,
|
||||
"topk_ids": topk_ids,
|
||||
"m": m,
|
||||
"n": n,
|
||||
"k": k,
|
||||
"e": num_experts,
|
||||
"device": device,
|
||||
# cuda graph params
|
||||
"cutlass_graph": cutlass_graph,
|
||||
"triton_graph": triton_graph,
|
||||
# Gen params
|
||||
"num_runs": num_runs,
|
||||
# Kernels
|
||||
"run_triton_moe": run_triton_moe,
|
||||
"run_cutlass_moe_fp4": run_cutlass_moe_fp4,
|
||||
"replay_graph": replay_graph,
|
||||
}
|
||||
|
||||
# Warmup
|
||||
run_triton_moe(
|
||||
a,
|
||||
w1_fp8q_notransp,
|
||||
w2_fp8q_notransp,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
w1_fp8scale,
|
||||
w2_fp8scale,
|
||||
a_fp8_scale,
|
||||
num_warmup,
|
||||
)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="run_triton_moe(a, w1_fp8q_notransp, w2_fp8q_notransp, topk_weights, topk_ids, w1_fp8scale, w2_fp8scale, a_fp8_scale, num_runs)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="triton_moe",
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
# Warmup
|
||||
replay_graph(triton_graph, num_warmup)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="replay_graph(triton_graph, num_runs)",
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="triton_moe_cuda_graphs",
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
# Warmup
|
||||
|
||||
run_cutlass_moe_fp4(
|
||||
a,
|
||||
w1_fp4,
|
||||
w2_fp4,
|
||||
w1_blockscale,
|
||||
w2_blockscale,
|
||||
w1_gs,
|
||||
w2_gs,
|
||||
a1_gs,
|
||||
a2_gs,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
num_experts,
|
||||
device,
|
||||
num_warmup,
|
||||
)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="run_cutlass_moe_fp4(a, w1_fp4, w2_fp4, w1_blockscale, w2_blockscale, w1_alphas, w2_alphas, a1_gscale, a2_gscale, topk_weights, topk_ids, m, n, k, e, device, num_runs)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="cutlass_moe_fp4",
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
# Warmup
|
||||
replay_graph(cutlass_graph, num_warmup)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="replay_graph(cutlass_graph, num_runs)",
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="cutlass_moe_fp4_cuda_graphs",
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
|
||||
def main(args):
|
||||
print("Benchmarking models:")
|
||||
for i, model in enumerate(args.models):
|
||||
print(f"[{i}] {model}")
|
||||
|
||||
results: list[benchmark.Measurement] = []
|
||||
|
||||
for model in args.models:
|
||||
for tp in args.tp_sizes:
|
||||
for layer in WEIGHT_SHAPES_MOE[model]:
|
||||
num_experts = layer[0]
|
||||
topk = layer[1]
|
||||
size_k = layer[2]
|
||||
size_n = layer[3] // tp
|
||||
|
||||
if len(args.limit_k) > 0 and size_k not in args.limit_k:
|
||||
continue
|
||||
|
||||
if len(args.limit_n) > 0 and size_n not in args.limit_n:
|
||||
continue
|
||||
|
||||
for per_act_token in PER_ACT_TOKEN_OPTS:
|
||||
for per_out_ch in PER_OUT_CH_OPTS:
|
||||
for size_m in args.batch_sizes:
|
||||
mkn = (size_m, size_k, size_n)
|
||||
bench_run(
|
||||
results,
|
||||
model,
|
||||
num_experts,
|
||||
topk,
|
||||
per_act_token,
|
||||
per_out_ch,
|
||||
mkn,
|
||||
)
|
||||
|
||||
compare = benchmark.Compare(results)
|
||||
compare.print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark NVFP4 CUTLASS MOE across specified models/shapes/batches"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=DEFAULT_MODELS,
|
||||
choices=WEIGHT_SHAPES_MOE.keys(),
|
||||
)
|
||||
parser.add_argument("--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES)
|
||||
parser.add_argument(
|
||||
"--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
|
||||
)
|
||||
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--limit-per-act-token", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[])
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@ -6,14 +6,18 @@ from benchmark_shapes import WEIGHT_SHAPES_MOE
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe_fp8,
|
||||
fused_experts,
|
||||
fused_topk)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
cutlass_moe_fp8,
|
||||
fused_experts,
|
||||
fused_topk,
|
||||
)
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
DEFAULT_MODELS = [
|
||||
"nm-testing/Mixtral-8x7B-Instruct-v0.1", "nm-testing/deepseekv2-lite",
|
||||
"ibm-granite/granite-3.0-1b-a400m", "ibm-granite/granite-3.0-3b-a800m"
|
||||
"nm-testing/Mixtral-8x7B-Instruct-v0.1",
|
||||
"nm-testing/deepseekv2-lite",
|
||||
"ibm-granite/granite-3.0-1b-a400m",
|
||||
"ibm-granite/granite-3.0-3b-a800m",
|
||||
]
|
||||
DEFAULT_BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128, 256, 512]
|
||||
DEFAULT_TP_SIZES = [1]
|
||||
@ -24,19 +28,27 @@ PER_OUT_CH_OPTS = [False]
|
||||
|
||||
def to_fp8(tensor: torch.Tensor):
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
return torch.round(tensor.clamp(
|
||||
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
|
||||
return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to(
|
||||
dtype=torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
|
||||
def bench_run(results: list[benchmark.Measurement], model: str,
|
||||
num_experts: int, topk: int, per_act_token: bool,
|
||||
per_out_ch: bool, mkn: tuple[int, int, int]):
|
||||
def bench_run(
|
||||
results: list[benchmark.Measurement],
|
||||
model: str,
|
||||
num_experts: int,
|
||||
topk: int,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
mkn: tuple[int, int, int],
|
||||
):
|
||||
label = "Quant Matmul"
|
||||
|
||||
sub_label = (
|
||||
"{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, "
|
||||
"MKN=({})".format(model, num_experts, topk, per_act_token, per_out_ch,
|
||||
mkn))
|
||||
"{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, MKN=({})".format(
|
||||
model, num_experts, topk, per_act_token, per_out_ch, mkn
|
||||
)
|
||||
)
|
||||
|
||||
print(f"Testing: {sub_label}")
|
||||
|
||||
@ -50,35 +62,17 @@ def bench_run(results: list[benchmark.Measurement], model: str,
|
||||
|
||||
_, a_scale = ops.scaled_fp8_quant(a)
|
||||
|
||||
w1_q = torch.empty((num_experts, 2 * n, k),
|
||||
device="cuda",
|
||||
dtype=torch.float8_e4m3fn)
|
||||
w2_q = torch.empty((num_experts, k, n),
|
||||
device="cuda",
|
||||
dtype=torch.float8_e4m3fn)
|
||||
w1_scale = torch.empty((num_experts, 1, 1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
w2_scale = torch.empty((num_experts, 1, 1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
w1_q = torch.empty(
|
||||
(num_experts, 2 * n, k), device="cuda", dtype=torch.float8_e4m3fn
|
||||
)
|
||||
w2_q = torch.empty((num_experts, k, n), device="cuda", dtype=torch.float8_e4m3fn)
|
||||
w1_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32)
|
||||
w2_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32)
|
||||
|
||||
ab_strides1 = torch.full((num_experts, ),
|
||||
k,
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
c_strides1 = torch.full((num_experts, ),
|
||||
2 * n,
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
ab_strides2 = torch.full((num_experts, ),
|
||||
n,
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
c_strides2 = torch.full((num_experts, ),
|
||||
k,
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
|
||||
c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
|
||||
ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
|
||||
|
||||
for expert in range(num_experts):
|
||||
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert])
|
||||
@ -91,82 +85,120 @@ def bench_run(results: list[benchmark.Measurement], model: str,
|
||||
score = torch.randn((m, num_experts), device="cuda", dtype=dtype)
|
||||
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
a, score, topk, renormalize=False)
|
||||
a, score, topk, renormalize=False
|
||||
)
|
||||
|
||||
def run_triton_moe(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
w1_scale: torch.Tensor, w2_scale: torch.Tensor,
|
||||
a_scale: torch.Tensor, num_repeats: int):
|
||||
def run_triton_moe(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
a_scale: torch.Tensor,
|
||||
num_repeats: int,
|
||||
):
|
||||
for _ in range(num_repeats):
|
||||
fused_experts(a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a_scale)
|
||||
fused_experts(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a_scale,
|
||||
)
|
||||
|
||||
def run_cutlass_moe(a: torch.Tensor, a_scale: torch.Tensor,
|
||||
w1: torch.Tensor, w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor, w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
ab_strides1: torch.Tensor, c_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor, c_strides2: torch.Tensor,
|
||||
num_repeats: int):
|
||||
def run_cutlass_moe(
|
||||
a: torch.Tensor,
|
||||
a_scale: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
ab_strides1: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
num_repeats: int,
|
||||
):
|
||||
for _ in range(num_repeats):
|
||||
cutlass_moe_fp8(a,
|
||||
w1,
|
||||
w2,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
ab_strides1,
|
||||
c_strides1,
|
||||
ab_strides2,
|
||||
c_strides2,
|
||||
a1_scale=a_scale)
|
||||
cutlass_moe_fp8(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
ab_strides1,
|
||||
c_strides1,
|
||||
ab_strides2,
|
||||
c_strides2,
|
||||
a1_scale=a_scale,
|
||||
)
|
||||
|
||||
def run_cutlass_from_graph(
|
||||
a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor,
|
||||
w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
ab_strides1: torch.Tensor, c_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor, c_strides2: torch.Tensor):
|
||||
a: torch.Tensor,
|
||||
a_scale: torch.Tensor,
|
||||
w1_q: torch.Tensor,
|
||||
w2_q: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
ab_strides1: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
):
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(
|
||||
pipeline_parallel_size=1))):
|
||||
return cutlass_moe_fp8(a,
|
||||
w1_q,
|
||||
w2_q,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
ab_strides1,
|
||||
c_strides1,
|
||||
ab_strides2,
|
||||
c_strides2,
|
||||
a1_scale=a_scale)
|
||||
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
||||
):
|
||||
return cutlass_moe_fp8(
|
||||
a,
|
||||
w1_q,
|
||||
w2_q,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
ab_strides1,
|
||||
c_strides1,
|
||||
ab_strides2,
|
||||
c_strides2,
|
||||
a1_scale=a_scale,
|
||||
)
|
||||
|
||||
def run_triton_from_graph(a: torch.Tensor, w1: torch.Tensor,
|
||||
w2: torch.Tensor, topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor, w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor, a_scale: torch.Tensor):
|
||||
def run_triton_from_graph(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
a_scale: torch.Tensor,
|
||||
):
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(
|
||||
pipeline_parallel_size=1))):
|
||||
return fused_experts(a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a_scale)
|
||||
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
||||
):
|
||||
return fused_experts(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a_scale,
|
||||
)
|
||||
|
||||
def replay_graph(graph, num_repeats):
|
||||
for _ in range(num_repeats):
|
||||
@ -176,16 +208,35 @@ def bench_run(results: list[benchmark.Measurement], model: str,
|
||||
cutlass_stream = torch.cuda.Stream()
|
||||
cutlass_graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(cutlass_graph, stream=cutlass_stream):
|
||||
run_cutlass_from_graph(a, a_scale, w1_q, w2_q, w1_scale, w2_scale,
|
||||
topk_weights, topk_ids, ab_strides1, c_strides1,
|
||||
ab_strides2, c_strides2)
|
||||
run_cutlass_from_graph(
|
||||
a,
|
||||
a_scale,
|
||||
w1_q,
|
||||
w2_q,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
ab_strides1,
|
||||
c_strides1,
|
||||
ab_strides2,
|
||||
c_strides2,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
triton_stream = torch.cuda.Stream()
|
||||
triton_graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(triton_graph, stream=triton_stream):
|
||||
run_triton_from_graph(a, w1_q_notransp, w2_q_notransp, topk_weights,
|
||||
topk_ids, w1_scale, w2_scale, a_scale)
|
||||
run_triton_from_graph(
|
||||
a,
|
||||
w1_q_notransp,
|
||||
w2_q_notransp,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
a_scale,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
min_run_time = 5
|
||||
@ -225,18 +276,27 @@ def bench_run(results: list[benchmark.Measurement], model: str,
|
||||
}
|
||||
|
||||
# Warmup
|
||||
run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids,
|
||||
w1_scale, w2_scale, a_scale, num_warmup)
|
||||
run_triton_moe(
|
||||
a,
|
||||
w1_q_notransp,
|
||||
w2_q_notransp,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
a_scale,
|
||||
num_warmup,
|
||||
)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt=
|
||||
"run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501
|
||||
stmt="run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="triton_moe",
|
||||
).blocked_autorange(min_run_time=min_run_time))
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
# Warmup
|
||||
replay_graph(triton_graph, num_warmup)
|
||||
@ -248,22 +308,35 @@ def bench_run(results: list[benchmark.Measurement], model: str,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="triton_moe_cuda_graphs",
|
||||
).blocked_autorange(min_run_time=min_run_time))
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
# Warmup
|
||||
run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights,
|
||||
topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2,
|
||||
num_warmup)
|
||||
run_cutlass_moe(
|
||||
a,
|
||||
a_scale,
|
||||
w1_q,
|
||||
w2_q,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
ab_strides1,
|
||||
c_strides1,
|
||||
ab_strides2,
|
||||
c_strides2,
|
||||
num_warmup,
|
||||
)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt=
|
||||
"run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501
|
||||
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="grouped_gemm_moe",
|
||||
).blocked_autorange(min_run_time=min_run_time))
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
# Warmup
|
||||
replay_graph(cutlass_graph, num_warmup)
|
||||
@ -275,7 +348,8 @@ def bench_run(results: list[benchmark.Measurement], model: str,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="grouped_gemm_moe_cuda_graphs",
|
||||
).blocked_autorange(min_run_time=min_run_time))
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
|
||||
def main(args):
|
||||
@ -303,8 +377,15 @@ def main(args):
|
||||
for per_out_ch in PER_OUT_CH_OPTS:
|
||||
for size_m in DEFAULT_BATCH_SIZES:
|
||||
mkn = (size_m, size_k, size_n)
|
||||
bench_run(results, model, num_experts, topk,
|
||||
per_act_token, per_out_ch, mkn)
|
||||
bench_run(
|
||||
results,
|
||||
model,
|
||||
num_experts,
|
||||
topk,
|
||||
per_act_token,
|
||||
per_out_ch,
|
||||
mkn,
|
||||
)
|
||||
|
||||
compare = benchmark.Compare(results)
|
||||
compare.print()
|
||||
@ -312,7 +393,8 @@ def main(args):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark Marlin across specified models/shapes/batches")
|
||||
description="Benchmark Marlin across specified models/shapes/batches"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
@ -320,21 +402,14 @@ if __name__ == "__main__":
|
||||
default=DEFAULT_MODELS,
|
||||
choices=WEIGHT_SHAPES_MOE.keys(),
|
||||
)
|
||||
parser.add_argument("--tp-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=DEFAULT_TP_SIZES)
|
||||
parser.add_argument("--batch-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=DEFAULT_BATCH_SIZES)
|
||||
parser.add_argument("--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES)
|
||||
parser.add_argument(
|
||||
"--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
|
||||
)
|
||||
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--limit-per-act-token",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[])
|
||||
parser.add_argument("--limit-per-act-token", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[])
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -10,14 +10,16 @@ from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def main(num_tokens: int,
|
||||
hidden_size: int,
|
||||
add_residual: bool,
|
||||
dtype: torch.dtype,
|
||||
seed: int = 0,
|
||||
do_profile: bool = False,
|
||||
num_warmup_iters: int = 5,
|
||||
num_iters: int = 100) -> None:
|
||||
def main(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
add_residual: bool,
|
||||
dtype: torch.dtype,
|
||||
seed: int = 0,
|
||||
do_profile: bool = False,
|
||||
num_warmup_iters: int = 5,
|
||||
num_iters: int = 100,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
@ -56,33 +58,35 @@ def main(num_tokens: int,
|
||||
print(f"Kernel running time: {latency * 1000000:.3f} us")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark the layernorm kernel.")
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(description="Benchmark the layernorm kernel.")
|
||||
parser.add_argument("--num-tokens", type=int, default=4096)
|
||||
parser.add_argument("--hidden-size", type=int, default=8192)
|
||||
parser.add_argument("--add-residual", action="store_true")
|
||||
parser.add_argument("--dtype",
|
||||
type=str,
|
||||
choices=["half", "bfloat16", "float"],
|
||||
default="half")
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, choices=["half", "bfloat16", "float"], default="half"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--profile", action="store_true")
|
||||
parser.add_argument("--num-warmup-iters", type=int, default=5)
|
||||
parser.add_argument("--num-iters",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Number of benchmark iterations. "
|
||||
"If --profile is set, this number is ignored")
|
||||
parser.add_argument(
|
||||
"--num-iters",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Number of benchmark iterations. "
|
||||
"If --profile is set, this number is ignored",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
main(num_tokens=args.num_tokens,
|
||||
hidden_size=args.hidden_size,
|
||||
add_residual=args.add_residual,
|
||||
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
|
||||
seed=args.seed,
|
||||
do_profile=args.profile,
|
||||
num_warmup_iters=args.num_warmup_iters,
|
||||
num_iters=args.num_iters)
|
||||
main(
|
||||
num_tokens=args.num_tokens,
|
||||
hidden_size=args.hidden_size,
|
||||
add_residual=args.add_residual,
|
||||
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
|
||||
seed=args.seed,
|
||||
do_profile=args.profile,
|
||||
num_warmup_iters=args.num_warmup_iters,
|
||||
num_iters=args.num_iters,
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -20,12 +20,18 @@ from weight_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, marlin_permute_scales,
|
||||
marlin_zero_points)
|
||||
GPTQ_MARLIN_MAX_PARALLEL,
|
||||
GPTQ_MARLIN_MIN_THREAD_N,
|
||||
marlin_permute_scales,
|
||||
marlin_zero_points,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
MarlinWorkspace)
|
||||
MarlinWorkspace,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
pack_rows, quantize_weights)
|
||||
pack_rows,
|
||||
quantize_weights,
|
||||
)
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
@ -82,12 +88,14 @@ def rand_data(shape, dtype=torch.float16, scale=1):
|
||||
return torch.randint(-15, 15, shape, dtype=dtype, device="cuda")
|
||||
|
||||
|
||||
def quantize_and_pack(atype: torch.dtype,
|
||||
w: torch.Tensor,
|
||||
wtype: ScalarType,
|
||||
stype: Optional[torch.dtype],
|
||||
group_size: Optional[int],
|
||||
zero_points: bool = False):
|
||||
def quantize_and_pack(
|
||||
atype: torch.dtype,
|
||||
w: torch.Tensor,
|
||||
wtype: ScalarType,
|
||||
stype: Optional[torch.dtype],
|
||||
group_size: Optional[int],
|
||||
zero_points: bool = False,
|
||||
):
|
||||
assert wtype.is_integer(), "TODO: support floating point weights"
|
||||
|
||||
w_ref, w_q, w_s, w_zp = quantize_weights(
|
||||
@ -96,21 +104,24 @@ def quantize_and_pack(atype: torch.dtype,
|
||||
group_size=group_size,
|
||||
zero_points=zero_points,
|
||||
# to match how the kernel applies zps
|
||||
ref_zero_points_after_scales=True)
|
||||
ref_zero_points_after_scales=True,
|
||||
)
|
||||
|
||||
w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape)
|
||||
return w_ref, w_q, w_s, w_zp
|
||||
|
||||
|
||||
def create_bench_tensors(shape: tuple[int, int, int], types: TypeConfig,
|
||||
group_size: Optional[int]) -> list[BenchmarkTensors]:
|
||||
def create_bench_tensors(
|
||||
shape: tuple[int, int, int], types: TypeConfig, group_size: Optional[int]
|
||||
) -> list[BenchmarkTensors]:
|
||||
m, n, k = shape
|
||||
|
||||
# we want to make sure that weights don't fit into L2 cache between runs so
|
||||
# we construct enough weights to exceed L2 cache, which is 50mb on a H100
|
||||
# so we target total weight size > 2*50mb
|
||||
num_weights = math.ceil(2 * 50 * 1024**2 * 8 /
|
||||
(k * n * types.weight_type.size_bits))
|
||||
num_weights = math.ceil(
|
||||
2 * 50 * 1024**2 * 8 / (k * n * types.weight_type.size_bits)
|
||||
)
|
||||
|
||||
a = rand_data((m, k), types.act_type, scale=5)
|
||||
|
||||
@ -124,8 +135,13 @@ def create_bench_tensors(shape: tuple[int, int, int], types: TypeConfig,
|
||||
w = w.to(torch.float16)
|
||||
|
||||
w_ref, w_q_packed, w_s, w_zp = quantize_and_pack(
|
||||
a.dtype, w, types.weight_type, types.group_scale_type, group_size,
|
||||
types.group_zero_type is not None)
|
||||
a.dtype,
|
||||
w,
|
||||
types.weight_type,
|
||||
types.group_scale_type,
|
||||
group_size,
|
||||
types.group_zero_type is not None,
|
||||
)
|
||||
|
||||
if not a.dtype.is_floating_point:
|
||||
aiinfo = torch.iinfo(a.dtype)
|
||||
@ -133,21 +149,30 @@ def create_bench_tensors(shape: tuple[int, int, int], types: TypeConfig,
|
||||
|
||||
w_ref = w_ref.to(torch.float32)
|
||||
|
||||
w_ch_s = None if types.channel_scale_type is None else\
|
||||
rand_data((n,), types.channel_scale_type)
|
||||
w_tok_s = None if types.token_scale_type is None else\
|
||||
rand_data((m,), types.token_scale_type)
|
||||
w_ch_s = (
|
||||
None
|
||||
if types.channel_scale_type is None
|
||||
else rand_data((n,), types.channel_scale_type)
|
||||
)
|
||||
w_tok_s = (
|
||||
None
|
||||
if types.token_scale_type is None
|
||||
else rand_data((m,), types.token_scale_type)
|
||||
)
|
||||
|
||||
benchmark_tensors.append(
|
||||
BenchmarkTensors(w_ref=w_ref,
|
||||
a=a,
|
||||
w_q=w_q_packed,
|
||||
wtype=types.weight_type,
|
||||
w_g_s=w_s,
|
||||
w_g_zp=w_zp,
|
||||
group_size=group_size,
|
||||
w_ch_s=w_ch_s,
|
||||
w_tok_s=w_tok_s))
|
||||
BenchmarkTensors(
|
||||
w_ref=w_ref,
|
||||
a=a,
|
||||
w_q=w_q_packed,
|
||||
wtype=types.weight_type,
|
||||
w_g_s=w_s,
|
||||
w_g_zp=w_zp,
|
||||
group_size=group_size,
|
||||
w_ch_s=w_ch_s,
|
||||
w_tok_s=w_tok_s,
|
||||
)
|
||||
)
|
||||
|
||||
return benchmark_tensors
|
||||
|
||||
@ -170,50 +195,57 @@ def cutlass_scaled_mm_create_bench_fn(bt: BenchmarkTensors) -> Callable:
|
||||
scale_b = torch.tensor(1.0, dtype=torch.float32, device=bt.a.device)
|
||||
w_col_major = bt.w_ref.to(bt.a.dtype).t().contiguous().t()
|
||||
return lambda: ops.cutlass_scaled_mm(
|
||||
bt.a, w_col_major, scale_a, scale_b, out_dtype=torch.float16)
|
||||
bt.a, w_col_major, scale_a, scale_b, out_dtype=torch.float16
|
||||
)
|
||||
|
||||
|
||||
def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
|
||||
device = bt.a.device
|
||||
|
||||
workspace = MarlinWorkspace(bt.w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_MAX_PARALLEL)
|
||||
workspace = MarlinWorkspace(
|
||||
bt.w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL
|
||||
)
|
||||
|
||||
if bt.w_g_zp is None:
|
||||
w_zp = torch.empty(0, dtype=torch.int, device=device)
|
||||
else:
|
||||
w_zp = marlin_zero_points(bt.w_g_zp, bt.w_ref.shape[0],
|
||||
bt.w_ref.shape[1], bt.wtype.size_bits)
|
||||
w_zp = marlin_zero_points(
|
||||
bt.w_g_zp, bt.w_ref.shape[0], bt.w_ref.shape[1], bt.wtype.size_bits
|
||||
)
|
||||
|
||||
if bt.group_size is None:
|
||||
w_s = torch.tensor([], device="cuda", dtype=torch.half)
|
||||
else:
|
||||
w_s = marlin_permute_scales(bt.w_g_s, bt.w_ref.shape[0],
|
||||
bt.w_ref.shape[1], bt.group_size)
|
||||
w_s = marlin_permute_scales(
|
||||
bt.w_g_s, bt.w_ref.shape[0], bt.w_ref.shape[1], bt.group_size
|
||||
)
|
||||
|
||||
sort_indices = torch.empty(0, dtype=torch.int, device=device)
|
||||
g_idx = torch.empty(0, dtype=torch.int, device=device)
|
||||
w_q = ops.gptq_marlin_repack(bt.w_q, sort_indices, bt.w_ref.shape[0],
|
||||
bt.w_ref.shape[1], bt.wtype.size_bits)
|
||||
w_q = ops.gptq_marlin_repack(
|
||||
bt.w_q, sort_indices, bt.w_ref.shape[0], bt.w_ref.shape[1], bt.wtype.size_bits
|
||||
)
|
||||
|
||||
if bt.a.dtype.is_floating_point:
|
||||
assert bt.w_ch_s is None
|
||||
assert bt.w_tok_s is None
|
||||
assert bt.group_size is not None
|
||||
|
||||
fn = lambda: ops.gptq_marlin_gemm(a=bt.a,
|
||||
b_q_weight=w_q,
|
||||
b_scales=w_s,
|
||||
b_zeros=w_zp,
|
||||
g_idx=g_idx,
|
||||
perm=sort_indices,
|
||||
workspace=workspace.scratch,
|
||||
b_q_type=bt.wtype,
|
||||
size_m=bt.a.shape[0],
|
||||
size_n=bt.w_ref.shape[1],
|
||||
size_k=bt.w_ref.shape[0],
|
||||
is_k_full=True,
|
||||
is_zp_float=False)
|
||||
fn = lambda: ops.gptq_marlin_gemm(
|
||||
a=bt.a,
|
||||
b_q_weight=w_q,
|
||||
b_scales=w_s,
|
||||
b_zeros=w_zp,
|
||||
g_idx=g_idx,
|
||||
perm=sort_indices,
|
||||
workspace=workspace.scratch,
|
||||
b_q_type=bt.wtype,
|
||||
size_m=bt.a.shape[0],
|
||||
size_n=bt.w_ref.shape[1],
|
||||
size_k=bt.w_ref.shape[0],
|
||||
is_k_full=True,
|
||||
is_zp_float=False,
|
||||
)
|
||||
else:
|
||||
assert bt.a.dtype == torch.int8
|
||||
assert bt.wtype == scalar_types.uint4b8
|
||||
@ -221,36 +253,35 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
|
||||
if bt.w_ch_s is not None:
|
||||
s_ch = bt.w_ch_s.to(torch.float32)
|
||||
else:
|
||||
s_ch = torch.ones(bt.w_ref.shape[1],
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
s_ch = torch.ones(bt.w_ref.shape[1], dtype=torch.float32, device=device)
|
||||
|
||||
if bt.w_tok_s is not None:
|
||||
s_tok = bt.w_tok_s.to(torch.float32)
|
||||
else:
|
||||
s_tok = torch.ones(bt.a.shape[0],
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
s_tok = torch.ones(bt.a.shape[0], dtype=torch.float32, device=device)
|
||||
|
||||
fn = lambda: ops.marlin_qqq_gemm(a=bt.a,
|
||||
b_q_weight=w_q,
|
||||
s_group=w_s,
|
||||
s_tok=s_tok,
|
||||
s_ch=s_ch,
|
||||
workspace=workspace.scratch,
|
||||
size_m=bt.a.shape[0],
|
||||
size_n=bt.w_ref.shape[1],
|
||||
size_k=bt.w_ref.shape[0])
|
||||
fn = lambda: ops.marlin_qqq_gemm(
|
||||
a=bt.a,
|
||||
b_q_weight=w_q,
|
||||
s_group=w_s,
|
||||
s_tok=s_tok,
|
||||
s_ch=s_ch,
|
||||
workspace=workspace.scratch,
|
||||
size_m=bt.a.shape[0],
|
||||
size_n=bt.w_ref.shape[1],
|
||||
size_k=bt.w_ref.shape[0],
|
||||
)
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
def machete_create_bench_fn(bt: BenchmarkTensors,
|
||||
out_type=torch.dtype,
|
||||
schedule=None) -> Callable:
|
||||
def machete_create_bench_fn(
|
||||
bt: BenchmarkTensors, out_type=torch.dtype, schedule=None
|
||||
) -> Callable:
|
||||
w_q = bt.w_q.t().contiguous().t() # make col major
|
||||
w_q = ops.machete_prepack_B(w_q, bt.a.dtype, bt.wtype,
|
||||
None if bt.w_g_s is None else bt.w_g_s.dtype)
|
||||
w_q = ops.machete_prepack_B(
|
||||
w_q, bt.a.dtype, bt.wtype, None if bt.w_g_s is None else bt.w_g_s.dtype
|
||||
)
|
||||
|
||||
w_g_zp = bt.w_g_zp
|
||||
if w_g_zp is not None:
|
||||
@ -275,26 +306,24 @@ def machete_create_bench_fn(bt: BenchmarkTensors,
|
||||
# bench
|
||||
|
||||
|
||||
def bench_fns(label: str, sub_label: str, description: str,
|
||||
fns: list[Callable]):
|
||||
|
||||
def bench_fns(label: str, sub_label: str, description: str, fns: list[Callable]):
|
||||
min_run_time = 1 if not NVTX_PROFILE else 0.1
|
||||
res = TBenchmark.Timer(
|
||||
stmt="""
|
||||
for fn in fns:
|
||||
fn()
|
||||
""",
|
||||
globals={
|
||||
"fns": fns
|
||||
},
|
||||
globals={"fns": fns},
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description=description,
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
|
||||
if NVTX_PROFILE:
|
||||
with nvtx.annotate("mm-bench"), nvtx.annotate(
|
||||
f"{label}|{sub_label}|{description}"):
|
||||
with (
|
||||
nvtx.annotate("mm-bench"),
|
||||
nvtx.annotate(f"{label}|{sub_label}|{description}"),
|
||||
):
|
||||
fns[0]()
|
||||
|
||||
return res
|
||||
@ -304,19 +333,20 @@ _SWEEP_SCHEDULES_RESULTS: Optional[pd.DataFrame] = None
|
||||
_SWEEP_SCHEDULES_RESULTS_CSV: Optional[str] = None
|
||||
|
||||
|
||||
def bench(types: TypeConfig,
|
||||
group_size: int,
|
||||
m: int,
|
||||
k: int,
|
||||
n: int,
|
||||
label: str,
|
||||
sub_label: str,
|
||||
sweep_schedules: bool = True) -> list[TMeasurement]:
|
||||
def bench(
|
||||
types: TypeConfig,
|
||||
group_size: int,
|
||||
m: int,
|
||||
k: int,
|
||||
n: int,
|
||||
label: str,
|
||||
sub_label: str,
|
||||
sweep_schedules: bool = True,
|
||||
) -> list[TMeasurement]:
|
||||
benchmark_tensors = create_bench_tensors((m, n, k), types, group_size)
|
||||
sub_label += f", L={len(benchmark_tensors)}"
|
||||
|
||||
name_type_string = f"W{types.weight_type}"+\
|
||||
f"-A{terse_type_name(types.act_type)}"
|
||||
name_type_string = f"W{types.weight_type}" + f"-A{terse_type_name(types.act_type)}"
|
||||
if types.group_scale_type is not None:
|
||||
name_type_string += f"-GS{terse_type_name(types.group_scale_type)}"
|
||||
if types.group_zero_type is not None:
|
||||
@ -332,31 +362,45 @@ def bench(types: TypeConfig,
|
||||
# pytorch impl
|
||||
timers.append(
|
||||
bench_fns(
|
||||
label, sub_label, "torch.matmul (fp16)",
|
||||
[torch_matmul_f16_create_bench_fn(bt)
|
||||
for bt in benchmark_tensors]))
|
||||
label,
|
||||
sub_label,
|
||||
"torch.matmul (fp16)",
|
||||
[torch_matmul_f16_create_bench_fn(bt) for bt in benchmark_tensors],
|
||||
)
|
||||
)
|
||||
|
||||
if types.act_type == torch.int8 or types.act_type == torch.float8_e4m3fn:
|
||||
timers.append(
|
||||
bench_fns(
|
||||
label, sub_label,
|
||||
f"cutlass_scaled_mm ({terse_type_name(types.act_type)})", [
|
||||
cutlass_scaled_mm_create_bench_fn(bt)
|
||||
for bt in benchmark_tensors
|
||||
]))
|
||||
label,
|
||||
sub_label,
|
||||
f"cutlass_scaled_mm ({terse_type_name(types.act_type)})",
|
||||
[cutlass_scaled_mm_create_bench_fn(bt) for bt in benchmark_tensors],
|
||||
)
|
||||
)
|
||||
|
||||
if types.act_type != torch.float8_e4m3fn:
|
||||
timers.append(
|
||||
bench_fns(label, sub_label, f"marlin ({name_type_string})",
|
||||
[marlin_create_bench_fn(bt)
|
||||
for bt in benchmark_tensors]))
|
||||
bench_fns(
|
||||
label,
|
||||
sub_label,
|
||||
f"marlin ({name_type_string})",
|
||||
[marlin_create_bench_fn(bt) for bt in benchmark_tensors],
|
||||
)
|
||||
)
|
||||
|
||||
# machete
|
||||
timers.append(
|
||||
bench_fns(label, sub_label, f"machete ({name_type_string})", [
|
||||
machete_create_bench_fn(bt, out_type=types.output_type)
|
||||
for bt in benchmark_tensors
|
||||
]))
|
||||
bench_fns(
|
||||
label,
|
||||
sub_label,
|
||||
f"machete ({name_type_string})",
|
||||
[
|
||||
machete_create_bench_fn(bt, out_type=types.output_type)
|
||||
for bt in benchmark_tensors
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
if sweep_schedules:
|
||||
global _SWEEP_SCHEDULES_RESULTS
|
||||
@ -371,7 +415,8 @@ def bench(types: TypeConfig,
|
||||
group_zeros_type=types.group_zero_type,
|
||||
token_scales_type=types.token_scale_type,
|
||||
channel_scales_type=types.channel_scale_type,
|
||||
out_type=types.output_type)
|
||||
out_type=types.output_type,
|
||||
)
|
||||
|
||||
if schedules is None or len(schedules) == 0:
|
||||
raise ValueError("No schedules found to sweep")
|
||||
@ -383,11 +428,17 @@ def bench(types: TypeConfig,
|
||||
if schedule_M >= 2 * max(m, 16) or schedule_M < m // 4:
|
||||
continue
|
||||
|
||||
res = bench_fns(label, sub_label, "machete_best", [
|
||||
machete_create_bench_fn(
|
||||
bt, out_type=types.output_type, schedule=schedule)
|
||||
for bt in benchmark_tensors
|
||||
])
|
||||
res = bench_fns(
|
||||
label,
|
||||
sub_label,
|
||||
"machete_best",
|
||||
[
|
||||
machete_create_bench_fn(
|
||||
bt, out_type=types.output_type, schedule=schedule
|
||||
)
|
||||
for bt in benchmark_tensors
|
||||
],
|
||||
)
|
||||
|
||||
results_row = {
|
||||
"M": m,
|
||||
@ -398,10 +449,8 @@ def bench(types: TypeConfig,
|
||||
"median": res.median,
|
||||
}
|
||||
if _SWEEP_SCHEDULES_RESULTS is None:
|
||||
_SWEEP_SCHEDULES_RESULTS = pd.DataFrame(
|
||||
columns=results_row.keys())
|
||||
_SWEEP_SCHEDULES_RESULTS.\
|
||||
loc[len(_SWEEP_SCHEDULES_RESULTS)] = results_row
|
||||
_SWEEP_SCHEDULES_RESULTS = pd.DataFrame(columns=results_row.keys())
|
||||
_SWEEP_SCHEDULES_RESULTS.loc[len(_SWEEP_SCHEDULES_RESULTS)] = results_row
|
||||
|
||||
print(f" {res.median:5.5} ", schedule)
|
||||
if not best or res.median < best.median:
|
||||
@ -422,8 +471,9 @@ def print_timers(timers: list[TMeasurement]):
|
||||
def run(args, MKNs: Iterable[tuple[int, int, int]]) -> Iterable[TMeasurement]:
|
||||
types = TypeConfig(
|
||||
act_type=args.act_type,
|
||||
weight_type=scalar_types.uint4b8 if args.group_zero_type is None \
|
||||
else scalar_types.uint4,
|
||||
weight_type=scalar_types.uint4b8
|
||||
if args.group_zero_type is None
|
||||
else scalar_types.uint4,
|
||||
output_type=args.out_type,
|
||||
group_scale_type=args.group_scale_type,
|
||||
group_zero_type=args.group_zero_type,
|
||||
@ -433,14 +483,16 @@ def run(args, MKNs: Iterable[tuple[int, int, int]]) -> Iterable[TMeasurement]:
|
||||
|
||||
results: list[TMeasurement] = []
|
||||
for m, k, n in MKNs:
|
||||
timers = bench(types,
|
||||
args.group_size,
|
||||
m,
|
||||
k,
|
||||
n,
|
||||
f"{args.act_type}-gemm",
|
||||
f"MKN=({m}x{k}x{n})",
|
||||
sweep_schedules=args.sweep_schedules)
|
||||
timers = bench(
|
||||
types,
|
||||
args.group_size,
|
||||
m,
|
||||
k,
|
||||
n,
|
||||
f"{args.act_type}-gemm",
|
||||
f"MKN=({m}x{k}x{n})",
|
||||
sweep_schedules=args.sweep_schedules,
|
||||
)
|
||||
print_timers(timers)
|
||||
results.extend(timers)
|
||||
|
||||
@ -454,7 +506,6 @@ def make_output(
|
||||
base_description: str,
|
||||
timestamp=None,
|
||||
):
|
||||
|
||||
print(f"== All Results {base_description} ====")
|
||||
print_timers(data)
|
||||
|
||||
@ -468,8 +519,7 @@ def make_output(
|
||||
|
||||
|
||||
def run_square_bench(args):
|
||||
dim_sizes = list(
|
||||
range(args.dim_start, args.dim_end + 1, args.dim_increment))
|
||||
dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment))
|
||||
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
|
||||
data = run(args.dtype, args.sweep_schedules, MKNs)
|
||||
|
||||
@ -479,8 +529,9 @@ def run_square_bench(args):
|
||||
def run_range_bench(args):
|
||||
m_start, k_start, n_start = (int(x) for x in args.dim_start.split(","))
|
||||
m_end, k_end, n_end = (int(x) for x in args.dim_end.split(","))
|
||||
m_increment, k_increment, n_increment = \
|
||||
(int(x) for x in args.dim_increment.split(","))
|
||||
m_increment, k_increment, n_increment = (
|
||||
int(x) for x in args.dim_increment.split(",")
|
||||
)
|
||||
Ms = list(range(m_start, m_end + 1, m_increment))
|
||||
Ks = list(range(k_start, k_end + 1, k_increment))
|
||||
Ns = list(range(n_start, n_end + 1, n_increment))
|
||||
@ -492,7 +543,6 @@ def run_range_bench(args):
|
||||
|
||||
|
||||
def run_model_bench(args):
|
||||
|
||||
print("Benchmarking models:")
|
||||
for i, model in enumerate(args.models):
|
||||
print(f"[{i}] {model}")
|
||||
@ -535,10 +585,13 @@ def run_model_bench(args):
|
||||
with open(f"model_bench-{type_string}-{timestr}.pkl", "wb") as f:
|
||||
args_dict = vars(args)
|
||||
args_dict.pop("func")
|
||||
pkl.dump({
|
||||
"args": args_dict,
|
||||
"results": all_results,
|
||||
}, f)
|
||||
pkl.dump(
|
||||
{
|
||||
"args": args_dict,
|
||||
"results": all_results,
|
||||
},
|
||||
f,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -554,7 +607,6 @@ if __name__ == "__main__":
|
||||
}[dt]
|
||||
|
||||
class ToTorchDtype(argparse.Action):
|
||||
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
setattr(namespace, self.dest, to_torch_dtype(values))
|
||||
|
||||
@ -580,32 +632,32 @@ Benchmark Machete GEMM.
|
||||
"--act-type",
|
||||
action=ToTorchDtype,
|
||||
required=True,
|
||||
choices=['bfloat16', 'float16', 'int8', 'float8_e4m3fn'],
|
||||
choices=["bfloat16", "float16", "int8", "float8_e4m3fn"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--group-scale-type",
|
||||
action=ToTorchDtype,
|
||||
choices=['bfloat16', 'float16'],
|
||||
choices=["bfloat16", "float16"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--group-zero-type",
|
||||
type=to_torch_dtype,
|
||||
choices=['bfloat16', 'float16'],
|
||||
choices=["bfloat16", "float16"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--channel-scale-type",
|
||||
action=ToTorchDtype,
|
||||
choices=['float'],
|
||||
choices=["float"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--token-scale-type",
|
||||
action=ToTorchDtype,
|
||||
choices=['float'],
|
||||
choices=["float"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--out-type",
|
||||
action=ToTorchDtype,
|
||||
choices=['bfloat16', 'float16'],
|
||||
choices=["bfloat16", "float16"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--group-size",
|
||||
@ -618,9 +670,11 @@ Benchmark Machete GEMM.
|
||||
action="store_true",
|
||||
help="Run a sweep over all supported schedules",
|
||||
)
|
||||
parser.add_argument("--sweep-csv-out",
|
||||
help="CSV to store sweep results",
|
||||
default="sch_sweep_results.csv")
|
||||
parser.add_argument(
|
||||
"--sweep-csv-out",
|
||||
help="CSV to store sweep results",
|
||||
default="sch_sweep_results.csv",
|
||||
)
|
||||
subparsers = parser.add_subparsers(dest="cmd", required=True)
|
||||
|
||||
square_parser = subparsers.add_parser("square_bench")
|
||||
@ -634,17 +688,20 @@ Benchmark Machete GEMM.
|
||||
"--dim-start",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Start value for M,K,N as common separated list")
|
||||
help="Start value for M,K,N as common separated list",
|
||||
)
|
||||
range_parser.add_argument(
|
||||
"--dim-end",
|
||||
type=str,
|
||||
required=True,
|
||||
help="End value (inclusive) for M,K,N as common separated list")
|
||||
help="End value (inclusive) for M,K,N as common separated list",
|
||||
)
|
||||
range_parser.add_argument(
|
||||
"--dim-increment",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Increment value for M,K,N as common separated list")
|
||||
help="Increment value for M,K,N as common separated list",
|
||||
)
|
||||
range_parser.set_defaults(func=run_range_bench)
|
||||
|
||||
model_parser = subparsers.add_parser("model_bench")
|
||||
@ -655,14 +712,12 @@ Benchmark Machete GEMM.
|
||||
default=DEFAULT_MODELS,
|
||||
choices=WEIGHT_SHAPES.keys(),
|
||||
)
|
||||
model_parser.add_argument("--tp-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=DEFAULT_TP_SIZES)
|
||||
model_parser.add_argument("--batch-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=DEFAULT_BATCH_SIZES)
|
||||
model_parser.add_argument(
|
||||
"--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES
|
||||
)
|
||||
model_parser.add_argument(
|
||||
"--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
|
||||
)
|
||||
model_parser.set_defaults(func=run_model_bench)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -6,19 +6,34 @@ from benchmark_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL,
|
||||
GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES,
|
||||
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.allspark_utils import (
|
||||
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, ALLSPARK_SUPPORTED_QUANT_TYPES)
|
||||
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
|
||||
ALLSPARK_SUPPORTED_QUANT_TYPES,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
MARLIN_SUPPORTED_GROUP_SIZES, query_marlin_supported_quant_types)
|
||||
GPTQ_MARLIN_MAX_PARALLEL,
|
||||
GPTQ_MARLIN_MIN_THREAD_N,
|
||||
MARLIN_SUPPORTED_GROUP_SIZES,
|
||||
query_marlin_supported_quant_types,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
MarlinWorkspace, marlin_quantize)
|
||||
MarlinWorkspace,
|
||||
marlin_quantize,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
|
||||
marlin_24_quantize)
|
||||
marlin_24_quantize,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights)
|
||||
gptq_pack,
|
||||
gptq_quantize_weights,
|
||||
quantize_weights,
|
||||
sort_weights,
|
||||
)
|
||||
from vllm.scalar_type import ScalarType
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
@ -29,22 +44,29 @@ ACT_ORDER_OPTS = [False, True]
|
||||
K_FULL_OPTS = [False, True]
|
||||
|
||||
|
||||
def bench_run(results: list[benchmark.Measurement], model: str,
|
||||
act_order: bool, is_k_full: bool, quant_type: ScalarType,
|
||||
group_size: int, size_m: int, size_k: int, size_n: int):
|
||||
def bench_run(
|
||||
results: list[benchmark.Measurement],
|
||||
model: str,
|
||||
act_order: bool,
|
||||
is_k_full: bool,
|
||||
quant_type: ScalarType,
|
||||
group_size: int,
|
||||
size_m: int,
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
):
|
||||
label = "Quant Matmul"
|
||||
|
||||
sub_label = ("{}, act={} k_full={}, q={}, g={}, "
|
||||
"MKN=({}x{}x{})".format(model, act_order, is_k_full,
|
||||
str(quant_type), group_size, size_m,
|
||||
size_k, size_n))
|
||||
sub_label = "{}, act={} k_full={}, q={}, g={}, MKN=({}x{}x{})".format(
|
||||
model, act_order, is_k_full, str(quant_type), group_size, size_m, size_k, size_n
|
||||
)
|
||||
|
||||
print(f"Testing: {sub_label}")
|
||||
|
||||
a = torch.randn(size_m, size_k).to(torch.half).cuda()
|
||||
b = torch.rand(size_k, size_n).to(torch.half).cuda()
|
||||
|
||||
a_tmp = (torch.zeros(size_m, size_k).to(torch.half).cuda())
|
||||
a_tmp = torch.zeros(size_m, size_k).to(torch.half).cuda()
|
||||
|
||||
# Marlin quant
|
||||
(
|
||||
@ -57,14 +79,16 @@ def bench_run(results: list[benchmark.Measurement], model: str,
|
||||
) = marlin_quantize(b, quant_type, group_size, act_order)
|
||||
|
||||
# Marlin_24 quant
|
||||
(marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta,
|
||||
marlin_24_s) = marlin_24_quantize(b, quant_type, group_size)
|
||||
(marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = (
|
||||
marlin_24_quantize(b, quant_type, group_size)
|
||||
)
|
||||
|
||||
marlin_zp = torch.empty(0, dtype=torch.int, device=b.device)
|
||||
|
||||
# GPTQ quant
|
||||
(w_ref, q_w, s, g_idx,
|
||||
rand_perm) = gptq_quantize_weights(b, quant_type, group_size, act_order)
|
||||
(w_ref, q_w, s, g_idx, rand_perm) = gptq_quantize_weights(
|
||||
b, quant_type, group_size, act_order
|
||||
)
|
||||
q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
|
||||
|
||||
# For act_order, sort the "weights" and "g_idx"
|
||||
@ -74,32 +98,37 @@ def bench_run(results: list[benchmark.Measurement], model: str,
|
||||
(q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx)
|
||||
|
||||
# Prepare
|
||||
marlin_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_MAX_PARALLEL)
|
||||
marlin_workspace = MarlinWorkspace(
|
||||
size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL
|
||||
)
|
||||
|
||||
marlin_24_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL)
|
||||
marlin_24_workspace = MarlinWorkspace(
|
||||
size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL
|
||||
)
|
||||
marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int)
|
||||
|
||||
# AllSpark W8A16 quant
|
||||
as_supported_case = (quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES
|
||||
and group_size == -1 and not act_order and is_k_full)
|
||||
as_supported_case = (
|
||||
quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES
|
||||
and group_size == -1
|
||||
and not act_order
|
||||
and is_k_full
|
||||
)
|
||||
if as_supported_case:
|
||||
properties = torch.cuda.get_device_properties(b.device.index)
|
||||
sm_count = properties.multi_processor_count
|
||||
sm_version = properties.major * 10 + properties.minor
|
||||
|
||||
supported_arch = (sm_version >= 80 and sm_version < 90)
|
||||
supported_arch = sm_version >= 80 and sm_version < 90
|
||||
as_supported_case = as_supported_case and supported_arch
|
||||
if supported_arch:
|
||||
has_zp = False
|
||||
w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size,
|
||||
has_zp)
|
||||
w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size, has_zp)
|
||||
qw = qw.to(torch.uint8)
|
||||
|
||||
qw_reorder, s_reorder, zp_reorder = \
|
||||
ops.allspark_repack_weight(
|
||||
qw, s, zp, has_zp)
|
||||
qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight(
|
||||
qw, s, zp, has_zp
|
||||
)
|
||||
CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
|
||||
|
||||
globals = {
|
||||
@ -136,8 +165,7 @@ def bench_run(results: list[benchmark.Measurement], model: str,
|
||||
"zp_reorder": zp_reorder if as_supported_case else None,
|
||||
"sm_count": sm_count if as_supported_case else None,
|
||||
"sm_version": sm_version if as_supported_case else None,
|
||||
"CUBLAS_M_THRESHOLD":
|
||||
CUBLAS_M_THRESHOLD if as_supported_case else None,
|
||||
"CUBLAS_M_THRESHOLD": CUBLAS_M_THRESHOLD if as_supported_case else None,
|
||||
# Kernels
|
||||
"gptq_marlin_gemm": ops.gptq_marlin_gemm,
|
||||
"gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm,
|
||||
@ -158,60 +186,63 @@ def bench_run(results: list[benchmark.Measurement], model: str,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="pytorch_gemm",
|
||||
).blocked_autorange(min_run_time=min_run_time))
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt=
|
||||
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501
|
||||
stmt="output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="gptq_marlin_gemm_fp16",
|
||||
).blocked_autorange(min_run_time=min_run_time))
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt=
|
||||
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501
|
||||
stmt="output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="gptq_marlin_gemm_fp32",
|
||||
).blocked_autorange(min_run_time=min_run_time))
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
if (quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
|
||||
and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES):
|
||||
if (
|
||||
quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
|
||||
and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
|
||||
):
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt=
|
||||
"output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)", # noqa: E501
|
||||
stmt="output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="gptq_marlin_24_gemm",
|
||||
).blocked_autorange(min_run_time=min_run_time))
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt=
|
||||
"q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)", # noqa: E501
|
||||
stmt="q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="gptq_marlin_repack",
|
||||
).blocked_autorange(min_run_time=min_run_time))
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
if as_supported_case:
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt=
|
||||
"output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)", # noqa: E501
|
||||
stmt="output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="allspark_w8a16_gemm_fp32",
|
||||
).blocked_autorange(min_run_time=min_run_time))
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
|
||||
def main(args):
|
||||
@ -233,37 +264,50 @@ def main(args):
|
||||
continue
|
||||
|
||||
for act_order in ACT_ORDER_OPTS:
|
||||
if len(args.limit_act_order
|
||||
) > 0 and act_order not in args.limit_act_order:
|
||||
if (
|
||||
len(args.limit_act_order) > 0
|
||||
and act_order not in args.limit_act_order
|
||||
):
|
||||
continue
|
||||
|
||||
for is_k_full in K_FULL_OPTS:
|
||||
if len(args.limit_k_full
|
||||
) > 0 and is_k_full not in args.limit_k_full:
|
||||
if (
|
||||
len(args.limit_k_full) > 0
|
||||
and is_k_full not in args.limit_k_full
|
||||
):
|
||||
continue
|
||||
|
||||
for quant_type in query_marlin_supported_quant_types(
|
||||
False):
|
||||
if len(args.limit_num_bits) > 0 and \
|
||||
quant_type.size_bits not in args.limit_num_bits:
|
||||
for quant_type in query_marlin_supported_quant_types(False):
|
||||
if (
|
||||
len(args.limit_num_bits) > 0
|
||||
and quant_type.size_bits not in args.limit_num_bits
|
||||
):
|
||||
continue
|
||||
|
||||
for group_size in MARLIN_SUPPORTED_GROUP_SIZES:
|
||||
if len(
|
||||
args.limit_group_size
|
||||
) > 0 and group_size not in args.limit_group_size:
|
||||
if (
|
||||
len(args.limit_group_size) > 0
|
||||
and group_size not in args.limit_group_size
|
||||
):
|
||||
continue
|
||||
|
||||
# For act_order, the group_size must be less than
|
||||
# size_k
|
||||
if act_order and (group_size == size_k
|
||||
or group_size == -1):
|
||||
if act_order and (group_size == size_k or group_size == -1):
|
||||
continue
|
||||
|
||||
for size_m in args.batch_sizes:
|
||||
bench_run(results, model, act_order, is_k_full,
|
||||
quant_type, group_size, size_m,
|
||||
size_k, size_n)
|
||||
bench_run(
|
||||
results,
|
||||
model,
|
||||
act_order,
|
||||
is_k_full,
|
||||
quant_type,
|
||||
group_size,
|
||||
size_m,
|
||||
size_k,
|
||||
size_n,
|
||||
)
|
||||
|
||||
compare = benchmark.Compare(results)
|
||||
compare.print()
|
||||
@ -274,7 +318,8 @@ def main(args):
|
||||
#
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark Marlin across specified models/shapes/batches")
|
||||
description="Benchmark Marlin across specified models/shapes/batches"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
@ -282,10 +327,9 @@ if __name__ == "__main__":
|
||||
default=DEFAULT_MODELS,
|
||||
choices=WEIGHT_SHAPES.keys(),
|
||||
)
|
||||
parser.add_argument("--batch-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=DEFAULT_BATCH_SIZES)
|
||||
parser.add_argument(
|
||||
"--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
|
||||
)
|
||||
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--limit-group-size", nargs="+", type=int, default=[])
|
||||
|
||||
@ -31,56 +31,60 @@ class BenchmarkConfig(TypedDict):
|
||||
num_stages: int
|
||||
|
||||
|
||||
def benchmark_config(config: BenchmarkConfig,
|
||||
num_tokens: int,
|
||||
num_experts: int,
|
||||
shard_intermediate_size: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
num_iters: int = 100,
|
||||
block_quant_shape: List[int] = None,
|
||||
use_deep_gemm: bool = False) -> float:
|
||||
def benchmark_config(
|
||||
config: BenchmarkConfig,
|
||||
num_tokens: int,
|
||||
num_experts: int,
|
||||
shard_intermediate_size: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
num_iters: int = 100,
|
||||
block_quant_shape: List[int] = None,
|
||||
use_deep_gemm: bool = False,
|
||||
) -> float:
|
||||
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
if use_int8_w8a16:
|
||||
w1 = torch.randint(-127,
|
||||
127, (
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
),
|
||||
dtype=torch.int8)
|
||||
w2 = torch.randint(-127,
|
||||
127, (
|
||||
num_experts,
|
||||
hidden_size,
|
||||
shard_intermediate_size // 2,
|
||||
),
|
||||
dtype=torch.int8)
|
||||
w1 = torch.randint(
|
||||
-127,
|
||||
127,
|
||||
(
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
),
|
||||
dtype=torch.int8,
|
||||
)
|
||||
w2 = torch.randint(
|
||||
-127,
|
||||
127,
|
||||
(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
shard_intermediate_size // 2,
|
||||
),
|
||||
dtype=torch.int8,
|
||||
)
|
||||
else:
|
||||
w1 = torch.randn(num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
dtype=init_dtype)
|
||||
w2 = torch.randn(num_experts,
|
||||
hidden_size,
|
||||
shard_intermediate_size // 2,
|
||||
dtype=init_dtype)
|
||||
gating_output = torch.randn(num_iters,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
dtype=torch.float32)
|
||||
w1 = torch.randn(
|
||||
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
|
||||
)
|
||||
w2 = torch.randn(
|
||||
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
|
||||
)
|
||||
gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
|
||||
|
||||
w1_scale = None
|
||||
w2_scale = None
|
||||
a1_scale = None
|
||||
a2_scale = None
|
||||
if use_int8_w8a16:
|
||||
w1_scale = torch.randn((num_experts, 2 * shard_intermediate_size),
|
||||
dtype=torch.float32)
|
||||
w1_scale = torch.randn(
|
||||
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32
|
||||
)
|
||||
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
|
||||
if use_fp8_w8a8:
|
||||
if block_quant_shape:
|
||||
@ -93,10 +97,14 @@ def benchmark_config(config: BenchmarkConfig,
|
||||
n_tiles_w2 = (K + block_n - 1) // block_n
|
||||
k_tiles_w1 = (K + block_k - 1) // block_k
|
||||
k_tiles_w2 = (N + block_k - 1) // block_k
|
||||
w1_scale = torch.rand((E, n_tiles_w1, k_tiles_w1),
|
||||
dtype=torch.float32) * factor_for_scale
|
||||
w2_scale = torch.rand((E, n_tiles_w2, k_tiles_w2),
|
||||
dtype=torch.float32) * factor_for_scale
|
||||
w1_scale = (
|
||||
torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
|
||||
* factor_for_scale
|
||||
)
|
||||
w2_scale = (
|
||||
torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
|
||||
* factor_for_scale
|
||||
)
|
||||
else:
|
||||
w1_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||
w2_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||
@ -114,10 +122,12 @@ def benchmark_config(config: BenchmarkConfig,
|
||||
|
||||
def run():
|
||||
from vllm.model_executor.layers.fused_moe import override_config
|
||||
|
||||
with override_config(config):
|
||||
if use_deep_gemm:
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
x, input_gating, topk, False)
|
||||
x, input_gating, topk, False
|
||||
)
|
||||
return fused_experts(
|
||||
x,
|
||||
w1,
|
||||
@ -213,8 +223,7 @@ def get_rocm_tuning_space(use_fp16):
|
||||
return param_ranges
|
||||
|
||||
|
||||
def get_configs_compute_bound(use_fp16,
|
||||
block_quant_shape) -> list[dict[str, int]]:
|
||||
def get_configs_compute_bound(use_fp16, block_quant_shape) -> list[dict[str, int]]:
|
||||
configs: list[BenchmarkConfig] = []
|
||||
|
||||
if current_platform.is_rocm():
|
||||
@ -250,20 +259,25 @@ def get_configs_compute_bound(use_fp16,
|
||||
if block_quant_shape is not None and not use_fp16:
|
||||
block_n, block_k = block_quant_shape[0], block_quant_shape[1]
|
||||
for config in configs[:]:
|
||||
if config["BLOCK_SIZE_K"] % block_k != 0 or config[
|
||||
"BLOCK_SIZE_N"] % block_n != 0:
|
||||
if (
|
||||
config["BLOCK_SIZE_K"] % block_k != 0
|
||||
or config["BLOCK_SIZE_N"] % block_n != 0
|
||||
):
|
||||
configs.remove(config)
|
||||
return configs
|
||||
|
||||
|
||||
def prune_rocm_search_space(num_tokens, shard_intermediate_size, hidden_size,
|
||||
search_space, is_fp16, topk):
|
||||
def prune_rocm_search_space(
|
||||
num_tokens, shard_intermediate_size, hidden_size, search_space, is_fp16, topk
|
||||
):
|
||||
N1, K1 = shard_intermediate_size, hidden_size
|
||||
N2, K2 = hidden_size, shard_intermediate_size // 2
|
||||
pruned_space_1 = prune_rocm_configs(num_tokens * topk, N1, K1,
|
||||
search_space, is_fp16)
|
||||
pruned_space_2 = prune_rocm_configs(num_tokens * topk, N2, K2,
|
||||
search_space, is_fp16)
|
||||
pruned_space_1 = prune_rocm_configs(
|
||||
num_tokens * topk, N1, K1, search_space, is_fp16
|
||||
)
|
||||
pruned_space_2 = prune_rocm_configs(
|
||||
num_tokens * topk, N2, K2, search_space, is_fp16
|
||||
)
|
||||
search_space = merge_unique_dicts(pruned_space_1, pruned_space_2)
|
||||
return search_space
|
||||
|
||||
@ -301,14 +315,14 @@ def prune_rocm_configs(M, N, K, configs, is_fp16=True):
|
||||
SPLIT_K = config.get("SPLIT_K", 1)
|
||||
GROUP_M = config.get("GROUP_SIZE_M")
|
||||
if is_fp16:
|
||||
if (matrix_instr_nonkdim > BLOCK_SIZE_M
|
||||
or matrix_instr_nonkdim > BLOCK_SIZE_N):
|
||||
if (
|
||||
matrix_instr_nonkdim > BLOCK_SIZE_M
|
||||
or matrix_instr_nonkdim > BLOCK_SIZE_N
|
||||
):
|
||||
continue
|
||||
if (matrix_instr_nonkdim >= M
|
||||
and matrix_instr_nonkdim != BLOCK_SIZE_M):
|
||||
if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M:
|
||||
continue
|
||||
if (matrix_instr_nonkdim >= N
|
||||
and matrix_instr_nonkdim != BLOCK_SIZE_N):
|
||||
if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N:
|
||||
continue
|
||||
# Skip BLOCK_SIZE that is too large compare to M/N
|
||||
# unless BLOCK_SIZE is already small enough
|
||||
@ -329,8 +343,10 @@ def prune_rocm_configs(M, N, K, configs, is_fp16=True):
|
||||
continue
|
||||
# out of shared memory resource
|
||||
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
|
||||
LDS = (BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a +
|
||||
BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b)
|
||||
LDS = (
|
||||
BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a
|
||||
+ BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b
|
||||
)
|
||||
if LDS > 65536:
|
||||
continue
|
||||
# Skip small block sizes and num_warps for large gemm
|
||||
@ -364,7 +380,6 @@ def merge_unique_dicts(list1, list2):
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
class BenchmarkWorker:
|
||||
|
||||
def __init__(self, seed: int) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
current_platform.seed_everything(seed)
|
||||
@ -388,36 +403,40 @@ class BenchmarkWorker:
|
||||
use_deep_gemm: bool = False,
|
||||
) -> tuple[dict[str, int], float]:
|
||||
current_platform.seed_everything(self.seed)
|
||||
dtype_str = get_config_dtype_str(dtype,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_fp8_w8a8=use_fp8_w8a8)
|
||||
dtype_str = get_config_dtype_str(
|
||||
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
|
||||
)
|
||||
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||
# is the intermediate size after silu_and_mul.
|
||||
op_config = get_moe_configs(num_experts, shard_intermediate_size // 2,
|
||||
dtype_str)
|
||||
op_config = get_moe_configs(
|
||||
num_experts, shard_intermediate_size // 2, dtype_str
|
||||
)
|
||||
if op_config is None:
|
||||
config = get_default_config(num_tokens,
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype_str,
|
||||
is_marlin=False)
|
||||
config = get_default_config(
|
||||
num_tokens,
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype_str,
|
||||
is_marlin=False,
|
||||
)
|
||||
else:
|
||||
config = op_config[min(op_config.keys(),
|
||||
key=lambda x: abs(x - num_tokens))]
|
||||
kernel_time = benchmark_config(config,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
num_iters=100,
|
||||
block_quant_shape=block_quant_shape,
|
||||
use_deep_gemm=use_deep_gemm)
|
||||
config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))]
|
||||
kernel_time = benchmark_config(
|
||||
config,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
num_iters=100,
|
||||
block_quant_shape=block_quant_shape,
|
||||
use_deep_gemm=use_deep_gemm,
|
||||
)
|
||||
return config, kernel_time
|
||||
|
||||
def tune(
|
||||
@ -438,10 +457,14 @@ class BenchmarkWorker:
|
||||
best_time = float("inf")
|
||||
if current_platform.is_rocm():
|
||||
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
|
||||
search_space = prune_rocm_search_space(num_tokens,
|
||||
shard_intermediate_size,
|
||||
hidden_size, search_space,
|
||||
is_fp16, topk)
|
||||
search_space = prune_rocm_search_space(
|
||||
num_tokens,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
search_space,
|
||||
is_fp16,
|
||||
topk,
|
||||
)
|
||||
|
||||
need_device_guard = False
|
||||
if current_platform.is_rocm():
|
||||
@ -449,8 +472,7 @@ class BenchmarkWorker:
|
||||
if visible_device != f"{self.device_id}":
|
||||
need_device_guard = True
|
||||
|
||||
with torch.cuda.device(
|
||||
self.device_id) if need_device_guard else nullcontext():
|
||||
with torch.cuda.device(self.device_id) if need_device_guard else nullcontext():
|
||||
for config in tqdm(search_space):
|
||||
try:
|
||||
kernel_time = benchmark_config(
|
||||
@ -465,7 +487,8 @@ class BenchmarkWorker:
|
||||
use_int8_w8a16,
|
||||
num_iters=20,
|
||||
block_quant_shape=block_quant_shape,
|
||||
use_deep_gemm=use_deep_gemm)
|
||||
use_deep_gemm=use_deep_gemm,
|
||||
)
|
||||
except triton.runtime.autotuner.OutOfResources:
|
||||
# Some configurations may be invalid and fail to compile.
|
||||
continue
|
||||
@ -481,42 +504,44 @@ class BenchmarkWorker:
|
||||
|
||||
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
|
||||
return {
|
||||
"BLOCK_SIZE_M":
|
||||
config["BLOCK_SIZE_M"],
|
||||
"BLOCK_SIZE_N":
|
||||
config["BLOCK_SIZE_N"],
|
||||
"BLOCK_SIZE_K":
|
||||
config["BLOCK_SIZE_K"],
|
||||
"GROUP_SIZE_M":
|
||||
config["GROUP_SIZE_M"],
|
||||
"num_warps":
|
||||
config["num_warps"],
|
||||
"num_stages":
|
||||
config["num_stages"],
|
||||
**({
|
||||
"waves_per_eu": config["waves_per_eu"]
|
||||
} if "waves_per_eu" in config else {}),
|
||||
**({
|
||||
"matrix_instr_nonkdim": config["matrix_instr_nonkdim"]
|
||||
} if "matrix_instr_nonkdim" in config else {}),
|
||||
**({
|
||||
"kpack": config["kpack"]
|
||||
} if "kpack" in config else {}),
|
||||
"BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
|
||||
"BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
|
||||
"BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
|
||||
"GROUP_SIZE_M": config["GROUP_SIZE_M"],
|
||||
"num_warps": config["num_warps"],
|
||||
"num_stages": config["num_stages"],
|
||||
**(
|
||||
{"waves_per_eu": config["waves_per_eu"]} if "waves_per_eu" in config else {}
|
||||
),
|
||||
**(
|
||||
{"matrix_instr_nonkdim": config["matrix_instr_nonkdim"]}
|
||||
if "matrix_instr_nonkdim" in config
|
||||
else {}
|
||||
),
|
||||
**({"kpack": config["kpack"]} if "kpack" in config else {}),
|
||||
}
|
||||
|
||||
|
||||
def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int,
|
||||
shard_intermediate_size: int, hidden_size: int, topk: int,
|
||||
dtype: torch.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool,
|
||||
block_quant_shape: List[int]) -> None:
|
||||
dtype_str = get_config_dtype_str(dtype,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_fp8_w8a8=use_fp8_w8a8)
|
||||
def save_configs(
|
||||
configs: dict[int, BenchmarkConfig],
|
||||
num_experts: int,
|
||||
shard_intermediate_size: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
block_quant_shape: List[int],
|
||||
) -> None:
|
||||
dtype_str = get_config_dtype_str(
|
||||
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
|
||||
)
|
||||
|
||||
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||
# is the intermediate size after silu_and_mul.
|
||||
filename = get_config_file_name(num_experts, shard_intermediate_size // 2,
|
||||
dtype_str, block_quant_shape)
|
||||
filename = get_config_file_name(
|
||||
num_experts, shard_intermediate_size // 2, dtype_str, block_quant_shape
|
||||
)
|
||||
|
||||
print(f"Writing best config to {filename}...")
|
||||
with open(filename, "w") as f:
|
||||
@ -525,18 +550,16 @@ def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int,
|
||||
|
||||
|
||||
def get_weight_block_size_safety(config, default_value=None):
|
||||
|
||||
quantization_config = getattr(config, 'quantization_config', {})
|
||||
quantization_config = getattr(config, "quantization_config", {})
|
||||
if isinstance(quantization_config, dict):
|
||||
return quantization_config.get('weight_block_size', default_value)
|
||||
return quantization_config.get("weight_block_size", default_value)
|
||||
return default_value
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
print(args)
|
||||
|
||||
config = get_config(model=args.model,
|
||||
trust_remote_code=args.trust_remote_code)
|
||||
config = get_config(model=args.model, trust_remote_code=args.trust_remote_code)
|
||||
if args.model_prefix:
|
||||
config = getattr(config, args.model_prefix)
|
||||
config = SimpleNamespace(**config)
|
||||
@ -551,14 +574,12 @@ def main(args: argparse.Namespace):
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
elif (config.architectures[0]
|
||||
in ("DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM")):
|
||||
elif config.architectures[0] in ("DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"):
|
||||
E = config.n_routed_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
elif config.architectures[0] in ("Qwen2MoeForCausalLM",
|
||||
"Qwen3MoeForCausalLM"):
|
||||
elif config.architectures[0] in ("Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"):
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
@ -573,16 +594,35 @@ def main(args: argparse.Namespace):
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
|
||||
hidden_size = config.hidden_size
|
||||
dtype = torch.float16 if current_platform.is_rocm() else getattr(
|
||||
torch, config.torch_dtype)
|
||||
dtype = (
|
||||
torch.float16
|
||||
if current_platform.is_rocm()
|
||||
else getattr(torch, config.torch_dtype)
|
||||
)
|
||||
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
||||
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
||||
block_quant_shape = get_weight_block_size_safety(config)
|
||||
|
||||
if args.batch_size is None:
|
||||
batch_sizes = [
|
||||
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
|
||||
2048, 3072, 4096
|
||||
1,
|
||||
2,
|
||||
4,
|
||||
8,
|
||||
16,
|
||||
24,
|
||||
32,
|
||||
48,
|
||||
64,
|
||||
96,
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
1536,
|
||||
2048,
|
||||
3072,
|
||||
4096,
|
||||
]
|
||||
else:
|
||||
batch_sizes = [args.batch_size]
|
||||
@ -593,7 +633,8 @@ def main(args: argparse.Namespace):
|
||||
# Ray will set ROCR_VISIBLE_DEVICES for device visibility
|
||||
logger.warning(
|
||||
"Ray uses ROCR_VISIBLE_DEVICES to control device accessibility."
|
||||
"Replacing HIP_VISIBLE_DEVICES with ROCR_VISIBLE_DEVICES.")
|
||||
"Replacing HIP_VISIBLE_DEVICES with ROCR_VISIBLE_DEVICES."
|
||||
)
|
||||
val = os.environ["HIP_VISIBLE_DEVICES"]
|
||||
os.environ["ROCR_VISIBLE_DEVICES"] = val
|
||||
del os.environ["HIP_VISIBLE_DEVICES"]
|
||||
@ -620,25 +661,59 @@ def main(args: argparse.Namespace):
|
||||
|
||||
start = time.time()
|
||||
configs = _distribute(
|
||||
"tune", [(batch_size, E, shard_intermediate_size, hidden_size,
|
||||
topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space,
|
||||
block_quant_shape, use_deep_gemm)
|
||||
for batch_size in batch_sizes])
|
||||
"tune",
|
||||
[
|
||||
(
|
||||
batch_size,
|
||||
E,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
search_space,
|
||||
block_quant_shape,
|
||||
use_deep_gemm,
|
||||
)
|
||||
for batch_size in batch_sizes
|
||||
],
|
||||
)
|
||||
best_configs = {
|
||||
M: sort_config(config)
|
||||
for M, config in zip(batch_sizes, configs)
|
||||
M: sort_config(config) for M, config in zip(batch_sizes, configs)
|
||||
}
|
||||
save_configs(best_configs, E, shard_intermediate_size, hidden_size,
|
||||
topk, dtype, use_fp8_w8a8, use_int8_w8a16,
|
||||
block_quant_shape)
|
||||
save_configs(
|
||||
best_configs,
|
||||
E,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
block_quant_shape,
|
||||
)
|
||||
end = time.time()
|
||||
print(f"Tuning took {end - start:.2f} seconds")
|
||||
else:
|
||||
outputs = _distribute(
|
||||
"benchmark",
|
||||
[(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype,
|
||||
use_fp8_w8a8, use_int8_w8a16, block_quant_shape, use_deep_gemm)
|
||||
for batch_size in batch_sizes])
|
||||
[
|
||||
(
|
||||
batch_size,
|
||||
E,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
block_quant_shape,
|
||||
use_deep_gemm,
|
||||
)
|
||||
for batch_size in batch_sizes
|
||||
],
|
||||
)
|
||||
|
||||
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
|
||||
print(f"Batch size: {batch_size}, config: {config}")
|
||||
@ -647,18 +722,15 @@ def main(args: argparse.Namespace):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser()
|
||||
parser.add_argument("--model",
|
||||
type=str,
|
||||
default="mistralai/Mixtral-8x7B-Instruct-v0.1")
|
||||
parser.add_argument("--tp-size",
|
||||
"-tp",
|
||||
"--tensor-parallel-size",
|
||||
type=int,
|
||||
default=2)
|
||||
parser.add_argument("--dtype",
|
||||
type=str,
|
||||
choices=["auto", "fp8_w8a8", "int8_w8a16"],
|
||||
default="auto")
|
||||
parser.add_argument(
|
||||
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp-size", "-tp", "--tensor-parallel-size", type=int, default=2
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto"
|
||||
)
|
||||
parser.add_argument("--use-deep-gemm", action="store_true")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--batch-size", type=int, required=False)
|
||||
|
||||
@ -8,7 +8,9 @@ import torch
|
||||
from transformers import AutoConfig
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
_moe_permute, _moe_unpermute_and_reduce)
|
||||
_moe_permute,
|
||||
_moe_unpermute_and_reduce,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import *
|
||||
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import *
|
||||
from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize
|
||||
@ -27,15 +29,17 @@ class BenchmarkConfig(TypedDict):
|
||||
num_stages: int
|
||||
|
||||
|
||||
def benchmark_permute(num_tokens: int,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
num_iters: int = 100,
|
||||
use_customized_permute: bool = False) -> float:
|
||||
def benchmark_permute(
|
||||
num_tokens: int,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
num_iters: int = 100,
|
||||
use_customized_permute: bool = False,
|
||||
) -> float:
|
||||
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
||||
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
# output_hidden_states = torch.empty_like(hidden_states)
|
||||
@ -46,36 +50,41 @@ def benchmark_permute(num_tokens: int,
|
||||
align_block_size = None
|
||||
qhidden_states = hidden_states
|
||||
|
||||
gating_output = torch.randn(num_iters,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
dtype=torch.float32)
|
||||
gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
|
||||
|
||||
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
qhidden_states, input_gating, topk, False)
|
||||
qhidden_states, input_gating, topk, False
|
||||
)
|
||||
|
||||
def prepare(i: int):
|
||||
input_gating.copy_(gating_output[i])
|
||||
|
||||
def run():
|
||||
if use_customized_permute:
|
||||
(permuted_hidden_states, first_token_off, inv_perm_idx,
|
||||
m_indices) = moe_permute(
|
||||
qhidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
token_expert_indices=token_expert_indices,
|
||||
topk=topk,
|
||||
n_expert=num_experts,
|
||||
n_local_expert=num_experts,
|
||||
expert_map=None,
|
||||
align_block_size=align_block_size,
|
||||
)
|
||||
(permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = (
|
||||
moe_permute(
|
||||
qhidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
token_expert_indices=token_expert_indices,
|
||||
topk=topk,
|
||||
n_expert=num_experts,
|
||||
n_local_expert=num_experts,
|
||||
expert_map=None,
|
||||
align_block_size=align_block_size,
|
||||
)
|
||||
)
|
||||
else:
|
||||
(permuted_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
|
||||
inv_perm) = _moe_permute(qhidden_states, None, topk_ids,
|
||||
num_experts, None, align_block_size)
|
||||
(
|
||||
permuted_hidden_states,
|
||||
a1q_scale,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
inv_perm,
|
||||
) = _moe_permute(
|
||||
qhidden_states, None, topk_ids, num_experts, None, align_block_size
|
||||
)
|
||||
|
||||
# JIT compilation & warmup
|
||||
run()
|
||||
@ -111,15 +120,17 @@ def benchmark_permute(num_tokens: int,
|
||||
return avg
|
||||
|
||||
|
||||
def benchmark_unpermute(num_tokens: int,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
num_iters: int = 100,
|
||||
use_customized_permute: bool = False) -> float:
|
||||
def benchmark_unpermute(
|
||||
num_tokens: int,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
num_iters: int = 100,
|
||||
use_customized_permute: bool = False,
|
||||
) -> float:
|
||||
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
||||
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
output_hidden_states = torch.empty_like(hidden_states)
|
||||
@ -133,46 +144,74 @@ def benchmark_unpermute(num_tokens: int,
|
||||
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
|
||||
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
qhidden_states, input_gating, topk, False)
|
||||
qhidden_states, input_gating, topk, False
|
||||
)
|
||||
|
||||
def prepare():
|
||||
if use_customized_permute:
|
||||
(permuted_hidden_states, first_token_off, inv_perm_idx,
|
||||
m_indices) = moe_permute(
|
||||
qhidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
token_expert_indices=token_expert_indices,
|
||||
topk=topk,
|
||||
n_expert=num_experts,
|
||||
n_local_expert=num_experts,
|
||||
expert_map=None,
|
||||
align_block_size=align_block_size,
|
||||
)
|
||||
(permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = (
|
||||
moe_permute(
|
||||
qhidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
token_expert_indices=token_expert_indices,
|
||||
topk=topk,
|
||||
n_expert=num_experts,
|
||||
n_local_expert=num_experts,
|
||||
expert_map=None,
|
||||
align_block_size=align_block_size,
|
||||
)
|
||||
)
|
||||
# convert to fp16/bf16 as gemm output
|
||||
return (permuted_hidden_states.to(dtype), first_token_off,
|
||||
inv_perm_idx, m_indices)
|
||||
return (
|
||||
permuted_hidden_states.to(dtype),
|
||||
first_token_off,
|
||||
inv_perm_idx,
|
||||
m_indices,
|
||||
)
|
||||
else:
|
||||
(permuted_qhidden_states, a1q_scale, sorted_token_ids, expert_ids,
|
||||
inv_perm) = _moe_permute(qhidden_states, None, topk_ids,
|
||||
num_experts, None, align_block_size)
|
||||
(
|
||||
permuted_qhidden_states,
|
||||
a1q_scale,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
inv_perm,
|
||||
) = _moe_permute(
|
||||
qhidden_states, None, topk_ids, num_experts, None, align_block_size
|
||||
)
|
||||
# convert to fp16/bf16 as gemm output
|
||||
return (permuted_qhidden_states.to(dtype), a1q_scale,
|
||||
sorted_token_ids, expert_ids, inv_perm)
|
||||
return (
|
||||
permuted_qhidden_states.to(dtype),
|
||||
a1q_scale,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
inv_perm,
|
||||
)
|
||||
|
||||
def run(input: tuple):
|
||||
if use_customized_permute:
|
||||
(permuted_hidden_states, first_token_off, inv_perm_idx,
|
||||
m_indices) = input
|
||||
moe_unpermute(permuted_hidden_states, topk_weights, topk_ids,
|
||||
inv_perm_idx, first_token_off, topk, num_experts,
|
||||
num_experts)
|
||||
(permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = input
|
||||
moe_unpermute(
|
||||
permuted_hidden_states,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inv_perm_idx,
|
||||
first_token_off,
|
||||
topk,
|
||||
num_experts,
|
||||
num_experts,
|
||||
)
|
||||
else:
|
||||
(permuted_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
|
||||
inv_perm) = input
|
||||
_moe_unpermute_and_reduce(output_hidden_states,
|
||||
permuted_hidden_states, inv_perm,
|
||||
topk_weights)
|
||||
(
|
||||
permuted_hidden_states,
|
||||
a1q_scale,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
inv_perm,
|
||||
) = input
|
||||
_moe_unpermute_and_reduce(
|
||||
output_hidden_states, permuted_hidden_states, inv_perm, topk_weights
|
||||
)
|
||||
|
||||
# JIT compilation & warmup
|
||||
input = prepare()
|
||||
@ -209,7 +248,6 @@ def benchmark_unpermute(num_tokens: int,
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
class BenchmarkWorker:
|
||||
|
||||
def __init__(self, seed: int) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
current_platform.seed_everything(seed)
|
||||
@ -241,7 +279,8 @@ class BenchmarkWorker:
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
num_iters=100,
|
||||
use_customized_permute=use_customized_permute)
|
||||
use_customized_permute=use_customized_permute,
|
||||
)
|
||||
unpermute_time = benchmark_unpermute(
|
||||
num_tokens,
|
||||
num_experts,
|
||||
@ -251,15 +290,15 @@ class BenchmarkWorker:
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
num_iters=100,
|
||||
use_customized_permute=use_customized_permute)
|
||||
use_customized_permute=use_customized_permute,
|
||||
)
|
||||
return permute_time, unpermute_time
|
||||
|
||||
|
||||
def get_weight_block_size_safety(config, default_value=None):
|
||||
|
||||
quantization_config = getattr(config, 'quantization_config', {})
|
||||
quantization_config = getattr(config, "quantization_config", {})
|
||||
if isinstance(quantization_config, dict):
|
||||
return quantization_config.get('weight_block_size', default_value)
|
||||
return quantization_config.get("weight_block_size", default_value)
|
||||
return default_value
|
||||
|
||||
|
||||
@ -267,20 +306,21 @@ def main(args: argparse.Namespace):
|
||||
print(args)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
args.model, trust_remote_code=args.trust_remote_code)
|
||||
args.model, trust_remote_code=args.trust_remote_code
|
||||
)
|
||||
if config.architectures[0] == "DbrxForCausalLM":
|
||||
E = config.ffn_config.moe_num_experts
|
||||
topk = config.ffn_config.moe_top_k
|
||||
elif config.architectures[0] == "JambaForCausalLM":
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
elif (config.architectures[0] == "DeepseekV3ForCausalLM"
|
||||
or config.architectures[0] == "DeepseekV2ForCausalLM"):
|
||||
elif (
|
||||
config.architectures[0] == "DeepseekV3ForCausalLM"
|
||||
or config.architectures[0] == "DeepseekV2ForCausalLM"
|
||||
):
|
||||
E = config.n_routed_experts
|
||||
topk = config.num_experts_per_tok
|
||||
elif config.architectures[0] in [
|
||||
"Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"
|
||||
]:
|
||||
elif config.architectures[0] in ["Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"]:
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
|
||||
@ -299,8 +339,24 @@ def main(args: argparse.Namespace):
|
||||
|
||||
if args.batch_size is None:
|
||||
batch_sizes = [
|
||||
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
|
||||
2048, 3072, 4096
|
||||
1,
|
||||
2,
|
||||
4,
|
||||
8,
|
||||
16,
|
||||
24,
|
||||
32,
|
||||
48,
|
||||
64,
|
||||
96,
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
1536,
|
||||
2048,
|
||||
3072,
|
||||
4096,
|
||||
]
|
||||
else:
|
||||
batch_sizes = [args.batch_size]
|
||||
@ -321,9 +377,21 @@ def main(args: argparse.Namespace):
|
||||
return ray.get(outputs)
|
||||
|
||||
outputs = _distribute(
|
||||
"benchmark", [(batch_size, E, hidden_size, topk, dtype, use_fp8_w8a8,
|
||||
use_int8_w8a16, use_customized_permute)
|
||||
for batch_size in batch_sizes])
|
||||
"benchmark",
|
||||
[
|
||||
(
|
||||
batch_size,
|
||||
E,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
use_customized_permute,
|
||||
)
|
||||
for batch_size in batch_sizes
|
||||
],
|
||||
)
|
||||
|
||||
for batch_size, (permute, unpermute) in zip(batch_sizes, outputs):
|
||||
print(f"Batch size: {batch_size}")
|
||||
@ -333,13 +401,12 @@ def main(args: argparse.Namespace):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser()
|
||||
parser.add_argument("--model",
|
||||
type=str,
|
||||
default="mistralai/Mixtral-8x7B-Instruct-v0.1")
|
||||
parser.add_argument("--dtype",
|
||||
type=str,
|
||||
choices=["auto", "fp8_w8a8", "int8_w8a16"],
|
||||
default="auto")
|
||||
parser.add_argument(
|
||||
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto"
|
||||
)
|
||||
parser.add_argument("--use-customized-permute", action="store_true")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--batch-size", type=int, required=False)
|
||||
|
||||
@ -9,8 +9,11 @@ import torch
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
|
||||
create_kv_caches_with_random)
|
||||
from vllm.utils import (
|
||||
STR_DTYPE_TO_TORCH_DTYPE,
|
||||
FlexibleArgumentParser,
|
||||
create_kv_caches_with_random,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -38,19 +41,15 @@ def main(
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
query = torch.empty(num_seqs,
|
||||
num_query_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
query = torch.empty(
|
||||
num_seqs, num_query_heads, head_size, dtype=dtype, device=device
|
||||
)
|
||||
query.uniform_(-scale, scale)
|
||||
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
alibi_slopes = None
|
||||
if use_alibi:
|
||||
alibi_slopes = torch.randn(num_query_heads,
|
||||
dtype=torch.float,
|
||||
device=device)
|
||||
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float, device=device)
|
||||
|
||||
seq_lens = [seq_len for _ in range(num_seqs)]
|
||||
max_seq_len = max(seq_lens)
|
||||
@ -61,24 +60,23 @@ def main(
|
||||
block_tables_lst: list[list[int]] = []
|
||||
for _ in range(num_seqs):
|
||||
block_table = [
|
||||
random.randint(0, NUM_BLOCKS - 1)
|
||||
for _ in range(max_num_blocks_per_seq)
|
||||
random.randint(0, NUM_BLOCKS - 1) for _ in range(max_num_blocks_per_seq)
|
||||
]
|
||||
block_tables_lst.append(block_table)
|
||||
|
||||
block_tables = torch.tensor(block_tables_lst,
|
||||
dtype=torch.int,
|
||||
device=device)
|
||||
block_tables = torch.tensor(block_tables_lst, dtype=torch.int, device=device)
|
||||
|
||||
# Create the KV cache.
|
||||
key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS,
|
||||
block_size,
|
||||
1,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
kv_cache_dtype,
|
||||
dtype,
|
||||
device=device)
|
||||
key_caches, value_caches = create_kv_caches_with_random(
|
||||
NUM_BLOCKS,
|
||||
block_size,
|
||||
1,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
kv_cache_dtype,
|
||||
dtype,
|
||||
device=device,
|
||||
)
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
|
||||
# Prepare for the paged attention kernel.
|
||||
@ -86,11 +84,8 @@ def main(
|
||||
if version == "v2":
|
||||
if current_platform.is_rocm():
|
||||
global PARTITION_SIZE
|
||||
if not args.custom_paged_attn:
|
||||
PARTITION_SIZE = 1024
|
||||
else:
|
||||
PARTITION_SIZE = PARTITION_SIZE_ROCM
|
||||
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
|
||||
PARTITION_SIZE = 1024 if not args.custom_paged_attn else PARTITION_SIZE_ROCM
|
||||
num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_query_heads, num_partitions, head_size),
|
||||
dtype=output.dtype,
|
||||
@ -110,9 +105,7 @@ def main(
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Using default kv_scale
|
||||
k_scale = v_scale = torch.tensor(1.0,
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
|
||||
for _ in range(num_iters):
|
||||
if version == "v1":
|
||||
@ -195,30 +188,29 @@ def main(
|
||||
print(f"Kernel running time: {latency * 1000000:.3f} us")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
logger.warning("This script benchmarks the paged attention kernel. "
|
||||
"By default this is no longer used in vLLM inference.")
|
||||
if __name__ == "__main__":
|
||||
logger.warning(
|
||||
"This script benchmarks the paged attention kernel. "
|
||||
"By default this is no longer used in vLLM inference."
|
||||
)
|
||||
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark the paged attention kernel.")
|
||||
parser.add_argument("--version",
|
||||
type=str,
|
||||
choices=["v1", "v2"],
|
||||
default="v2")
|
||||
parser = FlexibleArgumentParser(description="Benchmark the paged attention kernel.")
|
||||
parser.add_argument("--version", type=str, choices=["v1", "v2"], default="v2")
|
||||
parser.add_argument("--batch-size", type=int, default=8)
|
||||
parser.add_argument("--seq-len", type=int, default=4096)
|
||||
parser.add_argument("--num-query-heads", type=int, default=64)
|
||||
parser.add_argument("--num-kv-heads", type=int, default=8)
|
||||
parser.add_argument("--head-size",
|
||||
type=int,
|
||||
choices=[64, 80, 96, 112, 120, 128, 192, 256],
|
||||
default=128)
|
||||
parser.add_argument(
|
||||
"--head-size",
|
||||
type=int,
|
||||
choices=[64, 80, 96, 112, 120, 128, 192, 256],
|
||||
default=128,
|
||||
)
|
||||
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
|
||||
parser.add_argument("--use-alibi", action="store_true")
|
||||
parser.add_argument("--dtype",
|
||||
type=str,
|
||||
choices=["half", "bfloat16", "float"],
|
||||
default="half")
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, choices=["half", "bfloat16", "float"], default="half"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--profile", action="store_true")
|
||||
parser.add_argument(
|
||||
@ -228,10 +220,11 @@ if __name__ == '__main__':
|
||||
default="auto",
|
||||
help="Data type for kv cache storage. If 'auto', will use model "
|
||||
"data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
|
||||
"ROCm (AMD GPU) supports fp8 (=fp8_e4m3)")
|
||||
parser.add_argument("--custom-paged-attn",
|
||||
action="store_true",
|
||||
help="Use custom paged attention")
|
||||
"ROCm (AMD GPU) supports fp8 (=fp8_e4m3)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--custom-paged-attn", action="store_true", help="Use custom paged attention"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
|
||||
@ -10,15 +10,17 @@ from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def main(num_tokens: int,
|
||||
hidden_size: int,
|
||||
static_scale: bool,
|
||||
quant_dtype: torch.dtype,
|
||||
dtype: torch.dtype,
|
||||
seed: int = 0,
|
||||
do_profile: bool = False,
|
||||
num_warmup_iters: int = 5,
|
||||
num_iters: int = 100) -> None:
|
||||
def main(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
static_scale: bool,
|
||||
quant_dtype: torch.dtype,
|
||||
dtype: torch.dtype,
|
||||
seed: int = 0,
|
||||
do_profile: bool = False,
|
||||
num_warmup_iters: int = 5,
|
||||
num_iters: int = 100,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
@ -56,7 +58,7 @@ def main(num_tokens: int,
|
||||
print(f"Kernel running time: {latency * 1000000:.3f} us")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
|
||||
def to_torch_dtype(dt):
|
||||
if dt == "int8":
|
||||
@ -66,37 +68,40 @@ if __name__ == '__main__':
|
||||
raise ValueError(f"Unsupported dtype: {dt}")
|
||||
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark the quantization (fp8 or int8) kernel.")
|
||||
description="Benchmark the quantization (fp8 or int8) kernel."
|
||||
)
|
||||
parser.add_argument("--num-tokens", type=int, default=4096)
|
||||
parser.add_argument("--hidden-size", type=int, default=8192)
|
||||
parser.add_argument("--static-scale", action="store_true")
|
||||
parser.add_argument("--quant-dtype",
|
||||
type=str,
|
||||
choices=["fp8", "int8"],
|
||||
default="int8")
|
||||
parser.add_argument("--dtype",
|
||||
type=str,
|
||||
choices=["half", "bfloat16", "float"],
|
||||
default="half")
|
||||
parser.add_argument(
|
||||
"--quant-dtype", type=str, choices=["fp8", "int8"], default="int8"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, choices=["half", "bfloat16", "float"], default="half"
|
||||
)
|
||||
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--profile", action="store_true")
|
||||
parser.add_argument("--num-warmup-iters", type=int, default=5)
|
||||
parser.add_argument("--num-iters",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Number of benchmark iterations. "
|
||||
"If --profile is set, this number is ignored")
|
||||
parser.add_argument(
|
||||
"--num-iters",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Number of benchmark iterations. "
|
||||
"If --profile is set, this number is ignored",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
main(num_tokens=args.num_tokens,
|
||||
hidden_size=args.hidden_size,
|
||||
static_scale=args.static_scale,
|
||||
quant_dtype=to_torch_dtype(args.quant_dtype),
|
||||
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
|
||||
seed=args.seed,
|
||||
do_profile=args.profile,
|
||||
num_warmup_iters=args.num_warmup_iters,
|
||||
num_iters=args.num_iters)
|
||||
main(
|
||||
num_tokens=args.num_tokens,
|
||||
hidden_size=args.hidden_size,
|
||||
static_scale=args.static_scale,
|
||||
quant_dtype=to_torch_dtype(args.quant_dtype),
|
||||
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
|
||||
seed=args.seed,
|
||||
do_profile=args.profile,
|
||||
num_warmup_iters=args.num_warmup_iters,
|
||||
num_iters=args.num_iters,
|
||||
)
|
||||
|
||||
@ -12,7 +12,6 @@ from vllm.triton_utils import triton
|
||||
|
||||
|
||||
class HuggingFaceRMSNorm(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
@ -114,23 +113,19 @@ def rmsnorm_vllm(
|
||||
|
||||
def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
|
||||
dtype = torch.bfloat16
|
||||
x = torch.randn(batch_size,
|
||||
seq_len,
|
||||
hidden_size,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
|
||||
weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
|
||||
residual = torch.randn_like(x) if use_residual else None
|
||||
|
||||
output_naive = rmsnorm_naive(
|
||||
x.clone(), weight,
|
||||
residual.clone() if residual is not None else None)
|
||||
x.clone(), weight, residual.clone() if residual is not None else None
|
||||
)
|
||||
output_flashinfer = rmsnorm_flashinfer(
|
||||
x.clone(), weight,
|
||||
residual.clone() if residual is not None else None)
|
||||
x.clone(), weight, residual.clone() if residual is not None else None
|
||||
)
|
||||
output_vllm = rmsnorm_vllm(
|
||||
x.clone(), weight,
|
||||
residual.clone() if residual is not None else None)
|
||||
x.clone(), weight, residual.clone() if residual is not None else None
|
||||
)
|
||||
|
||||
if use_residual:
|
||||
output_naive = output_naive[0]
|
||||
@ -141,9 +136,9 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
|
||||
print(f"FlashInfer output={output_flashinfer}")
|
||||
print(f"vLLM output={output_vllm}")
|
||||
|
||||
if torch.allclose(output_naive, output_flashinfer, atol=1e-2,
|
||||
rtol=1e-2) and torch.allclose(
|
||||
output_naive, output_vllm, atol=1e-2, rtol=1e-2):
|
||||
if torch.allclose(
|
||||
output_naive, output_flashinfer, atol=1e-2, rtol=1e-2
|
||||
) and torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2):
|
||||
print("✅ All implementations match")
|
||||
else:
|
||||
print("❌ Implementations differ")
|
||||
@ -152,12 +147,10 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
|
||||
batch_size_range = [2**i for i in range(0, 7, 2)]
|
||||
seq_length_range = [2**i for i in range(6, 11, 1)]
|
||||
head_num_range = [32, 48]
|
||||
configs = list(
|
||||
itertools.product(head_num_range, batch_size_range, seq_length_range))
|
||||
configs = list(itertools.product(head_num_range, batch_size_range, seq_length_range))
|
||||
|
||||
|
||||
def get_benchmark(use_residual):
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["head_num", "batch_size", "seq_len"],
|
||||
@ -167,19 +160,15 @@ def get_benchmark(use_residual):
|
||||
line_names=["HuggingFace", "FlashInfer", "vLLM"],
|
||||
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
|
||||
ylabel="us",
|
||||
plot_name=
|
||||
f"rmsnorm-perf-{'with' if use_residual else 'without'}-residual",
|
||||
plot_name=f"rmsnorm-perf-{'with' if use_residual else 'without'}-residual",
|
||||
args={},
|
||||
))
|
||||
)
|
||||
)
|
||||
def benchmark(head_num, batch_size, seq_len, provider):
|
||||
dtype = torch.bfloat16
|
||||
hidden_size = head_num * 128 # assuming head_dim = 128
|
||||
|
||||
x = torch.randn(batch_size,
|
||||
seq_len,
|
||||
hidden_size,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
|
||||
weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
|
||||
residual = torch.randn_like(x) if use_residual else None
|
||||
|
||||
@ -240,9 +229,9 @@ if __name__ == "__main__":
|
||||
default=4096,
|
||||
help="Hidden size (2nd dimension) of the sequence",
|
||||
)
|
||||
parser.add_argument("--use-residual",
|
||||
action="store_true",
|
||||
help="Whether to use residual connection")
|
||||
parser.add_argument(
|
||||
"--use-residual", action="store_true", help="Whether to use residual connection"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-path",
|
||||
type=str,
|
||||
@ -253,10 +242,12 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
# Run correctness test
|
||||
calculate_diff(batch_size=args.batch_size,
|
||||
seq_len=args.seq_len,
|
||||
hidden_size=args.hidden_size,
|
||||
use_residual=args.use_residual)
|
||||
calculate_diff(
|
||||
batch_size=args.batch_size,
|
||||
seq_len=args.seq_len,
|
||||
hidden_size=args.hidden_size,
|
||||
use_residual=args.use_residual,
|
||||
)
|
||||
|
||||
# Get the benchmark function with proper use_residual setting
|
||||
benchmark = get_benchmark(args.use_residual)
|
||||
|
||||
@ -6,8 +6,7 @@ from typing import Optional
|
||||
import nvtx
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.rotary_embedding import (RotaryEmbedding,
|
||||
get_rope)
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding, get_rope
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
@ -32,40 +31,49 @@ def benchmark_rope_kernels_multi_lora(
|
||||
# silulating serving 4 LoRAs
|
||||
scaling_factors = [1, 2, 4, 8]
|
||||
# batched RoPE can take multiple scaling factors
|
||||
batched_rope = get_rope(head_size, rotary_dim, max_position, base,
|
||||
is_neox_style, {
|
||||
"rope_type": "linear",
|
||||
"factor": tuple(scaling_factors)
|
||||
})
|
||||
batched_rope = get_rope(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
{"rope_type": "linear", "factor": tuple(scaling_factors)},
|
||||
)
|
||||
# non-batched RoPE takes only one scaling factor, we create multiple
|
||||
# instances to simulate the same behavior
|
||||
non_batched_ropes: list[RotaryEmbedding] = []
|
||||
for scaling_factor in scaling_factors:
|
||||
non_batched_ropes.append(
|
||||
get_rope(head_size, rotary_dim, max_position, base, is_neox_style,
|
||||
{
|
||||
"rope_type": "linear",
|
||||
"factor": (scaling_factor, )
|
||||
}))
|
||||
get_rope(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
{"rope_type": "linear", "factor": (scaling_factor,)},
|
||||
)
|
||||
)
|
||||
|
||||
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
||||
query = torch.randn(batch_size,
|
||||
seq_len,
|
||||
num_heads * head_size,
|
||||
dtype=dtype)
|
||||
query = torch.randn(batch_size, seq_len, num_heads * head_size, dtype=dtype)
|
||||
key = torch.randn_like(query)
|
||||
|
||||
# create query offsets for batched RoPE, we concat multiple kv cache
|
||||
# together and each query needs to find the right kv cache of its type
|
||||
offset_map = torch.tensor(
|
||||
list(
|
||||
accumulate([0] + [
|
||||
max_position * scaling_factor * 2
|
||||
for scaling_factor in scaling_factors[:-1]
|
||||
])))
|
||||
query_types = torch.randint(0,
|
||||
len(scaling_factors), (batch_size, seq_len),
|
||||
device=device)
|
||||
accumulate(
|
||||
[0]
|
||||
+ [
|
||||
max_position * scaling_factor * 2
|
||||
for scaling_factor in scaling_factors[:-1]
|
||||
]
|
||||
)
|
||||
)
|
||||
)
|
||||
query_types = torch.randint(
|
||||
0, len(scaling_factors), (batch_size, seq_len), device=device
|
||||
)
|
||||
# map query types to offsets
|
||||
query_offsets = offset_map[query_types]
|
||||
# the kernel takes flattened offsets
|
||||
@ -86,27 +94,28 @@ def benchmark_rope_kernels_multi_lora(
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark the rotary embedding kernels.")
|
||||
description="Benchmark the rotary embedding kernels."
|
||||
)
|
||||
parser.add_argument("--is-neox-style", type=bool, default=True)
|
||||
parser.add_argument("--batch-size", type=int, default=16)
|
||||
parser.add_argument("--seq-len", type=int, default=512)
|
||||
parser.add_argument("--num-heads", type=int, default=8)
|
||||
parser.add_argument("--head-size",
|
||||
type=int,
|
||||
choices=[64, 80, 96, 112, 120, 128, 192, 256],
|
||||
default=128)
|
||||
parser.add_argument(
|
||||
"--head-size",
|
||||
type=int,
|
||||
choices=[64, 80, 96, 112, 120, 128, 192, 256],
|
||||
default=128,
|
||||
)
|
||||
parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
|
||||
parser.add_argument("--dtype",
|
||||
type=str,
|
||||
choices=["bfloat16", "float"],
|
||||
default="float")
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, choices=["bfloat16", "float"], default="float"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--device",
|
||||
type=str,
|
||||
choices=["cuda:0", "cuda:1"],
|
||||
default="cuda:0")
|
||||
parser.add_argument(
|
||||
"--device", type=str, choices=["cuda:0", "cuda:1"], default="cuda:0"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
|
||||
@ -14,14 +14,16 @@ import tqdm
|
||||
import triton
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
_w8a8_block_fp8_matmul)
|
||||
_w8a8_block_fp8_matmul,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
assert current_platform.is_cuda(
|
||||
), "Only support tune w8a8 block fp8 kernel on CUDA device."
|
||||
assert current_platform.is_cuda(), (
|
||||
"Only support tune w8a8 block fp8 kernel on CUDA device."
|
||||
)
|
||||
|
||||
DTYPE_MAP = {
|
||||
"float32": torch.float32,
|
||||
@ -40,7 +42,7 @@ def w8a8_block_matmul(
|
||||
config: dict[str, Any],
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
"""This function performs matrix multiplication with
|
||||
"""This function performs matrix multiplication with
|
||||
block-wise quantization.
|
||||
|
||||
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
|
||||
@ -51,7 +53,7 @@ def w8a8_block_matmul(
|
||||
B: The input tensor, e.g., weight.
|
||||
As: The per-token-group quantization scale for `A`.
|
||||
Bs: The per-block quantization scale for `B`.
|
||||
block_size: The block size for per-block quantization.
|
||||
block_size: The block size for per-block quantization.
|
||||
It should be 2-dim, e.g., [128, 128].
|
||||
output_dytpe: The dtype of the returned tensor.
|
||||
|
||||
@ -71,18 +73,18 @@ def w8a8_block_matmul(
|
||||
assert triton.cdiv(N, block_n) == Bs.shape[0]
|
||||
assert triton.cdiv(K, block_k) == Bs.shape[1]
|
||||
|
||||
C_shape = A.shape[:-1] + (N, )
|
||||
C_shape = A.shape[:-1] + (N,)
|
||||
C = A.new_empty(C_shape, dtype=output_dtype)
|
||||
|
||||
def grid(META):
|
||||
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) *
|
||||
triton.cdiv(N, META["BLOCK_SIZE_N"]), )
|
||||
return (
|
||||
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
if A.dtype == torch.float8_e4m3fn:
|
||||
kernel = _w8a8_block_fp8_matmul
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Currently, only support tune w8a8 block fp8 kernel.")
|
||||
raise RuntimeError("Currently, only support tune w8a8 block fp8 kernel.")
|
||||
|
||||
kernel[grid](
|
||||
A,
|
||||
@ -119,14 +121,16 @@ def get_configs_compute_bound():
|
||||
for block_n in [32, 64, 128, 256]:
|
||||
for num_warps in [4, 8]:
|
||||
for group_size in [1, 16, 32, 64]:
|
||||
configs.append({
|
||||
"BLOCK_SIZE_M": block_m,
|
||||
"BLOCK_SIZE_N": block_n,
|
||||
"BLOCK_SIZE_K": block_k,
|
||||
"GROUP_SIZE_M": group_size,
|
||||
"num_warps": num_warps,
|
||||
"num_stages": num_stages,
|
||||
})
|
||||
configs.append(
|
||||
{
|
||||
"BLOCK_SIZE_M": block_m,
|
||||
"BLOCK_SIZE_N": block_n,
|
||||
"BLOCK_SIZE_K": block_k,
|
||||
"GROUP_SIZE_M": group_size,
|
||||
"num_warps": num_warps,
|
||||
"num_stages": num_stages,
|
||||
}
|
||||
)
|
||||
return configs
|
||||
|
||||
|
||||
@ -165,15 +169,9 @@ def get_weight_shapes(tp_size):
|
||||
return weight_shapes
|
||||
|
||||
|
||||
def benchmark_config(A,
|
||||
B,
|
||||
As,
|
||||
Bs,
|
||||
block_size,
|
||||
config,
|
||||
out_dtype=torch.float16,
|
||||
num_iters=10):
|
||||
|
||||
def benchmark_config(
|
||||
A, B, As, Bs, block_size, config, out_dtype=torch.float16, num_iters=10
|
||||
):
|
||||
def run():
|
||||
w8a8_block_matmul(A, B, As, Bs, block_size, config, out_dtype)
|
||||
|
||||
@ -206,26 +204,26 @@ def tune(M, N, K, block_size, out_dtype, search_space, input_type):
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
A_fp32 = (
|
||||
(torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 *
|
||||
fp8_max)
|
||||
(torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
|
||||
)
|
||||
A = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
B_fp32 = (
|
||||
(torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 *
|
||||
fp8_max)
|
||||
(torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
|
||||
)
|
||||
B = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Currently, only support tune w8a8 block fp8 kernel.")
|
||||
raise RuntimeError("Currently, only support tune w8a8 block fp8 kernel.")
|
||||
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
n_tiles = (N + block_n - 1) // block_n
|
||||
k_tiles = (K + block_k - 1) // block_k
|
||||
|
||||
As = torch.rand(M, k_tiles, dtype=torch.float32,
|
||||
device="cuda") * factor_for_scale
|
||||
Bs = (torch.rand(n_tiles, k_tiles, dtype=torch.float32, device="cuda") *
|
||||
factor_for_scale)
|
||||
As = torch.rand(M, k_tiles, dtype=torch.float32, device="cuda") * factor_for_scale
|
||||
Bs = (
|
||||
torch.rand(n_tiles, k_tiles, dtype=torch.float32, device="cuda")
|
||||
* factor_for_scale
|
||||
)
|
||||
|
||||
best_config = None
|
||||
best_time = float("inf")
|
||||
@ -267,7 +265,8 @@ def save_configs(
|
||||
device_name = current_platform.get_device_name().replace(" ", "_")
|
||||
json_file_name = (
|
||||
f"N={N},K={K},device_name={device_name},dtype={input_type}_w8a8,"
|
||||
f"block_shape=[{block_n},{block_k}].json")
|
||||
f"block_shape=[{block_n},{block_k}].json"
|
||||
)
|
||||
|
||||
config_file_path = os.path.join(save_path, json_file_name)
|
||||
print(f"Writing best config to {config_file_path}...")
|
||||
@ -295,8 +294,7 @@ def tune_on_gpu(args_dict):
|
||||
|
||||
search_space = get_configs_compute_bound()
|
||||
search_space = [
|
||||
config for config in search_space
|
||||
if block_k % config["BLOCK_SIZE_K"] == 0
|
||||
config for config in search_space if block_k % config["BLOCK_SIZE_K"] == 0
|
||||
]
|
||||
|
||||
start = time.time()
|
||||
@ -312,15 +310,11 @@ def tune_on_gpu(args_dict):
|
||||
out_dtype,
|
||||
search_space,
|
||||
input_type,
|
||||
) for batch_size in tqdm(batch_sizes,
|
||||
desc=f"GPU {gpu_id} - Batch sizes")
|
||||
)
|
||||
for batch_size in tqdm(batch_sizes, desc=f"GPU {gpu_id} - Batch sizes")
|
||||
]
|
||||
best_configs = {
|
||||
M: config
|
||||
for M, config in zip(batch_sizes, benchmark_results)
|
||||
}
|
||||
save_configs(N, K, block_n, block_k, best_configs, save_path,
|
||||
input_type)
|
||||
best_configs = {M: config for M, config in zip(batch_sizes, benchmark_results)}
|
||||
save_configs(N, K, block_n, block_k, best_configs, save_path, input_type)
|
||||
|
||||
end = time.time()
|
||||
print(f"Tuning on GPU {gpu_id} took {end - start:.2f} seconds")
|
||||
@ -376,13 +370,14 @@ def main(args):
|
||||
|
||||
process_args = []
|
||||
for gpu_id in range(num_gpus):
|
||||
process_args.append({
|
||||
"gpu_id": gpu_id,
|
||||
"batch_sizes": batches_per_gpu[gpu_id],
|
||||
"weight_shapes":
|
||||
weight_shapes, # Each GPU processes all weight shapes
|
||||
"args": args,
|
||||
})
|
||||
process_args.append(
|
||||
{
|
||||
"gpu_id": gpu_id,
|
||||
"batch_sizes": batches_per_gpu[gpu_id],
|
||||
"weight_shapes": weight_shapes, # Each GPU processes all weight shapes
|
||||
"args": args,
|
||||
}
|
||||
)
|
||||
|
||||
ctx = mp.get_context("spawn")
|
||||
with ctx.Pool(num_gpus) as pool:
|
||||
@ -398,13 +393,11 @@ Tune triton w8a8 block fp8 for DeepSeek-V3/DeepSeek-R1:
|
||||
python3 benchmark_w8a8_block_fp8.py --tp-size 8 --input-type fp8
|
||||
Then copy to model_executor/layers/quantization/utils/configs
|
||||
""",
|
||||
formatter_class=argparse.RawTextHelpFormatter)
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
|
||||
parser.add_argument("--tp-size", "-tp", type=int, default=8)
|
||||
parser.add_argument("--input-type",
|
||||
type=str,
|
||||
choices=["fp8"],
|
||||
default="fp8")
|
||||
parser.add_argument("--input-type", type=str, choices=["fp8"], default="fp8")
|
||||
parser.add_argument(
|
||||
"--out-dtype",
|
||||
type=str,
|
||||
|
||||
@ -11,7 +11,9 @@ from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor
|
||||
# Import vLLM functions
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
|
||||
per_token_group_quant_fp8,
|
||||
w8a8_block_fp8_matmul,
|
||||
)
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
|
||||
|
||||
@ -14,13 +14,14 @@ from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description='Benchmark the latency of processing a single batch of '
|
||||
'requests till completion.')
|
||||
parser.add_argument('filename', type=str)
|
||||
description="Benchmark the latency of processing a single batch of "
|
||||
"requests till completion."
|
||||
)
|
||||
parser.add_argument("filename", type=str)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.filename, 'rb') as f:
|
||||
with open(args.filename, "rb") as f:
|
||||
data = pickle.load(f)
|
||||
raw_results: list[TMeasurement] = data["results"]
|
||||
|
||||
@ -38,11 +39,7 @@ if __name__ == "__main__":
|
||||
raise Exception("MKN not found")
|
||||
|
||||
kernel = v.task_spec.description
|
||||
results[KN].append({
|
||||
"kernel": kernel,
|
||||
"batch_size": M,
|
||||
"median": v.median
|
||||
})
|
||||
results[KN].append({"kernel": kernel, "batch_size": M, "median": v.median})
|
||||
|
||||
rows = int(math.ceil(len(results) / 2))
|
||||
fig, axs = plt.subplots(rows, 2, figsize=(12, 5 * rows))
|
||||
@ -50,14 +47,16 @@ if __name__ == "__main__":
|
||||
for axs_idx, (shape, data) in enumerate(results.items()):
|
||||
plt.sca(axs[axs_idx])
|
||||
df = pd.DataFrame(data)
|
||||
sns.lineplot(data=df,
|
||||
x="batch_size",
|
||||
y="median",
|
||||
hue="kernel",
|
||||
style="kernel",
|
||||
markers=True,
|
||||
dashes=False,
|
||||
palette="Dark2")
|
||||
sns.lineplot(
|
||||
data=df,
|
||||
x="batch_size",
|
||||
y="median",
|
||||
hue="kernel",
|
||||
style="kernel",
|
||||
markers=True,
|
||||
dashes=False,
|
||||
palette="Dark2",
|
||||
)
|
||||
plt.title(f"Shape: {shape}")
|
||||
plt.ylabel("time (median, s)")
|
||||
plt.tight_layout()
|
||||
|
||||
@ -23,6 +23,7 @@ class ArgPool:
|
||||
For every invocation during a benchmarking run, it will choose a
|
||||
different value from the list.
|
||||
"""
|
||||
|
||||
values: Iterable[Any]
|
||||
|
||||
def __getitem__(self, index):
|
||||
@ -30,9 +31,7 @@ class ArgPool:
|
||||
|
||||
|
||||
class Bench:
|
||||
|
||||
class ArgsIterator:
|
||||
|
||||
def __init__(self, args_list, kwargs_list):
|
||||
assert len(args_list) == len(kwargs_list)
|
||||
self.args_list = args_list
|
||||
@ -53,10 +52,16 @@ class Bench:
|
||||
def n_args(self):
|
||||
return self.n
|
||||
|
||||
def __init__(self, cuda_graph_params: Optional[CudaGraphBenchParams],
|
||||
label: str, sub_label: str, description: str, fn: Callable,
|
||||
*args, **kwargs):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cuda_graph_params: Optional[CudaGraphBenchParams],
|
||||
label: str,
|
||||
sub_label: str,
|
||||
description: str,
|
||||
fn: Callable,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
self.cuda_graph_params = cuda_graph_params
|
||||
self.use_cuda_graph = self.cuda_graph_params is not None
|
||||
self.label = label
|
||||
@ -67,10 +72,8 @@ class Bench:
|
||||
# Process args
|
||||
self._args = args
|
||||
self._kwargs = kwargs
|
||||
self.args_list, self.kwargs_list = self.collapse_argpool(
|
||||
*args, **kwargs)
|
||||
self.args_iterator = self.ArgsIterator(self.args_list,
|
||||
self.kwargs_list)
|
||||
self.args_list, self.kwargs_list = self.collapse_argpool(*args, **kwargs)
|
||||
self.args_iterator = self.ArgsIterator(self.args_list, self.kwargs_list)
|
||||
|
||||
# Cudagraph runner
|
||||
self.g = None
|
||||
@ -100,16 +103,13 @@ class Bench:
|
||||
|
||||
for i in range(argpool_size):
|
||||
# collapse args; Just pick the ith value
|
||||
args_list[i] = tuple([
|
||||
arg[i] if isinstance(arg, ArgPool) else arg
|
||||
for arg in args_list[i]
|
||||
])
|
||||
args_list[i] = tuple(
|
||||
[arg[i] if isinstance(arg, ArgPool) else arg for arg in args_list[i]]
|
||||
)
|
||||
|
||||
# collapse kwargs
|
||||
kwargs_i = kwargs_list[i]
|
||||
arg_pool_keys = [
|
||||
k for k, v in kwargs_i.items() if isinstance(v, ArgPool)
|
||||
]
|
||||
arg_pool_keys = [k for k, v in kwargs_i.items() if isinstance(v, ArgPool)]
|
||||
for k in arg_pool_keys:
|
||||
# again just pick the ith value
|
||||
kwargs_i[k] = kwargs_i[k][i]
|
||||
@ -142,7 +142,7 @@ class Bench:
|
||||
|
||||
def run_cudagrah(self) -> TMeasurement:
|
||||
assert self.use_cuda_graph
|
||||
globals = {'g': self.g}
|
||||
globals = {"g": self.g}
|
||||
|
||||
return TBenchmark.Timer(
|
||||
stmt="g.replay()",
|
||||
@ -162,15 +162,15 @@ class Bench:
|
||||
|
||||
has_arg_pool = self.args_iterator.n_args > 1
|
||||
if has_arg_pool:
|
||||
setup = '''
|
||||
setup = """
|
||||
args_iterator.reset()
|
||||
args_it = args_iterator.__next__()
|
||||
'''
|
||||
stmt = '''
|
||||
"""
|
||||
stmt = """
|
||||
args, kwargs = next(args_it)
|
||||
fn(*args, **kwargs)
|
||||
'''
|
||||
globals = {'fn': self.fn, 'args_iterator': self.args_iterator}
|
||||
"""
|
||||
globals = {"fn": self.fn, "args_iterator": self.args_iterator}
|
||||
else:
|
||||
# no arg pool. Just use the args and kwargs directly
|
||||
self.args_iterator.reset()
|
||||
@ -178,10 +178,10 @@ class Bench:
|
||||
args, kwargs = next(args_it)
|
||||
|
||||
setup = ""
|
||||
stmt = '''
|
||||
stmt = """
|
||||
fn(*args, **kwargs)
|
||||
'''
|
||||
globals = {'fn': self.fn, 'args': args, 'kwargs': kwargs}
|
||||
"""
|
||||
globals = {"fn": self.fn, "args": args, "kwargs": kwargs}
|
||||
|
||||
return TBenchmark.Timer(
|
||||
stmt=stmt,
|
||||
|
||||
@ -7,9 +7,8 @@ from vllm import LLM, SamplingParams
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
# A very long prompt, total number of tokens is about 15k.
|
||||
LONG_PROMPT = ["You are an expert in large language models, aren't you?"
|
||||
] * 1000
|
||||
LONG_PROMPT = ' '.join(LONG_PROMPT)
|
||||
LONG_PROMPT = ["You are an expert in large language models, aren't you?"] * 1000
|
||||
LONG_PROMPT = " ".join(LONG_PROMPT)
|
||||
|
||||
|
||||
def main(args):
|
||||
@ -30,32 +29,35 @@ def main(args):
|
||||
|
||||
print("------start generating------")
|
||||
for i in range(3):
|
||||
profiler.runctx('llm.generate(LONG_PROMPT, sampling_params)',
|
||||
globals(), locals())
|
||||
profiler.runctx(
|
||||
"llm.generate(LONG_PROMPT, sampling_params)", globals(), locals()
|
||||
)
|
||||
|
||||
# analyze the runtime of hashing function
|
||||
stats = pstats.Stats(profiler)
|
||||
stats.sort_stats('cumulative')
|
||||
stats.sort_stats("cumulative")
|
||||
total_time = 0
|
||||
total_calls = 0
|
||||
for func in stats.stats:
|
||||
if 'hash_of_block' in func[2]:
|
||||
if "hash_of_block" in func[2]:
|
||||
total_time = stats.stats[func][3]
|
||||
total_calls = stats.stats[func][0]
|
||||
percentage = (total_time / stats.total_tt) * 100
|
||||
print(f"Hashing took {total_time:.2f} seconds,"
|
||||
f"{percentage:.2f}% of the total runtime.")
|
||||
print(
|
||||
f"Hashing took {total_time:.2f} seconds,{percentage:.2f}% of the total runtime."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description='Benchmark the performance of hashing function in'
|
||||
'automatic prefix caching.')
|
||||
parser.add_argument('--model', type=str, default='lmsys/longchat-7b-16k')
|
||||
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
|
||||
parser.add_argument('--output-len', type=int, default=10)
|
||||
parser.add_argument('--enable-prefix-caching',
|
||||
action='store_true',
|
||||
help='enable prefix caching')
|
||||
description="Benchmark the performance of hashing function in"
|
||||
"automatic prefix caching."
|
||||
)
|
||||
parser.add_argument("--model", type=str, default="lmsys/longchat-7b-16k")
|
||||
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
|
||||
parser.add_argument("--output-len", type=int, default=10)
|
||||
parser.add_argument(
|
||||
"--enable-prefix-caching", action="store_true", help="enable prefix caching"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
54
benchmarks/pyproject.toml
Normal file
54
benchmarks/pyproject.toml
Normal file
@ -0,0 +1,54 @@
|
||||
# This local pyproject file is part of the migration from yapf to ruff format.
|
||||
# It uses the same core rules as the main pyproject.toml file, but with the
|
||||
# following differences:
|
||||
# - ruff line length is overridden to 88
|
||||
# - deprecated typing ignores (UP006, UP035) have been removed
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 88
|
||||
exclude = [
|
||||
# External file, leaving license intact
|
||||
"examples/other/fp8/quantizer/quantize.py",
|
||||
"vllm/vllm_flash_attn/flash_attn_interface.pyi"
|
||||
]
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"vllm/third_party/**" = ["ALL"]
|
||||
"vllm/version.py" = ["F401"]
|
||||
"vllm/_version.py" = ["ALL"]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
# pycodestyle
|
||||
"E",
|
||||
# Pyflakes
|
||||
"F",
|
||||
# pyupgrade
|
||||
"UP",
|
||||
# flake8-bugbear
|
||||
"B",
|
||||
# flake8-simplify
|
||||
"SIM",
|
||||
# isort
|
||||
"I",
|
||||
# flake8-logging-format
|
||||
"G",
|
||||
]
|
||||
ignore = [
|
||||
# star imports
|
||||
"F405", "F403",
|
||||
# lambda expression assignment
|
||||
"E731",
|
||||
# Loop control variable not used within loop body
|
||||
"B007",
|
||||
# f-string format
|
||||
"UP032",
|
||||
# Can remove once 3.10+ is the minimum Python version
|
||||
"UP007",
|
||||
]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
known-first-party = ["vllm"]
|
||||
|
||||
[tool.ruff.format]
|
||||
docstring-code-format = true
|
||||
@ -1,32 +1,98 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Define the model to use
|
||||
MODEL=${1:-"Qwen/Qwen2.5-7B-Instruct"}
|
||||
|
||||
# Define the backend to use
|
||||
BACKEND=${2:-"vllm"}
|
||||
|
||||
# Define the dataset to use
|
||||
DATASET=${3:-"xgrammar_bench"}
|
||||
|
||||
# default values
|
||||
MODEL=${MODEL:-"Qwen/Qwen2.5-7B-Instruct"}
|
||||
BACKEND=${BACKEND:-"vllm"}
|
||||
DATASET=${DATASET:-"xgrammar_bench"}
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
OUTPUT_DIR=${4:-"$SCRIPT_DIR/structured_output_benchmark_results"}
|
||||
OUTPUT_DIR=${OUTPUT_DIR:-"$SCRIPT_DIR/structured_output_benchmark_results"}
|
||||
PORT=${PORT:-8000}
|
||||
STRUCTURED_OUTPUT_RATIO=${STRUCTURED_OUTPUT_RATIO:-1}
|
||||
TOTAL_SECONDS=${TOTAL_SECONDS:-90}
|
||||
MAX_NEW_TOKENS=${MAX_NEW_TOKENS:-300}
|
||||
TOKENIZER_MODE=${TOKENIZER_MODE:-"auto"}
|
||||
|
||||
GUIDED_RATIO=${5:-0.5}
|
||||
usage() {
|
||||
echo "Usage: $0 [options]"
|
||||
echo "Options:"
|
||||
echo " --model MODEL Model to benchmark (default: $MODEL)"
|
||||
echo " --backend BACKEND Backend to use (default: $BACKEND)"
|
||||
echo " --dataset DATASET Dataset to use (default: $DATASET)"
|
||||
echo " --max-new-tokens N Maximum number of tokens to generate (default: $MAX_NEW_TOKENS)"
|
||||
echo " --output-dir DIR Output directory for results (default: $OUTPUT_DIR)"
|
||||
echo " --port PORT Port to use (default: $PORT)"
|
||||
echo " --structured-output-ratio N Ratio of structured outputs (default: $STRUCTURED_OUTPUT_RATIO)"
|
||||
echo " --tokenizer-mode MODE Tokenizer mode to use (default: $TOKENIZER_MODE)"
|
||||
echo " --total-seconds N Total seconds to run the benchmark (default: $TOTAL_SECONDS)"
|
||||
echo " -h, --help Show this help message and exit"
|
||||
exit 0
|
||||
}
|
||||
|
||||
# parse command line arguments
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
--model)
|
||||
MODEL="$2"
|
||||
shift 2
|
||||
;;
|
||||
--backend)
|
||||
BACKEND="$2"
|
||||
shift 2
|
||||
;;
|
||||
--dataset)
|
||||
DATASET="$2"
|
||||
shift 2
|
||||
;;
|
||||
--max-new-tokens)
|
||||
MAX_NEW_TOKENS="$2"
|
||||
shift 2
|
||||
;;
|
||||
--output-dir)
|
||||
OUTPUT_DIR="$2"
|
||||
shift 2
|
||||
;;
|
||||
--port)
|
||||
PORT="$2"
|
||||
shift 2
|
||||
;;
|
||||
--structured-output-ratio)
|
||||
STRUCTURED_OUTPUT_RATIO="$2"
|
||||
shift 2
|
||||
;;
|
||||
--tokenizer-mode)
|
||||
TOKENIZER_MODE="$2"
|
||||
shift 2
|
||||
;;
|
||||
--total-seconds)
|
||||
TOTAL_SECONDS="$2"
|
||||
shift 2
|
||||
;;
|
||||
-h|--help)
|
||||
usage
|
||||
;;
|
||||
*)
|
||||
echo "Unknown argument: $1\n"
|
||||
usage
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Create output directory if it doesn't exist
|
||||
mkdir -p "$OUTPUT_DIR"
|
||||
|
||||
# Define QPS values to test
|
||||
QPS_VALUES=(70 60 50 25 20 15 10)
|
||||
QPS_VALUES=(25 20 15 10 5 1)
|
||||
|
||||
# Common parameters
|
||||
COMMON_PARAMS="--backend $BACKEND \
|
||||
--model $MODEL \
|
||||
--dataset $DATASET \
|
||||
--structured-output-ratio $GUIDED_RATIO \
|
||||
--structured-output-ratio $STRUCTURED_OUTPUT_RATIO \
|
||||
--save-results \
|
||||
--result-dir $OUTPUT_DIR"
|
||||
--result-dir $OUTPUT_DIR \
|
||||
--output-len $MAX_NEW_TOKENS \
|
||||
--port $PORT \
|
||||
--tokenizer-mode $TOKENIZER_MODE"
|
||||
|
||||
echo "Starting structured output benchmark with model: $MODEL"
|
||||
echo "Backend: $BACKEND"
|
||||
@ -45,12 +111,15 @@ for qps in "${QPS_VALUES[@]}"; do
|
||||
# Construct filename for this run
|
||||
FILENAME="${BACKEND}_${qps}qps_$(basename $MODEL)_${DATASET}_${GIT_HASH}.json"
|
||||
|
||||
NUM_PROMPTS=$(echo "$TOTAL_SECONDS * $qps" | bc)
|
||||
NUM_PROMPTS=${NUM_PROMPTS%.*} # Remove fractional part
|
||||
echo "Running benchmark with $NUM_PROMPTS prompts"
|
||||
|
||||
# Run the benchmark
|
||||
python "$SCRIPT_DIR/benchmark_serving_structured_output.py" $COMMON_PARAMS \
|
||||
--request-rate $qps \
|
||||
--result-filename "$FILENAME" \
|
||||
--tokenizer-mode ${TOKENIZER_MODE:-"auto"} \
|
||||
--port ${PORT:-8000}
|
||||
--num-prompts $NUM_PROMPTS
|
||||
|
||||
echo "Completed benchmark with QPS: $qps"
|
||||
echo "----------------------------------------"
|
||||
|
||||
@ -228,11 +228,26 @@ macro(set_gencode_flags_for_srcs)
|
||||
"${multiValueArgs}" ${ARGN} )
|
||||
|
||||
foreach(_ARCH ${arg_CUDA_ARCHS})
|
||||
string(REPLACE "." "" _ARCH "${_ARCH}")
|
||||
set_gencode_flag_for_srcs(
|
||||
SRCS ${arg_SRCS}
|
||||
ARCH "compute_${_ARCH}"
|
||||
CODE "sm_${_ARCH}")
|
||||
# handle +PTX suffix: generate both sm and ptx codes if requested
|
||||
string(FIND "${_ARCH}" "+PTX" _HAS_PTX)
|
||||
if(NOT _HAS_PTX EQUAL -1)
|
||||
string(REPLACE "+PTX" "" _BASE_ARCH "${_ARCH}")
|
||||
string(REPLACE "." "" _STRIPPED_ARCH "${_BASE_ARCH}")
|
||||
set_gencode_flag_for_srcs(
|
||||
SRCS ${arg_SRCS}
|
||||
ARCH "compute_${_STRIPPED_ARCH}"
|
||||
CODE "sm_${_STRIPPED_ARCH}")
|
||||
set_gencode_flag_for_srcs(
|
||||
SRCS ${arg_SRCS}
|
||||
ARCH "compute_${_STRIPPED_ARCH}"
|
||||
CODE "compute_${_STRIPPED_ARCH}")
|
||||
else()
|
||||
string(REPLACE "." "" _STRIPPED_ARCH "${_ARCH}")
|
||||
set_gencode_flag_for_srcs(
|
||||
SRCS ${arg_SRCS}
|
||||
ARCH "compute_${_STRIPPED_ARCH}"
|
||||
CODE "sm_${_STRIPPED_ARCH}")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
if (${arg_BUILD_PTX_FOR_ARCH})
|
||||
@ -251,7 +266,10 @@ endmacro()
|
||||
#
|
||||
# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
|
||||
# `<major>.<minor>[letter]` compute the "loose intersection" with the
|
||||
# `TGT_CUDA_ARCHS` list of gencodes.
|
||||
# `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in
|
||||
# `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there
|
||||
# is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the
|
||||
# architecture in `SRC_CUDA_ARCHS`.
|
||||
# The loose intersection is defined as:
|
||||
# { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
|
||||
# where `<=` is the version comparison operator.
|
||||
@ -268,44 +286,63 @@ endmacro()
|
||||
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
|
||||
# OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a"
|
||||
#
|
||||
# Example With PTX:
|
||||
# SRC_CUDA_ARCHS="8.0+PTX"
|
||||
# TGT_CUDA_ARCHS="9.0"
|
||||
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
|
||||
# OUT_CUDA_ARCHS="8.0+PTX"
|
||||
#
|
||||
function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
|
||||
list(REMOVE_DUPLICATES SRC_CUDA_ARCHS)
|
||||
set(TGT_CUDA_ARCHS_ ${TGT_CUDA_ARCHS})
|
||||
set(_SRC_CUDA_ARCHS "${SRC_CUDA_ARCHS}")
|
||||
set(_TGT_CUDA_ARCHS ${TGT_CUDA_ARCHS})
|
||||
|
||||
# handle +PTX suffix: separate base arch for matching, record PTX requests
|
||||
set(_PTX_ARCHS)
|
||||
foreach(_arch ${_SRC_CUDA_ARCHS})
|
||||
if(_arch MATCHES "\\+PTX$")
|
||||
string(REPLACE "+PTX" "" _base "${_arch}")
|
||||
list(APPEND _PTX_ARCHS "${_base}")
|
||||
list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}")
|
||||
list(APPEND _SRC_CUDA_ARCHS "${_base}")
|
||||
endif()
|
||||
endforeach()
|
||||
list(REMOVE_DUPLICATES _PTX_ARCHS)
|
||||
list(REMOVE_DUPLICATES _SRC_CUDA_ARCHS)
|
||||
|
||||
# if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
|
||||
# remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS
|
||||
set(_CUDA_ARCHS)
|
||||
if ("9.0a" IN_LIST SRC_CUDA_ARCHS)
|
||||
list(REMOVE_ITEM SRC_CUDA_ARCHS "9.0a")
|
||||
if ("9.0" IN_LIST TGT_CUDA_ARCHS_)
|
||||
list(REMOVE_ITEM TGT_CUDA_ARCHS_ "9.0")
|
||||
if ("9.0a" IN_LIST _SRC_CUDA_ARCHS)
|
||||
list(REMOVE_ITEM _SRC_CUDA_ARCHS "9.0a")
|
||||
if ("9.0" IN_LIST TGT_CUDA_ARCHS)
|
||||
list(REMOVE_ITEM _TGT_CUDA_ARCHS "9.0")
|
||||
set(_CUDA_ARCHS "9.0a")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if ("10.0a" IN_LIST SRC_CUDA_ARCHS)
|
||||
list(REMOVE_ITEM SRC_CUDA_ARCHS "10.0a")
|
||||
if ("10.0a" IN_LIST _SRC_CUDA_ARCHS)
|
||||
list(REMOVE_ITEM _SRC_CUDA_ARCHS "10.0a")
|
||||
if ("10.0" IN_LIST TGT_CUDA_ARCHS)
|
||||
list(REMOVE_ITEM TGT_CUDA_ARCHS_ "10.0")
|
||||
list(REMOVE_ITEM _TGT_CUDA_ARCHS "10.0")
|
||||
set(_CUDA_ARCHS "10.0a")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
list(SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
|
||||
list(SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
|
||||
|
||||
# for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that
|
||||
# is less or equal to ARCH (but has the same major version since SASS binary
|
||||
# compatibility is only forward compatible within the same major version).
|
||||
foreach(_ARCH ${TGT_CUDA_ARCHS_})
|
||||
foreach(_ARCH ${_TGT_CUDA_ARCHS})
|
||||
set(_TMP_ARCH)
|
||||
# Extract the major version of the target arch
|
||||
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" TGT_ARCH_MAJOR "${_ARCH}")
|
||||
foreach(_SRC_ARCH ${SRC_CUDA_ARCHS})
|
||||
foreach(_SRC_ARCH ${_SRC_CUDA_ARCHS})
|
||||
# Extract the major version of the source arch
|
||||
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" SRC_ARCH_MAJOR "${_SRC_ARCH}")
|
||||
# Check major-version match AND version-less-or-equal
|
||||
# Check version-less-or-equal, and allow PTX arches to match across majors
|
||||
if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
|
||||
if (SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
|
||||
if (_SRC_ARCH IN_LIST _PTX_ARCHS OR SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
|
||||
set(_TMP_ARCH "${_SRC_ARCH}")
|
||||
endif()
|
||||
else()
|
||||
@ -321,6 +358,18 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
|
||||
endforeach()
|
||||
|
||||
list(REMOVE_DUPLICATES _CUDA_ARCHS)
|
||||
|
||||
# reapply +PTX suffix to architectures that requested PTX
|
||||
set(_FINAL_ARCHS)
|
||||
foreach(_arch ${_CUDA_ARCHS})
|
||||
if(_arch IN_LIST _PTX_ARCHS)
|
||||
list(APPEND _FINAL_ARCHS "${_arch}+PTX")
|
||||
else()
|
||||
list(APPEND _FINAL_ARCHS "${_arch}")
|
||||
endif()
|
||||
endforeach()
|
||||
set(_CUDA_ARCHS ${_FINAL_ARCHS})
|
||||
|
||||
set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
|
||||
@ -70,6 +70,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
|
||||
int64_t num_tokens = input.numel() / input.size(-1); \
|
||||
dim3 grid(num_tokens); \
|
||||
dim3 block(std::min(d, 1024)); \
|
||||
if (num_tokens == 0) { \
|
||||
return; \
|
||||
} \
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||
|
||||
@ -172,7 +172,7 @@ __device__ void paged_attention_kernel(
|
||||
|
||||
// Load the query to registers.
|
||||
// Each thread in a thread group has a different part of the query.
|
||||
// For example, if the the thread group size is 4, then the first thread in
|
||||
// For example, if the thread group size is 4, then the first thread in
|
||||
// the group has 0, 4, 8, ... th vectors of the query, and the second thread
|
||||
// has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
|
||||
// q is split from a qkv tensor, it may not be contiguous.
|
||||
@ -259,7 +259,7 @@ __device__ void paged_attention_kernel(
|
||||
|
||||
// Load a key to registers.
|
||||
// Each thread in a thread group has a different part of the key.
|
||||
// For example, if the the thread group size is 4, then the first thread in
|
||||
// For example, if the thread group size is 4, then the first thread in
|
||||
// the group has 0, 4, 8, ... th vectors of the key, and the second thread
|
||||
// has 1, 5, 9, ... th vectors of the key, and so on.
|
||||
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
|
||||
|
||||
401
csrc/attention/vertical_slash_index.cu
Normal file
401
csrc/attention/vertical_slash_index.cu
Normal file
@ -0,0 +1,401 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include <cuda.h>
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
__device__ int64_t save_blocks(int* block_offset, int64_t range_start,
|
||||
int64_t range_end, int64_t block_size,
|
||||
int64_t input_block_count, int64_t kv_seqlen) {
|
||||
if (range_start >= kv_seqlen) {
|
||||
return input_block_count;
|
||||
}
|
||||
if (range_end > kv_seqlen) {
|
||||
range_end = kv_seqlen;
|
||||
}
|
||||
int64_t current_block_count = input_block_count;
|
||||
for (int idx = range_start; idx < range_end; idx += block_size) {
|
||||
block_offset[current_block_count++] = idx;
|
||||
}
|
||||
return current_block_count;
|
||||
}
|
||||
|
||||
__global__ void convert_vertical_slash_indexes_kernel(
|
||||
const int* q_seqlens, // [BATCH, ]
|
||||
const int* kv_seqlens, // [BATCH, ]
|
||||
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
|
||||
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
|
||||
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
|
||||
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
|
||||
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
|
||||
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
|
||||
int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
|
||||
int64_t NNZ_V, int64_t NNZ_S,
|
||||
bool causal // True for intra, False for succ
|
||||
) {
|
||||
const int batch_idx = blockIdx.y;
|
||||
const int head_idx = blockIdx.x;
|
||||
const int group_idx = blockIdx.z;
|
||||
|
||||
int64_t q_seqlen = q_seqlens[batch_idx];
|
||||
int64_t kv_seqlen = kv_seqlens[batch_idx];
|
||||
int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x;
|
||||
int64_t start_m = block_idx_m * BLOCK_SIZE_M;
|
||||
if (start_m >= q_seqlen) {
|
||||
return;
|
||||
}
|
||||
int64_t end_m = start_m + BLOCK_SIZE_M;
|
||||
vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
|
||||
slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
|
||||
int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
|
||||
block_count += row_offset;
|
||||
block_offset += row_offset * NNZ_S;
|
||||
column_count += row_offset;
|
||||
column_index += row_offset * NNZ_V;
|
||||
|
||||
bool has_slash = true;
|
||||
int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0;
|
||||
int64_t s = 0, v = 0;
|
||||
int64_t v_idx = vertical_indexes[v++];
|
||||
int64_t s_idx = slash_indexes[s++];
|
||||
if (causal) {
|
||||
while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) {
|
||||
s_idx = slash_indexes[s++];
|
||||
}
|
||||
if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false;
|
||||
s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M);
|
||||
} else {
|
||||
while (s_idx >= end_m + kv_seqlen && s < NNZ_S) {
|
||||
s_idx = slash_indexes[s++];
|
||||
}
|
||||
if (s_idx > end_m + kv_seqlen) has_slash = false;
|
||||
s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M);
|
||||
}
|
||||
|
||||
int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx;
|
||||
if (!has_slash) {
|
||||
if (causal) {
|
||||
range_start = (kv_seqlen - q_seqlen) + end_m;
|
||||
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
|
||||
} else {
|
||||
range_start = kv_seqlen;
|
||||
range_end = kv_seqlen + BLOCK_SIZE_N;
|
||||
}
|
||||
}
|
||||
|
||||
bool slash_finished = false;
|
||||
while (1) {
|
||||
if (v_idx < range_end) {
|
||||
if (v_idx < range_start) {
|
||||
column_index[tmp_col_cnt++] = v_idx;
|
||||
}
|
||||
if (v < NNZ_V) {
|
||||
v_idx = vertical_indexes[v++];
|
||||
} else {
|
||||
if (causal)
|
||||
v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen);
|
||||
else
|
||||
v_idx = end_m + BLOCK_SIZE_N + kv_seqlen;
|
||||
}
|
||||
} else {
|
||||
if ((s < NNZ_S && causal) ||
|
||||
(s < NNZ_S && !causal && slash_indexes[s] >= start_m)) {
|
||||
if (causal)
|
||||
s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++],
|
||||
BLOCK_SIZE_M);
|
||||
else
|
||||
s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M);
|
||||
} else {
|
||||
if (v == NNZ_V || (v_idx > range_start && causal)) {
|
||||
// add the last vertical if no more slash
|
||||
if (v == NNZ_V && !causal && v_idx < kv_seqlen) {
|
||||
column_index[tmp_col_cnt++] = v_idx;
|
||||
}
|
||||
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
|
||||
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
|
||||
break;
|
||||
} else {
|
||||
if (causal) {
|
||||
range_start = (kv_seqlen - q_seqlen) + end_m;
|
||||
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
|
||||
} else {
|
||||
// if slash_finished but there are vertical left, save current
|
||||
// blocks
|
||||
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
|
||||
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
|
||||
range_start = kv_seqlen;
|
||||
range_end = kv_seqlen + BLOCK_SIZE_N;
|
||||
}
|
||||
slash_finished = true;
|
||||
}
|
||||
}
|
||||
if (!slash_finished) {
|
||||
if (s_idx > range_end + BLOCK_SIZE_M) {
|
||||
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
|
||||
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
|
||||
range_start = s_idx - BLOCK_SIZE_M;
|
||||
range_end = s_idx;
|
||||
} else if (s_idx > range_end) {
|
||||
range_end += BLOCK_SIZE_M;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
block_count[0] = tmp_blk_cnt;
|
||||
column_count[0] = tmp_col_cnt;
|
||||
}
|
||||
|
||||
void convert_vertical_slash_indexes_64x64(
|
||||
const int* q_seqlens, // [BATCH, ]
|
||||
const int* kv_seqlens, // [BATCH, ]
|
||||
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
|
||||
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
|
||||
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
|
||||
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
|
||||
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
|
||||
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
|
||||
int64_t BATCH_SIZE, int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M,
|
||||
int64_t BLOCK_SIZE_N, int64_t NNZ_V, int64_t NNZ_S, bool causal) {
|
||||
const int N_THREADS = 64;
|
||||
const dim3 dimBlock(N_THREADS);
|
||||
const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS);
|
||||
convert_vertical_slash_indexes_kernel<<<dimGrid, dimBlock>>>(
|
||||
q_seqlens, kv_seqlens, vertical_indexes, slash_indexes, block_count,
|
||||
block_offset, column_count, column_index, N_HEADS, N_ROWS, BLOCK_SIZE_M,
|
||||
BLOCK_SIZE_N, NNZ_V, NNZ_S, causal);
|
||||
}
|
||||
|
||||
/**
|
||||
* Implements the Algorithm 4 in paper https://arxiv.org/abs/2407.02490.
|
||||
*
|
||||
* This function builds the index of each row of blocks from vertical indices
|
||||
* and slash indices. The vertical indices are treated as points, while the
|
||||
* slash indices are converted as ranges. The output consists of the merged
|
||||
* ranges and separate column indices, where the ranges are represented by
|
||||
* block indices.
|
||||
*
|
||||
* The implementation is referenced from the original MInference repo:
|
||||
* https://github.com/microsoft/MInference/blob/main/csrc/vertical_slash_index.cu.
|
||||
*/
|
||||
void convert_vertical_slash_indexes(
|
||||
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
|
||||
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
|
||||
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
|
||||
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
|
||||
torch::Tensor q_seqlens, // [BATCH, ]
|
||||
torch::Tensor kv_seqlens, // [BATCH, ]
|
||||
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
|
||||
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
|
||||
int64_t context_size, int64_t block_size_M, int64_t block_size_N,
|
||||
bool causal) {
|
||||
cudaSetDevice(q_seqlens.get_device());
|
||||
|
||||
int batch_size = slash_indexes.size(0);
|
||||
int num_heads = slash_indexes.size(1);
|
||||
int nnz_slash = slash_indexes.size(2);
|
||||
int nnz_vertical = vertical_indexes.size(2);
|
||||
int num_rows = (context_size + block_size_M - 1) / block_size_M;
|
||||
|
||||
convert_vertical_slash_indexes_64x64(
|
||||
q_seqlens.data_ptr<int>(), kv_seqlens.data_ptr<int>(),
|
||||
vertical_indexes.data_ptr<int>(), slash_indexes.data_ptr<int>(),
|
||||
block_count.data_ptr<int>(), block_offset.data_ptr<int>(),
|
||||
column_count.data_ptr<int>(), column_index.data_ptr<int>(), batch_size,
|
||||
num_heads, num_rows, block_size_M, block_size_N, nnz_vertical, nnz_slash,
|
||||
causal);
|
||||
}
|
||||
|
||||
__global__ void convert_vertical_slash_indexes_kernel_mergehead(
|
||||
const int* q_seqlens, // [BATCH, ]
|
||||
const int* kv_seqlens, // [BATCH, ]
|
||||
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
|
||||
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
|
||||
const int* per_head_vertical_topkv, const int* per_head_slash_topkv,
|
||||
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
|
||||
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
|
||||
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
|
||||
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
|
||||
int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
|
||||
int64_t NNZ_V, int64_t NNZ_S,
|
||||
bool causal // True for intra, False for succ
|
||||
) {
|
||||
const int batch_idx = blockIdx.y;
|
||||
const int head_idx = blockIdx.x;
|
||||
const int group_idx = blockIdx.z;
|
||||
|
||||
int64_t q_seqlen = q_seqlens[batch_idx];
|
||||
int64_t kv_seqlen = kv_seqlens[batch_idx];
|
||||
int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x;
|
||||
int64_t start_m = block_idx_m * BLOCK_SIZE_M;
|
||||
if (start_m >= q_seqlen) {
|
||||
return;
|
||||
}
|
||||
int64_t end_m = start_m + BLOCK_SIZE_M;
|
||||
vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
|
||||
slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
|
||||
int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
|
||||
block_count += row_offset;
|
||||
block_offset += row_offset * NNZ_S;
|
||||
column_count += row_offset;
|
||||
column_index += row_offset * NNZ_V;
|
||||
|
||||
// MergeHead: each head has it's unique max topk NNZ_V,NNZ_S. (NNZ_V,NNZ_S
|
||||
// above is buffer size, use to compute offset)
|
||||
NNZ_S = per_head_slash_topkv[head_idx];
|
||||
NNZ_V = per_head_vertical_topkv[head_idx];
|
||||
|
||||
bool has_slash = true;
|
||||
int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0;
|
||||
int64_t s = 0, v = 0;
|
||||
int64_t v_idx = vertical_indexes[v++];
|
||||
int64_t s_idx = slash_indexes[s++];
|
||||
if (causal) {
|
||||
while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) {
|
||||
s_idx = slash_indexes[s++];
|
||||
}
|
||||
if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false;
|
||||
s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M);
|
||||
} else {
|
||||
while (s_idx >= end_m + kv_seqlen && s < NNZ_S) {
|
||||
s_idx = slash_indexes[s++];
|
||||
}
|
||||
if (s_idx > end_m + kv_seqlen) has_slash = false;
|
||||
s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M);
|
||||
}
|
||||
|
||||
int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx;
|
||||
if (!has_slash) {
|
||||
if (causal) {
|
||||
range_start = (kv_seqlen - q_seqlen) + end_m;
|
||||
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
|
||||
} else {
|
||||
range_start = kv_seqlen;
|
||||
range_end = kv_seqlen + BLOCK_SIZE_N;
|
||||
}
|
||||
}
|
||||
|
||||
bool slash_finished = false;
|
||||
while (1) {
|
||||
if (v_idx < range_end) {
|
||||
if (v_idx < range_start) {
|
||||
column_index[tmp_col_cnt++] = v_idx;
|
||||
}
|
||||
if (v < NNZ_V) {
|
||||
v_idx = vertical_indexes[v++];
|
||||
} else {
|
||||
if (causal)
|
||||
v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen);
|
||||
else
|
||||
v_idx = end_m + BLOCK_SIZE_N + kv_seqlen;
|
||||
}
|
||||
} else {
|
||||
if ((s < NNZ_S && causal) ||
|
||||
(s < NNZ_S && !causal && slash_indexes[s] >= start_m)) {
|
||||
if (causal)
|
||||
s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++],
|
||||
BLOCK_SIZE_M);
|
||||
else
|
||||
s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M);
|
||||
} else {
|
||||
if (v == NNZ_V || (v_idx > range_start && causal)) {
|
||||
// add the last vertical if no more slash
|
||||
if (v == NNZ_V && !causal && v_idx < kv_seqlen) {
|
||||
column_index[tmp_col_cnt++] = v_idx;
|
||||
}
|
||||
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
|
||||
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
|
||||
break;
|
||||
} else {
|
||||
if (causal) {
|
||||
range_start = (kv_seqlen - q_seqlen) + end_m;
|
||||
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
|
||||
} else {
|
||||
// if slash_finished but there are vertical left, save current
|
||||
// blocks
|
||||
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
|
||||
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
|
||||
range_start = kv_seqlen;
|
||||
range_end = kv_seqlen + BLOCK_SIZE_N;
|
||||
}
|
||||
slash_finished = true;
|
||||
}
|
||||
}
|
||||
if (!slash_finished) {
|
||||
if (s_idx > range_end + BLOCK_SIZE_M) {
|
||||
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
|
||||
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
|
||||
range_start = s_idx - BLOCK_SIZE_M;
|
||||
range_end = s_idx;
|
||||
} else if (s_idx > range_end) {
|
||||
range_end += BLOCK_SIZE_M;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
block_count[0] = tmp_blk_cnt;
|
||||
column_count[0] = tmp_col_cnt;
|
||||
}
|
||||
|
||||
void convert_vertical_slash_indexes_64x64_mergehead(
|
||||
const int* q_seqlens, // [BATCH, ]
|
||||
const int* kv_seqlens, // [BATCH, ]
|
||||
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
|
||||
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
|
||||
int* per_head_vertical_topkv, int* per_head_slash_topkv,
|
||||
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
|
||||
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
|
||||
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
|
||||
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
|
||||
int64_t BATCH_SIZE, int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M,
|
||||
int64_t BLOCK_SIZE_N, int64_t NNZ_V, int64_t NNZ_S, bool causal) {
|
||||
const int N_THREADS = 64;
|
||||
const dim3 dimBlock(N_THREADS);
|
||||
const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS);
|
||||
convert_vertical_slash_indexes_kernel_mergehead<<<dimGrid, dimBlock>>>(
|
||||
q_seqlens, kv_seqlens, vertical_indexes, slash_indexes,
|
||||
per_head_vertical_topkv, per_head_slash_topkv, block_count, block_offset,
|
||||
column_count, column_index, N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N,
|
||||
NNZ_V, NNZ_S, causal);
|
||||
}
|
||||
|
||||
/**
|
||||
* Implements the Algorithm 4 in paper https://arxiv.org/abs/2407.02490.
|
||||
*
|
||||
* Like the above convert_vertical_slash_indexes, but with
|
||||
* pre-computed vertical and slash counts.
|
||||
*/
|
||||
void convert_vertical_slash_indexes_mergehead(
|
||||
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
|
||||
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
|
||||
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
|
||||
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
|
||||
torch::Tensor q_seqlens, // [BATCH, ]
|
||||
torch::Tensor kv_seqlens, // [BATCH, ]
|
||||
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
|
||||
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
|
||||
torch::Tensor vertical_indices_count, // [N_HEADS, ]
|
||||
torch::Tensor slash_indices_count, // [N_HEADS, ]
|
||||
int64_t context_size, int64_t block_size_M, int64_t block_size_N,
|
||||
bool causal) {
|
||||
cudaSetDevice(q_seqlens.get_device());
|
||||
|
||||
int batch_size = slash_indexes.size(0);
|
||||
int num_heads = slash_indexes.size(1);
|
||||
int nnz_slash = slash_indexes.size(2);
|
||||
int nnz_vertical = vertical_indexes.size(2);
|
||||
int num_rows = (context_size + block_size_M - 1) / block_size_M;
|
||||
|
||||
convert_vertical_slash_indexes_64x64_mergehead(
|
||||
q_seqlens.data_ptr<int>(), kv_seqlens.data_ptr<int>(),
|
||||
vertical_indexes.data_ptr<int>(), slash_indexes.data_ptr<int>(),
|
||||
vertical_indices_count.data_ptr<int>(),
|
||||
slash_indices_count.data_ptr<int>(), block_count.data_ptr<int>(),
|
||||
block_offset.data_ptr<int>(), column_count.data_ptr<int>(),
|
||||
column_index.data_ptr<int>(), batch_size, num_heads, num_rows,
|
||||
block_size_M, block_size_N, nnz_vertical, nnz_slash, causal);
|
||||
}
|
||||
@ -315,6 +315,8 @@ static inline constexpr auto kS8 = ScalarType::int_(8);
|
||||
static inline constexpr auto kU8 = ScalarType::uint(8);
|
||||
static inline constexpr auto kU8B128 = ScalarType::uint(8, 128);
|
||||
|
||||
static inline constexpr auto kFE2M1f =
|
||||
ScalarType::float_(2, 1, true, ScalarType::NAN_NONE);
|
||||
static inline constexpr auto kFE3M2f =
|
||||
ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);
|
||||
static inline constexpr auto kFE4M3fn =
|
||||
@ -332,6 +334,7 @@ static inline constexpr auto kInt8 = kS8;
|
||||
static inline constexpr auto kUint8 = kU8;
|
||||
static inline constexpr auto kUint8b128 = kU8B128;
|
||||
|
||||
static inline constexpr auto kFloat4_e2m1f = kFE2M1f;
|
||||
static inline constexpr auto kFloat6_e3m2f = kFE3M2f;
|
||||
static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn;
|
||||
static inline constexpr auto kFloat8_e5m2 = kFE5M2;
|
||||
|
||||
@ -65,5 +65,19 @@
|
||||
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::UInt16, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::UInt32, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::UInt64, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
||||
|
||||
#define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH( \
|
||||
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__))
|
||||
|
||||
@ -31,7 +31,10 @@ TEMPLATE = ("template __global__ void Marlin<"
|
||||
|
||||
# int8 with zero point case (vllm::kU8) is also supported,
|
||||
# we don't add it to reduce wheel size.
|
||||
SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn"]
|
||||
SCALAR_TYPES = [
|
||||
"vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn",
|
||||
"vllm::kFE2M1f"
|
||||
]
|
||||
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
|
||||
|
||||
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
|
||||
@ -39,7 +42,7 @@ THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
|
||||
# = 0 : act order case
|
||||
# = -1 : channelwise quantization
|
||||
# > 0 : group_size=16*group_blocks
|
||||
GROUP_BLOCKS = [0, -1, 2, 4, 8]
|
||||
GROUP_BLOCKS = [0, -1, 1, 2, 4, 8]
|
||||
DTYPES = ["fp16", "bf16"]
|
||||
|
||||
|
||||
@ -72,6 +75,12 @@ def generate_new_kernels():
|
||||
# for fp8
|
||||
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
|
||||
continue
|
||||
# nvfp4 only supports group_size == 16
|
||||
if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]:
|
||||
continue
|
||||
# other quantization methods don't support group_size = 16
|
||||
if scalar_type != "vllm::kFE2M1f" and group_blocks == 1:
|
||||
continue
|
||||
|
||||
k_blocks = thread_configs[0] // 16
|
||||
n_blocks = thread_configs[1] // 16
|
||||
|
||||
@ -7,17 +7,18 @@
|
||||
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
#define MARLIN_KERNEL_PARAMS \
|
||||
const int4 *__restrict__ A, const int4 *__restrict__ B, \
|
||||
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
|
||||
const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \
|
||||
const int *__restrict__ g_idx, \
|
||||
const int32_t *__restrict__ sorted_token_ids_ptr, \
|
||||
const int32_t *__restrict__ expert_ids_ptr, \
|
||||
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
|
||||
const float *__restrict__ topk_weights_ptr, int top_k, \
|
||||
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
|
||||
int prob_n, int prob_k, int *locks, bool use_atomic_add, \
|
||||
#define MARLIN_KERNEL_PARAMS \
|
||||
const int4 *__restrict__ A, const int4 *__restrict__ B, \
|
||||
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
|
||||
const int4 *__restrict__ scales_ptr, \
|
||||
const uint16_t *__restrict__ scale2_ptr, \
|
||||
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
|
||||
const int32_t *__restrict__ sorted_token_ids_ptr, \
|
||||
const int32_t *__restrict__ expert_ids_ptr, \
|
||||
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
|
||||
const float *__restrict__ topk_weights_ptr, int top_k, \
|
||||
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
|
||||
int prob_n, int prob_k, int *locks, bool use_atomic_add, \
|
||||
bool use_fp32_reduce, int max_shared_mem
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
@ -301,9 +301,11 @@ __global__ void Marlin(
|
||||
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
|
||||
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
||||
// (k/groupsize)xn
|
||||
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
|
||||
// (k/groupsize)x(n/pack_factor)
|
||||
const int* __restrict__ g_idx, // int32 group indices of shape k
|
||||
const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4
|
||||
// only)
|
||||
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
|
||||
// (k/groupsize)x(n/pack_factor)
|
||||
const int* __restrict__ g_idx, // int32 group indices of shape k
|
||||
const int32_t* __restrict__ sorted_token_ids_ptr, // moe sorted_ids
|
||||
const int32_t* __restrict__ expert_ids_ptr, // moe expert ids
|
||||
const int32_t* __restrict__ num_tokens_past_padded_ptr, // moe num tokens
|
||||
@ -341,6 +343,16 @@ __global__ void Marlin(
|
||||
extern __shared__ int4 sh[];
|
||||
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
|
||||
constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8;
|
||||
constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 ||
|
||||
w_type == vllm::kU4B8 || w_type == vllm::kU8B128;
|
||||
// see comments of dequant.h for more details
|
||||
constexpr bool dequant_skip_flop =
|
||||
!is_int_type ||
|
||||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
|
||||
has_zp && !is_zp_float && !(w_type == vllm::kU8);
|
||||
|
||||
scalar_t2 global_scale;
|
||||
|
||||
constexpr bool has_act_order = group_blocks == 0;
|
||||
|
||||
constexpr int pack_factor = 32 / w_type.size_bits();
|
||||
@ -348,7 +360,8 @@ __global__ void Marlin(
|
||||
constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks);
|
||||
const int group_size =
|
||||
(!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups;
|
||||
const int scales_expert_stride = prob_n * prob_k / group_size / 8;
|
||||
const int scales_expert_stride =
|
||||
prob_n * prob_k / group_size / (w_type == vllm::kFE2M1f ? 16 : 8);
|
||||
const int zp_expert_stride =
|
||||
is_zp_float ? prob_n * prob_k / group_size / 8
|
||||
: prob_n * prob_k / group_size / (pack_factor * 4);
|
||||
@ -460,9 +473,16 @@ __global__ void Marlin(
|
||||
if (mul_topk_weights) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
sh_block_topk_weights[tid4 * 4 + i] =
|
||||
Dtype::num2num2(Dtype::float2num(
|
||||
topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]]));
|
||||
int idx = tid4 * 4 + i;
|
||||
idx = idx < block_num_valid_tokens ? idx : 0;
|
||||
if constexpr (w_type == vllm::kFE2M1f) {
|
||||
sh_block_topk_weights[idx] = __hmul2(
|
||||
global_scale, Dtype::num2num2(Dtype::float2num(
|
||||
topk_weights_ptr[sh_block_sorted_ids[idx]])));
|
||||
} else {
|
||||
sh_block_topk_weights[idx] = Dtype::num2num2(
|
||||
Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[idx]]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -493,6 +513,11 @@ __global__ void Marlin(
|
||||
expert_id = expert_ids_ptr[block_id];
|
||||
}
|
||||
|
||||
if constexpr (w_type == vllm::kFE2M1f) {
|
||||
uint16_t val = scale2_ptr[expert_id];
|
||||
global_scale = Dtype::num2num2(*reinterpret_cast<scalar_t*>(&val));
|
||||
}
|
||||
|
||||
B_expert_off = expert_id * prob_n * prob_k / (pack_factor * 4);
|
||||
scales_ptr += (expert_id - old_expert_id) * scales_expert_stride;
|
||||
if constexpr (has_zp) {
|
||||
@ -606,7 +631,7 @@ __global__ void Marlin(
|
||||
constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
|
||||
constexpr int s_tb_groups =
|
||||
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
|
||||
? thread_k_blocks / group_blocks
|
||||
? thread_k_blocks / group_blocks / (w_type == vllm::kFE2M1f ? 2 : 1)
|
||||
: 1;
|
||||
constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
|
||||
int s_gl_rd_delta = s_gl_stride;
|
||||
@ -664,7 +689,8 @@ __global__ void Marlin(
|
||||
if constexpr (group_blocks == -1) {
|
||||
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
||||
} else {
|
||||
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
|
||||
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) /
|
||||
(w_type == vllm::kFE2M1f ? 2 : 1) +
|
||||
s_sh_stride * slice_col + threadIdx.x;
|
||||
}
|
||||
}
|
||||
@ -688,10 +714,20 @@ __global__ void Marlin(
|
||||
// we scale a `half2` tile in column-major layout in the former and in
|
||||
// row-major in the latter case.
|
||||
int s_sh_rd;
|
||||
if constexpr (group_blocks != -1)
|
||||
if constexpr (group_blocks != -1 && w_type == vllm::kFE2M1f) {
|
||||
auto warp_id = threadIdx.x / 32;
|
||||
int n_warps = thread_n_blocks / 4;
|
||||
int warp_row = warp_id / n_warps;
|
||||
|
||||
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
(threadIdx.x % 32) / 4;
|
||||
else if constexpr (group_blocks == -1 && (m_block_size_8 || has_zp))
|
||||
s_sh_rd = s_sh_rd * 2 + warp_row % 2;
|
||||
|
||||
} else if constexpr (group_blocks != -1)
|
||||
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
(threadIdx.x % 32) / 4;
|
||||
else if constexpr (group_blocks == -1 &&
|
||||
(m_block_size_8 || (has_zp && !dequant_skip_flop)))
|
||||
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
(threadIdx.x % 32) / 8;
|
||||
else
|
||||
@ -801,7 +837,7 @@ __global__ void Marlin(
|
||||
sh_first_group_id = first_group_id;
|
||||
sh_num_groups = last_group_id - first_group_id + 1;
|
||||
|
||||
if (sh_num_groups < act_s_max_num_groups) {
|
||||
if (sh_num_groups > act_s_max_num_groups) {
|
||||
sh_num_groups = act_s_max_num_groups;
|
||||
}
|
||||
|
||||
@ -1021,12 +1057,19 @@ __global__ void Marlin(
|
||||
cur_k += k_iter_size * (k % b_sh_wr_iters);
|
||||
|
||||
int k_blocks = cur_k / 16;
|
||||
int cur_group_id = k_blocks / group_blocks;
|
||||
int cur_group_id =
|
||||
k_blocks / (group_blocks * (w_type == vllm::kFE2M1f ? 2 : 1));
|
||||
|
||||
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
||||
|
||||
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
|
||||
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
|
||||
if constexpr (w_type_id != vllm::kFE2M1f.id()) {
|
||||
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
|
||||
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
|
||||
} else {
|
||||
reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
|
||||
reinterpret_cast<int2*>(
|
||||
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1199,22 +1242,7 @@ __global__ void Marlin(
|
||||
};
|
||||
|
||||
auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) {
|
||||
if constexpr (has_zp && is_zp_float || !has_zp) {
|
||||
dequant<scalar_t2, w_type_id>(q, frag_b_ptr);
|
||||
} else {
|
||||
static_assert(has_zp && !is_zp_float);
|
||||
static_assert(w_type_id == vllm::kU4.id() || w_type_id == vllm::kU8.id());
|
||||
// If (has_zp && !is_zp_float),
|
||||
// we use not-zp version `dequant` function
|
||||
// to improve numerical accuracy.
|
||||
// Since both weight and zero point are dequanted using this logic,
|
||||
// the final dequanted weight would be correct.
|
||||
if constexpr (w_type_id == vllm::kU4.id()) {
|
||||
dequant<scalar_t2, vllm::kU4B8.id()>(q, frag_b_ptr);
|
||||
} else if constexpr (w_type_id == vllm::kU8.id()) {
|
||||
dequant<scalar_t2, vllm::kU8B128.id()>(q, frag_b_ptr);
|
||||
}
|
||||
}
|
||||
dequant<scalar_t2, w_type_id, dequant_skip_flop>(q, frag_b_ptr);
|
||||
};
|
||||
|
||||
// Execute the actual tensor core matmul of a sub-tile.
|
||||
@ -1244,13 +1272,23 @@ __global__ void Marlin(
|
||||
dequant_data(zp_quant_1, reinterpret_cast<scalar_t2*>(&frag_zp) + 2);
|
||||
}
|
||||
}
|
||||
if constexpr (has_zp && is_zp_float) {
|
||||
if constexpr (!dequant_skip_flop && has_zp && is_zp_float) {
|
||||
if (is_new_zp) {
|
||||
reinterpret_cast<int4*>(&frag_zp)[0] =
|
||||
reinterpret_cast<int4*>(&frag_zpf[k2])[0];
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (w_type == vllm::kFE2M1f) {
|
||||
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
|
||||
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];
|
||||
|
||||
dequant_fp8_scales<scalar_t2>(s_quant_0,
|
||||
reinterpret_cast<scalar_t2*>(&frag_s[k2]));
|
||||
dequant_fp8_scales<scalar_t2>(
|
||||
s_quant_1, reinterpret_cast<scalar_t2*>(&frag_s[k2]) + 2);
|
||||
}
|
||||
|
||||
// We have the m dimension as the inner loop in order to encourage overlapping
|
||||
// dequantization and matmul operations.
|
||||
#pragma unroll
|
||||
@ -1259,7 +1297,10 @@ __global__ void Marlin(
|
||||
FragB frag_b1;
|
||||
int b_quant_0, b_quant_1;
|
||||
|
||||
if constexpr (w_type.size_bits() == 4) {
|
||||
if constexpr (w_type_id == vllm::kFE2M1f.id()) {
|
||||
b_quant_1 = frag_b_quant[k2][0][j];
|
||||
b_quant_0 = b_quant_1 << 8;
|
||||
} else if constexpr (w_type.size_bits() == 4) {
|
||||
b_quant_0 = frag_b_quant[k2][0][j];
|
||||
b_quant_1 = b_quant_0 >> 8;
|
||||
} else {
|
||||
@ -1272,6 +1313,11 @@ __global__ void Marlin(
|
||||
dequant_data(b_quant_0, reinterpret_cast<scalar_t2*>(&frag_b0));
|
||||
dequant_data(b_quant_1, reinterpret_cast<scalar_t2*>(&frag_b1));
|
||||
|
||||
if constexpr (dequant_skip_flop && has_zp && !is_zp_float) {
|
||||
sub_zp<scalar_t>(frag_b0, frag_zp[j], 0);
|
||||
sub_zp<scalar_t>(frag_b1, frag_zp[j], 1);
|
||||
}
|
||||
|
||||
// Apply scale to frag_b0
|
||||
if constexpr (has_act_order) {
|
||||
static_assert(group_blocks != -1);
|
||||
@ -1279,7 +1325,8 @@ __global__ void Marlin(
|
||||
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0);
|
||||
scale4<scalar_t>(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
|
||||
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1);
|
||||
} else if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
|
||||
} else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float &&
|
||||
group_blocks == -1) {
|
||||
int idx = (threadIdx.x / 4) % 2;
|
||||
scalar_t2 s2 = Dtype::nums2num2(
|
||||
reinterpret_cast<scalar_t*>(&frag_s[j / 2][j % 2 * 2 + 0])[idx],
|
||||
@ -1287,7 +1334,7 @@ __global__ void Marlin(
|
||||
if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2);
|
||||
scale_and_sub<scalar_t>(frag_b0, s2.x, frag_zp[j].x);
|
||||
scale_and_sub<scalar_t>(frag_b1, s2.y, frag_zp[j].y);
|
||||
} else if constexpr (has_zp && group_blocks != -1) {
|
||||
} else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) {
|
||||
if (is_new_zp)
|
||||
frag_zp[j] = __hmul2(frag_zp[j],
|
||||
*reinterpret_cast<scalar_t2*>(&frag_s[k2][j]));
|
||||
@ -1554,10 +1601,17 @@ __global__ void Marlin(
|
||||
// For per-column quantization we finally apply the scale here (only for
|
||||
// 4-bit)
|
||||
if constexpr (!has_act_order && group_blocks == -1 &&
|
||||
w_type.size_bits() == 4 && !has_zp) {
|
||||
w_type.size_bits() == 4 &&
|
||||
(has_zp && dequant_skip_flop || !has_zp)) {
|
||||
res = __hmul2(res, s[0]);
|
||||
}
|
||||
|
||||
if constexpr (w_type == vllm::kFE2M1f) {
|
||||
if (!mul_topk_weights) {
|
||||
res = __hmul2(res, global_scale);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (m_block_size_8) {
|
||||
((scalar_t*)sh_red)[idx] = res.x;
|
||||
((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y;
|
||||
@ -1648,7 +1702,9 @@ __global__ void Marlin(
|
||||
if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
|
||||
if (i == 0) {
|
||||
fetch_col_zp_to_shared();
|
||||
fetch_col_scale_to_shared();
|
||||
if constexpr (!dequant_skip_flop) {
|
||||
fetch_col_scale_to_shared();
|
||||
}
|
||||
}
|
||||
}
|
||||
fetch_to_shared(i, i, i < slice_iters, i);
|
||||
@ -1711,17 +1767,20 @@ __global__ void Marlin(
|
||||
|
||||
if constexpr (has_act_order) {
|
||||
slice_k_start += tb_k * stages;
|
||||
slice_k_start_shared_fetch += tb_k * stages;
|
||||
int first_group_id = g_idx[slice_k_start];
|
||||
int last_g_idx = slice_k_start + stages * tb_k * 2;
|
||||
if (last_g_idx >= prob_k) {
|
||||
last_g_idx = prob_k - 1;
|
||||
}
|
||||
int last_group_id = g_idx[last_g_idx];
|
||||
if (last_group_id >= sh_first_group_id + sh_num_groups) {
|
||||
fetch_act_order_scales_to_shared(false, first_group_id,
|
||||
last_group_id);
|
||||
__syncthreads();
|
||||
|
||||
if (slice_k_start < prob_k) {
|
||||
slice_k_start_shared_fetch += tb_k * stages;
|
||||
int first_group_id = g_idx[slice_k_start];
|
||||
int last_g_idx = slice_k_start + stages * tb_k * 2;
|
||||
if (last_g_idx >= prob_k) {
|
||||
last_g_idx = prob_k - 1;
|
||||
}
|
||||
int last_group_id = g_idx[last_g_idx];
|
||||
if (last_group_id >= sh_first_group_id + sh_num_groups) {
|
||||
fetch_act_order_scales_to_shared(false, first_group_id,
|
||||
last_group_id);
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
if (slice_iters == 0) {
|
||||
@ -1737,7 +1796,8 @@ __global__ void Marlin(
|
||||
bool last = slice_idx == slice_count - 1;
|
||||
// For per-column scales, we only fetch them here in the final step before
|
||||
// write-out
|
||||
if constexpr (!has_act_order && group_blocks == -1 && !has_zp) {
|
||||
if constexpr (!has_act_order && group_blocks == -1 &&
|
||||
(has_zp && dequant_skip_flop || !has_zp)) {
|
||||
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
|
||||
if (s_sh_wr_pred) {
|
||||
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
|
||||
@ -1747,7 +1807,8 @@ __global__ void Marlin(
|
||||
}
|
||||
|
||||
thread_block_reduce();
|
||||
if constexpr (!has_act_order && group_blocks == -1 && !has_zp) {
|
||||
if constexpr (!has_act_order && group_blocks == -1 &&
|
||||
(has_zp && dequant_skip_flop || !has_zp)) {
|
||||
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
|
||||
cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
@ -1771,7 +1832,8 @@ __global__ void Marlin(
|
||||
// that converts the fp32 results to fp16 (so that we avoid possible
|
||||
// overflow in fp16)
|
||||
if constexpr (!has_act_order && group_blocks == -1 &&
|
||||
w_type.size_bits() == 8 && !has_zp) {
|
||||
w_type.size_bits() == 8 &&
|
||||
(has_zp && dequant_skip_flop || !has_zp)) {
|
||||
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks; i++) {
|
||||
|
||||
@ -291,6 +291,7 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
|
||||
// BIGGROUP: cases for big group size (group_blocks in [-1, 8])
|
||||
// FZP: cases for float-zero-point (is_zp_float = true)
|
||||
// ACT: cases for act order case (group_blocks == 0)
|
||||
// FP4: cases for nvfp4(e2m1) (group_blocks == 1)
|
||||
#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
|
||||
@ -338,6 +339,21 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
|
||||
|
||||
#define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
|
||||
|
||||
#define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
|
||||
|
||||
#define FP4_GET_IF(W_TYPE) \
|
||||
FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||
FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||
FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||
FP4_GET_IF_M234(W_TYPE, 8, 4, 128)
|
||||
|
||||
#define BIGGROUP_GET_IF(W_TYPE) \
|
||||
BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||
BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||
@ -394,6 +410,8 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
|
||||
|
||||
BIGGROUP_GET_IF(vllm::kFE4M3fn)
|
||||
|
||||
FP4_GET_IF(vllm::kFE2M1f)
|
||||
|
||||
ACT_GET_IF(vllm::kU4B8)
|
||||
ACT_GET_IF(vllm::kU8B128)
|
||||
|
||||
@ -465,7 +483,7 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
|
||||
|
||||
template <typename scalar_t>
|
||||
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
void* zp, void* g_idx, void* perm, void* a_tmp,
|
||||
void* s2, void* zp, void* g_idx, void* perm, void* a_tmp,
|
||||
void* sorted_token_ids, void* expert_ids,
|
||||
void* num_tokens_past_padded, void* topk_weights,
|
||||
int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep,
|
||||
@ -479,14 +497,16 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
bool m_block_size_8 = moe_block_size == 8;
|
||||
|
||||
if (has_zp) {
|
||||
TORCH_CHECK(q_type == vllm::kU4,
|
||||
"q_type must be u4 when has_zp = True. Got = ", q_type.str());
|
||||
TORCH_CHECK(
|
||||
q_type == vllm::kU4 || q_type == vllm::kU8,
|
||||
"q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
|
||||
} else {
|
||||
TORCH_CHECK(q_type == vllm::kU4B8 || q_type == vllm::kU8B128 ||
|
||||
q_type == vllm::kFE4M3fn,
|
||||
"q_type must be uint4b8, uint8b128 or fp8e4m3 when has_zp = "
|
||||
"False. Got = ",
|
||||
q_type.str());
|
||||
TORCH_CHECK(
|
||||
q_type == vllm::kU4B8 || q_type == vllm::kU8B128 ||
|
||||
q_type == vllm::kFE4M3fn || q_type == vllm::kFE2M1f,
|
||||
"q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when "
|
||||
"has_zp = False. Got = ",
|
||||
q_type.str());
|
||||
}
|
||||
|
||||
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
||||
@ -519,6 +539,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
int4* C_ptr = (int4*)C;
|
||||
int4* C_tmp_ptr = (int4*)C_tmp;
|
||||
const int4* s_ptr = (const int4*)s;
|
||||
const uint16_t* s2_ptr = (const uint16_t*)s2;
|
||||
const int4* zp_ptr = (const int4*)zp;
|
||||
const int* g_idx_ptr = (const int*)g_idx;
|
||||
const int* perm_ptr = (const int*)perm;
|
||||
@ -627,7 +648,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
// avoid ">>>" being formatted to "> > >"
|
||||
// clang-format off
|
||||
kernel<<<blocks, num_threads, max_shared_mem, stream>>>(
|
||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr,
|
||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr,
|
||||
sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,
|
||||
topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m,
|
||||
prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce, max_shared_mem);
|
||||
@ -639,6 +660,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
torch::Tensor moe_wna16_marlin_gemm(
|
||||
torch::Tensor& a, std::optional<torch::Tensor> const& c_or_none,
|
||||
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
|
||||
std::optional<torch::Tensor> const& global_scale_or_none,
|
||||
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
|
||||
@ -790,6 +812,17 @@ torch::Tensor moe_wna16_marlin_gemm(
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor global_scale;
|
||||
if (global_scale_or_none.has_value()) {
|
||||
global_scale = global_scale_or_none.value();
|
||||
TORCH_CHECK(b_q_type == vllm::kFE2M1f,
|
||||
"global_scale can only be used for float4_e2m1f.");
|
||||
} else {
|
||||
global_scale = torch::empty({0}, options);
|
||||
TORCH_CHECK(!(b_q_type == vllm::kFE2M1f),
|
||||
"the global_scale parameter must be passed for float4_e2m1f.");
|
||||
}
|
||||
|
||||
torch::Tensor b_zeros;
|
||||
if (b_zeros_or_none.has_value()) {
|
||||
b_zeros = b_zeros_or_none.value();
|
||||
@ -802,13 +835,14 @@ torch::Tensor moe_wna16_marlin_gemm(
|
||||
|
||||
if (has_zp) {
|
||||
TORCH_CHECK(
|
||||
b_q_type == vllm::kU4,
|
||||
"b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str());
|
||||
b_q_type == vllm::kU4 || b_q_type == vllm::kU8,
|
||||
"b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str());
|
||||
} else {
|
||||
TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 ||
|
||||
b_q_type == vllm::kFE4M3fn,
|
||||
"b_q_type must be uint4b8, uint8b128 or fp8e4m3 when has_zp = "
|
||||
"False. Got = ",
|
||||
b_q_type == vllm::kFE4M3fn || b_q_type == vllm::kFE2M1f,
|
||||
"b_q_type must be uint4b8, uint8b128, float8_e4m3fn or "
|
||||
"float4_e2m1f when "
|
||||
"has_zp = False. Got = ",
|
||||
b_q_type.str());
|
||||
}
|
||||
|
||||
@ -854,9 +888,16 @@ torch::Tensor moe_wna16_marlin_gemm(
|
||||
|
||||
int dev = a.get_device();
|
||||
if (a.scalar_type() == at::ScalarType::Half) {
|
||||
void* scales_ptr;
|
||||
if (b_q_type == vllm::kFE2M1f) {
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
|
||||
} else {
|
||||
scales_ptr = b_scales.data_ptr<at::Half>();
|
||||
}
|
||||
|
||||
MARLIN_NAMESPACE_NAME::marlin_mm<half>(
|
||||
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
|
||||
c_tmp.data_ptr<float>(), b_scales.data_ptr<at::Half>(),
|
||||
c_tmp.data_ptr<float>(), scales_ptr, global_scale.data_ptr<at::Half>(),
|
||||
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
|
||||
a_tmp.data_ptr<at::Half>(), sorted_token_ids.data_ptr(),
|
||||
expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(),
|
||||
@ -866,11 +907,18 @@ torch::Tensor moe_wna16_marlin_gemm(
|
||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||
use_atomic_add, use_fp32_reduce, is_zp_float);
|
||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||
void* scales_ptr;
|
||||
if (b_q_type == vllm::kFE2M1f) {
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
|
||||
} else {
|
||||
scales_ptr = b_scales.data_ptr<at::BFloat16>();
|
||||
}
|
||||
|
||||
MARLIN_NAMESPACE_NAME::marlin_mm<nv_bfloat16>(
|
||||
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
||||
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
|
||||
b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(),
|
||||
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
|
||||
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(), scales_ptr,
|
||||
global_scale.data_ptr<at::BFloat16>(), b_zeros.data_ptr(),
|
||||
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
|
||||
sorted_token_ids.data_ptr(), expert_ids.data_ptr(),
|
||||
num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(),
|
||||
moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k,
|
||||
|
||||
@ -326,7 +326,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
}
|
||||
|
||||
if (use_global_memory) {
|
||||
VLLM_DISPATCH_INTEGRAL_TYPES(
|
||||
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
|
||||
topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
|
||||
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
|
||||
// tensors
|
||||
@ -351,7 +351,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
cumsum_buffer.data_ptr<int32_t>());
|
||||
});
|
||||
} else if (use_i16) {
|
||||
VLLM_DISPATCH_INTEGRAL_TYPES(
|
||||
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
|
||||
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||
// set dynamic shared mem
|
||||
auto kernel =
|
||||
@ -366,7 +366,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
topk_ids.numel());
|
||||
});
|
||||
} else {
|
||||
VLLM_DISPATCH_INTEGRAL_TYPES(
|
||||
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
|
||||
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||
auto kernel =
|
||||
vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>;
|
||||
@ -391,7 +391,7 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
TORCH_CHECK(num_experts == 256,
|
||||
"sgl_moe_align_block_size kernel only supports deepseek v3.");
|
||||
|
||||
VLLM_DISPATCH_INTEGRAL_TYPES(
|
||||
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
|
||||
topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] {
|
||||
// calc needed amount of shared mem for `cumsum` tensors
|
||||
auto options_int =
|
||||
|
||||
@ -108,9 +108,17 @@ __launch_bounds__(TPB) __global__
|
||||
}
|
||||
}
|
||||
|
||||
template <int TPB>
|
||||
__launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output,
|
||||
int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert)
|
||||
template <int TPB, typename IndType>
|
||||
__launch_bounds__(TPB) __global__ void moeTopK(
|
||||
const float* inputs_after_softmax,
|
||||
const bool* finished,
|
||||
float* output,
|
||||
IndType* indices,
|
||||
int* source_rows,
|
||||
const int num_experts,
|
||||
const int k,
|
||||
const int start_expert,
|
||||
const int end_expert)
|
||||
{
|
||||
|
||||
using cub_kvp = cub::KeyValuePair<int, float>;
|
||||
@ -182,9 +190,9 @@ __launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax
|
||||
2) This implementation assumes k is small, but will work for any k.
|
||||
*/
|
||||
|
||||
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
|
||||
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG, typename IndType>
|
||||
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
|
||||
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices,
|
||||
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, IndType* indices,
|
||||
int* source_rows, const int k, const int start_expert, const int end_expert)
|
||||
{
|
||||
// We begin by enforcing compile time assertions and setting up compile time constants.
|
||||
@ -397,8 +405,8 @@ struct TopkConstants
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <int EXPERTS, int WARPS_PER_TB>
|
||||
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices,
|
||||
template <int EXPERTS, int WARPS_PER_TB, typename IndType>
|
||||
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, IndType* indices,
|
||||
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
|
||||
{
|
||||
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
|
||||
@ -421,10 +429,11 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
|
||||
token_expert_indices, num_tokens, topk, 0, num_experts, \
|
||||
stream);
|
||||
|
||||
template <typename IndType>
|
||||
void topkGatingSoftmaxKernelLauncher(
|
||||
const float* gating_output,
|
||||
float* topk_weights,
|
||||
int* topk_indicies,
|
||||
IndType* topk_indicies,
|
||||
int* token_expert_indices,
|
||||
float* softmax_workspace,
|
||||
const int num_tokens,
|
||||
@ -493,14 +502,32 @@ void topk_softmax(
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options());
|
||||
vllm::moe::topkGatingSoftmaxKernelLauncher(
|
||||
gating_output.data_ptr<float>(),
|
||||
topk_weights.data_ptr<float>(),
|
||||
topk_indices.data_ptr<int>(),
|
||||
token_expert_indices.data_ptr<int>(),
|
||||
softmax_workspace.data_ptr<float>(),
|
||||
num_tokens,
|
||||
num_experts,
|
||||
topk,
|
||||
stream);
|
||||
|
||||
if(topk_indices.scalar_type() == at::ScalarType::Int)
|
||||
{
|
||||
vllm::moe::topkGatingSoftmaxKernelLauncher(
|
||||
gating_output.data_ptr<float>(),
|
||||
topk_weights.data_ptr<float>(),
|
||||
topk_indices.data_ptr<int>(),
|
||||
token_expert_indices.data_ptr<int>(),
|
||||
softmax_workspace.data_ptr<float>(),
|
||||
num_tokens,
|
||||
num_experts,
|
||||
topk,
|
||||
stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(topk_indices.scalar_type() == at::ScalarType::UInt32);
|
||||
vllm::moe::topkGatingSoftmaxKernelLauncher(
|
||||
gating_output.data_ptr<float>(),
|
||||
topk_weights.data_ptr<float>(),
|
||||
topk_indices.data_ptr<uint32_t>(),
|
||||
token_expert_indices.data_ptr<int>(),
|
||||
softmax_workspace.data_ptr<float>(),
|
||||
num_tokens,
|
||||
num_experts,
|
||||
topk,
|
||||
stream);
|
||||
}
|
||||
}
|
||||
|
||||
@ -44,7 +44,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
|
||||
m.def(
|
||||
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
|
||||
"Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none,"
|
||||
"Tensor! b_q_weight, Tensor! b_scales, Tensor? global_scale, Tensor? "
|
||||
"b_zeros_or_none,"
|
||||
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
|
||||
"Tensor sorted_token_ids,"
|
||||
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
|
||||
|
||||
37
csrc/ops.h
37
csrc/ops.h
@ -59,6 +59,31 @@ void merge_attn_states(torch::Tensor& output,
|
||||
const torch::Tensor& prefix_lse,
|
||||
const torch::Tensor& suffix_output,
|
||||
const torch::Tensor& suffix_lse);
|
||||
|
||||
void convert_vertical_slash_indexes(
|
||||
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
|
||||
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
|
||||
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
|
||||
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
|
||||
torch::Tensor q_seqlens, // [BATCH, ]
|
||||
torch::Tensor kv_seqlens, // [BATCH, ]
|
||||
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
|
||||
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
|
||||
int64_t context_size, int64_t block_size_M, int64_t block_size_N,
|
||||
bool causal);
|
||||
|
||||
void convert_vertical_slash_indexes_mergehead(
|
||||
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
|
||||
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
|
||||
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
|
||||
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
|
||||
torch::Tensor q_seqlens, // [BATCH, ]
|
||||
torch::Tensor kv_seqlens, // [BATCH, ]
|
||||
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
|
||||
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
|
||||
torch::Tensor vertical_indices_count, // [N_HEADS, ]
|
||||
torch::Tensor slash_indices_count, int64_t context_size,
|
||||
int64_t block_size_M, int64_t block_size_N, bool causal);
|
||||
#endif
|
||||
|
||||
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
||||
@ -208,6 +233,12 @@ void cutlass_moe_mm(
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides);
|
||||
|
||||
void cutlass_fp4_group_mm(
|
||||
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
|
||||
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets);
|
||||
|
||||
void get_cutlass_moe_mm_data(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
@ -235,6 +266,12 @@ std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a);
|
||||
void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
|
||||
torch::Tensor& output_scale,
|
||||
torch::Tensor const& input_scale);
|
||||
|
||||
void scaled_fp4_experts_quant(
|
||||
torch::Tensor& output, torch::Tensor& output_scale,
|
||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts);
|
||||
#endif
|
||||
|
||||
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
|
||||
@ -44,7 +44,8 @@ inline __device__ void apply_rotary_embedding(
|
||||
// head_size]
|
||||
const scalar_t* cache_ptr, const int head_size, const int num_heads,
|
||||
const int num_kv_heads, const int rot_dim, const int token_idx,
|
||||
const int64_t query_stride, const int64_t key_stride) {
|
||||
const int64_t query_stride, const int64_t key_stride,
|
||||
const int64_t head_stride) {
|
||||
const int embed_dim = rot_dim / 2;
|
||||
const scalar_t* cos_ptr = cache_ptr;
|
||||
const scalar_t* sin_ptr = cache_ptr + embed_dim;
|
||||
@ -52,7 +53,8 @@ inline __device__ void apply_rotary_embedding(
|
||||
const int nq = num_heads * embed_dim;
|
||||
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
|
||||
const int head_idx = i / embed_dim;
|
||||
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
|
||||
const int64_t token_head =
|
||||
token_idx * query_stride + head_idx * head_stride;
|
||||
const int rot_offset = i % embed_dim;
|
||||
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
|
||||
query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
|
||||
@ -62,7 +64,8 @@ inline __device__ void apply_rotary_embedding(
|
||||
const int nk = num_kv_heads * embed_dim;
|
||||
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
|
||||
const int head_idx = i / embed_dim;
|
||||
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
|
||||
const int64_t token_head =
|
||||
token_idx * key_stride + head_idx * head_stride;
|
||||
const int rot_offset = i % embed_dim;
|
||||
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
|
||||
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
|
||||
@ -84,7 +87,8 @@ __global__ void rotary_embedding_kernel(
|
||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
|
||||
// 2]
|
||||
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
||||
const int num_heads, const int num_kv_heads, const int head_size) {
|
||||
const int64_t head_stride, const int num_heads, const int num_kv_heads,
|
||||
const int head_size) {
|
||||
// Each thread block is responsible for one token.
|
||||
const int token_idx = blockIdx.x;
|
||||
int64_t pos = positions[token_idx];
|
||||
@ -92,7 +96,7 @@ __global__ void rotary_embedding_kernel(
|
||||
|
||||
apply_rotary_embedding<scalar_t, IS_NEOX>(
|
||||
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
|
||||
token_idx, query_stride, key_stride);
|
||||
token_idx, query_stride, key_stride, head_stride);
|
||||
}
|
||||
|
||||
template <typename scalar_t, bool IS_NEOX>
|
||||
@ -109,9 +113,9 @@ __global__ void batched_rotary_embedding_kernel(
|
||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
|
||||
// 2]
|
||||
const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len]
|
||||
// or [num_tokens]
|
||||
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
||||
const int num_heads, const int num_kv_heads, const int head_size) {
|
||||
const int64_t head_stride, const int num_heads, const int num_kv_heads,
|
||||
const int head_size) {
|
||||
// Each thread block is responsible for one token.
|
||||
const int token_idx = blockIdx.x;
|
||||
int64_t pos = positions[token_idx];
|
||||
@ -121,7 +125,7 @@ __global__ void batched_rotary_embedding_kernel(
|
||||
|
||||
apply_rotary_embedding<scalar_t, IS_NEOX>(
|
||||
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
|
||||
token_idx, query_stride, key_stride);
|
||||
token_idx, query_stride, key_stride, head_stride);
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@ -179,6 +183,12 @@ void rotary_embedding(
|
||||
int seq_dim_idx = positions_ndim - 1;
|
||||
int64_t query_stride = query.stride(seq_dim_idx);
|
||||
int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
|
||||
// Determine head stride: for [*, heads, head_size] use stride of last dim;
|
||||
// for flat [*, heads*head_size], heads blocks are contiguous of size
|
||||
// head_size
|
||||
int query_ndim = query.dim();
|
||||
int64_t head_stride =
|
||||
(query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size;
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
|
||||
@ -190,14 +200,14 @@ void rotary_embedding(
|
||||
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
|
||||
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride, key_stride,
|
||||
num_heads, num_kv_heads, head_size);
|
||||
head_stride, num_heads, num_kv_heads, head_size);
|
||||
} else {
|
||||
vllm::rotary_embedding_kernel<scalar_t, false>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
|
||||
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride,
|
||||
key_stride, num_heads, num_kv_heads, head_size);
|
||||
key_stride, head_stride, num_heads, num_kv_heads, head_size);
|
||||
}
|
||||
});
|
||||
}
|
||||
@ -263,6 +273,12 @@ void batched_rotary_embedding(
|
||||
int seq_dim_idx = positions_ndim - 1;
|
||||
int64_t query_stride = query.stride(seq_dim_idx);
|
||||
int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
|
||||
// Determine head stride: for [*, heads, head_size] use stride of last dim;
|
||||
// for flat [*, heads*head_size], heads blocks are contiguous of size
|
||||
// head_size
|
||||
int query_ndim = query.dim();
|
||||
int64_t head_stride =
|
||||
(query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size;
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
|
||||
@ -276,7 +292,7 @@ void batched_rotary_embedding(
|
||||
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
|
||||
cos_sin_cache.data_ptr<scalar_t>(),
|
||||
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
|
||||
key_stride, num_heads, num_kv_heads, head_size);
|
||||
key_stride, head_stride, num_heads, num_kv_heads, head_size);
|
||||
} else {
|
||||
vllm::batched_rotary_embedding_kernel<scalar_t, false>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
@ -284,7 +300,7 @@ void batched_rotary_embedding(
|
||||
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
|
||||
cos_sin_cache.data_ptr<scalar_t>(),
|
||||
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
|
||||
key_stride, num_heads, num_kv_heads, head_size);
|
||||
key_stride, head_stride, num_heads, num_kv_heads, head_size);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@ -112,7 +112,8 @@ __global__ void act_and_mul_quant_kernel(
|
||||
void silu_and_mul_quant(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input, // [..., 2 * d]
|
||||
torch::Tensor& scale) {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn ||
|
||||
out.dtype() == torch::kFloat8_e4m3fnuz);
|
||||
TORCH_CHECK(input.dtype() == torch::kFloat16 ||
|
||||
input.dtype() == torch::kBFloat16);
|
||||
TORCH_CHECK(input.size(-1) % 2 == 0);
|
||||
|
||||
@ -26,7 +26,13 @@ static inline __device__ int8_t float_to_int8_rn(float x) {
|
||||
float dst = std::nearbyint(x);
|
||||
|
||||
// saturate
|
||||
dst = std::clamp(dst, i8_min, i8_max);
|
||||
|
||||
// See https://github.com/pytorch/pytorch/issues/127666
|
||||
// See https://github.com/llvm/llvm-project/issues/95183
|
||||
// hip-clang std::clamp __glibcxx_assert_fail host function when building on
|
||||
// Arch/gcc14. The following replaces std::clamp usage with similar logic
|
||||
// dst = std::clamp(dst, i8_min, i8_max);
|
||||
dst = (dst < i8_min) ? i8_min : (dst > i8_max) ? i8_max : dst;
|
||||
return static_cast<int8_t>(dst);
|
||||
#else
|
||||
// CUDA path
|
||||
@ -79,7 +85,13 @@ static inline __device__ int8_t int32_to_int8(int32_t x) {
|
||||
static_cast<int32_t>(std::numeric_limits<int8_t>::max());
|
||||
|
||||
// saturate
|
||||
int32_t dst = std::clamp(x, i8_min, i8_max);
|
||||
|
||||
// See https://github.com/pytorch/pytorch/issues/127666
|
||||
// See https://github.com/llvm/llvm-project/issues/95183
|
||||
// hip-clang std::clamp __glibcxx_assert_fail host function when building on
|
||||
// Arch/gcc14. The following replaces std::clamp usage with similar logic
|
||||
// int32_t dst = std::clamp(x, i8_min, i8_max);
|
||||
int32_t dst = (x < i8_min) ? i8_min : (x > i8_max) ? i8_max : x;
|
||||
return static_cast<int8_t>(dst);
|
||||
#else
|
||||
// CUDA path
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
#include <torch/all.h>
|
||||
#include "cuda_utils.h"
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
|
||||
template <typename Fp8Func, typename Int8Func, typename BlockwiseFunc>
|
||||
void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
@ -28,29 +29,46 @@ void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
using GroupShape = std::array<int64_t, 2>;
|
||||
auto make_group_shape = [](torch::Tensor const& x,
|
||||
torch::Tensor const& s) -> GroupShape {
|
||||
TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D");
|
||||
return {cuda_utils::ceil_div(x.size(0), s.size(0)),
|
||||
cuda_utils::ceil_div(x.size(1), s.size(1))};
|
||||
};
|
||||
TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor.");
|
||||
TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor.");
|
||||
int32_t version_num = get_sm_version_num();
|
||||
if (version_num >= 100) {
|
||||
TORCH_CHECK(
|
||||
a.size(0) == a_scales.size(0) &&
|
||||
cuda_utils::ceil_div(a.size(1), int64_t(128)) == a_scales.size(1),
|
||||
"a_scale_group_shape must be [1, 128].");
|
||||
TORCH_CHECK(
|
||||
cuda_utils::ceil_div(b.size(0), int64_t(128)) == b_scales.size(0) &&
|
||||
cuda_utils::ceil_div(b.size(1), int64_t(128)) == b_scales.size(1),
|
||||
"b_scale_group_shape must be [128, 128].");
|
||||
} else {
|
||||
// TODO: Remove this after using cutlass sm90 blockwise scaling gemm
|
||||
// kernel, or introducing ceil_div to the load_init() of mainloop.
|
||||
using GroupShape = std::array<int64_t, 2>;
|
||||
auto make_group_shape = [](torch::Tensor const& x,
|
||||
torch::Tensor const& s) -> GroupShape {
|
||||
TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D");
|
||||
return {cuda_utils::ceil_div(x.size(0), s.size(0)),
|
||||
cuda_utils::ceil_div(x.size(1), s.size(1))};
|
||||
};
|
||||
|
||||
GroupShape a_scale_group_shape = make_group_shape(a, a_scales);
|
||||
GroupShape b_scale_group_shape = make_group_shape(b, b_scales);
|
||||
GroupShape a_scale_group_shape = make_group_shape(a, a_scales);
|
||||
GroupShape b_scale_group_shape = make_group_shape(b, b_scales);
|
||||
|
||||
// 1x128 per-token group scales for activations
|
||||
// 128x128 blockwise scales for weights
|
||||
TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} &&
|
||||
b_scale_group_shape == GroupShape{128, 128} &&
|
||||
a.dtype() == torch::kFloat8_e4m3fn &&
|
||||
b.dtype() == torch::kFloat8_e4m3fn),
|
||||
"cutlass_scaled_mm only supports datatype float8_e4m3fn.\n"
|
||||
"a_scale_group_shape must be [1, 128]. Got: [",
|
||||
a_scale_group_shape[0], ", ", a_scale_group_shape[1],
|
||||
"]\n"
|
||||
"b_scale_group_shape must be [128, 128]. Got: [",
|
||||
b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]");
|
||||
}
|
||||
|
||||
// 1x128 per-token group scales for activations
|
||||
// 128x128 blockwise scales for weights
|
||||
TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} &&
|
||||
b_scale_group_shape == GroupShape{128, 128} &&
|
||||
a.dtype() == torch::kFloat8_e4m3fn &&
|
||||
b.dtype() == torch::kFloat8_e4m3fn),
|
||||
"cutlass_scaled_mm only supports datatype float8_e4m3fn.\n"
|
||||
"a_scale_group_shape must be [1, 128]. Got: [",
|
||||
a_scale_group_shape[0], ", ", a_scale_group_shape[1],
|
||||
"]\n"
|
||||
"b_scale_group_shape must be [128, 128]. Got: [",
|
||||
b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]");
|
||||
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
|
||||
blockwise_func(c, a, b, a_scales, b_scales);
|
||||
}
|
||||
|
||||
@ -29,7 +29,8 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
#endif
|
||||
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
|
||||
void cutlass_moe_mm_sm90(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
@ -37,12 +38,6 @@ void cutlass_moe_mm_sm90(
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides);
|
||||
|
||||
void get_cutlass_moe_mm_data_caller(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
||||
const int64_t num_experts, const int64_t n, const int64_t k);
|
||||
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
|
||||
@ -53,6 +48,15 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_SCALED_MM_SM90) && ENABLE_SCALED_MM_SM90 || \
|
||||
defined(ENABLE_SCALED_MM_SM100) && ENABLE_SCALED_MM_SM100
|
||||
void get_cutlass_moe_mm_data_caller(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
||||
const int64_t num_experts, const int64_t n, const int64_t k);
|
||||
#endif
|
||||
|
||||
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
@ -224,7 +228,8 @@ void get_cutlass_moe_mm_data(
|
||||
// This function currently gets compiled only if we have a valid cutlass moe
|
||||
// mm to run it for.
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
|
||||
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
||||
(defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM90)
|
||||
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
|
||||
problem_sizes2, input_permutation,
|
||||
output_permutation, num_experts, n, k);
|
||||
|
||||
402
csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu
Normal file
402
csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu
Normal file
@ -0,0 +1,402 @@
|
||||
#include <torch/all.h>
|
||||
#include <cutlass/arch/arch.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include <cassert>
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename ElementAB, typename ElementC, typename ElementSF,
|
||||
typename ElementAccumulator, typename LayoutSFA, typename LayoutSFB,
|
||||
typename ScaleConfig>
|
||||
__global__ void __get_group_gemm_starts(
|
||||
ElementAB** a_offsets, ElementAB** b_offsets, ElementC** out_offsets,
|
||||
ElementSF** a_scales_offsets, ElementSF** b_scales_offsets,
|
||||
ElementAccumulator** alpha_offsets, LayoutSFA* layout_sfa_base_as_int,
|
||||
LayoutSFB* layout_sfb_base_as_int, ElementAB* a_base_as_int,
|
||||
ElementAB* b_base_as_int, ElementC* out_base_as_int,
|
||||
ElementSF* a_scales_base_as_int, ElementSF* b_scales_base_as_int,
|
||||
ElementAccumulator* alphas_base_as_int, const int32_t* expert_offsets,
|
||||
const int32_t* sf_offsets, const int32_t* problem_sizes_as_shapes,
|
||||
const int K, const int N) {
|
||||
int64_t expert_id = threadIdx.x;
|
||||
if (expert_id >= gridDim.x * blockDim.x) {
|
||||
return;
|
||||
}
|
||||
// Originally int32_t but upcasting to int64_t to avoid overflow
|
||||
// during offset calculations
|
||||
int64_t expert_offset = static_cast<int64_t>(expert_offsets[expert_id]);
|
||||
int64_t sf_offset = static_cast<int64_t>(sf_offsets[expert_id]);
|
||||
// size for block in block scale.
|
||||
int64_t group_size = 16;
|
||||
int64_t m = static_cast<int64_t>(problem_sizes_as_shapes[expert_id * 3]);
|
||||
int64_t n = static_cast<int64_t>(problem_sizes_as_shapes[expert_id * 3 + 1]);
|
||||
int64_t k = static_cast<int64_t>(problem_sizes_as_shapes[expert_id * 3 + 2]);
|
||||
assert((m >= 0 && n == N && k == K && k % 2 == 0) &&
|
||||
"unexpected problem sizes");
|
||||
|
||||
int64_t half_k = static_cast<int64_t>(k / 2);
|
||||
int64_t group_k = static_cast<int64_t>(k / group_size);
|
||||
// Shape of A as uint8/byte = [M, K // 2]
|
||||
// Shape of B as uint8/byte = [E, N, K // 2]
|
||||
a_offsets[expert_id] = a_base_as_int + expert_offset * half_k;
|
||||
|
||||
b_offsets[expert_id] = b_base_as_int + expert_id * n * half_k;
|
||||
// Shape of C = [M, N]
|
||||
out_offsets[expert_id] = out_base_as_int + expert_offset * n;
|
||||
// Shape of a_scale = [sum(sf_sizes), K // group_size]
|
||||
a_scales_offsets[expert_id] = a_scales_base_as_int + sf_offset * group_k;
|
||||
|
||||
assert((reinterpret_cast<uintptr_t>(a_scales_offsets[expert_id]) % 128) ==
|
||||
0 &&
|
||||
"TMA requires 128-byte alignment");
|
||||
|
||||
// Shape of B scale = [E, N, K // group_size]
|
||||
b_scales_offsets[expert_id] = b_scales_base_as_int + expert_id * n * group_k;
|
||||
assert((reinterpret_cast<uintptr_t>(b_scales_offsets[expert_id]) % 128) ==
|
||||
0 &&
|
||||
"TMA requires 128-byte alignment");
|
||||
// Shape of alpha = [E]
|
||||
alpha_offsets[expert_id] = alphas_base_as_int + expert_id;
|
||||
|
||||
LayoutSFA* layout_sfa_ptr = layout_sfa_base_as_int + expert_id;
|
||||
LayoutSFB* layout_sfb_ptr = layout_sfb_base_as_int + expert_id;
|
||||
|
||||
*layout_sfa_ptr = ScaleConfig::tile_atom_to_shape_SFA(cute::make_shape(
|
||||
static_cast<int>(m), static_cast<int>(n), static_cast<int>(k), 1));
|
||||
*layout_sfb_ptr = ScaleConfig::tile_atom_to_shape_SFB(cute::make_shape(
|
||||
static_cast<int>(m), static_cast<int>(n), static_cast<int>(k), 1));
|
||||
}
|
||||
|
||||
#define __CALL_GET_STARTS_KERNEL_BLOCKSCALE(ELEMENT_AB_TYPE, SF_TYPE, \
|
||||
TENSOR_C_TYPE, C_TYPE, LayoutSFA, \
|
||||
LayoutSFB, ScaleConfig) \
|
||||
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
|
||||
__get_group_gemm_starts<ELEMENT_AB_TYPE, C_TYPE, SF_TYPE, float, \
|
||||
LayoutSFA, LayoutSFB, ScaleConfig> \
|
||||
<<<1, num_experts, 0, stream>>>( \
|
||||
static_cast<ELEMENT_AB_TYPE**>(a_starts.data_ptr()), \
|
||||
static_cast<ELEMENT_AB_TYPE**>(b_starts.data_ptr()), \
|
||||
static_cast<C_TYPE**>(out_starts.data_ptr()), \
|
||||
static_cast<SF_TYPE**>(a_scales_starts.data_ptr()), \
|
||||
static_cast<SF_TYPE**>(b_scales_starts.data_ptr()), \
|
||||
static_cast<float**>(alpha_starts.data_ptr()), \
|
||||
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()), \
|
||||
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr()), \
|
||||
static_cast<ELEMENT_AB_TYPE*>(a_tensors.data_ptr()), \
|
||||
static_cast<ELEMENT_AB_TYPE*>(b_tensors.data_ptr()), \
|
||||
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
|
||||
static_cast<SF_TYPE*>(a_scales.data_ptr()), \
|
||||
static_cast<SF_TYPE*>(b_scales.data_ptr()), \
|
||||
static_cast<float*>(alphas.data_ptr()), \
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()), \
|
||||
static_cast<int32_t*>(sf_offsets.data_ptr()), \
|
||||
static_cast<int32_t*>(problem_sizes.data_ptr()), K, N); \
|
||||
}
|
||||
|
||||
template <typename LayoutSFA, typename LayoutSFB, typename ScaleConfig>
|
||||
void run_get_group_gemm_starts(
|
||||
const torch::Tensor& a_starts, const torch::Tensor& b_starts,
|
||||
const torch::Tensor& out_starts, const torch::Tensor& a_scales_starts,
|
||||
const torch::Tensor& b_scales_starts, const torch::Tensor& alpha_starts,
|
||||
const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb,
|
||||
/*these are used for their base addresses*/
|
||||
torch::Tensor const& a_tensors, torch::Tensor const& b_tensors,
|
||||
torch::Tensor const& out_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& alphas,
|
||||
torch::Tensor const& expert_offsets, torch::Tensor const& sf_offsets,
|
||||
torch::Tensor const& problem_sizes, int M, int N, int K) {
|
||||
int num_experts = (int)expert_offsets.size(0);
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
|
||||
|
||||
TORCH_CHECK(out_tensors.size(1) == N,
|
||||
"Output tensor shape doesn't match expected shape");
|
||||
TORCH_CHECK(K / 2 == b_tensors.size(2),
|
||||
"b_tensors(dim = 2) and a_tensors(dim = 1) trailing"
|
||||
" dimension must match");
|
||||
if (false) {
|
||||
}
|
||||
//(ELEMENT_AB_TYPE, BS_TYPE, TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB,
|
||||
// ScaleConfig)
|
||||
__CALL_GET_STARTS_KERNEL_BLOCKSCALE(
|
||||
cutlass::float_e2m1_t, cutlass::float_ue4m3_t, torch::kBFloat16,
|
||||
cutlass::bfloat16_t, LayoutSFA, LayoutSFB, ScaleConfig)
|
||||
__CALL_GET_STARTS_KERNEL_BLOCKSCALE(cutlass::float_e2m1_t,
|
||||
cutlass::float_ue4m3_t, torch::kFloat16,
|
||||
half, LayoutSFA, LayoutSFB, ScaleConfig)
|
||||
else {
|
||||
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void run_fp4_blockwise_scaled_group_mm(
|
||||
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
|
||||
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M,
|
||||
int N, int K) {
|
||||
using ProblemShape =
|
||||
cutlass::gemm::GroupProblemShape<Shape<int32_t, int32_t, int32_t>>;
|
||||
using ElementType = cutlass::float_e2m1_t;
|
||||
using ElementSFType = cutlass::float_ue4m3_t;
|
||||
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
|
||||
using ElementB = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
|
||||
|
||||
using ElementC = OutType;
|
||||
using ElementD = ElementC;
|
||||
using ElementAccumulator = float;
|
||||
// Layout definitions
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = LayoutC;
|
||||
|
||||
// Alignment constraints
|
||||
static constexpr int AlignmentA = 32;
|
||||
static constexpr int AlignmentB = 32;
|
||||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
// Architecture definitions
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using EpilogueOperatorClass =
|
||||
cutlass::arch::OpClassTensorOp; // Epilogue Operator class tag
|
||||
using MainloopOperatorClass =
|
||||
cutlass::arch::OpClassBlockScaledTensorOp; // Mainloop Operator class tag
|
||||
using StageCountType =
|
||||
cutlass::gemm::collective::StageCountAuto; // Stage count maximized based
|
||||
// on the tile size
|
||||
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
struct MMA1SMConfig {
|
||||
using MmaTileShape = Shape<_128, _128, _128>;
|
||||
using KernelSchedule = cutlass::gemm::
|
||||
KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100; // Kernel to launch
|
||||
using EpilogueSchedule =
|
||||
cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch
|
||||
};
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, EpilogueOperatorClass, typename MMA1SMConfig::MmaTileShape,
|
||||
ClusterShape, Shape<_128, _64>, ElementAccumulator,
|
||||
ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementD,
|
||||
LayoutC*, AlignmentD,
|
||||
typename MMA1SMConfig::EpilogueSchedule>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, MainloopOperatorClass, ElementA, LayoutA*, AlignmentA,
|
||||
ElementB, LayoutB*, AlignmentB, ElementAccumulator,
|
||||
typename MMA1SMConfig::MmaTileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
typename MMA1SMConfig::KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel =
|
||||
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop,
|
||||
CollectiveEpilogue>;
|
||||
|
||||
using Gemm1SM = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
using Gemm = Gemm1SM;
|
||||
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
|
||||
|
||||
using LayoutSFA =
|
||||
typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA;
|
||||
using LayoutSFB =
|
||||
typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB;
|
||||
using ScaleConfig =
|
||||
typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
|
||||
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
|
||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||
auto options_int =
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device());
|
||||
|
||||
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int);
|
||||
torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int);
|
||||
torch::Tensor c_strides1 =
|
||||
torch::full({num_experts}, output.stride(0), options_int);
|
||||
torch::Tensor a_strides1 =
|
||||
torch::full({num_experts}, a.stride(0) * 2, options_int);
|
||||
torch::Tensor b_strides1 =
|
||||
torch::full({num_experts}, b.stride(1) * 2, options_int);
|
||||
|
||||
run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>(
|
||||
a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs,
|
||||
layout_sfa, layout_sfb, a, b, output, a_blockscale, b_blockscales, alphas,
|
||||
expert_offsets, sf_offsets, problem_sizes, M, N, K);
|
||||
|
||||
// Create an instance of the GEMM
|
||||
Gemm gemm_op;
|
||||
|
||||
// Initialize problem_sizes_as_shapes correctly
|
||||
UnderlyingProblemShape* problem_sizes_as_shapes =
|
||||
static_cast<UnderlyingProblemShape*>(problem_sizes.data_ptr());
|
||||
|
||||
// Set the Scheduler info
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::
|
||||
PersistentTileSchedulerSm100GroupParams<
|
||||
typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions;
|
||||
typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
|
||||
scheduler.raster_order = RasterOrderOptions::AlongM;
|
||||
hw_info.device_id = a.get_device();
|
||||
static std::unordered_map<int, int> cached_sm_counts;
|
||||
if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) {
|
||||
cached_sm_counts[hw_info.device_id] =
|
||||
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
|
||||
hw_info.device_id);
|
||||
}
|
||||
hw_info.sm_count = min(cached_sm_counts[hw_info.device_id], INT_MAX);
|
||||
|
||||
// Mainloop Arguments
|
||||
typename GemmKernel::MainloopArguments mainloop_args{
|
||||
static_cast<const ElementType**>(a_ptrs.data_ptr()),
|
||||
static_cast<StrideA*>(a_strides1.data_ptr()),
|
||||
static_cast<const ElementType**>(b_ptrs.data_ptr()),
|
||||
static_cast<StrideB*>(b_strides1.data_ptr()),
|
||||
static_cast<const ElementSFType**>(a_scales_ptrs.data_ptr()),
|
||||
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()),
|
||||
static_cast<const ElementSFType**>(b_scales_ptrs.data_ptr()),
|
||||
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr())};
|
||||
|
||||
// Epilogue Arguments
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
{}, // epilogue.thread
|
||||
nullptr,
|
||||
static_cast<StrideC*>(c_strides1.data_ptr()),
|
||||
static_cast<ElementD**>(out_ptrs.data_ptr()),
|
||||
static_cast<StrideC*>(c_strides1.data_ptr())};
|
||||
auto& fusion_args = epilogue_args.thread;
|
||||
fusion_args.alpha_ptr_array =
|
||||
reinterpret_cast<float**>(alpha_ptrs.data_ptr());
|
||||
fusion_args.dAlpha = {_0{}, _0{}, 1};
|
||||
|
||||
// Gemm Arguments
|
||||
typename GemmKernel::Arguments args{
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{num_experts, problem_sizes_as_shapes, nullptr},
|
||||
mainloop_args,
|
||||
epilogue_args,
|
||||
hw_info,
|
||||
scheduler};
|
||||
|
||||
size_t workspace_size = Gemm::get_workspace_size(args);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||
|
||||
auto can_implement_status = gemm_op.can_implement(args);
|
||||
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
|
||||
"Failed to implement GEMM");
|
||||
|
||||
// Run the GEMM
|
||||
auto status = gemm_op.initialize(args, workspace.data_ptr());
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
|
||||
|
||||
status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
|
||||
}
|
||||
|
||||
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
|
||||
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
|
||||
|
||||
#define CHECK_TYPE(x, st, m) \
|
||||
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
|
||||
#define CHECK_TH_CUDA(x, m) \
|
||||
TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.")
|
||||
#define CHECK_CONTIGUOUS(x, m) \
|
||||
TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.")
|
||||
#define CHECK_INPUT(x, st, m) \
|
||||
CHECK_TH_CUDA(x, m); \
|
||||
CHECK_CONTIGUOUS(x, m); \
|
||||
CHECK_TYPE(x, st, m)
|
||||
|
||||
void cutlass_fp4_group_mm(
|
||||
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
|
||||
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets) {
|
||||
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
|
||||
// Input validation
|
||||
CHECK_INPUT(a, FLOAT4_E2M1X2, "a");
|
||||
CHECK_INPUT(b, FLOAT4_E2M1X2, "b");
|
||||
CHECK_INPUT(a_blockscale, SF_DTYPE, "a_blockscale");
|
||||
CHECK_INPUT(b_blockscales, SF_DTYPE, "b_blockscales");
|
||||
CHECK_INPUT(alphas, at::ScalarType::Float, "alphas");
|
||||
|
||||
TORCH_CHECK(a_blockscale.dim() == 2,
|
||||
"expected a_blockscale to be of shape [num_experts, rounded_m,"
|
||||
" k // group_size], observed rank: ",
|
||||
a_blockscale.dim())
|
||||
TORCH_CHECK(b_blockscales.dim() == 3,
|
||||
"expected b_blockscale to be of shape: "
|
||||
" [num_experts, n, k // group_size], observed rank: ",
|
||||
b_blockscales.dim())
|
||||
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be a 2D tensor");
|
||||
TORCH_CHECK(problem_sizes.size(1) == 3,
|
||||
"problem_sizes must have the shape (num_experts, 3)");
|
||||
TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0),
|
||||
"Number of experts in problem_sizes must match expert_offsets");
|
||||
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32,
|
||||
"problem_sizes must be int32.");
|
||||
|
||||
int M = static_cast<int>(a.size(0));
|
||||
int N = static_cast<int>(b.size(1));
|
||||
int E = static_cast<int>(b.size(0));
|
||||
int K = static_cast<int>(2 * b.size(2));
|
||||
|
||||
if (output.scalar_type() == torch::kBFloat16) {
|
||||
run_fp4_blockwise_scaled_group_mm<cutlass::bfloat16_t>(
|
||||
output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
|
||||
expert_offsets, sf_offsets, M, N, K);
|
||||
} else {
|
||||
run_fp4_blockwise_scaled_group_mm<cutlass::half_t>(
|
||||
output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
|
||||
expert_offsets, sf_offsets, M, N, K);
|
||||
}
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled cutlass_fp4_group_mm kernel, vLLM must "
|
||||
"be compiled with ENABLE_NVFP4 for SM100+ and CUDA "
|
||||
"12.8 or above.");
|
||||
#endif
|
||||
}
|
||||
404
csrc/quantization/fp4/nvfp4_experts_quant.cu
Normal file
404
csrc/quantization/fp4/nvfp4_experts_quant.cu
Normal file
@ -0,0 +1,404 @@
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp8.h>
|
||||
|
||||
template <typename T>
|
||||
struct TypeConverter {
|
||||
using Type = half2;
|
||||
}; // keep for generality
|
||||
|
||||
template <>
|
||||
struct TypeConverter<half2> {
|
||||
using Type = half;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeConverter<half> {
|
||||
using Type = half2;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeConverter<__nv_bfloat162> {
|
||||
using Type = __nv_bfloat16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeConverter<__nv_bfloat16> {
|
||||
using Type = __nv_bfloat162;
|
||||
};
|
||||
|
||||
#define ELTS_PER_THREAD 8
|
||||
|
||||
constexpr int CVT_FP4_ELTS_PER_THREAD = 8;
|
||||
constexpr int CVT_FP4_SF_VEC_SIZE = 16;
|
||||
|
||||
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
|
||||
inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
uint32_t val;
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg .b8 byte0;\n"
|
||||
".reg .b8 byte1;\n"
|
||||
".reg .b8 byte2;\n"
|
||||
".reg .b8 byte3;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
|
||||
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
|
||||
"}"
|
||||
: "=r"(val)
|
||||
: "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]),
|
||||
"f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7]));
|
||||
return val;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
|
||||
inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
uint32_t val;
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg .b8 byte0;\n"
|
||||
".reg .b8 byte1;\n"
|
||||
".reg .b8 byte2;\n"
|
||||
".reg .b8 byte3;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
|
||||
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
|
||||
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
|
||||
"}"
|
||||
: "=r"(val)
|
||||
: "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y),
|
||||
"f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y));
|
||||
return val;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Fast reciprocal.
|
||||
inline __device__ float reciprocal_approximate_ftz(float a) {
|
||||
float b;
|
||||
asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a));
|
||||
return b;
|
||||
}
|
||||
|
||||
template <class SFType, int CVT_FP4_NUM_THREADS_PER_SF>
|
||||
__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx,
|
||||
int numCols,
|
||||
SFType* SFout) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 ||
|
||||
CVT_FP4_NUM_THREADS_PER_SF == 2);
|
||||
|
||||
// One pair of threads write one SF to global memory.
|
||||
// TODO: stage through smem for packed STG.32
|
||||
// is it better than STG.8 from 4 threads ?
|
||||
if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) {
|
||||
// SF vector index (16 elements share one SF in the K dimension).
|
||||
int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;
|
||||
int32_t mIdx = rowIdx;
|
||||
|
||||
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
|
||||
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
|
||||
|
||||
int32_t mTileIdx = mIdx / (32 * 4);
|
||||
// SF vector size 16.
|
||||
int factor = CVT_FP4_SF_VEC_SIZE * 4;
|
||||
int32_t numKTiles = (numCols + factor - 1) / factor;
|
||||
int64_t mTileStride = numKTiles * 32 * 4 * 4;
|
||||
|
||||
int32_t kTileIdx = (kIdx / 4);
|
||||
int64_t kTileStride = 32 * 4 * 4;
|
||||
|
||||
// M tile layout [32, 4] is column-major.
|
||||
int32_t outerMIdx = (mIdx % 32);
|
||||
int64_t outerMStride = 4 * 4;
|
||||
|
||||
int32_t innerMIdx = (mIdx % (32 * 4)) / 32;
|
||||
int64_t innerMStride = 4;
|
||||
|
||||
int32_t innerKIdx = (kIdx % 4);
|
||||
int64_t innerKStride = 1;
|
||||
|
||||
// Compute the global offset.
|
||||
int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride +
|
||||
outerMIdx * outerMStride + innerMIdx * innerMStride +
|
||||
innerKIdx * innerKStride;
|
||||
|
||||
return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
|
||||
}
|
||||
#endif
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Define a 16 bytes packed data type.
|
||||
template <class Type>
|
||||
struct PackedVec {
|
||||
typename TypeConverter<Type>::Type elts[4];
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PackedVec<__nv_fp8_e4m3> {
|
||||
__nv_fp8x2_e4m3 elts[8];
|
||||
};
|
||||
|
||||
// Quantizes the provided PackedVec into the uint32_t output
|
||||
template <class Type, bool UE8M0_SF = false>
|
||||
__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
|
||||
uint8_t* SFout) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
// Get absolute maximum values among the local 8 values.
|
||||
auto localMax = __habs2(vec.elts[0]);
|
||||
|
||||
// Local maximum value.
|
||||
#pragma unroll
|
||||
for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
|
||||
localMax = __hmax2(localMax, __habs2(vec.elts[i]));
|
||||
}
|
||||
|
||||
// Get the absolute maximum among all 16 values (two threads).
|
||||
localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax);
|
||||
// Get the final absolute maximum values.
|
||||
float vecMax = float(__hmax(localMax.x, localMax.y));
|
||||
|
||||
// Get the SF (max value of the vector / max value of e2m1).
|
||||
// maximum value of e2m1 = 6.0.
|
||||
// TODO: use half as compute data type.
|
||||
float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f));
|
||||
// 8 bits representation of the SF.
|
||||
uint8_t fp8SFVal;
|
||||
// Write the SF to global memory (STG.8).
|
||||
if constexpr (UE8M0_SF) {
|
||||
// Extract the 8 exponent bits from float32.
|
||||
// float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits.
|
||||
uint32_t tmp = reinterpret_cast<uint32_t&>(SFValue) >> 23;
|
||||
fp8SFVal = tmp & 0xff;
|
||||
// Convert back to fp32.
|
||||
reinterpret_cast<uint32_t&>(SFValue) = tmp << 23;
|
||||
} else {
|
||||
// Here SFValue is always positive, so E4M3 is the same as UE4M3.
|
||||
__nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue);
|
||||
reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp;
|
||||
// Convert back to fp32.
|
||||
SFValue = float(tmp);
|
||||
}
|
||||
// Get the output scale.
|
||||
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *
|
||||
// reciprocal(SFScaleVal))
|
||||
float outputScale =
|
||||
SFValue != 0 ? reciprocal_approximate_ftz(
|
||||
SFValue * reciprocal_approximate_ftz(SFScaleVal))
|
||||
: 0.0f;
|
||||
|
||||
if (SFout) {
|
||||
// Write the SF to global memory (STG.8).
|
||||
*SFout = fp8SFVal;
|
||||
}
|
||||
|
||||
// Convert the input to float.
|
||||
float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
|
||||
if constexpr (std::is_same_v<Type, half>) {
|
||||
fp2Vals[i] = __half22float2(vec.elts[i]);
|
||||
} else {
|
||||
fp2Vals[i] = __bfloat1622float2(vec.elts[i]);
|
||||
}
|
||||
fp2Vals[i].x *= outputScale;
|
||||
fp2Vals[i].y *= outputScale;
|
||||
}
|
||||
|
||||
// Convert to e2m1 values.
|
||||
uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals);
|
||||
|
||||
// Write the e2m1 values to global memory.
|
||||
return e2m1Vec;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Use UE4M3 by default.
|
||||
template <class Type, bool UE8M0_SF = false>
|
||||
__global__ void
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
__launch_bounds__(512, 4) cvt_fp16_to_fp4(
|
||||
#else
|
||||
cvt_fp16_to_fp4(
|
||||
#endif
|
||||
int32_t numRows, int32_t numCols, Type const* in, float const* SFScale,
|
||||
uint32_t* out, uint32_t* SFout, uint32_t* input_offset_by_experts,
|
||||
uint32_t* output_scale_offset_by_experts, int n_experts) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
using PackedVec = PackedVec<Type>;
|
||||
static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
|
||||
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
|
||||
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
|
||||
"Vec size is not matched.");
|
||||
|
||||
// Input tensor row/col loops.
|
||||
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) {
|
||||
for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD;
|
||||
colIdx += blockDim.x) {
|
||||
int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx;
|
||||
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
|
||||
// Get the output tensor offset.
|
||||
// Same as inOffset because 8 elements are packed into one uint32_t.
|
||||
int64_t outOffset = inOffset;
|
||||
auto& out_pos = out[outOffset];
|
||||
|
||||
// Find index within the experts.
|
||||
int rowIdx_in_expert = 0;
|
||||
int expert_idx = 0;
|
||||
for (int i = 0; i < n_experts; i++) {
|
||||
if (rowIdx >= input_offset_by_experts[i] &&
|
||||
rowIdx < input_offset_by_experts[i + 1]) {
|
||||
rowIdx_in_expert = rowIdx - input_offset_by_experts[i];
|
||||
expert_idx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Get the global scaling factor, which will be applied to the SF.
|
||||
// Note SFScale is the same as next GEMM's alpha, which is
|
||||
// (448.f / (Alpha_A / 6.f)).
|
||||
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];
|
||||
|
||||
int factor = CVT_FP4_SF_VEC_SIZE * 4;
|
||||
// The actual output_scales dim is computed from the padded numCols.
|
||||
int32_t numCols_padded = (numCols + factor - 1) / factor * factor;
|
||||
int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4;
|
||||
uint32_t* SFout_in_expert =
|
||||
SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout;
|
||||
|
||||
auto sf_out =
|
||||
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
|
||||
CVT_FP4_NUM_THREADS_PER_SF>(
|
||||
rowIdx_in_expert, colIdx, numCols, SFout_in_expert);
|
||||
|
||||
out_pos =
|
||||
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void quant_impl(void* output, void* output_scale, void* input,
|
||||
void* input_global_scale, void* input_offset_by_experts,
|
||||
void* output_scale_offset_by_experts, int m_topk, int k,
|
||||
int n_experts, cudaStream_t stream) {
|
||||
// TODO: this multiProcessorCount should be cached.
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int multiProcessorCount;
|
||||
cudaDeviceGetAttribute(&multiProcessorCount, cudaDevAttrMultiProcessorCount,
|
||||
device);
|
||||
|
||||
// Grid, Block size.
|
||||
// Each thread converts 8 values.
|
||||
dim3 block(std::min(int(k / ELTS_PER_THREAD), 512));
|
||||
// Get number of blocks per SM (assume we can fully utilize the SM).
|
||||
int const numBlocksPerSM = 2048 / block.x;
|
||||
dim3 grid(std::min(int(m_topk), multiProcessorCount * numBlocksPerSM));
|
||||
|
||||
cvt_fp16_to_fp4<T, false><<<grid, block, 0, stream>>>(
|
||||
m_topk, k, reinterpret_cast<T*>(input),
|
||||
reinterpret_cast<float*>(input_global_scale),
|
||||
reinterpret_cast<uint32_t*>(output),
|
||||
reinterpret_cast<uint32_t*>(output_scale),
|
||||
reinterpret_cast<uint32_t*>(input_offset_by_experts),
|
||||
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts), n_experts);
|
||||
}
|
||||
|
||||
/*Quantization entry for fp4 experts quantization*/
|
||||
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x, m) \
|
||||
TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
|
||||
#define CHECK_INPUT(x, m) \
|
||||
CHECK_TH_CUDA(x, m); \
|
||||
CHECK_CONTIGUOUS(x, m);
|
||||
|
||||
constexpr auto HALF = at::ScalarType::Half;
|
||||
constexpr auto BF16 = at::ScalarType::BFloat16;
|
||||
constexpr auto FLOAT = at::ScalarType::Float;
|
||||
constexpr auto INT = at::ScalarType::Int;
|
||||
constexpr auto UINT8 = at::ScalarType::Byte;
|
||||
|
||||
void scaled_fp4_experts_quant_sm100a(
|
||||
torch::Tensor& output, torch::Tensor& output_scale,
|
||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts) {
|
||||
CHECK_INPUT(output, "output must be a CUDA tensor");
|
||||
CHECK_INPUT(output_scale, "output_scale must be a CUDA tensor");
|
||||
CHECK_INPUT(input, "input must be a CUDA tensor");
|
||||
CHECK_INPUT(input_global_scale, "input_global_scale must be a CUDA tensor");
|
||||
CHECK_INPUT(input_offset_by_experts,
|
||||
"input_offset_by_experts must be a CUDA tensor");
|
||||
CHECK_INPUT(output_scale_offset_by_experts,
|
||||
"output_scale_offset_by_experts must be a CUDA tensor");
|
||||
|
||||
TORCH_CHECK(output.dim() == 2);
|
||||
TORCH_CHECK(output_scale.dim() == 2);
|
||||
TORCH_CHECK(input.dim() == 2);
|
||||
TORCH_CHECK(input_global_scale.dim() == 1);
|
||||
TORCH_CHECK(input_offset_by_experts.dim() == 1);
|
||||
TORCH_CHECK(output_scale_offset_by_experts.dim() == 1);
|
||||
|
||||
TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16);
|
||||
TORCH_CHECK(input_global_scale.scalar_type() == FLOAT);
|
||||
TORCH_CHECK(input_offset_by_experts.scalar_type() == INT);
|
||||
TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT);
|
||||
// output is uint8 (two nvfp4 values are packed into one uint8)
|
||||
// output_scale is int32 (four fp8 values are packed into one int32)
|
||||
TORCH_CHECK(output.scalar_type() == UINT8);
|
||||
TORCH_CHECK(output_scale.scalar_type() == INT);
|
||||
|
||||
const int BLOCK_SIZE = 16;
|
||||
auto m_topk = input.size(0);
|
||||
auto k = input.size(1);
|
||||
TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16");
|
||||
auto n_experts = input_global_scale.size(0);
|
||||
TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1);
|
||||
TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1);
|
||||
TORCH_CHECK(output.size(0) == m_topk);
|
||||
TORCH_CHECK(output.size(1) == k / 2);
|
||||
int scales_k = k / BLOCK_SIZE;
|
||||
// 4 means the swizzle requirement by nvidia nvfp4.
|
||||
int padded_k = (scales_k + (4 - 1)) / 4 * 4;
|
||||
// 4 means 4 fp8 values are packed into one int32
|
||||
TORCH_CHECK(output_scale.size(1) * 4 == padded_k);
|
||||
|
||||
auto in_dtype = input.dtype();
|
||||
at::cuda::CUDAGuard device_guard{(char)input.get_device()};
|
||||
const cudaStream_t stream =
|
||||
at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
if (in_dtype == at::ScalarType::Half) {
|
||||
quant_impl<half>(output.data_ptr(), output_scale.data_ptr(),
|
||||
input.data_ptr(), input_global_scale.data_ptr(),
|
||||
input_offset_by_experts.data_ptr(),
|
||||
output_scale_offset_by_experts.data_ptr(), m_topk, k,
|
||||
n_experts, stream);
|
||||
} else if (in_dtype == at::ScalarType::BFloat16) {
|
||||
quant_impl<__nv_bfloat16>(output.data_ptr(), output_scale.data_ptr(),
|
||||
input.data_ptr(), input_global_scale.data_ptr(),
|
||||
input_offset_by_experts.data_ptr(),
|
||||
output_scale_offset_by_experts.data_ptr(), m_topk,
|
||||
k, n_experts, stream);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Expected input data type to be half or bfloat16");
|
||||
}
|
||||
}
|
||||
@ -23,10 +23,32 @@ void scaled_fp4_quant_sm100a(torch::Tensor const& output,
|
||||
torch::Tensor const& input_sf);
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
|
||||
void scaled_fp4_experts_quant_sm100a(
|
||||
torch::Tensor& output, torch::Tensor& output_scale,
|
||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts);
|
||||
#endif
|
||||
|
||||
void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
|
||||
torch::Tensor& output_sf, torch::Tensor const& input_sf) {
|
||||
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
|
||||
return scaled_fp4_quant_sm100a(output, input, output_sf, input_sf);
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization");
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization kernel");
|
||||
}
|
||||
|
||||
void scaled_fp4_experts_quant(
|
||||
torch::Tensor& output, torch::Tensor& output_scale,
|
||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts) {
|
||||
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
|
||||
return scaled_fp4_experts_quant_sm100a(
|
||||
output, output_scale, input, input_global_scale, input_offset_by_experts,
|
||||
output_scale_offset_by_experts);
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||
"No compiled nvfp4 experts quantization kernel");
|
||||
}
|
||||
|
||||
@ -21,7 +21,13 @@ static __device__ __forceinline__ int8_t float_to_int8_rn(float const x) {
|
||||
// round
|
||||
float dst = std::nearbyint(x);
|
||||
// saturate
|
||||
dst = std::clamp(dst, i8_min, i8_max);
|
||||
|
||||
// See https://github.com/pytorch/pytorch/issues/127666
|
||||
// See https://github.com/llvm/llvm-project/issues/95183
|
||||
// hip-clang std::clamp __glibcxx_assert_fail host function when building on
|
||||
// Arch/gcc14. The following replaces std::clamp usage with similar logic
|
||||
// dst = std::clamp(dst, i8_min, i8_max);
|
||||
dst = (dst < i8_min) ? i8_min : (dst > i8_max) ? i8_max : dst;
|
||||
return static_cast<int8_t>(dst);
|
||||
#else
|
||||
// CUDA path
|
||||
|
||||
@ -1,3 +1,67 @@
|
||||
/*
|
||||
Fast Dequantization (Converting INT4/INT8/FP4/FP8 to FP16/BF16)
|
||||
|
||||
The process of fast dequantization can be summarized as a combination
|
||||
of bitwise operations and floating-point computations:
|
||||
|
||||
weight =>(bit_op / bitwise operations)=>
|
||||
f16_value =>(flop / floating-point computation)=>
|
||||
dequantized_weight
|
||||
|
||||
Since the dequantized weights typically require subtracting the zero point and
|
||||
applying a scale factor, the floating-point computation step can be fused with
|
||||
the zero-point subtraction and scaling operations.
|
||||
|
||||
The following are the parts that need to be modified for the fused operation
|
||||
of zero-point subtraction and scaling.
|
||||
|
||||
## INT4 => FP16/BF16 or INT8 => FP16
|
||||
|
||||
The floating-point computation is `__hsub2`
|
||||
|
||||
If has zero points:
|
||||
|
||||
flop(bit_op(weight)) - flop(bit_op(zp))
|
||||
= sub(bit_op(weight), bias) - sub(bit_op(zp), bias)
|
||||
= bit_op(weight) - bit_op(zp)
|
||||
|
||||
so we don't need additional modification.
|
||||
|
||||
If has float zero points:
|
||||
|
||||
flop(bit_op(weight)) - fzp
|
||||
= sub(bit_op(weight), bias) - fzp
|
||||
= bit_op(weight) - (fzp + bias)
|
||||
|
||||
where the `fzp + bias` can be computed at weight loading. But this
|
||||
may have accuracy issue, so we should not use this in most cases.
|
||||
|
||||
If has not zero points:
|
||||
|
||||
scale(flop(bit_op(weight)))
|
||||
= scale(sub(bit_op(weight), bias))
|
||||
= scale(bit_op(weight)) - scale(bias)
|
||||
= fma(bit_op(weight), scale_factor, scale(bias))
|
||||
|
||||
where the `scale(bias)` can be cached. But this may have accuracy issue,
|
||||
so we should not use this in most cases.
|
||||
|
||||
|
||||
## INT8 => BF16
|
||||
|
||||
INT8 => BF16 is a special case, it use byte_perm instead of flop.
|
||||
We cannot fused byte_perm with scaling.
|
||||
|
||||
|
||||
## FP4/FP8 => FP16/BF16
|
||||
|
||||
scale(flop(bit_op(weight)))
|
||||
= scale(mul(bit_op(weight), multiplier))
|
||||
= mul(bit_op(weight), scale_factor * multiplier)
|
||||
|
||||
where `scale_factor * multiplier` can be computed at weight loading.
|
||||
|
||||
*/
|
||||
|
||||
#include "marlin_dtypes.cuh"
|
||||
|
||||
@ -27,7 +91,8 @@ __device__ inline uint32_t prmt(uint32_t a) {
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename scalar_t2, vllm::ScalarTypeId w_type_id>
|
||||
template <typename scalar_t2, vllm::ScalarTypeId w_type_id,
|
||||
bool skip_flop = false>
|
||||
__device__ inline void dequant(int q, scalar_t2* frag_b);
|
||||
|
||||
//
|
||||
@ -40,7 +105,22 @@ __device__ inline void dequant(int q, scalar_t2* frag_b);
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
|
||||
//
|
||||
template <>
|
||||
__device__ inline void dequant<half2, vllm::kU4B8.id()>(int q, half2* frag_b) {
|
||||
__device__ inline void dequant<half2, vllm::kU4B8.id(), true>(int q,
|
||||
half2* frag_b) {
|
||||
const int MASK = 0x000f000f;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
q >>= 4;
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
|
||||
frag_b[0] = *reinterpret_cast<half2*>(&lo);
|
||||
frag_b[1] = *reinterpret_cast<half2*>(&hi);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, vllm::kU4B8.id(), false>(int q,
|
||||
half2* frag_b) {
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
@ -62,7 +142,14 @@ __device__ inline void dequant<half2, vllm::kU4B8.id()>(int q, half2* frag_b) {
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, vllm::kU4.id()>(int q, half2* frag_b) {
|
||||
__device__ inline void dequant<half2, vllm::kU4.id(), true>(int q,
|
||||
half2* frag_b) {
|
||||
dequant<half2, vllm::kU4B8.id(), true>(q, frag_b);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, vllm::kU4.id(), false>(int q,
|
||||
half2* frag_b) {
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
@ -84,7 +171,7 @@ __device__ inline void dequant<half2, vllm::kU4.id()>(int q, half2* frag_b) {
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kU4B8.id()>(
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kU4B8.id(), true>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
static constexpr uint32_t MASK = 0x000f000f;
|
||||
static constexpr uint32_t EX = 0x43004300;
|
||||
@ -96,39 +183,36 @@ __device__ inline void dequant<nv_bfloat162, vllm::kU4B8.id()>(
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
// clang-format on
|
||||
|
||||
static constexpr uint32_t MUL = 0x3F803F80;
|
||||
static constexpr uint32_t ADD = 0xC308C308;
|
||||
|
||||
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
frag_b[0] = *reinterpret_cast<nv_bfloat162*>(&lo);
|
||||
frag_b[1] = *reinterpret_cast<nv_bfloat162*>(&hi);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kU4.id()>(
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kU4B8.id(), false>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
static constexpr uint32_t MASK = 0x000f000f;
|
||||
static constexpr uint32_t EX = 0x43004300;
|
||||
dequant<nv_bfloat162, vllm::kU4B8.id(), true>(q, frag_b);
|
||||
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
// clang-format off
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
q >>= 4;
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
// clang-format on
|
||||
static constexpr uint32_t SUB = 0x43084308;
|
||||
|
||||
static constexpr uint32_t MUL = 0x3F803F80;
|
||||
static constexpr uint32_t ADD = 0xC300C300;
|
||||
frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast<const nv_bfloat162*>(&SUB));
|
||||
frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast<const nv_bfloat162*>(&SUB));
|
||||
}
|
||||
|
||||
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kU4.id(), true>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
dequant<nv_bfloat162, vllm::kU4B8.id(), true>(q, frag_b);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kU4.id(), false>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
dequant<nv_bfloat162, vllm::kU4.id(), true>(q, frag_b);
|
||||
|
||||
static constexpr uint32_t SUB = 0x43004300;
|
||||
|
||||
frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast<const nv_bfloat162*>(&SUB));
|
||||
frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast<const nv_bfloat162*>(&SUB));
|
||||
}
|
||||
|
||||
//
|
||||
@ -140,8 +224,8 @@ __device__ inline void dequant<nv_bfloat162, vllm::kU4.id()>(
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
|
||||
//
|
||||
template <>
|
||||
__device__ inline void dequant<half2, vllm::kU8B128.id()>(int q,
|
||||
half2* frag_b) {
|
||||
__device__ inline void dequant<half2, vllm::kU8B128.id(), true>(int q,
|
||||
half2* frag_b) {
|
||||
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
||||
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
||||
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
||||
@ -149,33 +233,42 @@ __device__ inline void dequant<half2, vllm::kU8B128.id()>(int q,
|
||||
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
|
||||
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
|
||||
|
||||
frag_b[0] = *reinterpret_cast<half2*>(&lo);
|
||||
frag_b[1] = *reinterpret_cast<half2*>(&hi);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, vllm::kU8B128.id(), false>(
|
||||
int q, half2* frag_b) {
|
||||
dequant<half2, vllm::kU8B128.id(), true>(q, frag_b);
|
||||
|
||||
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
|
||||
|
||||
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||
frag_b[0] = __hsub2(frag_b[0],
|
||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
|
||||
frag_b[1] = __hsub2(frag_b[1],
|
||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, vllm::kU8.id()>(int q, half2* frag_b) {
|
||||
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
||||
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
||||
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
||||
__device__ inline void dequant<half2, vllm::kU8.id(), true>(int q,
|
||||
half2* frag_b) {
|
||||
dequant<half2, vllm::kU8B128.id(), true>(q, frag_b);
|
||||
}
|
||||
|
||||
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
|
||||
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
|
||||
template <>
|
||||
__device__ inline void dequant<half2, vllm::kU8.id(), false>(int q,
|
||||
half2* frag_b) {
|
||||
dequant<half2, vllm::kU8.id(), true>(q, frag_b);
|
||||
|
||||
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
|
||||
|
||||
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||
frag_b[0] = __hsub2(frag_b[0],
|
||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
|
||||
frag_b[1] = __hsub2(frag_b[1],
|
||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kU8B128.id()>(
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kU8B128.id(), false>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
float fp32_intermediates[4];
|
||||
uint32_t* fp32_intermediates_casted =
|
||||
@ -200,7 +293,7 @@ __device__ inline void dequant<nv_bfloat162, vllm::kU8B128.id()>(
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kU8.id()>(
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kU8.id(), false>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
float fp32_intermediates[4];
|
||||
uint32_t* fp32_intermediates_casted =
|
||||
@ -225,22 +318,30 @@ __device__ inline void dequant<nv_bfloat162, vllm::kU8.id()>(
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, vllm::kFE4M3fn.id()>(int q,
|
||||
half2* frag_b) {
|
||||
__device__ inline void dequant<half2, vllm::kFE4M3fn.id(), true>(
|
||||
int q, half2* frag_b) {
|
||||
// Constants for FP8 (E4M3) and FP16 formats
|
||||
constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, FP16_EXPONENT = 5;
|
||||
constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5;
|
||||
constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT;
|
||||
|
||||
// Calculate MASK for extracting mantissa and exponent
|
||||
constexpr int MASK1 = 0x80000000;
|
||||
constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA);
|
||||
constexpr int MASK3 = MASK2 & 0x7fffffff;
|
||||
constexpr int MASK = MASK3 | (MASK3 >> 16);
|
||||
// Final MASK value: 0x7F007F00
|
||||
constexpr int MASK = 0x7F007F00;
|
||||
|
||||
// Extract and shift FP8 values to FP16 format
|
||||
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT);
|
||||
q <<= 8;
|
||||
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = *reinterpret_cast<const half2*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const half2*>(&Out2);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, vllm::kFE4M3fn.id(), false>(
|
||||
int q, half2* frag_b) {
|
||||
dequant<half2, vllm::kFE4M3fn.id(), true>(q, frag_b);
|
||||
|
||||
// Constants for FP8 (E4M3) and FP16 formats
|
||||
constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5;
|
||||
|
||||
// Construct and apply exponent bias
|
||||
constexpr int BIAS_OFFSET =
|
||||
@ -248,28 +349,36 @@ __device__ inline void dequant<half2, vllm::kFE4M3fn.id()>(int q,
|
||||
const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET));
|
||||
|
||||
// Convert to half2 and apply bias
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = __hmul2(*reinterpret_cast<const half2*>(&Out1), bias_reg);
|
||||
frag_b[0] = __hmul2(*reinterpret_cast<const half2*>(&Out2), bias_reg);
|
||||
frag_b[1] = __hmul2(frag_b[1], bias_reg);
|
||||
frag_b[0] = __hmul2(frag_b[0], bias_reg);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kFE4M3fn.id()>(
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kFE4M3fn.id(), true>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
// Constants for FP8 (E4M3) and BF16 formats
|
||||
constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, BF16_EXPONENT = 8;
|
||||
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
|
||||
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;
|
||||
|
||||
// Calculate MASK for extracting mantissa and exponent
|
||||
constexpr int MASK1 = 0x80000000;
|
||||
constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA);
|
||||
constexpr int MASK3 = MASK2 & 0x7fffffff;
|
||||
constexpr int MASK = MASK3 | (MASK3 >> 16);
|
||||
// Final MASK value: 0x7F007F00
|
||||
constexpr int MASK = 0x7F007F00;
|
||||
|
||||
// Extract and shift FP8 values to BF16 format
|
||||
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT);
|
||||
q <<= 8;
|
||||
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kFE4M3fn.id(), false>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
dequant<nv_bfloat162, vllm::kFE4M3fn.id(), true>(q, frag_b);
|
||||
|
||||
// Constants for FP8 (E4M3) and BF16 formats
|
||||
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
|
||||
|
||||
// Construct and apply exponent bias
|
||||
constexpr int BIAS_OFFSET =
|
||||
@ -281,9 +390,116 @@ __device__ inline void dequant<nv_bfloat162, vllm::kFE4M3fn.id()>(
|
||||
__float2bfloat162_rn(*reinterpret_cast<const float*>(&BIAS));
|
||||
|
||||
// Convert to bfloat162 and apply bias
|
||||
frag_b[1] = __hmul2(frag_b[1], bias_reg);
|
||||
frag_b[0] = __hmul2(frag_b[0], bias_reg);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, vllm::kFE2M1f.id(), true>(int q,
|
||||
half2* frag_b) {
|
||||
// Constants for FP4 (E2M1) and FP16 formats
|
||||
constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5;
|
||||
constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP4_EXPONENT;
|
||||
constexpr int MASK = 0x70007000;
|
||||
|
||||
// Extract and shift FP4 values to FP16 format
|
||||
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
q <<= 4;
|
||||
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = __hmul2(*reinterpret_cast<const nv_bfloat162*>(&Out1), bias_reg);
|
||||
frag_b[0] = __hmul2(*reinterpret_cast<const nv_bfloat162*>(&Out2), bias_reg);
|
||||
frag_b[1] = *reinterpret_cast<const half2*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const half2*>(&Out2);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, vllm::kFE2M1f.id(), false>(
|
||||
int q, half2* frag_b) {
|
||||
dequant<half2, vllm::kFE2M1f.id(), true>(q, frag_b);
|
||||
|
||||
// Constants for FP4 (E2M1) and FP16 formats
|
||||
constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5;
|
||||
|
||||
// Construct and apply exponent bias
|
||||
constexpr int BIAS_OFFSET =
|
||||
(1 << (FP16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1));
|
||||
const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET));
|
||||
|
||||
// Convert to half2 and apply bias
|
||||
frag_b[1] = __hmul2(frag_b[1], bias_reg);
|
||||
frag_b[0] = __hmul2(frag_b[0], bias_reg);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kFE2M1f.id(), true>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
// Constants for FP4 (E2M1) and FP16 formats
|
||||
constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8;
|
||||
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP4_EXPONENT;
|
||||
constexpr int MASK = 0x70007000;
|
||||
|
||||
// Extract and shift FP4 values to FP16 format
|
||||
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
q <<= 4;
|
||||
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, vllm::kFE2M1f.id(), false>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
dequant<nv_bfloat162, vllm::kFE2M1f.id(), true>(q, frag_b);
|
||||
|
||||
// Constants for FP4 (E2M1) and BF16 formats
|
||||
constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8;
|
||||
|
||||
// Construct and apply exponent bias
|
||||
constexpr int BIAS_OFFSET =
|
||||
(1 << (BF16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1));
|
||||
// Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent
|
||||
// position
|
||||
constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23;
|
||||
const nv_bfloat162 bias_reg =
|
||||
__float2bfloat162_rn(*reinterpret_cast<const float*>(&BIAS));
|
||||
|
||||
// Convert to half2 and apply bias
|
||||
frag_b[1] = __hmul2(frag_b[1], bias_reg);
|
||||
frag_b[0] = __hmul2(frag_b[0], bias_reg);
|
||||
}
|
||||
|
||||
template <typename scalar_t2>
|
||||
__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b);
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant_fp8_scales<half2>(int q, half2* frag_b) {
|
||||
int Out1 = (q & 0xFF00FF00) >> 1;
|
||||
;
|
||||
q <<= 8;
|
||||
int Out2 = (q & 0xFF00FF00) >> 1;
|
||||
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = *reinterpret_cast<const half2*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const half2*>(&Out2);
|
||||
};
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant_fp8_scales<nv_bfloat162>(int q,
|
||||
nv_bfloat162* frag_b) {
|
||||
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
|
||||
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;
|
||||
constexpr int MASK = 0x7F007F00;
|
||||
|
||||
// Extract and shift FP8 values to BF16 format
|
||||
int Out1 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
q <<= 8;
|
||||
int Out2 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@ -31,7 +31,10 @@ TEMPLATE = ("template __global__ void Marlin<"
|
||||
|
||||
# int8 with zero point case (vllm::kU8) is also supported,
|
||||
# we don't add it to reduce wheel size.
|
||||
SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn"]
|
||||
SCALAR_TYPES = [
|
||||
"vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn",
|
||||
"vllm::kFE2M1f"
|
||||
]
|
||||
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128),
|
||||
(128, 64, 128)]
|
||||
|
||||
@ -40,7 +43,7 @@ THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
|
||||
# = 0 : act order case
|
||||
# = -1 : channelwise quantization
|
||||
# > 0 : group_size=16*group_blocks
|
||||
GROUP_BLOCKS = [0, -1, 2, 4, 8]
|
||||
GROUP_BLOCKS = [0, 1, -1, 2, 4, 8]
|
||||
DTYPES = ["fp16", "bf16"]
|
||||
|
||||
|
||||
@ -73,6 +76,12 @@ def generate_new_kernels():
|
||||
# for fp8
|
||||
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
|
||||
continue
|
||||
# nvfp4 only supports group_size == 16
|
||||
if scalar_type == "vllm::kFE2M1f" and group_blocks != 1:
|
||||
continue
|
||||
# other quantization methods don't support group_size = 16
|
||||
if scalar_type != "vllm::kFE2M1f" and group_blocks == 1:
|
||||
continue
|
||||
|
||||
k_blocks = thread_configs[0] // 16
|
||||
n_blocks = thread_configs[1] // 16
|
||||
|
||||
@ -258,6 +258,7 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
|
||||
// BIGGROUP: cases for big group size (group_blocks in [-1, 8])
|
||||
// FZP: cases for float-zero-point (is_zp_float = true)
|
||||
// ACT: cases for act order case (group_blocks == 0)
|
||||
// FP4: cases for nvfp4(e2m1) (group_blocks == 1)
|
||||
#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
|
||||
@ -314,6 +315,23 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
|
||||
BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \
|
||||
BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128)
|
||||
|
||||
#define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
|
||||
|
||||
#define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
|
||||
|
||||
#define FP4_GET_IF(W_TYPE) \
|
||||
FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||
FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||
FP4_GET_IF_M1(W_TYPE, 4, 8, 128) \
|
||||
FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||
FP4_GET_IF_M234(W_TYPE, 8, 4, 128) \
|
||||
FP4_GET_IF_M234(W_TYPE, 4, 8, 128)
|
||||
|
||||
// We currently have 4-bit models only with group_blocks == 4
|
||||
#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \
|
||||
@ -366,6 +384,8 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
|
||||
COMMON_GET_IF(vllm::kU4B8)
|
||||
COMMON_GET_IF(vllm::kU8B128)
|
||||
|
||||
FP4_GET_IF(vllm::kFE2M1f)
|
||||
|
||||
BIGGROUP_GET_IF(vllm::kFE4M3fn)
|
||||
|
||||
ACT_GET_IF(vllm::kU4B8)
|
||||
@ -434,8 +454,8 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
|
||||
|
||||
template <typename scalar_t>
|
||||
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
void* zp, void* g_idx, void* perm, void* a_tmp, int prob_m,
|
||||
int prob_n, int prob_k, int lda, void* workspace,
|
||||
void* s2, void* zp, void* g_idx, void* perm, void* a_tmp,
|
||||
int prob_m, int prob_n, int prob_k, int lda, void* workspace,
|
||||
vllm::ScalarType const& q_type, bool has_act_order,
|
||||
bool is_k_full, bool has_zp, int num_groups, int group_size,
|
||||
int dev, cudaStream_t stream, int thread_k_init,
|
||||
@ -446,11 +466,12 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
q_type == vllm::kU4 || q_type == vllm::kU8,
|
||||
"q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
|
||||
} else {
|
||||
TORCH_CHECK(q_type == vllm::kU4B8 || q_type == vllm::kU8B128 ||
|
||||
q_type == vllm::kFE4M3fn,
|
||||
"q_type must be uint4b8, uint8b128 or float8_e4m3fn when "
|
||||
"has_zp = False. Got = ",
|
||||
q_type.str());
|
||||
TORCH_CHECK(
|
||||
q_type == vllm::kU4B8 || q_type == vllm::kU8B128 ||
|
||||
q_type == vllm::kFE4M3fn || q_type == vllm::kFE2M1f,
|
||||
"q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when "
|
||||
"has_zp = False. Got = ",
|
||||
q_type.str());
|
||||
}
|
||||
|
||||
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
||||
@ -483,6 +504,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
int4* C_ptr = (int4*)C;
|
||||
int4* C_tmp_ptr = (int4*)C_tmp;
|
||||
const int4* s_ptr = (const int4*)s;
|
||||
const uint16_t* s2_ptr = (const uint16_t*)s2;
|
||||
const int4* zp_ptr = (const int4*)zp;
|
||||
const int* g_idx_ptr = (const int*)g_idx;
|
||||
const int* perm_ptr = (const int*)perm;
|
||||
@ -601,7 +623,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
// avoid ">>>" being formatted to "> > >"
|
||||
// clang-format off
|
||||
kernel<<<blocks, num_threads, max_shared_mem_new, stream>>>(
|
||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, num_groups,
|
||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr, num_groups,
|
||||
prob_m_split, prob_n, prob_k, lda, locks, part_use_atomic_add,
|
||||
use_fp32_reduce, max_shared_mem_new);
|
||||
// clang-format on
|
||||
@ -617,6 +639,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
torch::Tensor gptq_marlin_gemm(
|
||||
torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
|
||||
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
|
||||
std::optional<torch::Tensor> const& global_scale_or_none,
|
||||
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
|
||||
@ -759,6 +782,17 @@ torch::Tensor gptq_marlin_gemm(
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor global_scale;
|
||||
if (global_scale_or_none.has_value()) {
|
||||
global_scale = global_scale_or_none.value();
|
||||
TORCH_CHECK(b_q_type == vllm::kFE2M1f,
|
||||
"global_scale can only be used for float4_e2m1f.");
|
||||
} else {
|
||||
global_scale = torch::empty({0}, options);
|
||||
TORCH_CHECK(!(b_q_type == vllm::kFE2M1f),
|
||||
"the global_scale parameter must be passed for float4_e2m1f.");
|
||||
}
|
||||
|
||||
torch::Tensor b_zeros;
|
||||
if (b_zeros_or_none.has_value()) {
|
||||
b_zeros = b_zeros_or_none.value();
|
||||
@ -774,8 +808,9 @@ torch::Tensor gptq_marlin_gemm(
|
||||
"b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str());
|
||||
} else {
|
||||
TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 ||
|
||||
b_q_type == vllm::kFE4M3fn,
|
||||
"b_q_type must be uint4b8, uint8b128 or float8_e4m3fn when "
|
||||
b_q_type == vllm::kFE4M3fn || b_q_type == vllm::kFE2M1f,
|
||||
"b_q_type must be uint4b8, uint8b128, float8_e4m3fn or "
|
||||
"float4_e2m1f when "
|
||||
"has_zp = False. Got = ",
|
||||
b_q_type.str());
|
||||
}
|
||||
@ -820,22 +855,36 @@ torch::Tensor gptq_marlin_gemm(
|
||||
|
||||
int dev = a.get_device();
|
||||
if (a.scalar_type() == at::ScalarType::Half) {
|
||||
void* scales_ptr;
|
||||
if (b_q_type == vllm::kFE2M1f) {
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
|
||||
} else {
|
||||
scales_ptr = b_scales.data_ptr<at::Half>();
|
||||
}
|
||||
|
||||
marlin::marlin_mm<half>(
|
||||
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
|
||||
c_tmp.data_ptr<float>(), b_scales.data_ptr<at::Half>(),
|
||||
c_tmp.data_ptr<float>(), scales_ptr, global_scale.data_ptr<at::Half>(),
|
||||
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
|
||||
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k, a.stride(0),
|
||||
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
|
||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float);
|
||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||
void* scales_ptr;
|
||||
if (b_q_type == vllm::kFE2M1f) {
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
|
||||
} else {
|
||||
scales_ptr = b_scales.data_ptr<at::BFloat16>();
|
||||
}
|
||||
|
||||
marlin::marlin_mm<nv_bfloat16>(
|
||||
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
||||
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
|
||||
b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(),
|
||||
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
|
||||
a.stride(0), workspace.data_ptr(), b_q_type, has_act_order, is_k_full,
|
||||
has_zp, num_groups, group_size, dev,
|
||||
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(), scales_ptr,
|
||||
global_scale.data_ptr<at::BFloat16>(), b_zeros.data_ptr(),
|
||||
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
|
||||
size_m, size_n, size_k, a.stride(0), workspace.data_ptr(), b_q_type,
|
||||
has_act_order, is_k_full, has_zp, num_groups, group_size, dev,
|
||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||
use_atomic_add, use_fp32_reduce, is_zp_float);
|
||||
} else {
|
||||
|
||||
@ -7,13 +7,14 @@
|
||||
#include "marlin_dtypes.cuh"
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
#define MARLIN_KERNEL_PARAMS \
|
||||
const int4 *__restrict__ A, const int4 *__restrict__ B, \
|
||||
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
|
||||
const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \
|
||||
const int *__restrict__ g_idx, int num_groups, int prob_m, int prob_n, \
|
||||
int prob_k, int lda, int *locks, bool use_atomic_add, \
|
||||
bool use_fp32_reduce, int max_shared_mem
|
||||
#define MARLIN_KERNEL_PARAMS \
|
||||
const int4 *__restrict__ A, const int4 *__restrict__ B, \
|
||||
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
|
||||
const int4 *__restrict__ scales_ptr, \
|
||||
const uint16_t *__restrict__ scale2_ptr, \
|
||||
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
|
||||
int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \
|
||||
bool use_atomic_add, bool use_fp32_reduce, int max_shared_mem
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
|
||||
@ -292,9 +292,11 @@ __global__ void Marlin(
|
||||
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
|
||||
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
||||
// (k/groupsize)xn
|
||||
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
|
||||
// (k/groupsize)x(n/pack_factor)
|
||||
const int* __restrict__ g_idx, // int32 group indices of shape k
|
||||
const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4
|
||||
// only)
|
||||
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
|
||||
// (k/groupsize)x(n/pack_factor)
|
||||
const int* __restrict__ g_idx, // int32 group indices of shape k
|
||||
int num_groups, // number of scale groups per output channel
|
||||
int prob_m, // batch dimension m
|
||||
int prob_n, // output dimension n
|
||||
@ -325,6 +327,21 @@ __global__ void Marlin(
|
||||
|
||||
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
|
||||
constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8;
|
||||
constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 ||
|
||||
w_type == vllm::kU4B8 || w_type == vllm::kU8B128;
|
||||
// see comments of dequant.h for more details
|
||||
constexpr bool dequant_skip_flop =
|
||||
!is_int_type ||
|
||||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
|
||||
has_zp && !is_zp_float && !(w_type == vllm::kU8);
|
||||
|
||||
scalar_t2 global_scale;
|
||||
|
||||
if constexpr (w_type == vllm::kFE2M1f) {
|
||||
uint16_t val = scale2_ptr[0];
|
||||
global_scale = Dtype::num2num2(*reinterpret_cast<scalar_t*>(&val));
|
||||
}
|
||||
|
||||
constexpr bool has_act_order = group_blocks == 0;
|
||||
constexpr int m_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks);
|
||||
|
||||
@ -481,7 +498,7 @@ __global__ void Marlin(
|
||||
constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
|
||||
constexpr int s_tb_groups =
|
||||
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
|
||||
? thread_k_blocks / group_blocks
|
||||
? thread_k_blocks / group_blocks / (w_type == vllm::kFE2M1f ? 2 : 1)
|
||||
: 1;
|
||||
constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
|
||||
int s_gl_rd_delta = s_gl_stride;
|
||||
@ -540,7 +557,8 @@ __global__ void Marlin(
|
||||
if constexpr (group_blocks == -1) {
|
||||
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
||||
} else {
|
||||
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
|
||||
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) /
|
||||
(w_type == vllm::kFE2M1f ? 2 : 1) +
|
||||
s_sh_stride * slice_col + threadIdx.x;
|
||||
}
|
||||
}
|
||||
@ -564,10 +582,20 @@ __global__ void Marlin(
|
||||
// we scale a `half2` tile in column-major layout in the former and in
|
||||
// row-major in the latter case.
|
||||
int s_sh_rd;
|
||||
if constexpr (group_blocks != -1)
|
||||
if constexpr (group_blocks != -1 && w_type == vllm::kFE2M1f) {
|
||||
auto warp_id = threadIdx.x / 32;
|
||||
int n_warps = thread_n_blocks / 4;
|
||||
int warp_row = warp_id / n_warps;
|
||||
|
||||
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
(threadIdx.x % 32) / 4;
|
||||
else if constexpr (group_blocks == -1 && (m_block_size_8 || has_zp))
|
||||
s_sh_rd = s_sh_rd * 2 + warp_row % 2;
|
||||
|
||||
} else if constexpr (group_blocks != -1)
|
||||
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
(threadIdx.x % 32) / 4;
|
||||
else if constexpr (group_blocks == -1 &&
|
||||
(m_block_size_8 || (has_zp && !dequant_skip_flop)))
|
||||
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
(threadIdx.x % 32) / 8;
|
||||
else
|
||||
@ -681,7 +709,7 @@ __global__ void Marlin(
|
||||
sh_first_group_id = first_group_id;
|
||||
sh_num_groups = last_group_id - first_group_id + 1;
|
||||
|
||||
if (sh_num_groups < act_s_max_num_groups) {
|
||||
if (sh_num_groups > act_s_max_num_groups) {
|
||||
sh_num_groups = act_s_max_num_groups;
|
||||
}
|
||||
|
||||
@ -887,12 +915,19 @@ __global__ void Marlin(
|
||||
cur_k += k_iter_size * (k % b_sh_wr_iters);
|
||||
|
||||
int k_blocks = cur_k / 16;
|
||||
int cur_group_id = k_blocks / group_blocks;
|
||||
int cur_group_id =
|
||||
k_blocks / (group_blocks * (w_type == vllm::kFE2M1f ? 2 : 1));
|
||||
|
||||
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
||||
|
||||
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
|
||||
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
|
||||
if constexpr (w_type_id != vllm::kFE2M1f.id()) {
|
||||
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
|
||||
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
|
||||
} else {
|
||||
reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
|
||||
reinterpret_cast<int2*>(
|
||||
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1065,22 +1100,7 @@ __global__ void Marlin(
|
||||
};
|
||||
|
||||
auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) {
|
||||
if constexpr (has_zp && is_zp_float || !has_zp) {
|
||||
dequant<scalar_t2, w_type_id>(q, frag_b_ptr);
|
||||
} else {
|
||||
static_assert(has_zp && !is_zp_float);
|
||||
static_assert(w_type_id == vllm::kU4.id() || w_type_id == vllm::kU8.id());
|
||||
// If (has_zp && !is_zp_float),
|
||||
// we use not-zp version `dequant` function
|
||||
// to improve numerical accuracy.
|
||||
// Since both weight and zero point are dequanted using this logic,
|
||||
// the final dequanted weight would be correct.
|
||||
if constexpr (w_type_id == vllm::kU4.id()) {
|
||||
dequant<scalar_t2, vllm::kU4B8.id()>(q, frag_b_ptr);
|
||||
} else if constexpr (w_type_id == vllm::kU8.id()) {
|
||||
dequant<scalar_t2, vllm::kU8B128.id()>(q, frag_b_ptr);
|
||||
}
|
||||
}
|
||||
dequant<scalar_t2, w_type_id, dequant_skip_flop>(q, frag_b_ptr);
|
||||
};
|
||||
|
||||
// Execute the actual tensor core matmul of a sub-tile.
|
||||
@ -1110,13 +1130,23 @@ __global__ void Marlin(
|
||||
dequant_data(zp_quant_1, reinterpret_cast<scalar_t2*>(&frag_zp) + 2);
|
||||
}
|
||||
}
|
||||
if constexpr (has_zp && is_zp_float) {
|
||||
if constexpr (!dequant_skip_flop && has_zp && is_zp_float) {
|
||||
if (is_new_zp) {
|
||||
reinterpret_cast<int4*>(&frag_zp)[0] =
|
||||
reinterpret_cast<int4*>(&frag_zpf[k2])[0];
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (w_type == vllm::kFE2M1f) {
|
||||
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
|
||||
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];
|
||||
|
||||
dequant_fp8_scales<scalar_t2>(s_quant_0,
|
||||
reinterpret_cast<scalar_t2*>(&frag_s[k2]));
|
||||
dequant_fp8_scales<scalar_t2>(
|
||||
s_quant_1, reinterpret_cast<scalar_t2*>(&frag_s[k2]) + 2);
|
||||
}
|
||||
|
||||
// We have the m dimension as the inner loop in order to encourage overlapping
|
||||
// dequantization and matmul operations.
|
||||
#pragma unroll
|
||||
@ -1125,7 +1155,10 @@ __global__ void Marlin(
|
||||
FragB frag_b1;
|
||||
int b_quant_0, b_quant_1;
|
||||
|
||||
if constexpr (w_type.size_bits() == 4) {
|
||||
if constexpr (w_type_id == vllm::kFE2M1f.id()) {
|
||||
b_quant_1 = frag_b_quant[k2][0][j];
|
||||
b_quant_0 = b_quant_1 << 8;
|
||||
} else if constexpr (w_type.size_bits() == 4) {
|
||||
b_quant_0 = frag_b_quant[k2][0][j];
|
||||
b_quant_1 = b_quant_0 >> 8;
|
||||
} else {
|
||||
@ -1138,6 +1171,11 @@ __global__ void Marlin(
|
||||
dequant_data(b_quant_0, reinterpret_cast<scalar_t2*>(&frag_b0));
|
||||
dequant_data(b_quant_1, reinterpret_cast<scalar_t2*>(&frag_b1));
|
||||
|
||||
if constexpr (dequant_skip_flop && has_zp && !is_zp_float) {
|
||||
sub_zp<scalar_t>(frag_b0, frag_zp[j], 0);
|
||||
sub_zp<scalar_t>(frag_b1, frag_zp[j], 1);
|
||||
}
|
||||
|
||||
// Apply scale to frag_b0
|
||||
if constexpr (has_act_order) {
|
||||
static_assert(group_blocks != -1);
|
||||
@ -1145,7 +1183,8 @@ __global__ void Marlin(
|
||||
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0);
|
||||
scale4<scalar_t>(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
|
||||
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1);
|
||||
} else if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
|
||||
} else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float &&
|
||||
group_blocks == -1) {
|
||||
int idx = (threadIdx.x / 4) % 2;
|
||||
scalar_t2 s2 = Dtype::nums2num2(
|
||||
reinterpret_cast<scalar_t*>(&frag_s[j / 2][j % 2 * 2 + 0])[idx],
|
||||
@ -1153,7 +1192,7 @@ __global__ void Marlin(
|
||||
if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2);
|
||||
scale_and_sub<scalar_t>(frag_b0, s2.x, frag_zp[j].x);
|
||||
scale_and_sub<scalar_t>(frag_b1, s2.y, frag_zp[j].y);
|
||||
} else if constexpr (has_zp && group_blocks != -1) {
|
||||
} else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) {
|
||||
if (is_new_zp)
|
||||
frag_zp[j] = __hmul2(frag_zp[j],
|
||||
*reinterpret_cast<scalar_t2*>(&frag_s[k2][j]));
|
||||
@ -1408,10 +1447,15 @@ __global__ void Marlin(
|
||||
// For per-column quantization we finally apply the scale here (only for
|
||||
// 4-bit)
|
||||
if constexpr (!has_act_order && group_blocks == -1 &&
|
||||
w_type.size_bits() == 4 && !has_zp) {
|
||||
w_type.size_bits() == 4 &&
|
||||
(has_zp && dequant_skip_flop || !has_zp)) {
|
||||
res = __hmul2(res, s[0]);
|
||||
}
|
||||
|
||||
if constexpr (w_type == vllm::kFE2M1f) {
|
||||
res = __hmul2(res, global_scale);
|
||||
}
|
||||
|
||||
if constexpr (m_block_size_8) {
|
||||
((scalar_t*)sh_red)[idx] = res.x;
|
||||
((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y;
|
||||
@ -1488,7 +1532,9 @@ __global__ void Marlin(
|
||||
if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
|
||||
if (i == 0) {
|
||||
fetch_col_zp_to_shared();
|
||||
fetch_col_scale_to_shared();
|
||||
if constexpr (!dequant_skip_flop) {
|
||||
fetch_col_scale_to_shared();
|
||||
}
|
||||
}
|
||||
}
|
||||
fetch_to_shared(i, i, i < slice_iters);
|
||||
@ -1542,16 +1588,20 @@ __global__ void Marlin(
|
||||
|
||||
if constexpr (has_act_order) {
|
||||
slice_k_start += tb_k * stages;
|
||||
slice_k_start_shared_fetch += tb_k * stages;
|
||||
int first_group_id = g_idx[slice_k_start];
|
||||
int last_g_idx = slice_k_start + stages * tb_k * 2;
|
||||
if (last_g_idx >= prob_k) {
|
||||
last_g_idx = prob_k - 1;
|
||||
}
|
||||
int last_group_id = g_idx[last_g_idx];
|
||||
if (last_group_id >= sh_first_group_id + sh_num_groups) {
|
||||
fetch_act_order_scales_to_shared(false, first_group_id, last_group_id);
|
||||
__syncthreads();
|
||||
|
||||
if (slice_k_start < prob_k) {
|
||||
slice_k_start_shared_fetch += tb_k * stages;
|
||||
int first_group_id = g_idx[slice_k_start];
|
||||
int last_g_idx = slice_k_start + stages * tb_k * 2;
|
||||
if (last_g_idx >= prob_k) {
|
||||
last_g_idx = prob_k - 1;
|
||||
}
|
||||
int last_group_id = g_idx[last_g_idx];
|
||||
if (last_group_id >= sh_first_group_id + sh_num_groups) {
|
||||
fetch_act_order_scales_to_shared(false, first_group_id,
|
||||
last_group_id);
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1563,7 +1613,8 @@ __global__ void Marlin(
|
||||
bool last = slice_idx == slice_count - 1;
|
||||
// For per-column scales, we only fetch them here in the final step before
|
||||
// write-out
|
||||
if constexpr (!has_act_order && group_blocks == -1 && !has_zp) {
|
||||
if constexpr (!has_act_order && group_blocks == -1 &&
|
||||
(has_zp && dequant_skip_flop || !has_zp)) {
|
||||
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
|
||||
if (s_sh_wr_pred) {
|
||||
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
|
||||
@ -1573,7 +1624,8 @@ __global__ void Marlin(
|
||||
}
|
||||
|
||||
thread_block_reduce();
|
||||
if constexpr (!has_act_order && group_blocks == -1 && !has_zp) {
|
||||
if constexpr (!has_act_order && group_blocks == -1 &&
|
||||
(has_zp && dequant_skip_flop || !has_zp)) {
|
||||
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
|
||||
cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
@ -1597,7 +1649,8 @@ __global__ void Marlin(
|
||||
// that converts the fp32 results to fp16 (so that we avoid possible
|
||||
// overflow in fp16)
|
||||
if constexpr (!has_act_order && group_blocks == -1 &&
|
||||
w_type.size_bits() == 8 && !has_zp) {
|
||||
w_type.size_bits() == 8 &&
|
||||
(has_zp && dequant_skip_flop || !has_zp)) {
|
||||
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks; i++) {
|
||||
|
||||
@ -77,6 +77,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor suffix_output,"
|
||||
" Tensor suffix_lse) -> ()");
|
||||
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
|
||||
|
||||
ops.def(
|
||||
"convert_vertical_slash_indexes("
|
||||
" Tensor! block_count, Tensor! block_offset, "
|
||||
" Tensor! column_count, Tensor! column_index, "
|
||||
" Tensor q_seqlens, Tensor q_seqlens, "
|
||||
" Tensor vertical_indexes, Tensor slash_indexes, "
|
||||
" int context_size, int block_size_M, int block_size_N, "
|
||||
" bool causal) -> ()");
|
||||
ops.impl("convert_vertical_slash_indexes", torch::kCUDA,
|
||||
&convert_vertical_slash_indexes);
|
||||
|
||||
ops.def(
|
||||
"convert_vertical_slash_indexes_mergehead("
|
||||
" Tensor! block_count, Tensor! block_offset, "
|
||||
" Tensor! column_count, Tensor! column_index, "
|
||||
" Tensor q_seqlens, Tensor q_seqlens, "
|
||||
" Tensor vertical_indexes, Tensor slash_indexes, "
|
||||
" Tensor vertical_indices_count, Tensor slash_indices_count, "
|
||||
" int context_size, int block_size_M, int block_size_N, "
|
||||
" bool causal) -> ()");
|
||||
ops.impl("convert_vertical_slash_indexes_mergehead", torch::kCUDA,
|
||||
&convert_vertical_slash_indexes_mergehead);
|
||||
#endif
|
||||
|
||||
// Activation ops
|
||||
@ -292,8 +315,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// gptq_marlin Optimized Quantized GEMM for GPTQ.
|
||||
ops.def(
|
||||
"gptq_marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, "
|
||||
"Tensor b_scales, Tensor? b_zeros_or_none, Tensor? g_idx_or_none, "
|
||||
"Tensor? perm_or_none, Tensor workspace, int b_q_type, "
|
||||
"Tensor b_scales, Tensor? global_scale, Tensor? b_zeros_or_none, Tensor? "
|
||||
"g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_q_type, "
|
||||
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
|
||||
"bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor",
|
||||
{stride_tag});
|
||||
@ -363,6 +386,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
{stride_tag});
|
||||
ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm);
|
||||
|
||||
// cutlass nvfp4 block scaled group GEMM
|
||||
ops.def(
|
||||
"cutlass_fp4_group_mm(Tensor! out, Tensor a, Tensor b,"
|
||||
" Tensor a_blockscale, Tensor b_blockscales, Tensor alphas,"
|
||||
" Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()",
|
||||
{stride_tag});
|
||||
ops.impl("cutlass_fp4_group_mm", torch::kCUDA, &cutlass_fp4_group_mm);
|
||||
|
||||
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
||||
// quantization, as well as bias
|
||||
ops.def(
|
||||
@ -492,6 +523,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor! output_scale, Tensor input_scale) -> ()");
|
||||
ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant);
|
||||
|
||||
// Compute NVFP4 experts quantization.
|
||||
ops.def(
|
||||
"scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
|
||||
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
|
||||
"Tensor output_scale_offset_by_experts) -> ()");
|
||||
ops.impl("scaled_fp4_experts_quant", torch::kCUDA, &scaled_fp4_experts_quant);
|
||||
|
||||
// Check if cutlass_scaled_mm_fp4 is supported for CUDA devices
|
||||
// of the given capability
|
||||
ops.def("cutlass_scaled_mm_supports_fp4(int cuda_device_capability) -> bool");
|
||||
|
||||
@ -77,7 +77,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
# can be useful for both `dev` and `test`
|
||||
# explicitly set the list to avoid issues with torch 2.2
|
||||
# see https://github.com/pytorch/pytorch/pull/123243
|
||||
ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX'
|
||||
ARG torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0 10.0+PTX'
|
||||
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
|
||||
# Override the arch list for flash-attn to reduce the binary size
|
||||
ARG vllm_fa_cmake_gpu_arches='80-real;90-real'
|
||||
@ -255,9 +255,10 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
. /etc/environment && \
|
||||
if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
|
||||
# uv pip install --system https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.4/flashinfer_python-0.2.4+cu124torch2.6-cp38-abi3-linux_x86_64.whl ; \
|
||||
# TESTING: install FlashInfer from source to test 2.7.0 final RC
|
||||
FLASHINFER_ENABLE_AOT=1 TORCH_CUDA_ARCH_LIST='7.5 8.0 8.6 8.9 9.0+PTX' \
|
||||
uv pip install --system --no-build-isolation "git+https://github.com/flashinfer-ai/flashinfer@v0.2.2.post1" ; \
|
||||
FLASHINFER_ENABLE_AOT=1 TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0 10.0+PTX' \
|
||||
uv pip install --system --no-build-isolation "git+https://github.com/flashinfer-ai/flashinfer@e00e8cedbfcb220f328fd36aa8f529f869b01e6b" ; \
|
||||
fi
|
||||
COPY examples examples
|
||||
COPY benchmarks benchmarks
|
||||
|
||||
@ -21,12 +21,8 @@ ENV UV_LINK_MODE=copy
|
||||
# Note: A dummy file 'control' is created in /tmp/ to artificially create dependencies between stages when building stages in parallel
|
||||
# when `--jobs=<N>` is passed with podman build command
|
||||
RUN microdnf install -y openssl-devel dnf \
|
||||
&& dnf install -y https://mirror.stream.centos.org/9-stream/BaseOS/`arch`/os/Packages/centos-gpg-keys-9.0-24.el9.noarch.rpm \
|
||||
https://mirror.stream.centos.org/9-stream/BaseOS/`arch`/os/Packages/centos-stream-repos-9.0-24.el9.noarch.rpm \
|
||||
https://dl.fedoraproject.org/pub/epel/epel-release-latest-9.noarch.rpm \
|
||||
&& dnf config-manager --add-repo https://mirror.stream.centos.org/9-stream/BaseOS/`arch`/os \
|
||||
&& dnf config-manager --add-repo https://mirror.stream.centos.org/9-stream/AppStream/`arch`/os \
|
||||
&& dnf config-manager --set-enabled crb \
|
||||
&& dnf install -y https://dl.fedoraproject.org/pub/epel/epel-release-latest-9.noarch.rpm \
|
||||
&& dnf config-manager --set-enabled codeready-builder-for-rhel-9-ppc64le-rpms \
|
||||
&& dnf install -y \
|
||||
git tar gcc-toolset-13 automake libtool numactl-devel lapack-devel \
|
||||
pkgconfig xsimd zeromq-devel kmod findutils protobuf* \
|
||||
|
||||
@ -10,6 +10,7 @@ chatbox
|
||||
dify
|
||||
dstack
|
||||
helm
|
||||
lobe-chat
|
||||
lws
|
||||
modal
|
||||
open-webui
|
||||
|
||||
13
docs/source/deployment/frameworks/lobe-chat.md
Normal file
13
docs/source/deployment/frameworks/lobe-chat.md
Normal file
@ -0,0 +1,13 @@
|
||||
(deployment-lobe-chat)=
|
||||
|
||||
# Lobe Chat
|
||||
|
||||
[Lobe Chat](https://github.com/lobehub/lobe-chat) is an open-source, modern-design ChatGPT/LLMs UI/Framework.
|
||||
|
||||
Supports speech-synthesis, multi-modal, and extensible (function call) plugin system.
|
||||
|
||||
One-click FREE deployment of your private OpenAI ChatGPT/Claude/Gemini/Groq/Ollama chat application.
|
||||
|
||||
It supports vLLM as a AI model provider to efficiently serve large language models.
|
||||
|
||||
For details, see the tutorial [Using vLLM in LobeChat](https://lobehub.com/docs/usage/providers/vllm).
|
||||
@ -53,6 +53,45 @@ Key points from the PyTorch security guide:
|
||||
- Implement proper authentication and authorization for management interfaces
|
||||
- Follow the principle of least privilege for all system components
|
||||
|
||||
## Security and Firewalls: Protecting Exposed vLLM Systems
|
||||
|
||||
While vLLM is designed to allow unsafe network services to be isolated to
|
||||
private networks, there are components—such as dependencies and underlying
|
||||
frameworks—that may open insecure services listening on all network interfaces,
|
||||
sometimes outside of vLLM's direct control.
|
||||
|
||||
A major concern is the use of `torch.distributed`, which vLLM leverages for
|
||||
distributed communication, including when using vLLM on a single host. When vLLM
|
||||
uses TCP initialization (see [PyTorch TCP Initialization
|
||||
documentation](https://docs.pytorch.org/docs/stable/distributed.html#tcp-initialization)),
|
||||
PyTorch creates a `TCPStore` that, by default, listens on all network
|
||||
interfaces. This means that unless additional protections are put in place,
|
||||
these services may be accessible to any host that can reach your machine via any
|
||||
network interface.
|
||||
|
||||
**From a PyTorch perspective, any use of `torch.distributed` should be
|
||||
considered insecure by default.** This is a known and intentional behavior from
|
||||
the PyTorch team.
|
||||
|
||||
### Firewall Configuration Guidance
|
||||
|
||||
The best way to protect your vLLM system is to carefully configure a firewall to
|
||||
expose only the minimum network surface area necessary. In most cases, this
|
||||
means:
|
||||
|
||||
- **Block all incoming connections except to the TCP port the API server is
|
||||
listening on.**
|
||||
|
||||
- Ensure that ports used for internal communication (such as those for
|
||||
`torch.distributed` and KV cache transfer) are only accessible from trusted
|
||||
hosts or networks.
|
||||
|
||||
- Never expose these internal ports to the public internet or untrusted
|
||||
networks.
|
||||
|
||||
Consult your operating system or application platform documentation for specific
|
||||
firewall configuration instructions.
|
||||
|
||||
## Reporting Security Vulnerabilities
|
||||
|
||||
If you believe you have found a security vulnerability in vLLM, please report it following the project's security policy. For more information on how to report security issues and the project's security policy, please see the [vLLM Security Policy](https://github.com/vllm-project/vllm/blob/main/SECURITY.md).
|
||||
|
||||
@ -415,8 +415,8 @@ The discussion in <gh-issue:10582> about adding prefix cache metrics yielded
|
||||
some interesting points which may be relevant to how we approach
|
||||
future metrics.
|
||||
|
||||
Every time the prefix cache is queried, we record the number of blocks
|
||||
queried and the number of queried blocks present in the cache
|
||||
Every time the prefix cache is queried, we record the number of tokens
|
||||
queried and the number of queried tokens present in the cache
|
||||
(i.e. hits).
|
||||
|
||||
However, the metric of interest is the hit rate - i.e. the number of
|
||||
|
||||
@ -66,7 +66,7 @@ The commit ID `0dfa347e8877a4d4ed19ee56c140fa518470028c` may change over time. P
|
||||
|
||||
The server entrypoint accepts all other LoRA configuration parameters (`max_loras`, `max_lora_rank`, `max_cpu_loras`,
|
||||
etc.), which will apply to all forthcoming requests. Upon querying the `/models` endpoint, we should see our LoRA along
|
||||
with its base model:
|
||||
with its base model (if `jq` is not installed, you can follow [this guide](https://jqlang.org/download/) to install it.):
|
||||
|
||||
```bash
|
||||
curl localhost:8000/v1/models | jq .
|
||||
@ -134,7 +134,7 @@ curl -X POST http://localhost:8000/v1/load_lora_adapter \
|
||||
}'
|
||||
```
|
||||
|
||||
Upon a successful request, the API will respond with a 200 OK status code. If an error occurs, such as if the adapter
|
||||
Upon a successful request, the API will respond with a `200 OK` status code from `vllm serve`, and `curl` returns the response body: `Success: LoRA adapter 'sql_adapter' added successfully`. If an error occurs, such as if the adapter
|
||||
cannot be found or loaded, an appropriate error message will be returned.
|
||||
|
||||
Unloading a LoRA Adapter:
|
||||
@ -142,6 +142,8 @@ Unloading a LoRA Adapter:
|
||||
To unload a LoRA adapter that has been previously loaded, send a POST request to the `/v1/unload_lora_adapter` endpoint
|
||||
with the name or ID of the adapter to be unloaded.
|
||||
|
||||
Upon a successful request, the API responds with a `200 OK` status code from `vllm serve`, and `curl` returns the response body: `Success: LoRA adapter 'sql_adapter' removed successfully`.
|
||||
|
||||
Example request to unload a LoRA adapter:
|
||||
|
||||
```bash
|
||||
@ -157,9 +159,12 @@ Alternatively, you can use the LoRAResolver plugin to dynamically load LoRA adap
|
||||
|
||||
You can set up multiple LoRAResolver plugins if you want to load LoRA adapters from different sources. For example, you might have one resolver for local files and another for S3 storage. vLLM will load the first LoRA adapter that it finds.
|
||||
|
||||
You can either install existing plugins or implement your own.
|
||||
You can either install existing plugins or implement your own. By default, vLLM comes with a [resolver plugin to load LoRA adapters from a local directory.](https://github.com/vllm-project/vllm/tree/main/vllm/plugins/lora_resolvers)
|
||||
To enable this resolver, set `VLLM_ALLOW_RUNTIME_LORA_UPDATING` to True, set `VLLM_PLUGINS` to include `lora_filesystem_resolver`, and then set `VLLM_LORA_RESOLVER_CACHE_DIR` to a local directory. When vLLM receives a request using a LoRA adapter `foobar`,
|
||||
it will first look in the local directory for a directory `foobar`, and attempt to load the contents of that directory as a LoRA adapter. If successful, the request will complete as normal and
|
||||
that adapter will then be available for normal use on the server.
|
||||
|
||||
Steps to implement your own LoRAResolver plugin:
|
||||
Alternatively, follow these example steps to implement your own plugin:
|
||||
1. Implement the LoRAResolver interface.
|
||||
|
||||
Example of a simple S3 LoRAResolver implementation:
|
||||
|
||||
@ -141,10 +141,10 @@ Remember to check whether the `reasoning_content` exists in the response before
|
||||
The reasoning content is also available in the structured output. The structured output engine like `xgrammar` will use the reasoning content to generate structured output. It is only supported in v0 engine now.
|
||||
|
||||
```bash
|
||||
VLLM_USE_V1=0 vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --reasoning-parser deepseek_r1
|
||||
vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --reasoning-parser deepseek_r1
|
||||
```
|
||||
|
||||
Please note that the `VLLM_USE_V1` environment variable must be set to `0` to use the v0 engine.
|
||||
The following is an example client:
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
@ -236,6 +236,13 @@ For Qwen2.5, the chat template in tokenizer_config.json has already included sup
|
||||
|
||||
Flags: `--tool-call-parser hermes`
|
||||
|
||||
### DeepSeek-V3 Models (`deepseek_v3`)
|
||||
|
||||
Supported models:
|
||||
* `deepseek-ai/DeepSeek-V3-0324`
|
||||
|
||||
Flags: `--tool-call-parser deepseek_v3 --chat-template examples/tool_chat_template_deepseekv3.jinja`
|
||||
|
||||
### Models with Pythonic Tool Calls (`pythonic`)
|
||||
|
||||
A growing number of models output a python list to represent tool calls instead of using JSON. This has the advantage of inherently supporting parallel tool calls and removing ambiguity around the JSON schema required for tool calls. The `pythonic` tool parser can support such models.
|
||||
|
||||
@ -19,8 +19,8 @@ If you are using NVIDIA GPUs, you can install vLLM using [pip](https://pypi.org/
|
||||
It's recommended to use [uv](https://docs.astral.sh/uv/), a very fast Python environment manager, to create and manage Python environments. Please follow the [documentation](https://docs.astral.sh/uv/#getting-started) to install `uv`. After installing `uv`, you can create a new Python environment and install vLLM using the following commands:
|
||||
|
||||
```console
|
||||
uv venv myenv --python 3.12 --seed
|
||||
source myenv/bin/activate
|
||||
uv venv --python 3.12 --seed
|
||||
source .venv/bin/activate
|
||||
uv pip install vllm
|
||||
```
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user