Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| f96a3cc713 | |||
| 32c0155774 |
@ -8,12 +8,12 @@ import zipfile
|
|||||||
# Note that we have 400 MiB quota, please use it wisely.
|
# Note that we have 400 MiB quota, please use it wisely.
|
||||||
# See https://github.com/pypi/support/issues/3792 .
|
# See https://github.com/pypi/support/issues/3792 .
|
||||||
# Please also sync the value with the one in Dockerfile.
|
# Please also sync the value with the one in Dockerfile.
|
||||||
VLLM_MAX_SIZE_MB = int(os.environ.get("VLLM_MAX_SIZE_MB", 400))
|
VLLM_MAX_SIZE_MB = int(os.environ.get('VLLM_MAX_SIZE_MB', 400))
|
||||||
|
|
||||||
|
|
||||||
def print_top_10_largest_files(zip_file):
|
def print_top_10_largest_files(zip_file):
|
||||||
"""Print the top 10 largest files in the given zip file."""
|
"""Print the top 10 largest files in the given zip file."""
|
||||||
with zipfile.ZipFile(zip_file, "r") as z:
|
with zipfile.ZipFile(zip_file, 'r') as z:
|
||||||
file_sizes = [(f, z.getinfo(f).file_size) for f in z.namelist()]
|
file_sizes = [(f, z.getinfo(f).file_size) for f in z.namelist()]
|
||||||
file_sizes.sort(key=lambda x: x[1], reverse=True)
|
file_sizes.sort(key=lambda x: x[1], reverse=True)
|
||||||
for f, size in file_sizes[:10]:
|
for f, size in file_sizes[:10]:
|
||||||
@ -28,18 +28,14 @@ def check_wheel_size(directory):
|
|||||||
wheel_path = os.path.join(root, file_name)
|
wheel_path = os.path.join(root, file_name)
|
||||||
wheel_size_mb = os.path.getsize(wheel_path) / (1024 * 1024)
|
wheel_size_mb = os.path.getsize(wheel_path) / (1024 * 1024)
|
||||||
if wheel_size_mb > VLLM_MAX_SIZE_MB:
|
if wheel_size_mb > VLLM_MAX_SIZE_MB:
|
||||||
print(
|
print(f"Not allowed: Wheel {wheel_path} is larger "
|
||||||
f"Not allowed: Wheel {wheel_path} is larger "
|
f"({wheel_size_mb:.2f} MB) than the limit "
|
||||||
f"({wheel_size_mb:.2f} MB) than the limit "
|
f"({VLLM_MAX_SIZE_MB} MB).")
|
||||||
f"({VLLM_MAX_SIZE_MB} MB)."
|
|
||||||
)
|
|
||||||
print_top_10_largest_files(wheel_path)
|
print_top_10_largest_files(wheel_path)
|
||||||
return 1
|
return 1
|
||||||
else:
|
else:
|
||||||
print(
|
print(f"Wheel {wheel_path} is within the allowed size "
|
||||||
f"Wheel {wheel_path} is within the allowed size "
|
f"({wheel_size_mb:.2f} MB).")
|
||||||
f"({wheel_size_mb:.2f} MB)."
|
|
||||||
)
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
@ -49,4 +45,4 @@ if __name__ == "__main__":
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
directory = sys.argv[1]
|
directory = sys.argv[1]
|
||||||
sys.exit(check_wheel_size(directory))
|
sys.exit(check_wheel_size(directory))
|
||||||
@ -22,5 +22,5 @@ with open("index.html", "w") as f:
|
|||||||
print(f"Generated index.html for {args.wheel}")
|
print(f"Generated index.html for {args.wheel}")
|
||||||
# cloudfront requires escaping the '+' character
|
# cloudfront requires escaping the '+' character
|
||||||
f.write(
|
f.write(
|
||||||
template.format(wheel=filename, wheel_html_escaped=filename.replace("+", "%2B"))
|
template.format(wheel=filename,
|
||||||
)
|
wheel_html_escaped=filename.replace("+", "%2B")))
|
||||||
|
|||||||
@ -8,14 +8,11 @@ def pytest_addoption(parser):
|
|||||||
parser.addoption(
|
parser.addoption(
|
||||||
"--config-list-file",
|
"--config-list-file",
|
||||||
action="store",
|
action="store",
|
||||||
help="Path to the file listing model config YAMLs (one per line)",
|
help="Path to the file listing model config YAMLs (one per line)")
|
||||||
)
|
parser.addoption("--tp-size",
|
||||||
parser.addoption(
|
action="store",
|
||||||
"--tp-size",
|
default="1",
|
||||||
action="store",
|
help="Tensor parallel size to use for evaluation")
|
||||||
default="1",
|
|
||||||
help="Tensor parallel size to use for evaluation",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
@ -36,8 +33,7 @@ def pytest_generate_tests(metafunc):
|
|||||||
config_dir = config_list_file.parent
|
config_dir = config_list_file.parent
|
||||||
with open(config_list_file, encoding="utf-8") as f:
|
with open(config_list_file, encoding="utf-8") as f:
|
||||||
configs = [
|
configs = [
|
||||||
config_dir / line.strip()
|
config_dir / line.strip() for line in f
|
||||||
for line in f
|
|
||||||
if line.strip() and not line.startswith("#")
|
if line.strip() and not line.startswith("#")
|
||||||
]
|
]
|
||||||
metafunc.parametrize("config_filename", configs)
|
metafunc.parametrize("config_filename", configs)
|
||||||
|
|||||||
@ -16,22 +16,19 @@ RTOL = 0.08
|
|||||||
|
|
||||||
|
|
||||||
def launch_lm_eval(eval_config, tp_size):
|
def launch_lm_eval(eval_config, tp_size):
|
||||||
trust_remote_code = eval_config.get("trust_remote_code", False)
|
trust_remote_code = eval_config.get('trust_remote_code', False)
|
||||||
model_args = (
|
model_args = f"pretrained={eval_config['model_name']}," \
|
||||||
f"pretrained={eval_config['model_name']},"
|
f"tensor_parallel_size={tp_size}," \
|
||||||
f"tensor_parallel_size={tp_size},"
|
f"enforce_eager=true," \
|
||||||
f"enforce_eager=true,"
|
f"add_bos_token=true," \
|
||||||
f"add_bos_token=true,"
|
f"trust_remote_code={trust_remote_code}"
|
||||||
f"trust_remote_code={trust_remote_code}"
|
|
||||||
)
|
|
||||||
results = lm_eval.simple_evaluate(
|
results = lm_eval.simple_evaluate(
|
||||||
model="vllm",
|
model="vllm",
|
||||||
model_args=model_args,
|
model_args=model_args,
|
||||||
tasks=[task["name"] for task in eval_config["tasks"]],
|
tasks=[task["name"] for task in eval_config["tasks"]],
|
||||||
num_fewshot=eval_config["num_fewshot"],
|
num_fewshot=eval_config["num_fewshot"],
|
||||||
limit=eval_config["limit"],
|
limit=eval_config["limit"],
|
||||||
batch_size="auto",
|
batch_size="auto")
|
||||||
)
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@ -45,10 +42,9 @@ def test_lm_eval_correctness_param(config_filename, tp_size):
|
|||||||
for metric in task["metrics"]:
|
for metric in task["metrics"]:
|
||||||
ground_truth = metric["value"]
|
ground_truth = metric["value"]
|
||||||
measured_value = results["results"][task["name"]][metric["name"]]
|
measured_value = results["results"][task["name"]][metric["name"]]
|
||||||
print(
|
print(f'{task["name"]} | {metric["name"]}: '
|
||||||
f"{task['name']} | {metric['name']}: "
|
f'ground_truth={ground_truth} | measured={measured_value}')
|
||||||
f"ground_truth={ground_truth} | measured={measured_value}"
|
success = success and np.isclose(
|
||||||
)
|
ground_truth, measured_value, rtol=RTOL)
|
||||||
success = success and np.isclose(ground_truth, measured_value, rtol=RTOL)
|
|
||||||
|
|
||||||
assert success
|
assert success
|
||||||
|
|||||||
@ -65,18 +65,18 @@ def read_markdown(file):
|
|||||||
|
|
||||||
|
|
||||||
def results_to_json(latency, throughput, serving):
|
def results_to_json(latency, throughput, serving):
|
||||||
return json.dumps(
|
return json.dumps({
|
||||||
{
|
'latency': latency.to_dict(),
|
||||||
"latency": latency.to_dict(),
|
'throughput': throughput.to_dict(),
|
||||||
"throughput": throughput.to_dict(),
|
'serving': serving.to_dict()
|
||||||
"serving": serving.to_dict(),
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
# collect results
|
# collect results
|
||||||
for test_file in results_folder.glob("*.json"):
|
for test_file in results_folder.glob("*.json"):
|
||||||
|
|
||||||
with open(test_file) as f:
|
with open(test_file) as f:
|
||||||
raw_result = json.loads(f.read())
|
raw_result = json.loads(f.read())
|
||||||
|
|
||||||
@ -120,8 +120,7 @@ if __name__ == "__main__":
|
|||||||
for perc in [10, 25, 50, 75, 90, 99]:
|
for perc in [10, 25, 50, 75, 90, 99]:
|
||||||
# Multiply 1000 to convert the time unit from s to ms
|
# Multiply 1000 to convert the time unit from s to ms
|
||||||
raw_result.update(
|
raw_result.update(
|
||||||
{f"P{perc}": 1000 * raw_result["percentiles"][str(perc)]}
|
{f"P{perc}": 1000 * raw_result["percentiles"][str(perc)]})
|
||||||
)
|
|
||||||
raw_result["avg_latency"] = raw_result["avg_latency"] * 1000
|
raw_result["avg_latency"] = raw_result["avg_latency"] * 1000
|
||||||
|
|
||||||
# add the result to raw_result
|
# add the result to raw_result
|
||||||
@ -154,27 +153,26 @@ if __name__ == "__main__":
|
|||||||
serving_results = pd.DataFrame.from_dict(serving_results)
|
serving_results = pd.DataFrame.from_dict(serving_results)
|
||||||
throughput_results = pd.DataFrame.from_dict(throughput_results)
|
throughput_results = pd.DataFrame.from_dict(throughput_results)
|
||||||
|
|
||||||
raw_results_json = results_to_json(
|
raw_results_json = results_to_json(latency_results, throughput_results,
|
||||||
latency_results, throughput_results, serving_results
|
serving_results)
|
||||||
)
|
|
||||||
|
|
||||||
# remapping the key, for visualization purpose
|
# remapping the key, for visualization purpose
|
||||||
if not latency_results.empty:
|
if not latency_results.empty:
|
||||||
latency_results = latency_results[list(latency_column_mapping.keys())].rename(
|
latency_results = latency_results[list(
|
||||||
columns=latency_column_mapping
|
latency_column_mapping.keys())].rename(
|
||||||
)
|
columns=latency_column_mapping)
|
||||||
if not serving_results.empty:
|
if not serving_results.empty:
|
||||||
serving_results = serving_results[list(serving_column_mapping.keys())].rename(
|
serving_results = serving_results[list(
|
||||||
columns=serving_column_mapping
|
serving_column_mapping.keys())].rename(
|
||||||
)
|
columns=serving_column_mapping)
|
||||||
if not throughput_results.empty:
|
if not throughput_results.empty:
|
||||||
throughput_results = throughput_results[
|
throughput_results = throughput_results[list(
|
||||||
list(throughput_results_column_mapping.keys())
|
throughput_results_column_mapping.keys())].rename(
|
||||||
].rename(columns=throughput_results_column_mapping)
|
columns=throughput_results_column_mapping)
|
||||||
|
|
||||||
processed_results_json = results_to_json(
|
processed_results_json = results_to_json(latency_results,
|
||||||
latency_results, throughput_results, serving_results
|
throughput_results,
|
||||||
)
|
serving_results)
|
||||||
|
|
||||||
for df in [latency_results, serving_results, throughput_results]:
|
for df in [latency_results, serving_results, throughput_results]:
|
||||||
if df.empty:
|
if df.empty:
|
||||||
@ -186,39 +184,38 @@ if __name__ == "__main__":
|
|||||||
# The GPUs sometimes come in format of "GPUTYPE\nGPUTYPE\n...",
|
# The GPUs sometimes come in format of "GPUTYPE\nGPUTYPE\n...",
|
||||||
# we want to turn it into "8xGPUTYPE"
|
# we want to turn it into "8xGPUTYPE"
|
||||||
df["GPU"] = df["GPU"].apply(
|
df["GPU"] = df["GPU"].apply(
|
||||||
lambda x: f"{len(x.split('\n'))}x{x.split('\n')[0]}"
|
lambda x: f"{len(x.split('\n'))}x{x.split('\n')[0]}")
|
||||||
)
|
|
||||||
|
|
||||||
# get markdown tables
|
# get markdown tables
|
||||||
latency_md_table = tabulate(
|
latency_md_table = tabulate(latency_results,
|
||||||
latency_results, headers="keys", tablefmt="pipe", showindex=False
|
headers='keys',
|
||||||
)
|
tablefmt='pipe',
|
||||||
serving_md_table = tabulate(
|
showindex=False)
|
||||||
serving_results, headers="keys", tablefmt="pipe", showindex=False
|
serving_md_table = tabulate(serving_results,
|
||||||
)
|
headers='keys',
|
||||||
throughput_md_table = tabulate(
|
tablefmt='pipe',
|
||||||
throughput_results, headers="keys", tablefmt="pipe", showindex=False
|
showindex=False)
|
||||||
)
|
throughput_md_table = tabulate(throughput_results,
|
||||||
|
headers='keys',
|
||||||
|
tablefmt='pipe',
|
||||||
|
showindex=False)
|
||||||
|
|
||||||
# document the result
|
# document the result
|
||||||
with open(results_folder / "benchmark_results.md", "w") as f:
|
with open(results_folder / "benchmark_results.md", "w") as f:
|
||||||
results = read_markdown(
|
|
||||||
"../.buildkite/nightly-benchmarks/"
|
results = read_markdown("../.buildkite/nightly-benchmarks/" +
|
||||||
+ "performance-benchmarks-descriptions.md"
|
"performance-benchmarks-descriptions.md")
|
||||||
)
|
|
||||||
results = results.format(
|
results = results.format(
|
||||||
latency_tests_markdown_table=latency_md_table,
|
latency_tests_markdown_table=latency_md_table,
|
||||||
throughput_tests_markdown_table=throughput_md_table,
|
throughput_tests_markdown_table=throughput_md_table,
|
||||||
serving_tests_markdown_table=serving_md_table,
|
serving_tests_markdown_table=serving_md_table,
|
||||||
benchmarking_results_in_json_string=processed_results_json,
|
benchmarking_results_in_json_string=processed_results_json)
|
||||||
)
|
|
||||||
f.write(results)
|
f.write(results)
|
||||||
|
|
||||||
# document benchmarking results in json
|
# document benchmarking results in json
|
||||||
with open(results_folder / "benchmark_results.json", "w") as f:
|
with open(results_folder / "benchmark_results.json", "w") as f:
|
||||||
results = (
|
|
||||||
latency_results.to_dict(orient="records")
|
results = latency_results.to_dict(
|
||||||
+ throughput_results.to_dict(orient="records")
|
orient='records') + throughput_results.to_dict(
|
||||||
+ serving_results.to_dict(orient="records")
|
orient='records') + serving_results.to_dict(orient='records')
|
||||||
)
|
|
||||||
f.write(json.dumps(results))
|
f.write(json.dumps(results))
|
||||||
|
|||||||
@ -14,12 +14,15 @@ def main(model, cachedir):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Download and save Hugging Face tokenizer"
|
description="Download and save Hugging Face tokenizer")
|
||||||
)
|
parser.add_argument("--model",
|
||||||
parser.add_argument("--model", type=str, required=True, help="Name of the model")
|
type=str,
|
||||||
parser.add_argument(
|
required=True,
|
||||||
"--cachedir", type=str, required=True, help="Directory to save the tokenizer"
|
help="Name of the model")
|
||||||
)
|
parser.add_argument("--cachedir",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Directory to save the tokenizer")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args.model, args.cachedir)
|
main(args.model, args.cachedir)
|
||||||
|
|||||||
@ -11,33 +11,33 @@ from tabulate import tabulate
|
|||||||
|
|
||||||
def parse_arguments():
|
def parse_arguments():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Parse command line arguments for summary-nightly-results script."
|
description=
|
||||||
)
|
'Parse command line arguments for summary-nightly-results script.')
|
||||||
parser.add_argument(
|
parser.add_argument('--results-folder',
|
||||||
"--results-folder",
|
type=str,
|
||||||
type=str,
|
required=True,
|
||||||
required=True,
|
help='The folder where the results are stored.')
|
||||||
help="The folder where the results are stored.",
|
parser.add_argument('--description',
|
||||||
)
|
type=str,
|
||||||
parser.add_argument(
|
required=True,
|
||||||
"--description", type=str, required=True, help="Description of the results."
|
help='Description of the results.')
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
def get_perf(df, method, model, metric):
|
def get_perf(df, method, model, metric):
|
||||||
|
|
||||||
means = []
|
means = []
|
||||||
|
|
||||||
for qps in [2, 4, 8, 16, "inf"]:
|
for qps in [2, 4, 8, 16, "inf"]:
|
||||||
target = df["Test name"].str.contains(model)
|
target = df['Test name'].str.contains(model)
|
||||||
target = target & df["Engine"].str.contains(method)
|
target = target & df['Engine'].str.contains(method)
|
||||||
target = target & df["Test name"].str.contains("qps_" + str(qps))
|
target = target & df['Test name'].str.contains("qps_" + str(qps))
|
||||||
filtered_df = df[target]
|
filtered_df = df[target]
|
||||||
|
|
||||||
if filtered_df.empty:
|
if filtered_df.empty:
|
||||||
means.append(0.0)
|
means.append(0.)
|
||||||
else:
|
else:
|
||||||
means.append(filtered_df[metric].values[0])
|
means.append(filtered_df[metric].values[0])
|
||||||
|
|
||||||
@ -45,6 +45,7 @@ def get_perf(df, method, model, metric):
|
|||||||
|
|
||||||
|
|
||||||
def get_perf_w_std(df, method, model, metric):
|
def get_perf_w_std(df, method, model, metric):
|
||||||
|
|
||||||
if metric in ["TTFT", "ITL"]:
|
if metric in ["TTFT", "ITL"]:
|
||||||
mean = get_perf(df, method, model, "Mean " + metric + " (ms)")
|
mean = get_perf(df, method, model, "Mean " + metric + " (ms)")
|
||||||
mean = mean.tolist()
|
mean = mean.tolist()
|
||||||
@ -59,8 +60,7 @@ def get_perf_w_std(df, method, model, metric):
|
|||||||
else:
|
else:
|
||||||
assert metric == "Tput"
|
assert metric == "Tput"
|
||||||
mean = get_perf(df, method, model, "Input Tput (tok/s)") + get_perf(
|
mean = get_perf(df, method, model, "Input Tput (tok/s)") + get_perf(
|
||||||
df, method, model, "Output Tput (tok/s)"
|
df, method, model, "Output Tput (tok/s)")
|
||||||
)
|
|
||||||
mean = mean.tolist()
|
mean = mean.tolist()
|
||||||
std = None
|
std = None
|
||||||
|
|
||||||
@ -80,17 +80,18 @@ def main(args):
|
|||||||
# generate markdown table
|
# generate markdown table
|
||||||
df = pd.DataFrame.from_dict(results)
|
df = pd.DataFrame.from_dict(results)
|
||||||
|
|
||||||
md_table = tabulate(df, headers="keys", tablefmt="pipe", showindex=False)
|
md_table = tabulate(df, headers='keys', tablefmt='pipe', showindex=False)
|
||||||
|
|
||||||
with open(args.description) as f:
|
with open(args.description) as f:
|
||||||
description = f.read()
|
description = f.read()
|
||||||
|
|
||||||
description = description.format(nightly_results_benchmarking_table=md_table)
|
description = description.format(
|
||||||
|
nightly_results_benchmarking_table=md_table)
|
||||||
|
|
||||||
with open("nightly_results.md", "w") as f:
|
with open("nightly_results.md", "w") as f:
|
||||||
f.write(description)
|
f.write(description)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
args = parse_arguments()
|
args = parse_arguments()
|
||||||
main(args)
|
main(args)
|
||||||
|
|||||||
@ -34,8 +34,10 @@ serving_column_mapping = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
# collect results
|
# collect results
|
||||||
for test_file in results_folder.glob("*.json"):
|
for test_file in results_folder.glob("*.json"):
|
||||||
|
|
||||||
with open(test_file) as f:
|
with open(test_file) as f:
|
||||||
raw_result = json.loads(f.read())
|
raw_result = json.loads(f.read())
|
||||||
|
|
||||||
@ -54,16 +56,17 @@ if __name__ == "__main__":
|
|||||||
serving_results = pd.DataFrame.from_dict(serving_results)
|
serving_results = pd.DataFrame.from_dict(serving_results)
|
||||||
|
|
||||||
if not serving_results.empty:
|
if not serving_results.empty:
|
||||||
serving_results = serving_results[list(serving_column_mapping.keys())].rename(
|
serving_results = serving_results[list(
|
||||||
columns=serving_column_mapping
|
serving_column_mapping.keys())].rename(
|
||||||
)
|
columns=serving_column_mapping)
|
||||||
|
|
||||||
serving_md_table_with_headers = tabulate(
|
serving_md_table_with_headers = tabulate(serving_results,
|
||||||
serving_results, headers="keys", tablefmt="pipe", showindex=False
|
headers='keys',
|
||||||
)
|
tablefmt='pipe',
|
||||||
|
showindex=False)
|
||||||
# remove the first line of header
|
# remove the first line of header
|
||||||
serving_md_table_lines = serving_md_table_with_headers.split("\n")
|
serving_md_table_lines = serving_md_table_with_headers.split('\n')
|
||||||
serving_md_table_without_header = "\n".join(serving_md_table_lines[2:])
|
serving_md_table_without_header = '\n'.join(serving_md_table_lines[2:])
|
||||||
|
|
||||||
prefix = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
prefix = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||||
prefix = prefix + "_" + os.environ.get("CURRENT_LLM_SERVING_ENGINE")
|
prefix = prefix + "_" + os.environ.get("CURRENT_LLM_SERVING_ENGINE")
|
||||||
@ -73,9 +76,10 @@ if __name__ == "__main__":
|
|||||||
# document results with header.
|
# document results with header.
|
||||||
# for those who wants to reproduce our benchmark.
|
# for those who wants to reproduce our benchmark.
|
||||||
f.write(serving_md_table_with_headers)
|
f.write(serving_md_table_with_headers)
|
||||||
f.write("\n")
|
f.write('\n')
|
||||||
|
|
||||||
# document benchmarking results in json
|
# document benchmarking results in json
|
||||||
with open(results_folder / f"{prefix}_nightly_results.json", "w") as f:
|
with open(results_folder / f"{prefix}_nightly_results.json", "w") as f:
|
||||||
results = serving_results.to_dict(orient="records")
|
|
||||||
|
results = serving_results.to_dict(orient='records')
|
||||||
f.write(json.dumps(results))
|
f.write(json.dumps(results))
|
||||||
|
|||||||
@ -1,46 +0,0 @@
|
|||||||
# This local pyproject file is part of the migration from yapf to ruff format.
|
|
||||||
# It uses the same core rules as the main pyproject.toml file, but with the
|
|
||||||
# following differences:
|
|
||||||
# - ruff line length is overridden to 88
|
|
||||||
# - deprecated typing ignores (UP006, UP035) have been removed
|
|
||||||
|
|
||||||
[tool.ruff]
|
|
||||||
line-length = 88
|
|
||||||
|
|
||||||
[tool.ruff.lint.per-file-ignores]
|
|
||||||
"vllm/third_party/**" = ["ALL"]
|
|
||||||
"vllm/version.py" = ["F401"]
|
|
||||||
"vllm/_version.py" = ["ALL"]
|
|
||||||
|
|
||||||
[tool.ruff.lint]
|
|
||||||
select = [
|
|
||||||
# pycodestyle
|
|
||||||
"E",
|
|
||||||
# Pyflakes
|
|
||||||
"F",
|
|
||||||
# pyupgrade
|
|
||||||
"UP",
|
|
||||||
# flake8-bugbear
|
|
||||||
"B",
|
|
||||||
# flake8-simplify
|
|
||||||
"SIM",
|
|
||||||
# isort
|
|
||||||
"I",
|
|
||||||
# flake8-logging-format
|
|
||||||
"G",
|
|
||||||
]
|
|
||||||
ignore = [
|
|
||||||
# star imports
|
|
||||||
"F405", "F403",
|
|
||||||
# lambda expression assignment
|
|
||||||
"E731",
|
|
||||||
# Loop control variable not used within loop body
|
|
||||||
"B007",
|
|
||||||
# f-string format
|
|
||||||
"UP032",
|
|
||||||
# Can remove once 3.10+ is the minimum Python version
|
|
||||||
"UP007",
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.ruff.format]
|
|
||||||
docstring-code-format = true
|
|
||||||
@ -14,7 +14,7 @@ steps:
|
|||||||
agents:
|
agents:
|
||||||
queue: cpu_queue_postmerge
|
queue: cpu_queue_postmerge
|
||||||
commands:
|
commands:
|
||||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.6.3 --build-arg torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0+PTX' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
|
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.6.3 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
|
||||||
- "mkdir artifacts"
|
- "mkdir artifacts"
|
||||||
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||||
- "bash .buildkite/scripts/upload-wheels.sh"
|
- "bash .buildkite/scripts/upload-wheels.sh"
|
||||||
@ -31,7 +31,7 @@ steps:
|
|||||||
agents:
|
agents:
|
||||||
queue: cpu_queue_postmerge
|
queue: cpu_queue_postmerge
|
||||||
commands:
|
commands:
|
||||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=11.8.0 --build-arg torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0+PTX' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
|
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=11.8.0 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
|
||||||
- "mkdir artifacts"
|
- "mkdir artifacts"
|
||||||
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||||
- "bash .buildkite/scripts/upload-wheels.sh"
|
- "bash .buildkite/scripts/upload-wheels.sh"
|
||||||
@ -64,7 +64,7 @@ steps:
|
|||||||
- "docker push vllm/vllm-tpu:$BUILDKITE_COMMIT"
|
- "docker push vllm/vllm-tpu:$BUILDKITE_COMMIT"
|
||||||
plugins:
|
plugins:
|
||||||
- docker-login#v3.0.0:
|
- docker-login#v3.0.0:
|
||||||
username: vllmbot
|
username: vllm
|
||||||
password-env: DOCKERHUB_TOKEN
|
password-env: DOCKERHUB_TOKEN
|
||||||
env:
|
env:
|
||||||
DOCKER_BUILDKIT: "1"
|
DOCKER_BUILDKIT: "1"
|
||||||
|
|||||||
@ -3,9 +3,6 @@
|
|||||||
# This script runs test inside the corresponding ROCm docker container.
|
# This script runs test inside the corresponding ROCm docker container.
|
||||||
set -o pipefail
|
set -o pipefail
|
||||||
|
|
||||||
# Export Python path
|
|
||||||
export PYTHONPATH=".."
|
|
||||||
|
|
||||||
# Print ROCm version
|
# Print ROCm version
|
||||||
echo "--- Confirming Clean Initial State"
|
echo "--- Confirming Clean Initial State"
|
||||||
while true; do
|
while true; do
|
||||||
@ -77,23 +74,6 @@ HF_MOUNT="/root/.cache/huggingface"
|
|||||||
|
|
||||||
commands=$@
|
commands=$@
|
||||||
echo "Commands:$commands"
|
echo "Commands:$commands"
|
||||||
|
|
||||||
if [[ $commands == *"pytest -v -s basic_correctness/test_basic_correctness.py"* ]]; then
|
|
||||||
commands=${commands//"pytest -v -s basic_correctness/test_basic_correctness.py"/"VLLM_USE_TRITON_FLASH_ATTN=0 pytest -v -s basic_correctness/test_basic_correctness.py"}
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [[ $commands == *"pytest -v -s models/test_registry.py"* ]]; then
|
|
||||||
commands=${commands//"pytest -v -s models/test_registry.py"/"pytest -v -s models/test_registry.py -k 'not BambaForCausalLM and not GritLM and not Mamba2ForCausalLM and not Zamba2ForCausalLM'"}
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [[ $commands == *"VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2'"* ]]; then
|
|
||||||
commands=${commands//"VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2'"/"VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2 and not BambaForCausalLM and not Gemma2ForCausalLM and not Grok1ModelForCausalLM and not Zamba2ForCausalLM and not Gemma2Model and not GritLM'"}
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [[ $commands == *"pytest -v -s compile/test_basic_correctness.py"* ]]; then
|
|
||||||
commands=${commands//"pytest -v -s compile/test_basic_correctness.py"/"VLLM_USE_TRITON_FLASH_ATTN=0 pytest -v -s compile/test_basic_correctness.py"}
|
|
||||||
fi
|
|
||||||
|
|
||||||
#ignore certain kernels tests
|
#ignore certain kernels tests
|
||||||
if [[ $commands == *" kernels/core"* ]]; then
|
if [[ $commands == *" kernels/core"* ]]; then
|
||||||
commands="${commands} \
|
commands="${commands} \
|
||||||
@ -181,8 +161,6 @@ fi
|
|||||||
|
|
||||||
|
|
||||||
PARALLEL_JOB_COUNT=8
|
PARALLEL_JOB_COUNT=8
|
||||||
MYPYTHONPATH=".."
|
|
||||||
|
|
||||||
# check if the command contains shard flag, we will run all shards in parallel because the host have 8 GPUs.
|
# check if the command contains shard flag, we will run all shards in parallel because the host have 8 GPUs.
|
||||||
if [[ $commands == *"--shard-id="* ]]; then
|
if [[ $commands == *"--shard-id="* ]]; then
|
||||||
# assign job count as the number of shards used
|
# assign job count as the number of shards used
|
||||||
@ -203,7 +181,6 @@ if [[ $commands == *"--shard-id="* ]]; then
|
|||||||
-e AWS_SECRET_ACCESS_KEY \
|
-e AWS_SECRET_ACCESS_KEY \
|
||||||
-v "${HF_CACHE}:${HF_MOUNT}" \
|
-v "${HF_CACHE}:${HF_MOUNT}" \
|
||||||
-e "HF_HOME=${HF_MOUNT}" \
|
-e "HF_HOME=${HF_MOUNT}" \
|
||||||
-e "PYTHONPATH=${MYPYTHONPATH}" \
|
|
||||||
--name "${container_name}_${GPU}" \
|
--name "${container_name}_${GPU}" \
|
||||||
"${image_name}" \
|
"${image_name}" \
|
||||||
/bin/bash -c "${commands_gpu}" \
|
/bin/bash -c "${commands_gpu}" \
|
||||||
@ -234,7 +211,6 @@ else
|
|||||||
-e AWS_SECRET_ACCESS_KEY \
|
-e AWS_SECRET_ACCESS_KEY \
|
||||||
-v "${HF_CACHE}:${HF_MOUNT}" \
|
-v "${HF_CACHE}:${HF_MOUNT}" \
|
||||||
-e "HF_HOME=${HF_MOUNT}" \
|
-e "HF_HOME=${HF_MOUNT}" \
|
||||||
-e "PYTHONPATH=${MYPYTHONPATH}" \
|
|
||||||
--name "${container_name}" \
|
--name "${container_name}" \
|
||||||
"${image_name}" \
|
"${image_name}" \
|
||||||
/bin/bash -c "${commands}"
|
/bin/bash -c "${commands}"
|
||||||
|
|||||||
@ -32,12 +32,9 @@ function cpu_tests() {
|
|||||||
set -e
|
set -e
|
||||||
pip install pytest pytest-asyncio einops peft Pillow soundfile transformers_stream_generator matplotlib
|
pip install pytest pytest-asyncio einops peft Pillow soundfile transformers_stream_generator matplotlib
|
||||||
pip install sentence-transformers datamodel_code_generator
|
pip install sentence-transformers datamodel_code_generator
|
||||||
pytest -v -s tests/models/language/generation/test_bart.py -m cpu_model
|
pytest -v -s tests/models/embedding/language/test_cls_models.py::test_classification_models[float-jason9693/Qwen2.5-1.5B-apeach]
|
||||||
pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-openai-community/gpt2]
|
pytest -v -s tests/models/embedding/language/test_embedding.py::test_models[half-BAAI/bge-base-en-v1.5]
|
||||||
pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-facebook/opt-125m]
|
pytest -v -s tests/models/encoder_decoder/language -m cpu_model"
|
||||||
pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-google/gemma-1.1-2b-it]
|
|
||||||
pytest -v -s tests/models/language/pooling/test_classification.py::test_models[float-jason9693/Qwen2.5-1.5B-apeach]
|
|
||||||
pytest -v -s tests/models/language/pooling/test_embedding.py::test_models[half-BAAI/bge-base-en-v1.5]"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# All of CPU tests are expected to be finished less than 40 mins.
|
# All of CPU tests are expected to be finished less than 40 mins.
|
||||||
|
|||||||
@ -10,17 +10,15 @@ docker build -t hpu-test-env -f docker/Dockerfile.hpu .
|
|||||||
# Setup cleanup
|
# Setup cleanup
|
||||||
# certain versions of HPU software stack have a bug that can
|
# certain versions of HPU software stack have a bug that can
|
||||||
# override the exit code of the script, so we need to use
|
# override the exit code of the script, so we need to use
|
||||||
# separate remove_docker_containers and remove_docker_containers_and_exit
|
# separate remove_docker_container and remove_docker_container_and_exit
|
||||||
# functions, while other platforms only need one remove_docker_container
|
# functions, while other platforms only need one remove_docker_container
|
||||||
# function.
|
# function.
|
||||||
EXITCODE=1
|
EXITCODE=1
|
||||||
remove_docker_containers() { docker rm -f hpu-test || true; docker rm -f hpu-test-tp2 || true; }
|
remove_docker_container() { docker rm -f hpu-test || true; }
|
||||||
remove_docker_containers_and_exit() { remove_docker_containers; exit $EXITCODE; }
|
remove_docker_container_and_exit() { remove_docker_container; exit $EXITCODE; }
|
||||||
trap remove_docker_containers_and_exit EXIT
|
trap remove_docker_container_and_exit EXIT
|
||||||
remove_docker_containers
|
remove_docker_container
|
||||||
|
|
||||||
# Run the image and launch offline inference
|
# Run the image and launch offline inference
|
||||||
docker run --runtime=habana --name=hpu-test --network=host -e HABANA_VISIBLE_DEVICES=all -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m
|
docker run --runtime=habana --name=hpu-test --network=host -e HABANA_VISIBLE_DEVICES=all -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m
|
||||||
docker run --runtime=habana --name=hpu-test-tp2 --network=host -e HABANA_VISIBLE_DEVICES=all -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --tensor-parallel-size 2
|
|
||||||
|
|
||||||
EXITCODE=$?
|
EXITCODE=$?
|
||||||
|
|||||||
@ -11,14 +11,13 @@ container_name="neuron_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)"
|
|||||||
HF_CACHE="$(realpath ~)/huggingface"
|
HF_CACHE="$(realpath ~)/huggingface"
|
||||||
mkdir -p "${HF_CACHE}"
|
mkdir -p "${HF_CACHE}"
|
||||||
HF_MOUNT="/root/.cache/huggingface"
|
HF_MOUNT="/root/.cache/huggingface"
|
||||||
HF_TOKEN=$(aws secretsmanager get-secret-value --secret-id "ci/vllm-neuron/hf-token" --region us-west-2 --query 'SecretString' --output text | jq -r .VLLM_NEURON_CI_HF_TOKEN)
|
|
||||||
|
|
||||||
NEURON_COMPILE_CACHE_URL="$(realpath ~)/neuron_compile_cache"
|
NEURON_COMPILE_CACHE_URL="$(realpath ~)/neuron_compile_cache"
|
||||||
mkdir -p "${NEURON_COMPILE_CACHE_URL}"
|
mkdir -p "${NEURON_COMPILE_CACHE_URL}"
|
||||||
NEURON_COMPILE_CACHE_MOUNT="/root/.cache/neuron_compile_cache"
|
NEURON_COMPILE_CACHE_MOUNT="/root/.cache/neuron_compile_cache"
|
||||||
|
|
||||||
# Try building the docker image
|
# Try building the docker image
|
||||||
aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws
|
aws ecr get-login-password --region us-west-2 | docker login --username AWS --password-stdin 763104351884.dkr.ecr.us-west-2.amazonaws.com
|
||||||
|
|
||||||
# prune old image and containers to save disk space, and only once a day
|
# prune old image and containers to save disk space, and only once a day
|
||||||
# by using a timestamp file in tmp.
|
# by using a timestamp file in tmp.
|
||||||
@ -48,16 +47,8 @@ trap remove_docker_container EXIT
|
|||||||
docker run --rm -it --device=/dev/neuron0 --network bridge \
|
docker run --rm -it --device=/dev/neuron0 --network bridge \
|
||||||
-v "${HF_CACHE}:${HF_MOUNT}" \
|
-v "${HF_CACHE}:${HF_MOUNT}" \
|
||||||
-e "HF_HOME=${HF_MOUNT}" \
|
-e "HF_HOME=${HF_MOUNT}" \
|
||||||
-e "HF_TOKEN=${HF_TOKEN}" \
|
|
||||||
-v "${NEURON_COMPILE_CACHE_URL}:${NEURON_COMPILE_CACHE_MOUNT}" \
|
-v "${NEURON_COMPILE_CACHE_URL}:${NEURON_COMPILE_CACHE_MOUNT}" \
|
||||||
-e "NEURON_COMPILE_CACHE_URL=${NEURON_COMPILE_CACHE_MOUNT}" \
|
-e "NEURON_COMPILE_CACHE_URL=${NEURON_COMPILE_CACHE_MOUNT}" \
|
||||||
--name "${container_name}" \
|
--name "${container_name}" \
|
||||||
${image_name} \
|
${image_name} \
|
||||||
/bin/bash -c "
|
/bin/bash -c "python3 /workspace/vllm/examples/offline_inference/neuron.py && python3 -m pytest /workspace/vllm/tests/neuron/1_core/ -v --capture=tee-sys && python3 -m pytest /workspace/vllm/tests/neuron/2_core/ -v --capture=tee-sys"
|
||||||
python3 /workspace/vllm/examples/offline_inference/neuron.py;
|
|
||||||
python3 -m pytest /workspace/vllm/tests/neuron/1_core/ -v --capture=tee-sys;
|
|
||||||
for f in /workspace/vllm/tests/neuron/2_core/*.py; do
|
|
||||||
echo 'Running test file: '$f;
|
|
||||||
python3 -m pytest \$f -v --capture=tee-sys;
|
|
||||||
done
|
|
||||||
"
|
|
||||||
|
|||||||
@ -26,27 +26,27 @@ docker run --privileged --net host --shm-size=16G -it \
|
|||||||
&& tpu-info \
|
&& tpu-info \
|
||||||
&& { \
|
&& { \
|
||||||
echo TEST_0: Running test_perf.py; \
|
echo TEST_0: Running test_perf.py; \
|
||||||
python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_perf.py; \
|
pytest -s -v /workspace/vllm/tests/tpu/test_perf.py; \
|
||||||
echo TEST_0_EXIT_CODE: \$?; \
|
echo TEST_0_EXIT_CODE: \$?; \
|
||||||
} & \
|
} & \
|
||||||
{ \
|
&& { \
|
||||||
echo TEST_1: Running test_compilation.py; \
|
echo TEST_1: Running test_compilation.py; \
|
||||||
python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_compilation.py; \
|
pytest -s -v /workspace/vllm/tests/tpu/test_compilation.py; \
|
||||||
echo TEST_1_EXIT_CODE: \$?; \
|
echo TEST_1_EXIT_CODE: \$?; \
|
||||||
} & \
|
} & \
|
||||||
{ \
|
{ \
|
||||||
echo TEST_2: Running test_basic.py; \
|
echo TEST_2: Running test_basic.py; \
|
||||||
python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_basic.py; \
|
pytest -s -v /workspace/vllm/tests/v1/tpu/test_basic.py; \
|
||||||
echo TEST_2_EXIT_CODE: \$?; \
|
echo TEST_2_EXIT_CODE: \$?; \
|
||||||
} & \
|
} & \
|
||||||
{ \
|
{ \
|
||||||
echo TEST_3: Running test_accuracy.py::test_lm_eval_accuracy_v1_engine; \
|
echo TEST_3: Running test_accuracy.py::test_lm_eval_accuracy_v1_engine; \
|
||||||
python3 -m pytest -s -v /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine; \
|
pytest -s -v /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine; \
|
||||||
echo TEST_3_EXIT_CODE: \$?; \
|
echo TEST_3_EXIT_CODE: \$?; \
|
||||||
} & \
|
} & \
|
||||||
{ \
|
{ \
|
||||||
echo TEST_4: Running test_quantization_accuracy.py; \
|
echo TEST_4: Running test_quantization_accuracy.py; \
|
||||||
python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py; \
|
pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py; \
|
||||||
echo TEST_4_EXIT_CODE: \$?; \
|
echo TEST_4_EXIT_CODE: \$?; \
|
||||||
} & \
|
} & \
|
||||||
{ \
|
{ \
|
||||||
@ -56,43 +56,43 @@ docker run --privileged --net host --shm-size=16G -it \
|
|||||||
} & \
|
} & \
|
||||||
{ \
|
{ \
|
||||||
echo TEST_6: Running test_tpu_model_runner.py; \
|
echo TEST_6: Running test_tpu_model_runner.py; \
|
||||||
python3 -m pytest -s -v /workspace/vllm/tests/tpu/worker/test_tpu_model_runner.py; \
|
pytest -s -v /workspace/vllm/tests/tpu/worker/test_tpu_model_runner.py; \
|
||||||
echo TEST_6_EXIT_CODE: \$?; \
|
echo TEST_6_EXIT_CODE: \$?; \
|
||||||
} & \
|
} & \
|
||||||
{ \
|
&& { \
|
||||||
echo TEST_7: Running test_sampler.py; \
|
echo TEST_7: Running test_sampler.py; \
|
||||||
python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py; \
|
pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py; \
|
||||||
echo TEST_7_EXIT_CODE: \$?; \
|
echo TEST_7_EXIT_CODE: \$?; \
|
||||||
} & \
|
} & \
|
||||||
{ \
|
&& { \
|
||||||
echo TEST_8: Running test_topk_topp_sampler.py; \
|
echo TEST_8: Running test_topk_topp_sampler.py; \
|
||||||
python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py; \
|
pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py; \
|
||||||
echo TEST_8_EXIT_CODE: \$?; \
|
echo TEST_8_EXIT_CODE: \$?; \
|
||||||
} & \
|
} & \
|
||||||
{ \
|
&& { \
|
||||||
echo TEST_9: Running test_multimodal.py; \
|
echo TEST_9: Running test_multimodal.py; \
|
||||||
python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_multimodal.py; \
|
pytest -s -v /workspace/vllm/tests/v1/tpu/test_multimodal.py; \
|
||||||
echo TEST_9_EXIT_CODE: \$?; \
|
echo TEST_9_EXIT_CODE: \$?; \
|
||||||
} & \
|
} & \
|
||||||
{ \
|
&& { \
|
||||||
echo TEST_10: Running test_pallas.py; \
|
echo TEST_10: Running test_pallas.py; \
|
||||||
python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py; \
|
pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py; \
|
||||||
echo TEST_10_EXIT_CODE: \$?; \
|
echo TEST_10_EXIT_CODE: \$?; \
|
||||||
} & \
|
} & \
|
||||||
{ \
|
&& { \
|
||||||
echo TEST_11: Running test_struct_output_generate.py; \
|
echo TEST_11: Running test_struct_output_generate.py; \
|
||||||
python3 -m pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py; \
|
pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py; \
|
||||||
echo TEST_11_EXIT_CODE: \$?; \
|
echo TEST_11_EXIT_CODE: \$?; \
|
||||||
} & \
|
} & \
|
||||||
{ \
|
&& { \
|
||||||
echo TEST_12: Running test_moe_pallas.py; \
|
echo TEST_12: Running test_moe_pallas.py; \
|
||||||
python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py; \
|
pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py; \
|
||||||
echo TEST_12_EXIT_CODE: \$?; \
|
echo TEST_12_EXIT_CODE: \$?; \
|
||||||
} & \
|
} & \
|
||||||
# Disable the TPU LoRA tests until the feature is activated
|
# Disable the TPU LoRA tests until the feature is activated
|
||||||
# & { \
|
# && { \
|
||||||
# echo TEST_13: Running test_moe_pallas.py; \
|
# echo TEST_13: Running test_moe_pallas.py; \
|
||||||
# python3 -m pytest -s -v /workspace/vllm/tests/tpu/lora/; \
|
# pytest -s -v /workspace/vllm/tests/tpu/lora/; \
|
||||||
# echo TEST_13_EXIT_CODE: \$?; \
|
# echo TEST_13_EXIT_CODE: \$?; \
|
||||||
# } & \
|
# } & \
|
||||||
wait \
|
wait \
|
||||||
|
|||||||
@ -75,4 +75,3 @@ else
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
aws s3 cp "$wheel" "s3://vllm-wheels/$version/"
|
aws s3 cp "$wheel" "s3://vllm-wheels/$version/"
|
||||||
aws s3 cp index.html "s3://vllm-wheels/$version/vllm/index.html"
|
|
||||||
|
|||||||
@ -32,17 +32,16 @@ steps:
|
|||||||
##### fast check tests #####
|
##### fast check tests #####
|
||||||
|
|
||||||
- label: Documentation Build # 2min
|
- label: Documentation Build # 2min
|
||||||
mirror_hardwares: [amdexperimental]
|
working_dir: "/vllm-workspace/test_docs/docs"
|
||||||
working_dir: "/vllm-workspace/test_docs"
|
|
||||||
fast_check: true
|
fast_check: true
|
||||||
no_gpu: True
|
no_gpu: True
|
||||||
commands:
|
commands:
|
||||||
- pip install -r ../requirements/docs.txt
|
- pip install -r ../../requirements/docs.txt
|
||||||
# TODO: add `--strict` once warnings in docstrings are fixed
|
- SPHINXOPTS=\"-W\" make html
|
||||||
- mkdocs build
|
# Check API reference (if it fails, you may have missing mock imports)
|
||||||
|
- grep \"sig sig-object py\" build/html/api/vllm/vllm.sampling_params.html
|
||||||
|
|
||||||
- label: Async Engine, Inputs, Utils, Worker Test # 24min
|
- label: Async Engine, Inputs, Utils, Worker Test # 24min
|
||||||
mirror_hardwares: [amdexperimental]
|
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
- tests/mq_llm_engine
|
- tests/mq_llm_engine
|
||||||
@ -58,13 +57,11 @@ steps:
|
|||||||
- pytest -v -s async_engine # AsyncLLMEngine
|
- pytest -v -s async_engine # AsyncLLMEngine
|
||||||
- NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py
|
- NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py
|
||||||
- pytest -v -s test_inputs.py
|
- pytest -v -s test_inputs.py
|
||||||
- pytest -v -s test_outputs.py
|
|
||||||
- pytest -v -s multimodal
|
- pytest -v -s multimodal
|
||||||
- pytest -v -s test_utils.py # Utils
|
- pytest -v -s test_utils.py # Utils
|
||||||
- pytest -v -s worker # Worker
|
- pytest -v -s worker # Worker
|
||||||
|
|
||||||
- label: Python-only Installation Test
|
- label: Python-only Installation Test
|
||||||
mirror_hardwares: [amdexperimental]
|
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- tests/standalone_tests/python_only_compile.sh
|
- tests/standalone_tests/python_only_compile.sh
|
||||||
- setup.py
|
- setup.py
|
||||||
@ -72,7 +69,7 @@ steps:
|
|||||||
- bash standalone_tests/python_only_compile.sh
|
- bash standalone_tests/python_only_compile.sh
|
||||||
|
|
||||||
- label: Basic Correctness Test # 30min
|
- label: Basic Correctness Test # 30min
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
#mirror_hardwares: [amd]
|
||||||
fast_check: true
|
fast_check: true
|
||||||
torch_nightly: true
|
torch_nightly: true
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
@ -89,7 +86,6 @@ steps:
|
|||||||
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
|
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
|
||||||
|
|
||||||
- label: Chunked Prefill Test
|
- label: Chunked Prefill Test
|
||||||
mirror_hardwares: [amdexperimental]
|
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
- tests/basic_correctness/test_chunked_prefill
|
- tests/basic_correctness/test_chunked_prefill
|
||||||
@ -98,7 +94,7 @@ steps:
|
|||||||
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
|
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
|
||||||
|
|
||||||
- label: Core Test # 10min
|
- label: Core Test # 10min
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
mirror_hardwares: [amd]
|
||||||
fast_check: true
|
fast_check: true
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/core
|
- vllm/core
|
||||||
@ -108,10 +104,10 @@ steps:
|
|||||||
- pytest -v -s core
|
- pytest -v -s core
|
||||||
|
|
||||||
- label: Entrypoints Test # 40min
|
- label: Entrypoints Test # 40min
|
||||||
mirror_hardwares: [amdexperimental]
|
|
||||||
working_dir: "/vllm-workspace/tests"
|
working_dir: "/vllm-workspace/tests"
|
||||||
fast_check: true
|
fast_check: true
|
||||||
torch_nightly: true
|
torch_nightly: true
|
||||||
|
#mirror_hardwares: [amd]
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
- tests/entrypoints/llm
|
- tests/entrypoints/llm
|
||||||
@ -125,12 +121,11 @@ steps:
|
|||||||
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
|
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
|
||||||
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
|
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
|
||||||
- VLLM_USE_V1=0 pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process
|
- VLLM_USE_V1=0 pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process
|
||||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/
|
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_openai_schema.py
|
||||||
- pytest -v -s entrypoints/test_chat_utils.py
|
- pytest -v -s entrypoints/test_chat_utils.py
|
||||||
- VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
|
- VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
|
||||||
|
|
||||||
- label: Distributed Tests (4 GPUs) # 10min
|
- label: Distributed Tests (4 GPUs) # 10min
|
||||||
mirror_hardwares: [amdexperimental]
|
|
||||||
working_dir: "/vllm-workspace/tests"
|
working_dir: "/vllm-workspace/tests"
|
||||||
num_gpus: 4
|
num_gpus: 4
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
@ -138,7 +133,6 @@ steps:
|
|||||||
- vllm/core/
|
- vllm/core/
|
||||||
- tests/distributed/test_utils
|
- tests/distributed/test_utils
|
||||||
- tests/distributed/test_pynccl
|
- tests/distributed/test_pynccl
|
||||||
- tests/distributed/test_events
|
|
||||||
- tests/spec_decode/e2e/test_integration_dist_tp4
|
- tests/spec_decode/e2e/test_integration_dist_tp4
|
||||||
- tests/compile/test_basic_correctness
|
- tests/compile/test_basic_correctness
|
||||||
- examples/offline_inference/rlhf.py
|
- examples/offline_inference/rlhf.py
|
||||||
@ -149,25 +143,22 @@ steps:
|
|||||||
# test with tp=2 and external_dp=2
|
# test with tp=2 and external_dp=2
|
||||||
- VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
|
- VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
|
||||||
- torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
|
- torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
|
||||||
# test with tp=2 and pp=2
|
|
||||||
- PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
|
|
||||||
# test with internal dp
|
# test with internal dp
|
||||||
- python3 ../examples/offline_inference/data_parallel.py
|
- python3 ../examples/offline_inference/data_parallel.py
|
||||||
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
||||||
- pytest -v -s distributed/test_utils.py
|
- pytest -v -s distributed/test_utils.py
|
||||||
- pytest -v -s compile/test_basic_correctness.py
|
- pytest -v -s compile/test_basic_correctness.py
|
||||||
- pytest -v -s distributed/test_pynccl.py
|
- pytest -v -s distributed/test_pynccl.py
|
||||||
- pytest -v -s distributed/test_events.py
|
|
||||||
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
|
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
|
||||||
# TODO: create a dedicated test section for multi-GPU example tests
|
# TODO: create a dedicated test section for multi-GPU example tests
|
||||||
# when we have multiple distributed example tests
|
# when we have multiple distributed example tests
|
||||||
- pushd ../examples/offline_inference
|
- pushd ../examples/offline_inference
|
||||||
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py
|
- python3 rlhf.py
|
||||||
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
|
- RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
|
||||||
- popd
|
- popd
|
||||||
|
|
||||||
- label: Metrics, Tracing Test # 10min
|
- label: Metrics, Tracing Test # 10min
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
mirror_hardwares: [amd]
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
@ -181,7 +172,7 @@ steps:
|
|||||||
##### 1 GPU test #####
|
##### 1 GPU test #####
|
||||||
|
|
||||||
- label: Regression Test # 5min
|
- label: Regression Test # 5min
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
#mirror_hardwares: [amd]
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
- tests/test_regression
|
- tests/test_regression
|
||||||
@ -191,7 +182,7 @@ steps:
|
|||||||
working_dir: "/vllm-workspace/tests" # optional
|
working_dir: "/vllm-workspace/tests" # optional
|
||||||
|
|
||||||
- label: Engine Test # 10min
|
- label: Engine Test # 10min
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
mirror_hardwares: [amd]
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
- tests/engine
|
- tests/engine
|
||||||
@ -205,7 +196,7 @@ steps:
|
|||||||
- pytest -v -s tokenization
|
- pytest -v -s tokenization
|
||||||
|
|
||||||
- label: V1 Test
|
- label: V1 Test
|
||||||
mirror_hardwares: [amdexperimental]
|
#mirror_hardwares: [amd]
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
- tests/v1
|
- tests/v1
|
||||||
@ -218,11 +209,10 @@ steps:
|
|||||||
- pytest -v -s v1/worker
|
- pytest -v -s v1/worker
|
||||||
- pytest -v -s v1/structured_output
|
- pytest -v -s v1/structured_output
|
||||||
- pytest -v -s v1/spec_decode
|
- pytest -v -s v1/spec_decode
|
||||||
- pytest -v -s v1/kv_connector/unit
|
|
||||||
- pytest -v -s v1/test_serial_utils.py
|
- pytest -v -s v1/test_serial_utils.py
|
||||||
|
- pytest -v -s v1/test_stats.py
|
||||||
- pytest -v -s v1/test_utils.py
|
- pytest -v -s v1/test_utils.py
|
||||||
- pytest -v -s v1/test_oracle.py
|
- pytest -v -s v1/test_oracle.py
|
||||||
- pytest -v -s v1/test_metrics_reader.py
|
|
||||||
# TODO: accuracy does not match, whether setting
|
# TODO: accuracy does not match, whether setting
|
||||||
# VLLM_USE_FLASHINFER_SAMPLER or not on H100.
|
# VLLM_USE_FLASHINFER_SAMPLER or not on H100.
|
||||||
- pytest -v -s v1/e2e
|
- pytest -v -s v1/e2e
|
||||||
@ -231,8 +221,8 @@ steps:
|
|||||||
- pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine
|
- pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine
|
||||||
|
|
||||||
- label: Examples Test # 25min
|
- label: Examples Test # 25min
|
||||||
mirror_hardwares: [amdexperimental]
|
|
||||||
working_dir: "/vllm-workspace/examples"
|
working_dir: "/vllm-workspace/examples"
|
||||||
|
#mirror_hardwares: [amd]
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/entrypoints
|
- vllm/entrypoints
|
||||||
- examples/
|
- examples/
|
||||||
@ -247,7 +237,7 @@ steps:
|
|||||||
- python3 offline_inference/vision_language.py --seed 0
|
- python3 offline_inference/vision_language.py --seed 0
|
||||||
- python3 offline_inference/vision_language_embedding.py --seed 0
|
- python3 offline_inference/vision_language_embedding.py --seed 0
|
||||||
- python3 offline_inference/vision_language_multi_image.py --seed 0
|
- python3 offline_inference/vision_language_multi_image.py --seed 0
|
||||||
- VLLM_USE_V1=0 python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
|
- VLLM_USE_V1=0 python3 other/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 other/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
|
||||||
- python3 offline_inference/encoder_decoder.py
|
- python3 offline_inference/encoder_decoder.py
|
||||||
- python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0
|
- python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0
|
||||||
- python3 offline_inference/basic/classify.py
|
- python3 offline_inference/basic/classify.py
|
||||||
@ -256,7 +246,7 @@ steps:
|
|||||||
- VLLM_USE_V1=0 python3 offline_inference/profiling.py --model facebook/opt-125m run_num_steps --num-steps 2
|
- VLLM_USE_V1=0 python3 offline_inference/profiling.py --model facebook/opt-125m run_num_steps --num-steps 2
|
||||||
|
|
||||||
- label: Prefix Caching Test # 9min
|
- label: Prefix Caching Test # 9min
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
mirror_hardwares: [amd]
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
- tests/prefix_caching
|
- tests/prefix_caching
|
||||||
@ -264,7 +254,6 @@ steps:
|
|||||||
- pytest -v -s prefix_caching
|
- pytest -v -s prefix_caching
|
||||||
|
|
||||||
- label: Samplers Test # 36min
|
- label: Samplers Test # 36min
|
||||||
mirror_hardwares: [amdexperimental]
|
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/model_executor/layers
|
- vllm/model_executor/layers
|
||||||
- vllm/sampling_metadata.py
|
- vllm/sampling_metadata.py
|
||||||
@ -275,7 +264,7 @@ steps:
|
|||||||
- VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers
|
- VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers
|
||||||
|
|
||||||
- label: LogitsProcessor Test # 5min
|
- label: LogitsProcessor Test # 5min
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
mirror_hardwares: [amd]
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/model_executor/layers
|
- vllm/model_executor/layers
|
||||||
- vllm/model_executor/guided_decoding
|
- vllm/model_executor/guided_decoding
|
||||||
@ -286,7 +275,6 @@ steps:
|
|||||||
- pytest -v -s model_executor/test_guided_processors.py
|
- pytest -v -s model_executor/test_guided_processors.py
|
||||||
|
|
||||||
- label: Speculative decoding tests # 40min
|
- label: Speculative decoding tests # 40min
|
||||||
mirror_hardwares: [amdexperimental]
|
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/spec_decode
|
- vllm/spec_decode
|
||||||
- tests/spec_decode
|
- tests/spec_decode
|
||||||
@ -297,7 +285,7 @@ steps:
|
|||||||
- pytest -v -s spec_decode/e2e/test_eagle_correctness.py
|
- pytest -v -s spec_decode/e2e/test_eagle_correctness.py
|
||||||
|
|
||||||
- label: LoRA Test %N # 15min each
|
- label: LoRA Test %N # 15min each
|
||||||
mirror_hardwares: [amdexperimental]
|
#mirror_hardwares: [amd]
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/lora
|
- vllm/lora
|
||||||
- tests/lora
|
- tests/lora
|
||||||
@ -305,7 +293,6 @@ steps:
|
|||||||
parallelism: 4
|
parallelism: 4
|
||||||
|
|
||||||
- label: PyTorch Compilation Unit Tests
|
- label: PyTorch Compilation Unit Tests
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
|
||||||
torch_nightly: true
|
torch_nightly: true
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
@ -313,12 +300,9 @@ steps:
|
|||||||
commands:
|
commands:
|
||||||
- pytest -v -s compile/test_pass_manager.py
|
- pytest -v -s compile/test_pass_manager.py
|
||||||
- pytest -v -s compile/test_fusion.py
|
- pytest -v -s compile/test_fusion.py
|
||||||
- pytest -v -s compile/test_silu_mul_quant_fusion.py
|
|
||||||
- pytest -v -s compile/test_sequence_parallelism.py
|
- pytest -v -s compile/test_sequence_parallelism.py
|
||||||
- pytest -v -s compile/test_async_tp.py
|
|
||||||
|
|
||||||
- label: PyTorch Fullgraph Smoke Test # 9min
|
- label: PyTorch Fullgraph Smoke Test # 9min
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
|
||||||
torch_nightly: true
|
torch_nightly: true
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
@ -330,7 +314,6 @@ steps:
|
|||||||
- pytest -v -s compile/piecewise/test_toy_llama.py
|
- pytest -v -s compile/piecewise/test_toy_llama.py
|
||||||
|
|
||||||
- label: PyTorch Fullgraph Test # 18min
|
- label: PyTorch Fullgraph Test # 18min
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
|
||||||
torch_nightly: true
|
torch_nightly: true
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
@ -339,7 +322,7 @@ steps:
|
|||||||
- pytest -v -s compile/test_full_graph.py
|
- pytest -v -s compile/test_full_graph.py
|
||||||
|
|
||||||
- label: Kernels Core Operation Test
|
- label: Kernels Core Operation Test
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
mirror_hardwares: [amd]
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- csrc/
|
- csrc/
|
||||||
- tests/kernels/core
|
- tests/kernels/core
|
||||||
@ -347,7 +330,7 @@ steps:
|
|||||||
- pytest -v -s kernels/core
|
- pytest -v -s kernels/core
|
||||||
|
|
||||||
- label: Kernels Attention Test %N
|
- label: Kernels Attention Test %N
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
mirror_hardwares: [amd]
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- csrc/attention/
|
- csrc/attention/
|
||||||
- vllm/attention
|
- vllm/attention
|
||||||
@ -358,7 +341,7 @@ steps:
|
|||||||
parallelism: 2
|
parallelism: 2
|
||||||
|
|
||||||
- label: Kernels Quantization Test %N
|
- label: Kernels Quantization Test %N
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
mirror_hardwares: [amd]
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- csrc/quantization/
|
- csrc/quantization/
|
||||||
- vllm/model_executor/layers/quantization
|
- vllm/model_executor/layers/quantization
|
||||||
@ -368,7 +351,7 @@ steps:
|
|||||||
parallelism: 2
|
parallelism: 2
|
||||||
|
|
||||||
- label: Kernels MoE Test
|
- label: Kernels MoE Test
|
||||||
mirror_hardwares: [amdexperimental]
|
#mirror_hardwares: [amd]
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- csrc/moe/
|
- csrc/moe/
|
||||||
- tests/kernels/moe
|
- tests/kernels/moe
|
||||||
@ -377,7 +360,7 @@ steps:
|
|||||||
- pytest -v -s kernels/moe
|
- pytest -v -s kernels/moe
|
||||||
|
|
||||||
- label: Kernels Mamba Test
|
- label: Kernels Mamba Test
|
||||||
mirror_hardwares: [amdexperimental]
|
#mirror_hardwares: [amd]
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- csrc/mamba/
|
- csrc/mamba/
|
||||||
- tests/kernels/mamba
|
- tests/kernels/mamba
|
||||||
@ -385,28 +368,25 @@ steps:
|
|||||||
- pytest -v -s kernels/mamba
|
- pytest -v -s kernels/mamba
|
||||||
|
|
||||||
- label: Tensorizer Test # 11min
|
- label: Tensorizer Test # 11min
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
# mirror_hardwares: [amd]
|
||||||
soft_fail: true
|
soft_fail: true
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/model_executor/model_loader
|
- vllm/model_executor/model_loader
|
||||||
- tests/tensorizer_loader
|
- tests/tensorizer_loader
|
||||||
- tests/entrypoints/openai/test_tensorizer_entrypoint.py
|
|
||||||
commands:
|
commands:
|
||||||
- apt-get update && apt-get install -y curl libsodium23
|
- apt-get update && apt-get install -y curl libsodium23
|
||||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||||
- pytest -v -s tensorizer_loader
|
- pytest -v -s tensorizer_loader
|
||||||
- pytest -v -s entrypoints/openai/test_tensorizer_entrypoint.py
|
|
||||||
|
|
||||||
- label: Benchmarks # 9min
|
- label: Benchmarks # 9min
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
|
||||||
working_dir: "/vllm-workspace/.buildkite"
|
working_dir: "/vllm-workspace/.buildkite"
|
||||||
|
mirror_hardwares: [amd]
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- benchmarks/
|
- benchmarks/
|
||||||
commands:
|
commands:
|
||||||
- bash scripts/run-benchmarks.sh
|
- bash scripts/run-benchmarks.sh
|
||||||
|
|
||||||
- label: Benchmarks CLI Test # 10min
|
- label: Benchmarks CLI Test # 10min
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
- tests/benchmarks/
|
- tests/benchmarks/
|
||||||
@ -414,7 +394,6 @@ steps:
|
|||||||
- pytest -v -s benchmarks/
|
- pytest -v -s benchmarks/
|
||||||
|
|
||||||
- label: Quantization Test
|
- label: Quantization Test
|
||||||
mirror_hardwares: [amdexperimental]
|
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- csrc/
|
- csrc/
|
||||||
- vllm/model_executor/layers/quantization
|
- vllm/model_executor/layers/quantization
|
||||||
@ -423,7 +402,6 @@ steps:
|
|||||||
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization
|
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization
|
||||||
|
|
||||||
- label: LM Eval Small Models # 53min
|
- label: LM Eval Small Models # 53min
|
||||||
mirror_hardwares: [amdexperimental]
|
|
||||||
working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
|
working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- csrc/
|
- csrc/
|
||||||
@ -433,7 +411,6 @@ steps:
|
|||||||
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-small.txt --tp-size=1
|
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-small.txt --tp-size=1
|
||||||
|
|
||||||
- label: OpenAI API correctness
|
- label: OpenAI API correctness
|
||||||
mirror_hardwares: [amdexperimental]
|
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- csrc/
|
- csrc/
|
||||||
- vllm/entrypoints/openai/
|
- vllm/entrypoints/openai/
|
||||||
@ -442,7 +419,6 @@ steps:
|
|||||||
- pytest -s entrypoints/openai/correctness/
|
- pytest -s entrypoints/openai/correctness/
|
||||||
|
|
||||||
- label: Encoder Decoder tests # 5min
|
- label: Encoder Decoder tests # 5min
|
||||||
mirror_hardwares: [amdexperimental]
|
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
- tests/encoder_decoder
|
- tests/encoder_decoder
|
||||||
@ -450,8 +426,8 @@ steps:
|
|||||||
- pytest -v -s encoder_decoder
|
- pytest -v -s encoder_decoder
|
||||||
|
|
||||||
- label: OpenAI-Compatible Tool Use # 20 min
|
- label: OpenAI-Compatible Tool Use # 20 min
|
||||||
mirror_hardwares: [amdexperimental]
|
|
||||||
fast_check: false
|
fast_check: false
|
||||||
|
#mirror_hardwares: [ amd ]
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
- tests/tool_use
|
- tests/tool_use
|
||||||
@ -463,7 +439,6 @@ steps:
|
|||||||
##### models test #####
|
##### models test #####
|
||||||
|
|
||||||
- label: Basic Models Test # 24min
|
- label: Basic Models Test # 24min
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
|
||||||
torch_nightly: true
|
torch_nightly: true
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
@ -473,55 +448,43 @@ steps:
|
|||||||
- pytest -v -s models/test_registry.py
|
- pytest -v -s models/test_registry.py
|
||||||
- pytest -v -s models/test_utils.py
|
- pytest -v -s models/test_utils.py
|
||||||
- pytest -v -s models/test_vision.py
|
- pytest -v -s models/test_vision.py
|
||||||
- pytest -v -s models/test_initialization.py
|
# V1 Test: https://github.com/vllm-project/vllm/issues/14531
|
||||||
|
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2'
|
||||||
|
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'llama4'
|
||||||
|
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'plamo2'
|
||||||
|
|
||||||
- label: Language Models Test (Standard)
|
- label: Language Models Test (Standard)
|
||||||
mirror_hardwares: [amdexperimental]
|
#mirror_hardwares: [amd]
|
||||||
torch_nightly: true
|
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
- tests/models/language
|
- tests/models/language
|
||||||
commands:
|
commands:
|
||||||
# Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
|
# Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
|
||||||
- pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8'
|
- pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8'
|
||||||
- pip freeze | grep -E 'torch'
|
|
||||||
- pytest -v -s models/language -m core_model
|
- pytest -v -s models/language -m core_model
|
||||||
|
|
||||||
- label: Language Models Test (Extended Generation) # 1hr20min
|
- label: Language Models Test (Extended)
|
||||||
mirror_hardwares: [amdexperimental]
|
|
||||||
optional: true
|
optional: true
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
- tests/models/language/generation
|
- tests/models/language
|
||||||
commands:
|
commands:
|
||||||
# Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
|
# Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
|
||||||
- pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8'
|
- pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8'
|
||||||
- pytest -v -s models/language/generation -m 'not core_model'
|
- pytest -v -s models/language -m 'not core_model'
|
||||||
|
|
||||||
- label: Language Models Test (Extended Pooling) # 36min
|
|
||||||
mirror_hardwares: [amdexperimental]
|
|
||||||
optional: true
|
|
||||||
source_file_dependencies:
|
|
||||||
- vllm/
|
|
||||||
- tests/models/language/pooling
|
|
||||||
commands:
|
|
||||||
- pytest -v -s models/language/pooling -m 'not core_model'
|
|
||||||
|
|
||||||
- label: Multi-Modal Models Test (Standard)
|
- label: Multi-Modal Models Test (Standard)
|
||||||
mirror_hardwares: [amdexperimental]
|
#mirror_hardwares: [amd]
|
||||||
torch_nightly: true
|
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
- tests/models/multimodal
|
- tests/models/multimodal
|
||||||
commands:
|
commands:
|
||||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||||
- pip freeze | grep -E 'torch'
|
|
||||||
- pytest -v -s models/multimodal/processing
|
- pytest -v -s models/multimodal/processing
|
||||||
- pytest -v -s --ignore models/multimodal/generation/test_whisper.py models/multimodal -m core_model
|
- pytest -v -s --ignore models/multimodal/generation/test_whisper.py models/multimodal -m core_model
|
||||||
- cd .. && pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work
|
- cd .. && pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work
|
||||||
|
|
||||||
- label: Multi-Modal Models Test (Extended) 1
|
- label: Multi-Modal Models Test (Extended) 1
|
||||||
mirror_hardwares: [amdexperimental]
|
|
||||||
optional: true
|
optional: true
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
@ -531,7 +494,6 @@ steps:
|
|||||||
- pytest -v -s --ignore models/multimodal/generation/test_common.py --ignore models/multimodal/processing models/multimodal -m 'not core_model'
|
- pytest -v -s --ignore models/multimodal/generation/test_common.py --ignore models/multimodal/processing models/multimodal -m 'not core_model'
|
||||||
|
|
||||||
- label: Multi-Modal Models Test (Extended) 2
|
- label: Multi-Modal Models Test (Extended) 2
|
||||||
mirror_hardwares: [amdexperimental]
|
|
||||||
optional: true
|
optional: true
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
@ -541,7 +503,6 @@ steps:
|
|||||||
- pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=0) and not core_model'
|
- pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=0) and not core_model'
|
||||||
|
|
||||||
- label: Multi-Modal Models Test (Extended) 3
|
- label: Multi-Modal Models Test (Extended) 3
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
|
||||||
optional: true
|
optional: true
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
@ -551,7 +512,7 @@ steps:
|
|||||||
- pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=1) and not core_model'
|
- pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=1) and not core_model'
|
||||||
|
|
||||||
- label: Quantized Models Test
|
- label: Quantized Models Test
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
#mirror_hardwares: [amd]
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/model_executor/layers/quantization
|
- vllm/model_executor/layers/quantization
|
||||||
- tests/models/quantization
|
- tests/models/quantization
|
||||||
@ -560,7 +521,7 @@ steps:
|
|||||||
|
|
||||||
# This test is used only in PR development phase to test individual models and should never run on main
|
# This test is used only in PR development phase to test individual models and should never run on main
|
||||||
- label: Custom Models Test
|
- label: Custom Models Test
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
mirror_hardwares: [amd]
|
||||||
optional: true
|
optional: true
|
||||||
commands:
|
commands:
|
||||||
- echo 'Testing custom models...'
|
- echo 'Testing custom models...'
|
||||||
@ -572,7 +533,7 @@ steps:
|
|||||||
##### multi gpus test #####
|
##### multi gpus test #####
|
||||||
|
|
||||||
- label: Distributed Comm Ops Test # 7min
|
- label: Distributed Comm Ops Test # 7min
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
mirror_hardwares: [amd]
|
||||||
working_dir: "/vllm-workspace/tests"
|
working_dir: "/vllm-workspace/tests"
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
@ -583,7 +544,6 @@ steps:
|
|||||||
- pytest -v -s distributed/test_shm_broadcast.py
|
- pytest -v -s distributed/test_shm_broadcast.py
|
||||||
|
|
||||||
- label: 2 Node Tests (4 GPUs in total) # 16min
|
- label: 2 Node Tests (4 GPUs in total) # 16min
|
||||||
mirror_hardwares: [amdexperimental]
|
|
||||||
working_dir: "/vllm-workspace/tests"
|
working_dir: "/vllm-workspace/tests"
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
num_nodes: 2
|
num_nodes: 2
|
||||||
@ -602,7 +562,7 @@ steps:
|
|||||||
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
|
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
|
||||||
|
|
||||||
- label: Distributed Tests (2 GPUs) # 40min
|
- label: Distributed Tests (2 GPUs) # 40min
|
||||||
mirror_hardwares: [amdexperimental]
|
#mirror_hardwares: [amd]
|
||||||
working_dir: "/vllm-workspace/tests"
|
working_dir: "/vllm-workspace/tests"
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
@ -639,14 +599,13 @@ steps:
|
|||||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown
|
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown
|
||||||
|
|
||||||
- label: Plugin Tests (2 GPUs) # 40min
|
- label: Plugin Tests (2 GPUs) # 40min
|
||||||
mirror_hardwares: [amdexperimental]
|
|
||||||
working_dir: "/vllm-workspace/tests"
|
working_dir: "/vllm-workspace/tests"
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/plugins/
|
- vllm/plugins/
|
||||||
- tests/plugins/
|
- tests/plugins/
|
||||||
commands:
|
commands:
|
||||||
# begin platform plugin and general plugin tests, all the code in-between runs on dummy platform
|
# begin platform plugin tests, all the code in-between runs on dummy platform
|
||||||
- pip install -e ./plugins/vllm_add_dummy_platform
|
- pip install -e ./plugins/vllm_add_dummy_platform
|
||||||
- pytest -v -s plugins_tests/test_platform_plugins.py
|
- pytest -v -s plugins_tests/test_platform_plugins.py
|
||||||
- pip uninstall vllm_add_dummy_platform -y
|
- pip uninstall vllm_add_dummy_platform -y
|
||||||
@ -657,10 +616,8 @@ steps:
|
|||||||
- pytest -v -s distributed/test_distributed_oot.py
|
- pytest -v -s distributed/test_distributed_oot.py
|
||||||
- pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process
|
- pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process
|
||||||
- pytest -v -s models/test_oot_registration.py # it needs a clean process
|
- pytest -v -s models/test_oot_registration.py # it needs a clean process
|
||||||
- pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins
|
|
||||||
|
|
||||||
- label: Multi-step Tests (4 GPUs) # 36min
|
- label: Multi-step Tests (4 GPUs) # 36min
|
||||||
mirror_hardwares: [amdexperimental]
|
|
||||||
working_dir: "/vllm-workspace/tests"
|
working_dir: "/vllm-workspace/tests"
|
||||||
num_gpus: 4
|
num_gpus: 4
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
@ -681,7 +638,6 @@ steps:
|
|||||||
- pytest -v -s multi_step/test_correctness_llm.py
|
- pytest -v -s multi_step/test_correctness_llm.py
|
||||||
|
|
||||||
- label: Pipeline Parallelism Test # 45min
|
- label: Pipeline Parallelism Test # 45min
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
|
||||||
working_dir: "/vllm-workspace/tests"
|
working_dir: "/vllm-workspace/tests"
|
||||||
num_gpus: 4
|
num_gpus: 4
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
@ -695,7 +651,6 @@ steps:
|
|||||||
- pytest -v -s distributed/test_pipeline_parallel.py
|
- pytest -v -s distributed/test_pipeline_parallel.py
|
||||||
|
|
||||||
- label: LoRA TP Test (Distributed)
|
- label: LoRA TP Test (Distributed)
|
||||||
mirror_hardwares: [amdexperimental, amdproduction]
|
|
||||||
num_gpus: 4
|
num_gpus: 4
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/lora
|
- vllm/lora
|
||||||
@ -711,7 +666,6 @@ steps:
|
|||||||
|
|
||||||
|
|
||||||
- label: Weight Loading Multiple GPU Test # 33min
|
- label: Weight Loading Multiple GPU Test # 33min
|
||||||
mirror_hardwares: [amdexperimental]
|
|
||||||
working_dir: "/vllm-workspace/tests"
|
working_dir: "/vllm-workspace/tests"
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
@ -721,7 +675,6 @@ steps:
|
|||||||
- bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt
|
- bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt
|
||||||
|
|
||||||
- label: Weight Loading Multiple GPU Test - Large Models # optional
|
- label: Weight Loading Multiple GPU Test - Large Models # optional
|
||||||
mirror_hardwares: [amdexperimental]
|
|
||||||
working_dir: "/vllm-workspace/tests"
|
working_dir: "/vllm-workspace/tests"
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
gpu: a100
|
gpu: a100
|
||||||
|
|||||||
6
.github/CODEOWNERS
vendored
6
.github/CODEOWNERS
vendored
@ -13,7 +13,6 @@
|
|||||||
/vllm/model_executor/guided_decoding @mgoin @russellb
|
/vllm/model_executor/guided_decoding @mgoin @russellb
|
||||||
/vllm/multimodal @DarkLight1337 @ywang96
|
/vllm/multimodal @DarkLight1337 @ywang96
|
||||||
/vllm/vllm_flash_attn @LucasWilkinson
|
/vllm/vllm_flash_attn @LucasWilkinson
|
||||||
/vllm/lora @jeejeelee
|
|
||||||
CMakeLists.txt @tlrmchlsmth
|
CMakeLists.txt @tlrmchlsmth
|
||||||
|
|
||||||
# vLLM V1
|
# vLLM V1
|
||||||
@ -41,8 +40,3 @@ CMakeLists.txt @tlrmchlsmth
|
|||||||
/tests/v1/entrypoints/llm/test_struct_output_generate.py @mgoin @russellb
|
/tests/v1/entrypoints/llm/test_struct_output_generate.py @mgoin @russellb
|
||||||
/tests/v1/structured_output @mgoin @russellb
|
/tests/v1/structured_output @mgoin @russellb
|
||||||
/tests/weight_loading @mgoin @youkaichao
|
/tests/weight_loading @mgoin @youkaichao
|
||||||
/tests/lora @jeejeelee
|
|
||||||
|
|
||||||
# Docs
|
|
||||||
/docs @hmellor
|
|
||||||
mkdocs.yaml @hmellor
|
|
||||||
6
.github/ISSUE_TEMPLATE/400-bug-report.yml
vendored
6
.github/ISSUE_TEMPLATE/400-bug-report.yml
vendored
@ -81,14 +81,14 @@ body:
|
|||||||
required: true
|
required: true
|
||||||
- type: markdown
|
- type: markdown
|
||||||
attributes:
|
attributes:
|
||||||
value: |
|
value: >
|
||||||
⚠️ Please separate bugs of `transformers` implementation or usage from bugs of `vllm`. If you think anything is wrong with the model's output:
|
⚠️ Please separate bugs of `transformers` implementation or usage from bugs of `vllm`. If you think anything is wrong with the models' output:
|
||||||
|
|
||||||
- Try the counterpart of `transformers` first. If the error appears, please go to [their issues](https://github.com/huggingface/transformers/issues?q=is%3Aissue+is%3Aopen+sort%3Aupdated-desc).
|
- Try the counterpart of `transformers` first. If the error appears, please go to [their issues](https://github.com/huggingface/transformers/issues?q=is%3Aissue+is%3Aopen+sort%3Aupdated-desc).
|
||||||
|
|
||||||
- If the error only appears in vllm, please provide the detailed script of how you run `transformers` and `vllm`, also highlight the difference and what you expect.
|
- If the error only appears in vllm, please provide the detailed script of how you run `transformers` and `vllm`, also highlight the difference and what you expect.
|
||||||
|
|
||||||
Thanks for reporting 🙏!
|
Thanks for contributing 🎉!
|
||||||
- type: checkboxes
|
- type: checkboxes
|
||||||
id: askllm
|
id: askllm
|
||||||
attributes:
|
attributes:
|
||||||
|
|||||||
69
.github/ISSUE_TEMPLATE/450-ci-failure.yml
vendored
69
.github/ISSUE_TEMPLATE/450-ci-failure.yml
vendored
@ -1,69 +0,0 @@
|
|||||||
name: 🧪 CI failure report
|
|
||||||
description: Report a failing test.
|
|
||||||
title: "[CI Failure]: "
|
|
||||||
labels: ["ci-failure"]
|
|
||||||
|
|
||||||
body:
|
|
||||||
- type: markdown
|
|
||||||
attributes:
|
|
||||||
value: >
|
|
||||||
#### Include the name of the failing Buildkite step and test file in the title.
|
|
||||||
- type: input
|
|
||||||
attributes:
|
|
||||||
label: Name of failing test
|
|
||||||
description: |
|
|
||||||
Paste in the fully-qualified name of the failing test from the logs.
|
|
||||||
placeholder: |
|
|
||||||
`path/to/test_file.py::test_name[params]`
|
|
||||||
validations:
|
|
||||||
required: true
|
|
||||||
- type: checkboxes
|
|
||||||
attributes:
|
|
||||||
label: Basic information
|
|
||||||
description: Select all items that apply to the failing test.
|
|
||||||
options:
|
|
||||||
- label: Flaky test
|
|
||||||
- label: Can reproduce locally
|
|
||||||
- label: Caused by external libraries (e.g. bug in `transformers`)
|
|
||||||
- type: textarea
|
|
||||||
attributes:
|
|
||||||
label: 🧪 Describe the failing test
|
|
||||||
description: |
|
|
||||||
Please provide a clear and concise description of the failing test.
|
|
||||||
placeholder: |
|
|
||||||
A clear and concise description of the failing test.
|
|
||||||
|
|
||||||
```
|
|
||||||
The error message you got, with the full traceback and the error logs with [dump_input.py:##] if present.
|
|
||||||
```
|
|
||||||
validations:
|
|
||||||
required: true
|
|
||||||
- type: textarea
|
|
||||||
attributes:
|
|
||||||
label: 📝 History of failing test
|
|
||||||
description: |
|
|
||||||
Since when did the test start to fail?
|
|
||||||
You can look up its history via [Buildkite Test Suites](https://buildkite.com/organizations/vllm/analytics/suites/ci-1/tests?branch=main).
|
|
||||||
|
|
||||||
If you have time, identify the PR that caused the test to fail on main. You can do so via the following methods:
|
|
||||||
|
|
||||||
- Use Buildkite Test Suites to find the PR where the test failure first occurred, and reproduce the failure locally.
|
|
||||||
|
|
||||||
- Run [`git bisect`](https://git-scm.com/docs/git-bisect) locally.
|
|
||||||
|
|
||||||
- Manually unblock Buildkite steps for suspected PRs on main and check the results. (authorized users only)
|
|
||||||
placeholder: |
|
|
||||||
Approximate timeline and/or problematic PRs
|
|
||||||
|
|
||||||
A link to the Buildkite analytics of the failing test (if available)
|
|
||||||
validations:
|
|
||||||
required: true
|
|
||||||
- type: textarea
|
|
||||||
attributes:
|
|
||||||
label: CC List.
|
|
||||||
description: >
|
|
||||||
The list of people you want to CC. Usually, this includes those who worked on the PR that failed the test.
|
|
||||||
- type: markdown
|
|
||||||
attributes:
|
|
||||||
value: >
|
|
||||||
Thanks for reporting 🙏!
|
|
||||||
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
@ -3,4 +3,4 @@ FILL IN THE PR DESCRIPTION HERE
|
|||||||
FIX #xxxx (*link existing issues this PR will resolve*)
|
FIX #xxxx (*link existing issues this PR will resolve*)
|
||||||
|
|
||||||
<!--- pyml disable-next-line no-emphasis-as-heading -->
|
<!--- pyml disable-next-line no-emphasis-as-heading -->
|
||||||
**BEFORE SUBMITTING, PLEASE READ <https://docs.vllm.ai/en/latest/contributing>** (anything written below this line will be removed by GitHub Actions)
|
**BEFORE SUBMITTING, PLEASE READ <https://docs.vllm.ai/en/latest/contributing/overview.html>** (anything written below this line will be removed by GitHub Actions)
|
||||||
|
|||||||
17
.github/mergify.yml
vendored
17
.github/mergify.yml
vendored
@ -58,7 +58,7 @@ pull_request_rules:
|
|||||||
- files~=^benchmarks/structured_schemas/
|
- files~=^benchmarks/structured_schemas/
|
||||||
- files=benchmarks/benchmark_serving_structured_output.py
|
- files=benchmarks/benchmark_serving_structured_output.py
|
||||||
- files=benchmarks/run_structured_output_benchmark.sh
|
- files=benchmarks/run_structured_output_benchmark.sh
|
||||||
- files=docs/features/structured_outputs.md
|
- files=docs/source/features/structured_outputs.md
|
||||||
- files=examples/offline_inference/structured_outputs.py
|
- files=examples/offline_inference/structured_outputs.py
|
||||||
- files=examples/online_serving/openai_chat_completion_structured_outputs.py
|
- files=examples/online_serving/openai_chat_completion_structured_outputs.py
|
||||||
- files=examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py
|
- files=examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py
|
||||||
@ -135,7 +135,9 @@ pull_request_rules:
|
|||||||
- files~=^tests/entrypoints/openai/tool_parsers/
|
- files~=^tests/entrypoints/openai/tool_parsers/
|
||||||
- files=tests/entrypoints/openai/test_chat_with_tool_reasoning.py
|
- files=tests/entrypoints/openai/test_chat_with_tool_reasoning.py
|
||||||
- files~=^vllm/entrypoints/openai/tool_parsers/
|
- files~=^vllm/entrypoints/openai/tool_parsers/
|
||||||
- files=docs/features/tool_calling.md
|
- files=docs/source/features/tool_calling.md
|
||||||
|
- files=docs/source/getting_started/examples/openai_chat_completion_client_with_tools.md
|
||||||
|
- files=docs/source/getting_started/examples/chat_with_tools.md
|
||||||
- files~=^examples/tool_chat_*
|
- files~=^examples/tool_chat_*
|
||||||
- files=examples/offline_inference/chat_with_tools.py
|
- files=examples/offline_inference/chat_with_tools.py
|
||||||
- files=examples/online_serving/openai_chat_completion_client_with_tools_required.py
|
- files=examples/online_serving/openai_chat_completion_client_with_tools_required.py
|
||||||
@ -161,17 +163,6 @@ pull_request_rules:
|
|||||||
|
|
||||||
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
|
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
|
||||||
|
|
||||||
- name: assign reviewer for tensorizer changes
|
|
||||||
conditions:
|
|
||||||
- files~=^vllm/model_executor/model_loader/tensorizer.py
|
|
||||||
- files~=^vllm/model_executor/model_loader/tensorizer_loader.py
|
|
||||||
- files~=^tests/entrypoints/openai/test_tensorizer_entrypoint.py
|
|
||||||
- files~=^tests/tensorizer_loader/
|
|
||||||
actions:
|
|
||||||
assign:
|
|
||||||
users:
|
|
||||||
- "sangstar"
|
|
||||||
|
|
||||||
- name: remove 'needs-rebase' label when conflict is resolved
|
- name: remove 'needs-rebase' label when conflict is resolved
|
||||||
conditions:
|
conditions:
|
||||||
- -conflict
|
- -conflict
|
||||||
|
|||||||
2
.github/scripts/cleanup_pr_body.sh
vendored
2
.github/scripts/cleanup_pr_body.sh
vendored
@ -26,7 +26,7 @@ sed -i '/\*\*BEFORE SUBMITTING, PLEASE READ.*\*\*/,$d' "${NEW}"
|
|||||||
|
|
||||||
# Remove HTML <details> section that includes <summary> text of "PR Checklist (Click to Expand)"
|
# Remove HTML <details> section that includes <summary> text of "PR Checklist (Click to Expand)"
|
||||||
python3 - <<EOF
|
python3 - <<EOF
|
||||||
import regex as re
|
import re
|
||||||
|
|
||||||
with open("${NEW}", "r") as file:
|
with open("${NEW}", "r") as file:
|
||||||
content = file.read()
|
content = file.read()
|
||||||
|
|||||||
2
.github/workflows/add_label_automerge.yml
vendored
2
.github/workflows/add_label_automerge.yml
vendored
@ -1,6 +1,4 @@
|
|||||||
name: Add label on auto-merge enabled
|
name: Add label on auto-merge enabled
|
||||||
permissions:
|
|
||||||
pull-requests: write
|
|
||||||
on:
|
on:
|
||||||
pull_request_target:
|
pull_request_target:
|
||||||
types:
|
types:
|
||||||
|
|||||||
7
.github/workflows/cleanup_pr_body.yml
vendored
7
.github/workflows/cleanup_pr_body.yml
vendored
@ -20,12 +20,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
python-version: '3.12'
|
python-version: '3.12'
|
||||||
|
|
||||||
- name: Install Python dependencies
|
|
||||||
run: |
|
|
||||||
python3 -m pip install --upgrade pip
|
|
||||||
python3 -m pip install regex
|
|
||||||
|
|
||||||
- name: Update PR description
|
- name: Update PR description
|
||||||
env:
|
env:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
run: bash .github/scripts/cleanup_pr_body.sh "${{ github.event.number }}"
|
run: .github/scripts/cleanup_pr_body.sh "${{ github.event.number }}"
|
||||||
|
|||||||
3
.github/workflows/lint-and-deploy.yaml
vendored
3
.github/workflows/lint-and-deploy.yaml
vendored
@ -2,9 +2,6 @@ name: Lint and Deploy Charts
|
|||||||
|
|
||||||
on: pull_request
|
on: pull_request
|
||||||
|
|
||||||
permissions:
|
|
||||||
contents: read
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
lint-and-deploy:
|
lint-and-deploy:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|||||||
3
.github/workflows/pre-commit.yml
vendored
3
.github/workflows/pre-commit.yml
vendored
@ -5,9 +5,6 @@ on:
|
|||||||
push:
|
push:
|
||||||
branches: [main]
|
branches: [main]
|
||||||
|
|
||||||
permissions:
|
|
||||||
contents: read
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
pre-commit:
|
pre-commit:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|||||||
2
.github/workflows/reminder_comment.yml
vendored
2
.github/workflows/reminder_comment.yml
vendored
@ -1,6 +1,4 @@
|
|||||||
name: PR Reminder Comment Bot
|
name: PR Reminder Comment Bot
|
||||||
permissions:
|
|
||||||
pull-requests: write
|
|
||||||
on:
|
on:
|
||||||
pull_request_target:
|
pull_request_target:
|
||||||
types: [opened]
|
types: [opened]
|
||||||
|
|||||||
6
.gitignore
vendored
6
.gitignore
vendored
@ -77,6 +77,11 @@ instance/
|
|||||||
# Scrapy stuff:
|
# Scrapy stuff:
|
||||||
.scrapy
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
docs/source/getting_started/examples/
|
||||||
|
docs/source/api/vllm
|
||||||
|
|
||||||
# PyBuilder
|
# PyBuilder
|
||||||
.pybuilder/
|
.pybuilder/
|
||||||
target/
|
target/
|
||||||
@ -146,7 +151,6 @@ venv.bak/
|
|||||||
|
|
||||||
# mkdocs documentation
|
# mkdocs documentation
|
||||||
/site
|
/site
|
||||||
docs/examples
|
|
||||||
|
|
||||||
# mypy
|
# mypy
|
||||||
.mypy_cache/
|
.mypy_cache/
|
||||||
|
|||||||
@ -16,8 +16,6 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
args: [--output-format, github, --fix]
|
args: [--output-format, github, --fix]
|
||||||
- id: ruff-format
|
|
||||||
files: ^(.buildkite|benchmarks|examples)/.*
|
|
||||||
- repo: https://github.com/codespell-project/codespell
|
- repo: https://github.com/codespell-project/codespell
|
||||||
rev: v2.4.1
|
rev: v2.4.1
|
||||||
hooks:
|
hooks:
|
||||||
@ -39,7 +37,6 @@ repos:
|
|||||||
rev: v0.9.29
|
rev: v0.9.29
|
||||||
hooks:
|
hooks:
|
||||||
- id: pymarkdown
|
- id: pymarkdown
|
||||||
exclude: '.*\.inc\.md'
|
|
||||||
args: [fix]
|
args: [fix]
|
||||||
- repo: https://github.com/rhysd/actionlint
|
- repo: https://github.com/rhysd/actionlint
|
||||||
rev: v1.7.7
|
rev: v1.7.7
|
||||||
@ -128,21 +125,6 @@ repos:
|
|||||||
name: Update Dockerfile dependency graph
|
name: Update Dockerfile dependency graph
|
||||||
entry: tools/update-dockerfile-graph.sh
|
entry: tools/update-dockerfile-graph.sh
|
||||||
language: script
|
language: script
|
||||||
- id: enforce-import-regex-instead-of-re
|
|
||||||
name: Enforce import regex as re
|
|
||||||
entry: python tools/enforce_regex_import.py
|
|
||||||
language: python
|
|
||||||
types: [python]
|
|
||||||
pass_filenames: false
|
|
||||||
additional_dependencies: [regex]
|
|
||||||
# forbid directly import triton
|
|
||||||
- id: forbid-direct-triton-import
|
|
||||||
name: "Forbid direct 'import triton'"
|
|
||||||
entry: python tools/check_triton_import.py
|
|
||||||
language: python
|
|
||||||
types: [python]
|
|
||||||
pass_filenames: false
|
|
||||||
additional_dependencies: [regex]
|
|
||||||
# Keep `suggestion` last
|
# Keep `suggestion` last
|
||||||
- id: suggestion
|
- id: suggestion
|
||||||
name: Suggestion
|
name: Suggestion
|
||||||
|
|||||||
@ -8,8 +8,12 @@ build:
|
|||||||
tools:
|
tools:
|
||||||
python: "3.12"
|
python: "3.12"
|
||||||
|
|
||||||
mkdocs:
|
sphinx:
|
||||||
configuration: mkdocs.yaml
|
configuration: docs/source/conf.py
|
||||||
|
fail_on_warning: true
|
||||||
|
|
||||||
|
# If using Sphinx, optionally build your docs in additional formats such as PDF
|
||||||
|
formats: []
|
||||||
|
|
||||||
# Optionally declare the Python requirements required to build your docs
|
# Optionally declare the Python requirements required to build your docs
|
||||||
python:
|
python:
|
||||||
|
|||||||
@ -29,6 +29,9 @@ set(ignoreMe "${VLLM_PYTHON_PATH}")
|
|||||||
#
|
#
|
||||||
set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12")
|
set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12")
|
||||||
|
|
||||||
|
# Supported NVIDIA architectures.
|
||||||
|
set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0")
|
||||||
|
|
||||||
# Supported AMD GPU architectures.
|
# Supported AMD GPU architectures.
|
||||||
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201")
|
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201")
|
||||||
|
|
||||||
@ -76,15 +79,6 @@ endif()
|
|||||||
#
|
#
|
||||||
find_package(Torch REQUIRED)
|
find_package(Torch REQUIRED)
|
||||||
|
|
||||||
# Supported NVIDIA architectures.
|
|
||||||
# This check must happen after find_package(Torch) because that's when CMAKE_CUDA_COMPILER_VERSION gets defined
|
|
||||||
if(DEFINED CMAKE_CUDA_COMPILER_VERSION AND
|
|
||||||
CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8)
|
|
||||||
set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0")
|
|
||||||
else()
|
|
||||||
set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Forward the non-CUDA device extensions to external CMake scripts.
|
# Forward the non-CUDA device extensions to external CMake scripts.
|
||||||
#
|
#
|
||||||
@ -232,13 +226,10 @@ endif()
|
|||||||
#
|
#
|
||||||
|
|
||||||
set(VLLM_EXT_SRC
|
set(VLLM_EXT_SRC
|
||||||
"csrc/mamba/mamba_ssm/selective_scan_fwd.cu"
|
|
||||||
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
|
|
||||||
"csrc/cache_kernels.cu"
|
"csrc/cache_kernels.cu"
|
||||||
"csrc/attention/paged_attention_v1.cu"
|
"csrc/attention/paged_attention_v1.cu"
|
||||||
"csrc/attention/paged_attention_v2.cu"
|
"csrc/attention/paged_attention_v2.cu"
|
||||||
"csrc/attention/merge_attn_states.cu"
|
"csrc/attention/merge_attn_states.cu"
|
||||||
"csrc/attention/vertical_slash_index.cu"
|
|
||||||
"csrc/pos_encoding_kernels.cu"
|
"csrc/pos_encoding_kernels.cu"
|
||||||
"csrc/activation_kernels.cu"
|
"csrc/activation_kernels.cu"
|
||||||
"csrc/layernorm_kernels.cu"
|
"csrc/layernorm_kernels.cu"
|
||||||
@ -289,13 +280,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
FetchContent_MakeAvailable(cutlass)
|
FetchContent_MakeAvailable(cutlass)
|
||||||
|
|
||||||
list(APPEND VLLM_EXT_SRC
|
list(APPEND VLLM_EXT_SRC
|
||||||
|
"csrc/mamba/mamba_ssm/selective_scan_fwd.cu"
|
||||||
|
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
|
||||||
"csrc/quantization/aqlm/gemm_kernels.cu"
|
"csrc/quantization/aqlm/gemm_kernels.cu"
|
||||||
"csrc/quantization/awq/gemm_kernels.cu"
|
"csrc/quantization/awq/gemm_kernels.cu"
|
||||||
"csrc/permute_cols.cu"
|
"csrc/permute_cols.cu"
|
||||||
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
|
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
|
||||||
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
|
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
|
||||||
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
|
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
|
||||||
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu"
|
|
||||||
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
|
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
|
||||||
"csrc/cutlass_extensions/common.cpp"
|
"csrc/cutlass_extensions/common.cpp"
|
||||||
"csrc/attention/mla/cutlass_mla_entry.cu")
|
"csrc/attention/mla/cutlass_mla_entry.cu")
|
||||||
@ -307,8 +299,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
# Only build Marlin kernels if we are building for at least some compatible archs.
|
# Only build Marlin kernels if we are building for at least some compatible archs.
|
||||||
# Keep building Marlin for 9.0 as there are some group sizes and shapes that
|
# Keep building Marlin for 9.0 as there are some group sizes and shapes that
|
||||||
# are not supported by Machete yet.
|
# are not supported by Machete yet.
|
||||||
# 9.0 for latest bf16 atomicAdd PTX
|
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
|
||||||
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;9.0+PTX" "${CUDA_ARCHS}")
|
|
||||||
if (MARLIN_ARCHS)
|
if (MARLIN_ARCHS)
|
||||||
|
|
||||||
#
|
#
|
||||||
@ -452,9 +443,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
#
|
#
|
||||||
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
|
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
|
||||||
# kernels for the remaining archs that are not already built for 3x.
|
# kernels for the remaining archs that are not already built for 3x.
|
||||||
# (Build 8.9 for FP8)
|
|
||||||
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
|
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
|
||||||
"7.5;8.0;8.9+PTX" "${CUDA_ARCHS}")
|
"7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
|
||||||
# subtract out the archs that are already built for 3x
|
# subtract out the archs that are already built for 3x
|
||||||
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
|
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
|
||||||
if (SCALED_MM_2X_ARCHS)
|
if (SCALED_MM_2X_ARCHS)
|
||||||
@ -505,9 +495,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND FP4_ARCHS)
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND FP4_ARCHS)
|
||||||
set(SRCS
|
set(SRCS
|
||||||
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
|
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
|
||||||
"csrc/quantization/fp4/nvfp4_experts_quant.cu"
|
"csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu")
|
||||||
"csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu"
|
|
||||||
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
|
|
||||||
set_gencode_flags_for_srcs(
|
set_gencode_flags_for_srcs(
|
||||||
SRCS "${SRCS}"
|
SRCS "${SRCS}"
|
||||||
CUDA_ARCHS "${FP4_ARCHS}")
|
CUDA_ARCHS "${FP4_ARCHS}")
|
||||||
@ -545,7 +533,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works
|
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works
|
||||||
# on Hopper). get_cutlass_moe_mm_data should only be compiled if it's possible
|
# on Hopper). get_cutlass_moe_mm_data should only be compiled if it's possible
|
||||||
# to compile MoE kernels that use its output.
|
# to compile MoE kernels that use its output.
|
||||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}")
|
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
|
||||||
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu"
|
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu"
|
||||||
"csrc/quantization/cutlass_w8a8/moe/moe_data.cu")
|
"csrc/quantization/cutlass_w8a8/moe/moe_data.cu")
|
||||||
@ -683,8 +671,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
CUDA_ARCHS "${CUDA_ARCHS}")
|
CUDA_ARCHS "${CUDA_ARCHS}")
|
||||||
|
|
||||||
list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}")
|
list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}")
|
||||||
# 9.0 for latest bf16 atomicAdd PTX
|
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
|
||||||
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;9.0+PTX" "${CUDA_ARCHS}")
|
|
||||||
if (MARLIN_MOE_ARCHS)
|
if (MARLIN_MOE_ARCHS)
|
||||||
|
|
||||||
#
|
#
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
# Contributing to vLLM
|
# Contributing to vLLM
|
||||||
|
|
||||||
You may find information about contributing to vLLM on [docs.vllm.ai](https://docs.vllm.ai/en/latest/contributing).
|
You may find information about contributing to vLLM on [docs.vllm.ai](https://docs.vllm.ai/en/latest/contributing/overview.html).
|
||||||
|
|||||||
12
README.md
12
README.md
@ -1,7 +1,7 @@
|
|||||||
<p align="center">
|
<p align="center">
|
||||||
<picture>
|
<picture>
|
||||||
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/assets/logos/vllm-logo-text-dark.png">
|
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/logos/vllm-logo-text-dark.png">
|
||||||
<img alt="vLLM" src="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/assets/logos/vllm-logo-text-light.png" width=55%>
|
<img alt="vLLM" src="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/logos/vllm-logo-text-light.png" width=55%>
|
||||||
</picture>
|
</picture>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
@ -58,7 +58,7 @@ vLLM is fast with:
|
|||||||
- Efficient management of attention key and value memory with [**PagedAttention**](https://blog.vllm.ai/2023/06/20/vllm.html)
|
- Efficient management of attention key and value memory with [**PagedAttention**](https://blog.vllm.ai/2023/06/20/vllm.html)
|
||||||
- Continuous batching of incoming requests
|
- Continuous batching of incoming requests
|
||||||
- Fast model execution with CUDA/HIP graph
|
- Fast model execution with CUDA/HIP graph
|
||||||
- Quantizations: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [AutoRound](https://arxiv.org/abs/2309.05516),INT4, INT8, and FP8.
|
- Quantizations: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), INT4, INT8, and FP8.
|
||||||
- Optimized CUDA kernels, including integration with FlashAttention and FlashInfer.
|
- Optimized CUDA kernels, including integration with FlashAttention and FlashInfer.
|
||||||
- Speculative decoding
|
- Speculative decoding
|
||||||
- Chunked prefill
|
- Chunked prefill
|
||||||
@ -74,7 +74,7 @@ vLLM is flexible and easy to use with:
|
|||||||
- OpenAI-compatible API server
|
- OpenAI-compatible API server
|
||||||
- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, TPU, and AWS Neuron.
|
- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, TPU, and AWS Neuron.
|
||||||
- Prefix caching support
|
- Prefix caching support
|
||||||
- Multi-LoRA support
|
- Multi-lora support
|
||||||
|
|
||||||
vLLM seamlessly supports most popular open-source models on HuggingFace, including:
|
vLLM seamlessly supports most popular open-source models on HuggingFace, including:
|
||||||
- Transformer-like LLMs (e.g., Llama)
|
- Transformer-like LLMs (e.g., Llama)
|
||||||
@ -100,14 +100,14 @@ Visit our [documentation](https://docs.vllm.ai/en/latest/) to learn more.
|
|||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
We welcome and value any contributions and collaborations.
|
We welcome and value any contributions and collaborations.
|
||||||
Please check out [Contributing to vLLM](https://docs.vllm.ai/en/latest/contributing/index.html) for how to get involved.
|
Please check out [Contributing to vLLM](https://docs.vllm.ai/en/stable/contributing/overview.html) for how to get involved.
|
||||||
|
|
||||||
## Sponsors
|
## Sponsors
|
||||||
|
|
||||||
vLLM is a community project. Our compute resources for development and testing are supported by the following organizations. Thank you for your support!
|
vLLM is a community project. Our compute resources for development and testing are supported by the following organizations. Thank you for your support!
|
||||||
|
|
||||||
<!-- Note: Please sort them in alphabetical order. -->
|
<!-- Note: Please sort them in alphabetical order. -->
|
||||||
<!-- Note: Please keep these consistent with docs/community/sponsors.md -->
|
<!-- Note: Please keep these consistent with docs/source/community/sponsors.md -->
|
||||||
Cash Donations:
|
Cash Donations:
|
||||||
- a16z
|
- a16z
|
||||||
- Dropbox
|
- Dropbox
|
||||||
|
|||||||
@ -146,9 +146,10 @@ python3 vllm/benchmarks/benchmark_serving.py \
|
|||||||
|
|
||||||
``` bash
|
``` bash
|
||||||
VLLM_USE_V1=1 vllm serve meta-llama/Meta-Llama-3-8B-Instruct \
|
VLLM_USE_V1=1 vllm serve meta-llama/Meta-Llama-3-8B-Instruct \
|
||||||
|
--speculative-model "[ngram]" \
|
||||||
--ngram_prompt_lookup_min 2 \
|
--ngram_prompt_lookup_min 2 \
|
||||||
--ngram-prompt-lookup-max 5 \
|
--ngram-prompt-lookup-max 5 \
|
||||||
--speculative_config '{"model": "[ngram]", "num_speculative_tokens": 5}
|
--num_speculative_tokens 5
|
||||||
```
|
```
|
||||||
|
|
||||||
``` bash
|
``` bash
|
||||||
@ -273,9 +274,10 @@ python3 vllm/benchmarks/benchmark_throughput.py \
|
|||||||
--output-len=100 \
|
--output-len=100 \
|
||||||
--num-prompts=2048 \
|
--num-prompts=2048 \
|
||||||
--async-engine \
|
--async-engine \
|
||||||
|
--speculative-model="[ngram]" \
|
||||||
--ngram_prompt_lookup_min=2 \
|
--ngram_prompt_lookup_min=2 \
|
||||||
--ngram-prompt-lookup-max=5 \
|
--ngram-prompt-lookup-max=5 \
|
||||||
--speculative_config '{"model": "[ngram]", "num_speculative_tokens": 5}
|
--num_speculative_tokens=5
|
||||||
```
|
```
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|||||||
@ -12,7 +12,8 @@ from typing import Optional, Union
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
import huggingface_hub.constants
|
import huggingface_hub.constants
|
||||||
from tqdm.asyncio import tqdm
|
from tqdm.asyncio import tqdm
|
||||||
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
|
from transformers import (AutoTokenizer, PreTrainedTokenizer,
|
||||||
|
PreTrainedTokenizerFast)
|
||||||
|
|
||||||
# NOTE(simon): do not import vLLM here so the benchmark script
|
# NOTE(simon): do not import vLLM here so the benchmark script
|
||||||
# can run without vLLM installed.
|
# can run without vLLM installed.
|
||||||
@ -42,7 +43,8 @@ class RequestFuncOutput:
|
|||||||
latency: float = 0.0
|
latency: float = 0.0
|
||||||
output_tokens: int = 0
|
output_tokens: int = 0
|
||||||
ttft: float = 0.0 # Time to first token
|
ttft: float = 0.0 # Time to first token
|
||||||
itl: list[float] = field(default_factory=list) # list of inter-token latencies
|
itl: list[float] = field(
|
||||||
|
default_factory=list) # list of inter-token latencies
|
||||||
tpot: float = 0.0 # avg next-token latencies
|
tpot: float = 0.0 # avg next-token latencies
|
||||||
prompt_len: int = 0
|
prompt_len: int = 0
|
||||||
error: str = ""
|
error: str = ""
|
||||||
@ -55,9 +57,8 @@ async def async_request_tgi(
|
|||||||
api_url = request_func_input.api_url
|
api_url = request_func_input.api_url
|
||||||
assert api_url.endswith("generate_stream")
|
assert api_url.endswith("generate_stream")
|
||||||
|
|
||||||
async with aiohttp.ClientSession(
|
async with aiohttp.ClientSession(trust_env=True,
|
||||||
trust_env=True, timeout=AIOHTTP_TIMEOUT
|
timeout=AIOHTTP_TIMEOUT) as session:
|
||||||
) as session:
|
|
||||||
params = {
|
params = {
|
||||||
"max_new_tokens": request_func_input.output_len,
|
"max_new_tokens": request_func_input.output_len,
|
||||||
"do_sample": True,
|
"do_sample": True,
|
||||||
@ -104,7 +105,8 @@ async def async_request_tgi(
|
|||||||
|
|
||||||
# Decoding phase
|
# Decoding phase
|
||||||
else:
|
else:
|
||||||
output.itl.append(timestamp - most_recent_timestamp)
|
output.itl.append(timestamp -
|
||||||
|
most_recent_timestamp)
|
||||||
|
|
||||||
most_recent_timestamp = timestamp
|
most_recent_timestamp = timestamp
|
||||||
|
|
||||||
@ -131,9 +133,8 @@ async def async_request_trt_llm(
|
|||||||
api_url = request_func_input.api_url
|
api_url = request_func_input.api_url
|
||||||
assert api_url.endswith("generate_stream")
|
assert api_url.endswith("generate_stream")
|
||||||
|
|
||||||
async with aiohttp.ClientSession(
|
async with aiohttp.ClientSession(trust_env=True,
|
||||||
trust_env=True, timeout=AIOHTTP_TIMEOUT
|
timeout=AIOHTTP_TIMEOUT) as session:
|
||||||
) as session:
|
|
||||||
payload = {
|
payload = {
|
||||||
"accumulate_tokens": True,
|
"accumulate_tokens": True,
|
||||||
"text_input": request_func_input.prompt,
|
"text_input": request_func_input.prompt,
|
||||||
@ -158,7 +159,8 @@ async def async_request_trt_llm(
|
|||||||
if not chunk_bytes:
|
if not chunk_bytes:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
chunk = chunk_bytes.decode("utf-8").removeprefix("data:")
|
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
||||||
|
"data:")
|
||||||
|
|
||||||
data = json.loads(chunk)
|
data = json.loads(chunk)
|
||||||
output.generated_text += data["text_output"]
|
output.generated_text += data["text_output"]
|
||||||
@ -170,7 +172,8 @@ async def async_request_trt_llm(
|
|||||||
|
|
||||||
# Decoding phase
|
# Decoding phase
|
||||||
else:
|
else:
|
||||||
output.itl.append(timestamp - most_recent_timestamp)
|
output.itl.append(timestamp -
|
||||||
|
most_recent_timestamp)
|
||||||
|
|
||||||
most_recent_timestamp = timestamp
|
most_recent_timestamp = timestamp
|
||||||
|
|
||||||
@ -194,14 +197,9 @@ async def async_request_deepspeed_mii(
|
|||||||
request_func_input: RequestFuncInput,
|
request_func_input: RequestFuncInput,
|
||||||
pbar: Optional[tqdm] = None,
|
pbar: Optional[tqdm] = None,
|
||||||
) -> RequestFuncOutput:
|
) -> RequestFuncOutput:
|
||||||
api_url = request_func_input.api_url
|
async with aiohttp.ClientSession(trust_env=True,
|
||||||
assert api_url.endswith(("completions", "profile")), (
|
timeout=AIOHTTP_TIMEOUT) as session:
|
||||||
"OpenAI Completions API URL must end with 'completions' or 'profile'."
|
|
||||||
)
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession(
|
|
||||||
trust_env=True, timeout=AIOHTTP_TIMEOUT
|
|
||||||
) as session:
|
|
||||||
payload = {
|
payload = {
|
||||||
"model": request_func_input.model,
|
"model": request_func_input.model,
|
||||||
"prompt": request_func_input.prompt,
|
"prompt": request_func_input.prompt,
|
||||||
@ -209,8 +207,6 @@ async def async_request_deepspeed_mii(
|
|||||||
"temperature": 0.01, # deepspeed-mii does not accept 0.0 temp.
|
"temperature": 0.01, # deepspeed-mii does not accept 0.0 temp.
|
||||||
"top_p": 1.0,
|
"top_p": 1.0,
|
||||||
}
|
}
|
||||||
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
|
||||||
|
|
||||||
output = RequestFuncOutput()
|
output = RequestFuncOutput()
|
||||||
output.prompt_len = request_func_input.prompt_len
|
output.prompt_len = request_func_input.prompt_len
|
||||||
|
|
||||||
@ -221,21 +217,19 @@ async def async_request_deepspeed_mii(
|
|||||||
|
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
try:
|
try:
|
||||||
async with session.post(
|
async with session.post(url=request_func_input.api_url,
|
||||||
url=api_url, json=payload, headers=headers
|
json=payload) as response:
|
||||||
) as response:
|
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
parsed_resp = await response.json()
|
parsed_resp = await response.json()
|
||||||
output.latency = time.perf_counter() - st
|
output.latency = time.perf_counter() - st
|
||||||
if "choices" in parsed_resp:
|
if "choices" in parsed_resp:
|
||||||
output.generated_text = parsed_resp["choices"][0]["text"]
|
output.generated_text = parsed_resp["choices"][0][
|
||||||
|
"text"]
|
||||||
elif "text" in parsed_resp:
|
elif "text" in parsed_resp:
|
||||||
output.generated_text = parsed_resp["text"][0]
|
output.generated_text = parsed_resp["text"][0]
|
||||||
else:
|
else:
|
||||||
output.error = (
|
output.error = ("Unexpected response format: "
|
||||||
"Unexpected response format: "
|
"neither 'choices' nor 'text' found")
|
||||||
"neither 'choices' nor 'text' found"
|
|
||||||
)
|
|
||||||
output.success = False
|
output.success = False
|
||||||
output.success = True
|
output.success = True
|
||||||
else:
|
else:
|
||||||
@ -256,17 +250,15 @@ async def async_request_openai_completions(
|
|||||||
pbar: Optional[tqdm] = None,
|
pbar: Optional[tqdm] = None,
|
||||||
) -> RequestFuncOutput:
|
) -> RequestFuncOutput:
|
||||||
api_url = request_func_input.api_url
|
api_url = request_func_input.api_url
|
||||||
assert api_url.endswith(("completions", "profile")), (
|
assert api_url.endswith(
|
||||||
"OpenAI Completions API URL must end with 'completions' or 'profile'."
|
("completions", "profile")
|
||||||
)
|
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
|
||||||
|
|
||||||
async with aiohttp.ClientSession(
|
async with aiohttp.ClientSession(trust_env=True,
|
||||||
trust_env=True, timeout=AIOHTTP_TIMEOUT
|
timeout=AIOHTTP_TIMEOUT) as session:
|
||||||
) as session:
|
|
||||||
payload = {
|
payload = {
|
||||||
"model": request_func_input.model_name
|
"model": request_func_input.model_name \
|
||||||
if request_func_input.model_name
|
if request_func_input.model_name else request_func_input.model,
|
||||||
else request_func_input.model,
|
|
||||||
"prompt": request_func_input.prompt,
|
"prompt": request_func_input.prompt,
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
"repetition_penalty": 1.0,
|
"repetition_penalty": 1.0,
|
||||||
@ -281,7 +273,9 @@ async def async_request_openai_completions(
|
|||||||
payload["ignore_eos"] = request_func_input.ignore_eos
|
payload["ignore_eos"] = request_func_input.ignore_eos
|
||||||
if request_func_input.extra_body:
|
if request_func_input.extra_body:
|
||||||
payload.update(request_func_input.extra_body)
|
payload.update(request_func_input.extra_body)
|
||||||
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
headers = {
|
||||||
|
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
|
||||||
|
}
|
||||||
|
|
||||||
output = RequestFuncOutput()
|
output = RequestFuncOutput()
|
||||||
output.prompt_len = request_func_input.prompt_len
|
output.prompt_len = request_func_input.prompt_len
|
||||||
@ -290,9 +284,8 @@ async def async_request_openai_completions(
|
|||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
most_recent_timestamp = st
|
most_recent_timestamp = st
|
||||||
try:
|
try:
|
||||||
async with session.post(
|
async with session.post(url=api_url, json=payload,
|
||||||
url=api_url, json=payload, headers=headers
|
headers=headers) as response:
|
||||||
) as response:
|
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
first_chunk_received = False
|
first_chunk_received = False
|
||||||
async for chunk_bytes in response.content:
|
async for chunk_bytes in response.content:
|
||||||
@ -300,7 +293,8 @@ async def async_request_openai_completions(
|
|||||||
if not chunk_bytes:
|
if not chunk_bytes:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
|
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
||||||
|
"data: ")
|
||||||
if chunk != "[DONE]":
|
if chunk != "[DONE]":
|
||||||
data = json.loads(chunk)
|
data = json.loads(chunk)
|
||||||
|
|
||||||
@ -320,20 +314,21 @@ async def async_request_openai_completions(
|
|||||||
|
|
||||||
# Decoding phase
|
# Decoding phase
|
||||||
else:
|
else:
|
||||||
output.itl.append(timestamp - most_recent_timestamp)
|
output.itl.append(timestamp -
|
||||||
|
most_recent_timestamp)
|
||||||
|
|
||||||
most_recent_timestamp = timestamp
|
most_recent_timestamp = timestamp
|
||||||
generated_text += text or ""
|
generated_text += text or ""
|
||||||
elif usage := data.get("usage"):
|
elif usage := data.get("usage"):
|
||||||
output.output_tokens = usage.get("completion_tokens")
|
output.output_tokens = usage.get(
|
||||||
|
"completion_tokens")
|
||||||
if first_chunk_received:
|
if first_chunk_received:
|
||||||
output.success = True
|
output.success = True
|
||||||
else:
|
else:
|
||||||
output.success = False
|
output.success = False
|
||||||
output.error = (
|
output.error = (
|
||||||
"Never received a valid chunk to calculate TTFT."
|
"Never received a valid chunk to calculate TTFT."
|
||||||
"This response will be marked as failed!"
|
"This response will be marked as failed!")
|
||||||
)
|
|
||||||
output.generated_text = generated_text
|
output.generated_text = generated_text
|
||||||
output.latency = most_recent_timestamp - st
|
output.latency = most_recent_timestamp - st
|
||||||
else:
|
else:
|
||||||
@ -354,22 +349,23 @@ async def async_request_openai_chat_completions(
|
|||||||
pbar: Optional[tqdm] = None,
|
pbar: Optional[tqdm] = None,
|
||||||
) -> RequestFuncOutput:
|
) -> RequestFuncOutput:
|
||||||
api_url = request_func_input.api_url
|
api_url = request_func_input.api_url
|
||||||
assert api_url.endswith(("chat/completions", "profile")), (
|
assert api_url.endswith(
|
||||||
"OpenAI Chat Completions API URL must end with 'chat/completions'."
|
("chat/completions", "profile")
|
||||||
)
|
), "OpenAI Chat Completions API URL must end with 'chat/completions'."
|
||||||
|
|
||||||
async with aiohttp.ClientSession(
|
async with aiohttp.ClientSession(trust_env=True,
|
||||||
trust_env=True, timeout=AIOHTTP_TIMEOUT
|
timeout=AIOHTTP_TIMEOUT) as session:
|
||||||
) as session:
|
|
||||||
content = [{"type": "text", "text": request_func_input.prompt}]
|
content = [{"type": "text", "text": request_func_input.prompt}]
|
||||||
if request_func_input.multi_modal_content:
|
if request_func_input.multi_modal_content:
|
||||||
content.append(request_func_input.multi_modal_content)
|
content.append(request_func_input.multi_modal_content)
|
||||||
payload = {
|
payload = {
|
||||||
"model": request_func_input.model_name
|
"model": request_func_input.model_name \
|
||||||
if request_func_input.model_name
|
if request_func_input.model_name else request_func_input.model,
|
||||||
else request_func_input.model,
|
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "user", "content": content},
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": content
|
||||||
|
},
|
||||||
],
|
],
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
"max_completion_tokens": request_func_input.output_len,
|
"max_completion_tokens": request_func_input.output_len,
|
||||||
@ -395,16 +391,16 @@ async def async_request_openai_chat_completions(
|
|||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
most_recent_timestamp = st
|
most_recent_timestamp = st
|
||||||
try:
|
try:
|
||||||
async with session.post(
|
async with session.post(url=api_url, json=payload,
|
||||||
url=api_url, json=payload, headers=headers
|
headers=headers) as response:
|
||||||
) as response:
|
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
async for chunk_bytes in response.content:
|
async for chunk_bytes in response.content:
|
||||||
chunk_bytes = chunk_bytes.strip()
|
chunk_bytes = chunk_bytes.strip()
|
||||||
if not chunk_bytes:
|
if not chunk_bytes:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
|
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
||||||
|
"data: ")
|
||||||
if chunk != "[DONE]":
|
if chunk != "[DONE]":
|
||||||
timestamp = time.perf_counter()
|
timestamp = time.perf_counter()
|
||||||
data = json.loads(chunk)
|
data = json.loads(chunk)
|
||||||
@ -418,11 +414,13 @@ async def async_request_openai_chat_completions(
|
|||||||
|
|
||||||
# Decoding phase
|
# Decoding phase
|
||||||
else:
|
else:
|
||||||
output.itl.append(timestamp - most_recent_timestamp)
|
output.itl.append(timestamp -
|
||||||
|
most_recent_timestamp)
|
||||||
|
|
||||||
generated_text += content or ""
|
generated_text += content or ""
|
||||||
elif usage := data.get("usage"):
|
elif usage := data.get("usage"):
|
||||||
output.output_tokens = usage.get("completion_tokens")
|
output.output_tokens = usage.get(
|
||||||
|
"completion_tokens")
|
||||||
|
|
||||||
most_recent_timestamp = timestamp
|
most_recent_timestamp = timestamp
|
||||||
|
|
||||||
@ -448,28 +446,25 @@ async def async_request_openai_audio(
|
|||||||
) -> RequestFuncOutput:
|
) -> RequestFuncOutput:
|
||||||
# Lazy import without PlaceholderModule to avoid vllm dep.
|
# Lazy import without PlaceholderModule to avoid vllm dep.
|
||||||
import soundfile
|
import soundfile
|
||||||
|
|
||||||
api_url = request_func_input.api_url
|
api_url = request_func_input.api_url
|
||||||
assert api_url.endswith(("transcriptions", "translations")), (
|
assert api_url.endswith(
|
||||||
"OpenAI Chat Completions API URL must end with 'transcriptions' "
|
("transcriptions", "translations"
|
||||||
)
|
)), "OpenAI Chat Completions API URL must end with 'transcriptions' "
|
||||||
"or `translations`."
|
"or `translations`."
|
||||||
|
|
||||||
async with aiohttp.ClientSession(
|
async with aiohttp.ClientSession(trust_env=True,
|
||||||
trust_env=True, timeout=AIOHTTP_TIMEOUT
|
timeout=AIOHTTP_TIMEOUT) as session:
|
||||||
) as session:
|
|
||||||
content = [{"type": "text", "text": request_func_input.prompt}]
|
content = [{"type": "text", "text": request_func_input.prompt}]
|
||||||
payload = {
|
payload = {
|
||||||
"model": request_func_input.model_name
|
"model": request_func_input.model_name \
|
||||||
if request_func_input.model_name
|
if request_func_input.model_name else request_func_input.model,
|
||||||
else request_func_input.model,
|
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
"max_completion_tokens": request_func_input.output_len,
|
"max_completion_tokens": request_func_input.output_len,
|
||||||
"stream": True,
|
"stream": True,
|
||||||
"language": "en",
|
"language": "en",
|
||||||
# Flattened due to multipart/form-data
|
# Flattened due to multipart/form-data
|
||||||
"stream_include_usage": True,
|
"stream_include_usage": True,
|
||||||
"stream_continuous_usage_stats": True,
|
"stream_continuous_usage_stats": True
|
||||||
}
|
}
|
||||||
if request_func_input.extra_body:
|
if request_func_input.extra_body:
|
||||||
payload.update(request_func_input.extra_body)
|
payload.update(request_func_input.extra_body)
|
||||||
@ -484,9 +479,9 @@ async def async_request_openai_audio(
|
|||||||
buffer.seek(0)
|
buffer.seek(0)
|
||||||
return buffer
|
return buffer
|
||||||
|
|
||||||
with to_bytes(*request_func_input.multi_modal_content["audio"]) as f:
|
with to_bytes(*request_func_input.multi_modal_content['audio']) as f:
|
||||||
form = aiohttp.FormData()
|
form = aiohttp.FormData()
|
||||||
form.add_field("file", f, content_type="audio/wav")
|
form.add_field('file', f, content_type='audio/wav')
|
||||||
for key, value in payload.items():
|
for key, value in payload.items():
|
||||||
form.add_field(key, str(value))
|
form.add_field(key, str(value))
|
||||||
|
|
||||||
@ -498,22 +493,24 @@ async def async_request_openai_audio(
|
|||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
most_recent_timestamp = st
|
most_recent_timestamp = st
|
||||||
try:
|
try:
|
||||||
async with session.post(
|
async with session.post(url=api_url,
|
||||||
url=api_url, data=form, headers=headers
|
data=form,
|
||||||
) as response:
|
headers=headers) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
async for chunk_bytes in response.content:
|
async for chunk_bytes in response.content:
|
||||||
chunk_bytes = chunk_bytes.strip()
|
chunk_bytes = chunk_bytes.strip()
|
||||||
if not chunk_bytes:
|
if not chunk_bytes:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
|
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
||||||
|
"data: ")
|
||||||
if chunk != "[DONE]":
|
if chunk != "[DONE]":
|
||||||
timestamp = time.perf_counter()
|
timestamp = time.perf_counter()
|
||||||
data = json.loads(chunk)
|
data = json.loads(chunk)
|
||||||
|
|
||||||
if choices := data.get("choices"):
|
if choices := data.get("choices"):
|
||||||
content = choices[0]["delta"].get("content")
|
content = choices[0]["delta"].get(
|
||||||
|
"content")
|
||||||
# First token
|
# First token
|
||||||
if ttft == 0.0:
|
if ttft == 0.0:
|
||||||
ttft = timestamp - st
|
ttft = timestamp - st
|
||||||
@ -522,14 +519,12 @@ async def async_request_openai_audio(
|
|||||||
# Decoding phase
|
# Decoding phase
|
||||||
else:
|
else:
|
||||||
output.itl.append(
|
output.itl.append(
|
||||||
timestamp - most_recent_timestamp
|
timestamp - most_recent_timestamp)
|
||||||
)
|
|
||||||
|
|
||||||
generated_text += content or ""
|
generated_text += content or ""
|
||||||
elif usage := data.get("usage"):
|
elif usage := data.get("usage"):
|
||||||
output.output_tokens = usage.get(
|
output.output_tokens = usage.get(
|
||||||
"completion_tokens"
|
"completion_tokens")
|
||||||
)
|
|
||||||
|
|
||||||
most_recent_timestamp = timestamp
|
most_recent_timestamp = timestamp
|
||||||
|
|
||||||
@ -550,7 +545,7 @@ async def async_request_openai_audio(
|
|||||||
|
|
||||||
|
|
||||||
def get_model(pretrained_model_name_or_path: str) -> str:
|
def get_model(pretrained_model_name_or_path: str) -> str:
|
||||||
if os.getenv("VLLM_USE_MODELSCOPE", "False").lower() == "true":
|
if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true':
|
||||||
from modelscope import snapshot_download
|
from modelscope import snapshot_download
|
||||||
|
|
||||||
from vllm.model_executor.model_loader.weight_utils import get_lock
|
from vllm.model_executor.model_loader.weight_utils import get_lock
|
||||||
@ -561,8 +556,7 @@ def get_model(pretrained_model_name_or_path: str) -> str:
|
|||||||
model_path = snapshot_download(
|
model_path = snapshot_download(
|
||||||
model_id=pretrained_model_name_or_path,
|
model_id=pretrained_model_name_or_path,
|
||||||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||||
ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"],
|
ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])
|
||||||
)
|
|
||||||
|
|
||||||
return model_path
|
return model_path
|
||||||
return pretrained_model_name_or_path
|
return pretrained_model_name_or_path
|
||||||
@ -575,23 +569,23 @@ def get_tokenizer(
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||||
if pretrained_model_name_or_path is not None and not os.path.exists(
|
if pretrained_model_name_or_path is not None and not os.path.exists(
|
||||||
pretrained_model_name_or_path
|
pretrained_model_name_or_path):
|
||||||
):
|
pretrained_model_name_or_path = get_model(
|
||||||
pretrained_model_name_or_path = get_model(pretrained_model_name_or_path)
|
pretrained_model_name_or_path)
|
||||||
if tokenizer_mode == "slow":
|
if tokenizer_mode == "slow":
|
||||||
if kwargs.get("use_fast", False):
|
if kwargs.get("use_fast", False):
|
||||||
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
|
raise ValueError(
|
||||||
|
"Cannot use the fast tokenizer in slow tokenizer mode.")
|
||||||
kwargs["use_fast"] = False
|
kwargs["use_fast"] = False
|
||||||
if tokenizer_mode == "mistral":
|
if tokenizer_mode == "mistral":
|
||||||
try:
|
try:
|
||||||
from vllm.transformers_utils.tokenizer import MistralTokenizer
|
from vllm.transformers_utils.tokenizer import MistralTokenizer
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ImportError(
|
raise ImportError("MistralTokenizer requires vllm package.\n"
|
||||||
"MistralTokenizer requires vllm package.\n"
|
"Please install it with `pip install vllm` "
|
||||||
"Please install it with `pip install vllm` "
|
"to use mistral tokenizer mode.") from e
|
||||||
"to use mistral tokenizer mode."
|
return MistralTokenizer.from_pretrained(
|
||||||
) from e
|
str(pretrained_model_name_or_path))
|
||||||
return MistralTokenizer.from_pretrained(str(pretrained_model_name_or_path))
|
|
||||||
else:
|
else:
|
||||||
return AutoTokenizer.from_pretrained(
|
return AutoTokenizer.from_pretrained(
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
@ -614,7 +608,7 @@ ASYNC_REQUEST_FUNCS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
OPENAI_COMPATIBLE_BACKENDS = [
|
OPENAI_COMPATIBLE_BACKENDS = [
|
||||||
k
|
k for k, v in ASYNC_REQUEST_FUNCS.items()
|
||||||
for k, v in ASYNC_REQUEST_FUNCS.items()
|
if v in (async_request_openai_completions,
|
||||||
if v in (async_request_openai_completions, async_request_openai_chat_completions)
|
async_request_openai_chat_completions)
|
||||||
]
|
]
|
||||||
|
|||||||
@ -35,7 +35,6 @@ from transformers import PreTrainedTokenizerBase
|
|||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.lora.utils import get_adapter_absolute_path
|
from vllm.lora.utils import get_adapter_absolute_path
|
||||||
from vllm.multimodal import MultiModalDataDict
|
from vllm.multimodal import MultiModalDataDict
|
||||||
from vllm.multimodal.image import convert_image_mode
|
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -83,12 +82,14 @@ class BenchmarkDataset(ABC):
|
|||||||
self.dataset_path = dataset_path
|
self.dataset_path = dataset_path
|
||||||
# Set the random seed, ensuring that a None value is replaced with the
|
# Set the random seed, ensuring that a None value is replaced with the
|
||||||
# default seed.
|
# default seed.
|
||||||
self.random_seed = random_seed if random_seed is not None else self.DEFAULT_SEED
|
self.random_seed = (random_seed
|
||||||
|
if random_seed is not None else self.DEFAULT_SEED)
|
||||||
self.data = None
|
self.data = None
|
||||||
|
|
||||||
def apply_multimodal_chat_transformation(
|
def apply_multimodal_chat_transformation(
|
||||||
self, prompt: str, mm_content: Optional[MultiModalDataDict] = None
|
self,
|
||||||
) -> list[dict]:
|
prompt: str,
|
||||||
|
mm_content: Optional[MultiModalDataDict] = None) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
Transform a prompt and optional multimodal content into a chat format.
|
Transform a prompt and optional multimodal content into a chat format.
|
||||||
This method is used for chat models that expect a specific conversation
|
This method is used for chat models that expect a specific conversation
|
||||||
@ -110,7 +111,8 @@ class BenchmarkDataset(ABC):
|
|||||||
NotImplementedError: If a subclass does not implement this method.
|
NotImplementedError: If a subclass does not implement this method.
|
||||||
"""
|
"""
|
||||||
# TODO (jenniferzhao): add support for downloading data
|
# TODO (jenniferzhao): add support for downloading data
|
||||||
raise NotImplementedError("load_data must be implemented in subclasses.")
|
raise NotImplementedError(
|
||||||
|
"load_data must be implemented in subclasses.")
|
||||||
|
|
||||||
def get_random_lora_request(
|
def get_random_lora_request(
|
||||||
self,
|
self,
|
||||||
@ -156,9 +158,8 @@ class BenchmarkDataset(ABC):
|
|||||||
return lora_request, lora_tokenizer_cache[lora_id] or tokenizer
|
return lora_request, lora_tokenizer_cache[lora_id] or tokenizer
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def sample(
|
def sample(self, tokenizer: PreTrainedTokenizerBase,
|
||||||
self, tokenizer: PreTrainedTokenizerBase, num_requests: int
|
num_requests: int) -> list[SampleRequest]:
|
||||||
) -> list[SampleRequest]:
|
|
||||||
"""
|
"""
|
||||||
Abstract method to generate sample requests from the dataset.
|
Abstract method to generate sample requests from the dataset.
|
||||||
|
|
||||||
@ -176,9 +177,8 @@ class BenchmarkDataset(ABC):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError("sample must be implemented in subclasses.")
|
raise NotImplementedError("sample must be implemented in subclasses.")
|
||||||
|
|
||||||
def maybe_oversample_requests(
|
def maybe_oversample_requests(self, requests: list[SampleRequest],
|
||||||
self, requests: list[SampleRequest], num_requests: int
|
num_requests: int) -> None:
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
Oversamples the list of requests if its size is less than the desired
|
Oversamples the list of requests if its size is less than the desired
|
||||||
number.
|
number.
|
||||||
@ -189,9 +189,11 @@ class BenchmarkDataset(ABC):
|
|||||||
"""
|
"""
|
||||||
if len(requests) < num_requests:
|
if len(requests) < num_requests:
|
||||||
random.seed(self.random_seed)
|
random.seed(self.random_seed)
|
||||||
additional = random.choices(requests, k=num_requests - len(requests))
|
additional = random.choices(requests,
|
||||||
|
k=num_requests - len(requests))
|
||||||
requests.extend(additional)
|
requests.extend(additional)
|
||||||
logger.info("Oversampled requests to reach %d total samples.", num_requests)
|
logger.info("Oversampled requests to reach %d total samples.",
|
||||||
|
num_requests)
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@ -216,14 +218,14 @@ def is_valid_sequence(
|
|||||||
"""
|
"""
|
||||||
# Check for invalid conditions
|
# Check for invalid conditions
|
||||||
prompt_too_short = prompt_len < min_len
|
prompt_too_short = prompt_len < min_len
|
||||||
output_too_short = (not skip_min_output_len_check) and (output_len < min_len)
|
output_too_short = (not skip_min_output_len_check) and (output_len
|
||||||
|
< min_len)
|
||||||
prompt_too_long = prompt_len > max_prompt_len
|
prompt_too_long = prompt_len > max_prompt_len
|
||||||
combined_too_long = (prompt_len + output_len) > max_total_len
|
combined_too_long = (prompt_len + output_len) > max_total_len
|
||||||
|
|
||||||
# Return True if none of the invalid conditions are met
|
# Return True if none of the invalid conditions are met
|
||||||
return not (
|
return not (prompt_too_short or output_too_short or prompt_too_long
|
||||||
prompt_too_short or output_too_short or prompt_too_long or combined_too_long
|
or combined_too_long)
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
@ -255,28 +257,28 @@ def process_image(image: Any) -> Mapping[str, Any]:
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If the input is not a supported type.
|
ValueError: If the input is not a supported type.
|
||||||
"""
|
"""
|
||||||
if isinstance(image, dict) and "bytes" in image:
|
if isinstance(image, dict) and 'bytes' in image:
|
||||||
image = Image.open(BytesIO(image["bytes"]))
|
image = Image.open(BytesIO(image['bytes']))
|
||||||
if isinstance(image, Image.Image):
|
if isinstance(image, Image.Image):
|
||||||
image = convert_image_mode(image, "RGB")
|
image = image.convert("RGB")
|
||||||
with io.BytesIO() as image_data:
|
with io.BytesIO() as image_data:
|
||||||
image.save(image_data, format="JPEG")
|
image.save(image_data, format="JPEG")
|
||||||
image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8")
|
image_base64 = base64.b64encode(
|
||||||
|
image_data.getvalue()).decode("utf-8")
|
||||||
return {
|
return {
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
"image_url": {"url": f"data:image/jpeg;base64,{image_base64}"},
|
"image_url": {
|
||||||
|
"url": f"data:image/jpeg;base64,{image_base64}"
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if isinstance(image, str):
|
if isinstance(image, str):
|
||||||
image_url = (
|
image_url = (image if image.startswith(
|
||||||
image if image.startswith(("http://", "file://")) else f"file://{image}"
|
("http://", "file://")) else f"file://{image}")
|
||||||
)
|
|
||||||
return {"type": "image_url", "image_url": {"url": image_url}}
|
return {"type": "image_url", "image_url": {"url": image_url}}
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(f"Invalid image input {image}. Must be a PIL.Image.Image"
|
||||||
f"Invalid image input {image}. Must be a PIL.Image.Image"
|
" or str or dictionary with raw image bytes.")
|
||||||
" or str or dictionary with raw image bytes."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@ -316,11 +318,8 @@ class RandomDataset(BenchmarkDataset):
|
|||||||
num_special_tokens = tokenizer.num_special_tokens_to_add()
|
num_special_tokens = tokenizer.num_special_tokens_to_add()
|
||||||
real_input_len = input_len - num_special_tokens
|
real_input_len = input_len - num_special_tokens
|
||||||
|
|
||||||
prefix_token_ids = (
|
prefix_token_ids = (np.random.randint(
|
||||||
np.random.randint(0, vocab_size, size=prefix_len).tolist()
|
0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else [])
|
||||||
if prefix_len > 0
|
|
||||||
else []
|
|
||||||
)
|
|
||||||
|
|
||||||
# New sampling logic: [X * (1 - b), X * (1 + b)]
|
# New sampling logic: [X * (1 - b), X * (1 + b)]
|
||||||
input_low = int(real_input_len * (1 - range_ratio))
|
input_low = int(real_input_len * (1 - range_ratio))
|
||||||
@ -330,17 +329,21 @@ class RandomDataset(BenchmarkDataset):
|
|||||||
|
|
||||||
# Add logging for debugging
|
# Add logging for debugging
|
||||||
logger.info("Sampling input_len from [%s, %s]", input_low, input_high)
|
logger.info("Sampling input_len from [%s, %s]", input_low, input_high)
|
||||||
logger.info("Sampling output_len from [%s, %s]", output_low, output_high)
|
logger.info("Sampling output_len from [%s, %s]", output_low,
|
||||||
|
output_high)
|
||||||
|
|
||||||
input_lens = np.random.randint(input_low, input_high + 1, size=num_requests)
|
input_lens = np.random.randint(input_low,
|
||||||
output_lens = np.random.randint(output_low, output_high + 1, size=num_requests)
|
input_high + 1,
|
||||||
|
size=num_requests)
|
||||||
|
output_lens = np.random.randint(output_low,
|
||||||
|
output_high + 1,
|
||||||
|
size=num_requests)
|
||||||
offsets = np.random.randint(0, vocab_size, size=num_requests)
|
offsets = np.random.randint(0, vocab_size, size=num_requests)
|
||||||
|
|
||||||
requests = []
|
requests = []
|
||||||
for i in range(num_requests):
|
for i in range(num_requests):
|
||||||
inner_seq = (
|
inner_seq = ((offsets[i] + i + np.arange(input_lens[i])) %
|
||||||
(offsets[i] + i + np.arange(input_lens[i])) % vocab_size
|
vocab_size).tolist()
|
||||||
).tolist()
|
|
||||||
token_sequence = prefix_token_ids + inner_seq
|
token_sequence = prefix_token_ids + inner_seq
|
||||||
prompt = tokenizer.decode(token_sequence)
|
prompt = tokenizer.decode(token_sequence)
|
||||||
# After decoding the prompt we have to encode and decode it again.
|
# After decoding the prompt we have to encode and decode it again.
|
||||||
@ -351,9 +354,8 @@ class RandomDataset(BenchmarkDataset):
|
|||||||
# [1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
|
# [1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
|
||||||
# To avoid uncontrolled change of the prompt length,
|
# To avoid uncontrolled change of the prompt length,
|
||||||
# the encoded sequence is truncated before being decode again.
|
# the encoded sequence is truncated before being decode again.
|
||||||
re_encoded_sequence = tokenizer.encode(prompt, add_special_tokens=False)[
|
re_encoded_sequence = tokenizer.encode(
|
||||||
: input_lens[i]
|
prompt, add_special_tokens=False)[:input_lens[i]]
|
||||||
]
|
|
||||||
prompt = tokenizer.decode(re_encoded_sequence)
|
prompt = tokenizer.decode(re_encoded_sequence)
|
||||||
total_input_len = prefix_len + int(input_lens[i])
|
total_input_len = prefix_len + int(input_lens[i])
|
||||||
requests.append(
|
requests.append(
|
||||||
@ -361,8 +363,7 @@ class RandomDataset(BenchmarkDataset):
|
|||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
prompt_len=total_input_len,
|
prompt_len=total_input_len,
|
||||||
expected_output_len=int(output_lens[i]),
|
expected_output_len=int(output_lens[i]),
|
||||||
)
|
))
|
||||||
)
|
|
||||||
return requests
|
return requests
|
||||||
|
|
||||||
|
|
||||||
@ -389,8 +390,7 @@ class ShareGPTDataset(BenchmarkDataset):
|
|||||||
self.data = json.load(f)
|
self.data = json.load(f)
|
||||||
# Filter entries with at least two conversation turns.
|
# Filter entries with at least two conversation turns.
|
||||||
self.data = [
|
self.data = [
|
||||||
entry
|
entry for entry in self.data
|
||||||
for entry in self.data
|
|
||||||
if "conversations" in entry and len(entry["conversations"]) >= 2
|
if "conversations" in entry and len(entry["conversations"]) >= 2
|
||||||
]
|
]
|
||||||
random.seed(self.random_seed)
|
random.seed(self.random_seed)
|
||||||
@ -416,28 +416,27 @@ class ShareGPTDataset(BenchmarkDataset):
|
|||||||
)
|
)
|
||||||
|
|
||||||
lora_request, tokenizer = self.get_random_lora_request(
|
lora_request, tokenizer = self.get_random_lora_request(
|
||||||
tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path
|
tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path)
|
||||||
)
|
|
||||||
prompt_ids = tokenizer(prompt).input_ids
|
prompt_ids = tokenizer(prompt).input_ids
|
||||||
completion_ids = tokenizer(completion).input_ids
|
completion_ids = tokenizer(completion).input_ids
|
||||||
prompt_len = len(prompt_ids)
|
prompt_len = len(prompt_ids)
|
||||||
new_output_len = len(completion_ids) if output_len is None else output_len
|
new_output_len = (len(completion_ids)
|
||||||
if not is_valid_sequence(
|
if output_len is None else output_len)
|
||||||
prompt_len,
|
if not is_valid_sequence(prompt_len,
|
||||||
new_output_len,
|
new_output_len,
|
||||||
skip_min_output_len_check=output_len is not None,
|
skip_min_output_len_check=output_len
|
||||||
):
|
is not None):
|
||||||
continue
|
continue
|
||||||
if enable_multimodal_chat:
|
if enable_multimodal_chat:
|
||||||
prompt = self.apply_multimodal_chat_transformation(prompt, None)
|
prompt = self.apply_multimodal_chat_transformation(
|
||||||
|
prompt, None)
|
||||||
samples.append(
|
samples.append(
|
||||||
SampleRequest(
|
SampleRequest(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
prompt_len=prompt_len,
|
prompt_len=prompt_len,
|
||||||
expected_output_len=new_output_len,
|
expected_output_len=new_output_len,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
)
|
))
|
||||||
)
|
|
||||||
self.maybe_oversample_requests(samples, num_requests)
|
self.maybe_oversample_requests(samples, num_requests)
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
@ -483,20 +482,20 @@ class SonnetDataset(BenchmarkDataset):
|
|||||||
) -> list:
|
) -> list:
|
||||||
# Calculate average token length for a poem line.
|
# Calculate average token length for a poem line.
|
||||||
tokenized_lines = [tokenizer(line).input_ids for line in self.data]
|
tokenized_lines = [tokenizer(line).input_ids for line in self.data]
|
||||||
avg_len = sum(len(tokens) for tokens in tokenized_lines) / len(tokenized_lines)
|
avg_len = sum(len(tokens)
|
||||||
|
for tokens in tokenized_lines) / len(tokenized_lines)
|
||||||
|
|
||||||
# Build the base prompt.
|
# Build the base prompt.
|
||||||
base_prompt = "Pick as many lines as you can from these poem lines:\n"
|
base_prompt = "Pick as many lines as you can from these poem lines:\n"
|
||||||
base_msg = [{"role": "user", "content": base_prompt}]
|
base_msg = [{"role": "user", "content": base_prompt}]
|
||||||
base_fmt = tokenizer.apply_chat_template(
|
base_fmt = tokenizer.apply_chat_template(base_msg,
|
||||||
base_msg, add_generation_prompt=True, tokenize=False
|
add_generation_prompt=True,
|
||||||
)
|
tokenize=False)
|
||||||
base_offset = len(tokenizer(base_fmt).input_ids)
|
base_offset = len(tokenizer(base_fmt).input_ids)
|
||||||
if input_len <= base_offset:
|
if input_len <= base_offset:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"'input_len' must be higher than the base prompt length "
|
f"'input_len' must be higher than the base prompt length "
|
||||||
f"({base_offset})."
|
f"({base_offset}).")
|
||||||
)
|
|
||||||
|
|
||||||
# Determine how many poem lines to use.
|
# Determine how many poem lines to use.
|
||||||
num_input_lines = round((input_len - base_offset) / avg_len)
|
num_input_lines = round((input_len - base_offset) / avg_len)
|
||||||
@ -505,23 +504,21 @@ class SonnetDataset(BenchmarkDataset):
|
|||||||
|
|
||||||
samples = []
|
samples = []
|
||||||
while len(samples) < num_requests:
|
while len(samples) < num_requests:
|
||||||
extra_lines = random.choices(
|
extra_lines = random.choices(self.data,
|
||||||
self.data, k=num_input_lines - num_prefix_lines
|
k=num_input_lines - num_prefix_lines)
|
||||||
)
|
|
||||||
prompt = f"{base_prompt}{''.join(prefix_lines + extra_lines)}"
|
prompt = f"{base_prompt}{''.join(prefix_lines + extra_lines)}"
|
||||||
msg = [{"role": "user", "content": prompt}]
|
msg = [{"role": "user", "content": prompt}]
|
||||||
prompt_formatted = tokenizer.apply_chat_template(
|
prompt_formatted = tokenizer.apply_chat_template(
|
||||||
msg, add_generation_prompt=True, tokenize=False
|
msg, add_generation_prompt=True, tokenize=False)
|
||||||
)
|
|
||||||
prompt_len = len(tokenizer(prompt_formatted).input_ids)
|
prompt_len = len(tokenizer(prompt_formatted).input_ids)
|
||||||
if prompt_len <= input_len:
|
if prompt_len <= input_len:
|
||||||
samples.append(
|
samples.append(
|
||||||
SampleRequest(
|
SampleRequest(
|
||||||
prompt=prompt_formatted if return_prompt_formatted else prompt,
|
prompt=prompt_formatted
|
||||||
|
if return_prompt_formatted else prompt,
|
||||||
prompt_len=prompt_len,
|
prompt_len=prompt_len,
|
||||||
expected_output_len=output_len,
|
expected_output_len=output_len,
|
||||||
)
|
))
|
||||||
)
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
||||||
@ -541,9 +538,7 @@ class BurstGPTDataset(BenchmarkDataset):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.load_data()
|
self.load_data()
|
||||||
|
|
||||||
def load_data(
|
def load_data(self, ):
|
||||||
self,
|
|
||||||
):
|
|
||||||
if self.dataset_path is None:
|
if self.dataset_path is None:
|
||||||
raise ValueError("dataset_path must be provided for loading data.")
|
raise ValueError("dataset_path must be provided for loading data.")
|
||||||
|
|
||||||
@ -557,7 +552,8 @@ class BurstGPTDataset(BenchmarkDataset):
|
|||||||
|
|
||||||
def _sample_loaded_data(self, num_requests: int) -> list:
|
def _sample_loaded_data(self, num_requests: int) -> list:
|
||||||
if num_requests <= len(self.data):
|
if num_requests <= len(self.data):
|
||||||
data = self.data.sample(n=num_requests, random_state=self.random_seed)
|
data = self.data.sample(n=num_requests,
|
||||||
|
random_state=self.random_seed)
|
||||||
else:
|
else:
|
||||||
data = self.data.sample(
|
data = self.data.sample(
|
||||||
n=num_requests,
|
n=num_requests,
|
||||||
@ -581,8 +577,7 @@ class BurstGPTDataset(BenchmarkDataset):
|
|||||||
input_len = int(data[i][2])
|
input_len = int(data[i][2])
|
||||||
output_len = int(data[i][3])
|
output_len = int(data[i][3])
|
||||||
lora_req, tokenizer = self.get_random_lora_request(
|
lora_req, tokenizer = self.get_random_lora_request(
|
||||||
tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path
|
tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path)
|
||||||
)
|
|
||||||
vocab_size = tokenizer.vocab_size
|
vocab_size = tokenizer.vocab_size
|
||||||
# Generate a synthetic prompt: a list of token IDs computed as (i +
|
# Generate a synthetic prompt: a list of token IDs computed as (i +
|
||||||
# j) modulo vocab_size.
|
# j) modulo vocab_size.
|
||||||
@ -594,8 +589,7 @@ class BurstGPTDataset(BenchmarkDataset):
|
|||||||
prompt_len=input_len,
|
prompt_len=input_len,
|
||||||
expected_output_len=output_len,
|
expected_output_len=output_len,
|
||||||
lora_request=lora_req,
|
lora_request=lora_req,
|
||||||
)
|
))
|
||||||
)
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
||||||
@ -638,23 +632,20 @@ class HuggingFaceDataset(BenchmarkDataset):
|
|||||||
|
|
||||||
class ConversationDataset(HuggingFaceDataset):
|
class ConversationDataset(HuggingFaceDataset):
|
||||||
"""Dataset for conversation data with multimodal support."""
|
"""Dataset for conversation data with multimodal support."""
|
||||||
|
|
||||||
SUPPORTED_DATASET_PATHS = {
|
SUPPORTED_DATASET_PATHS = {
|
||||||
"lmms-lab/LLaVA-OneVision-Data",
|
'lmms-lab/LLaVA-OneVision-Data', 'Aeala/ShareGPT_Vicuna_unfiltered'
|
||||||
"Aeala/ShareGPT_Vicuna_unfiltered",
|
|
||||||
}
|
}
|
||||||
IS_MULTIMODAL = True
|
IS_MULTIMODAL = True
|
||||||
|
|
||||||
def sample(
|
def sample(self,
|
||||||
self,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
num_requests: int,
|
||||||
num_requests: int,
|
output_len: Optional[int] = None,
|
||||||
output_len: Optional[int] = None,
|
enable_multimodal_chat: bool = False,
|
||||||
enable_multimodal_chat: bool = False,
|
**kwargs) -> list:
|
||||||
**kwargs,
|
|
||||||
) -> list:
|
|
||||||
# Filter examples with at least 2 conversations
|
# Filter examples with at least 2 conversations
|
||||||
filtered_data = self.data.filter(lambda x: len(x["conversations"]) >= 2)
|
filtered_data = self.data.filter(
|
||||||
|
lambda x: len(x["conversations"]) >= 2)
|
||||||
sampled_requests = []
|
sampled_requests = []
|
||||||
dynamic_output = output_len is None
|
dynamic_output = output_len is None
|
||||||
|
|
||||||
@ -670,22 +661,24 @@ class ConversationDataset(HuggingFaceDataset):
|
|||||||
completion_len = len(completion_ids)
|
completion_len = len(completion_ids)
|
||||||
output_len = completion_len if dynamic_output else output_len
|
output_len = completion_len if dynamic_output else output_len
|
||||||
assert isinstance(output_len, int) and output_len > 0
|
assert isinstance(output_len, int) and output_len > 0
|
||||||
if dynamic_output and not is_valid_sequence(prompt_len, completion_len):
|
if dynamic_output and not is_valid_sequence(
|
||||||
|
prompt_len, completion_len):
|
||||||
continue
|
continue
|
||||||
mm_content = process_image(item["image"]) if "image" in item else None
|
mm_content = process_image(
|
||||||
|
item["image"]) if "image" in item else None
|
||||||
if enable_multimodal_chat:
|
if enable_multimodal_chat:
|
||||||
# Note: when chat is enabled the request prompt_len is no longer
|
# Note: when chat is enabled the request prompt_len is no longer
|
||||||
# accurate and we will be using request output to count the
|
# accurate and we will be using request output to count the
|
||||||
# actual prompt len and output len
|
# actual prompt len and output len
|
||||||
prompt = self.apply_multimodal_chat_transformation(prompt, mm_content)
|
prompt = self.apply_multimodal_chat_transformation(
|
||||||
|
prompt, mm_content)
|
||||||
sampled_requests.append(
|
sampled_requests.append(
|
||||||
SampleRequest(
|
SampleRequest(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
prompt_len=prompt_len,
|
prompt_len=prompt_len,
|
||||||
expected_output_len=output_len,
|
expected_output_len=output_len,
|
||||||
multi_modal_data=mm_content,
|
multi_modal_data=mm_content,
|
||||||
)
|
))
|
||||||
)
|
|
||||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||||
return sampled_requests
|
return sampled_requests
|
||||||
|
|
||||||
@ -702,8 +695,10 @@ class VisionArenaDataset(HuggingFaceDataset):
|
|||||||
|
|
||||||
DEFAULT_OUTPUT_LEN = 128
|
DEFAULT_OUTPUT_LEN = 128
|
||||||
SUPPORTED_DATASET_PATHS = {
|
SUPPORTED_DATASET_PATHS = {
|
||||||
"lmarena-ai/VisionArena-Chat": lambda x: x["conversation"][0][0]["content"],
|
"lmarena-ai/VisionArena-Chat":
|
||||||
"lmarena-ai/vision-arena-bench-v0.1": lambda x: x["turns"][0][0]["content"],
|
lambda x: x["conversation"][0][0]["content"],
|
||||||
|
"lmarena-ai/vision-arena-bench-v0.1":
|
||||||
|
lambda x: x["turns"][0][0]["content"]
|
||||||
}
|
}
|
||||||
IS_MULTIMODAL = True
|
IS_MULTIMODAL = True
|
||||||
|
|
||||||
@ -715,14 +710,16 @@ class VisionArenaDataset(HuggingFaceDataset):
|
|||||||
enable_multimodal_chat: bool = False,
|
enable_multimodal_chat: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> list:
|
) -> list:
|
||||||
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
|
output_len = (output_len
|
||||||
|
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
|
||||||
sampled_requests = []
|
sampled_requests = []
|
||||||
for item in self.data:
|
for item in self.data:
|
||||||
if len(sampled_requests) >= num_requests:
|
if len(sampled_requests) >= num_requests:
|
||||||
break
|
break
|
||||||
parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path)
|
parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path)
|
||||||
if parser_fn is None:
|
if parser_fn is None:
|
||||||
raise ValueError(f"Unsupported dataset path: {self.dataset_path}")
|
raise ValueError(
|
||||||
|
f"Unsupported dataset path: {self.dataset_path}")
|
||||||
prompt = parser_fn(item)
|
prompt = parser_fn(item)
|
||||||
mm_content = process_image(item["images"][0])
|
mm_content = process_image(item["images"][0])
|
||||||
prompt_len = len(tokenizer(prompt).input_ids)
|
prompt_len = len(tokenizer(prompt).input_ids)
|
||||||
@ -730,15 +727,15 @@ class VisionArenaDataset(HuggingFaceDataset):
|
|||||||
# Note: when chat is enabled the request prompt_len is no longer
|
# Note: when chat is enabled the request prompt_len is no longer
|
||||||
# accurate and we will be using request output to count the
|
# accurate and we will be using request output to count the
|
||||||
# actual prompt len
|
# actual prompt len
|
||||||
prompt = self.apply_multimodal_chat_transformation(prompt, mm_content)
|
prompt = self.apply_multimodal_chat_transformation(
|
||||||
|
prompt, mm_content)
|
||||||
sampled_requests.append(
|
sampled_requests.append(
|
||||||
SampleRequest(
|
SampleRequest(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
prompt_len=prompt_len,
|
prompt_len=prompt_len,
|
||||||
expected_output_len=output_len,
|
expected_output_len=output_len,
|
||||||
multi_modal_data=mm_content,
|
multi_modal_data=mm_content,
|
||||||
)
|
))
|
||||||
)
|
|
||||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||||
return sampled_requests
|
return sampled_requests
|
||||||
|
|
||||||
@ -763,15 +760,14 @@ class InstructCoderDataset(HuggingFaceDataset):
|
|||||||
"likaixin/InstructCoder",
|
"likaixin/InstructCoder",
|
||||||
}
|
}
|
||||||
|
|
||||||
def sample(
|
def sample(self,
|
||||||
self,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
num_requests: int,
|
||||||
num_requests: int,
|
output_len: Optional[int] = None,
|
||||||
output_len: Optional[int] = None,
|
enable_multimodal_chat: bool = False,
|
||||||
enable_multimodal_chat: bool = False,
|
**kwargs) -> list:
|
||||||
**kwargs,
|
output_len = (output_len
|
||||||
) -> list:
|
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
|
||||||
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
|
|
||||||
sampled_requests = []
|
sampled_requests = []
|
||||||
for item in self.data:
|
for item in self.data:
|
||||||
if len(sampled_requests) >= num_requests:
|
if len(sampled_requests) >= num_requests:
|
||||||
@ -783,8 +779,7 @@ class InstructCoderDataset(HuggingFaceDataset):
|
|||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
prompt_len=prompt_len,
|
prompt_len=prompt_len,
|
||||||
expected_output_len=output_len,
|
expected_output_len=output_len,
|
||||||
)
|
))
|
||||||
)
|
|
||||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||||
return sampled_requests
|
return sampled_requests
|
||||||
|
|
||||||
@ -799,38 +794,38 @@ class MTBenchDataset(HuggingFaceDataset):
|
|||||||
MT-Bench Dataset.
|
MT-Bench Dataset.
|
||||||
https://huggingface.co/datasets/philschmid/mt-bench
|
https://huggingface.co/datasets/philschmid/mt-bench
|
||||||
|
|
||||||
We create a single turn dataset for MT-Bench.
|
We create a single turn dataset for MT-Bench.
|
||||||
This is similar to Spec decoding benchmark setup in vLLM
|
This is similar to Spec decoding benchmark setup in vLLM
|
||||||
https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18
|
https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
|
||||||
DEFAULT_OUTPUT_LEN = 256 # avg len used in SD bench in vLLM
|
DEFAULT_OUTPUT_LEN = 256 # avg len used in SD bench in vLLM
|
||||||
SUPPORTED_DATASET_PATHS = {
|
SUPPORTED_DATASET_PATHS = {
|
||||||
"philschmid/mt-bench",
|
"philschmid/mt-bench",
|
||||||
}
|
}
|
||||||
|
|
||||||
def sample(
|
def sample(self,
|
||||||
self,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
num_requests: int,
|
||||||
num_requests: int,
|
output_len: Optional[int] = None,
|
||||||
output_len: Optional[int] = None,
|
enable_multimodal_chat: bool = False,
|
||||||
enable_multimodal_chat: bool = False,
|
**kwargs) -> list:
|
||||||
**kwargs,
|
output_len = (output_len
|
||||||
) -> list:
|
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
|
||||||
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
|
|
||||||
sampled_requests = []
|
sampled_requests = []
|
||||||
|
|
||||||
for item in self.data:
|
for item in self.data:
|
||||||
if len(sampled_requests) >= num_requests:
|
if len(sampled_requests) >= num_requests:
|
||||||
break
|
break
|
||||||
prompt = item["turns"][0]
|
prompt = item['turns'][0]
|
||||||
|
|
||||||
# apply template
|
# apply template
|
||||||
prompt = tokenizer.apply_chat_template(
|
prompt = tokenizer.apply_chat_template([{
|
||||||
[{"role": "user", "content": prompt}],
|
"role": "user",
|
||||||
add_generation_prompt=True,
|
"content": prompt
|
||||||
tokenize=False,
|
}],
|
||||||
)
|
add_generation_prompt=True,
|
||||||
|
tokenize=False)
|
||||||
|
|
||||||
prompt_len = len(tokenizer(prompt).input_ids)
|
prompt_len = len(tokenizer(prompt).input_ids)
|
||||||
sampled_requests.append(
|
sampled_requests.append(
|
||||||
@ -838,8 +833,7 @@ class MTBenchDataset(HuggingFaceDataset):
|
|||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
prompt_len=prompt_len,
|
prompt_len=prompt_len,
|
||||||
expected_output_len=output_len,
|
expected_output_len=output_len,
|
||||||
)
|
))
|
||||||
)
|
|
||||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||||
return sampled_requests
|
return sampled_requests
|
||||||
|
|
||||||
@ -853,27 +847,23 @@ class AIMODataset(HuggingFaceDataset):
|
|||||||
"""
|
"""
|
||||||
Dataset class for processing a AIMO dataset with reasoning questions.
|
Dataset class for processing a AIMO dataset with reasoning questions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
SUPPORTED_DATASET_PATHS = {
|
SUPPORTED_DATASET_PATHS = {
|
||||||
"AI-MO/aimo-validation-aime",
|
"AI-MO/aimo-validation-aime", "AI-MO/NuminaMath-1.5",
|
||||||
"AI-MO/NuminaMath-1.5",
|
"AI-MO/NuminaMath-CoT"
|
||||||
"AI-MO/NuminaMath-CoT",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def sample(
|
def sample(self,
|
||||||
self,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
num_requests: int,
|
||||||
num_requests: int,
|
output_len: Optional[int] = None,
|
||||||
output_len: Optional[int] = None,
|
**kwargs) -> list:
|
||||||
**kwargs,
|
|
||||||
) -> list:
|
|
||||||
sampled_requests = []
|
sampled_requests = []
|
||||||
dynamic_output = output_len is None
|
dynamic_output = output_len is None
|
||||||
|
|
||||||
for item in self.data:
|
for item in self.data:
|
||||||
if len(sampled_requests) >= num_requests:
|
if len(sampled_requests) >= num_requests:
|
||||||
break
|
break
|
||||||
prompt, completion = item["problem"], item["solution"]
|
prompt, completion = item['problem'], item["solution"]
|
||||||
|
|
||||||
prompt_ids = tokenizer(prompt).input_ids
|
prompt_ids = tokenizer(prompt).input_ids
|
||||||
completion_ids = tokenizer(completion).input_ids
|
completion_ids = tokenizer(completion).input_ids
|
||||||
@ -881,9 +871,10 @@ class AIMODataset(HuggingFaceDataset):
|
|||||||
completion_len = len(completion_ids)
|
completion_len = len(completion_ids)
|
||||||
output_len = completion_len if dynamic_output else output_len
|
output_len = completion_len if dynamic_output else output_len
|
||||||
assert isinstance(output_len, int) and output_len > 0
|
assert isinstance(output_len, int) and output_len > 0
|
||||||
if dynamic_output and not is_valid_sequence(
|
if dynamic_output and not is_valid_sequence(prompt_len,
|
||||||
prompt_len, completion_len, max_prompt_len=2048, max_total_len=32000
|
completion_len,
|
||||||
):
|
max_prompt_len=2048,
|
||||||
|
max_total_len=32000):
|
||||||
continue
|
continue
|
||||||
sampled_requests.append(
|
sampled_requests.append(
|
||||||
SampleRequest(
|
SampleRequest(
|
||||||
@ -891,8 +882,7 @@ class AIMODataset(HuggingFaceDataset):
|
|||||||
prompt_len=prompt_len,
|
prompt_len=prompt_len,
|
||||||
expected_output_len=output_len,
|
expected_output_len=output_len,
|
||||||
multi_modal_data=None,
|
multi_modal_data=None,
|
||||||
)
|
))
|
||||||
)
|
|
||||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||||
return sampled_requests
|
return sampled_requests
|
||||||
|
|
||||||
@ -915,25 +905,25 @@ You are a code completion assistant and your task is to analyze user edits and t
|
|||||||
|
|
||||||
### Response:
|
### Response:
|
||||||
|
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
|
||||||
|
|
||||||
def _format_zeta_prompt(
|
def _format_zeta_prompt(
|
||||||
sample: dict, original_start_marker: str = "<|editable_region_start|>"
|
sample: dict,
|
||||||
) -> dict:
|
original_start_marker: str = "<|editable_region_start|>") -> dict:
|
||||||
"""Format the zeta prompt for the Next Edit Prediction (NEP) dataset.
|
"""Format the zeta prompt for the Next Edit Prediction (NEP) dataset.
|
||||||
|
|
||||||
This function formats examples from the NEP dataset
|
This function formats examples from the NEP dataset
|
||||||
into prompts and expected outputs. It could be
|
into prompts and expected outputs. It could be
|
||||||
further extended to support more NEP datasets.
|
further extended to support more NEP datasets.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sample: The dataset sample containing events,
|
sample: The dataset sample containing events,
|
||||||
inputs, and outputs.
|
inputs, and outputs.
|
||||||
original_start_marker: The marker indicating the
|
original_start_marker: The marker indicating the
|
||||||
start of the editable region. Defaults to
|
start of the editable region. Defaults to
|
||||||
"<|editable_region_start|>".
|
"<|editable_region_start|>".
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary with the formatted prompts and expected outputs.
|
A dictionary with the formatted prompts and expected outputs.
|
||||||
"""
|
"""
|
||||||
@ -963,8 +953,10 @@ class NextEditPredictionDataset(HuggingFaceDataset):
|
|||||||
"zed-industries/zeta": _format_zeta_prompt,
|
"zed-industries/zeta": _format_zeta_prompt,
|
||||||
}
|
}
|
||||||
|
|
||||||
def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int, **kwargs):
|
def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int,
|
||||||
formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(self.dataset_path)
|
**kwargs):
|
||||||
|
formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(
|
||||||
|
self.dataset_path)
|
||||||
if formatting_prompt_func is None:
|
if formatting_prompt_func is None:
|
||||||
raise ValueError(f"Unsupported dataset path: {self.dataset_path}")
|
raise ValueError(f"Unsupported dataset path: {self.dataset_path}")
|
||||||
samples = []
|
samples = []
|
||||||
@ -975,10 +967,8 @@ class NextEditPredictionDataset(HuggingFaceDataset):
|
|||||||
prompt=sample["prompt"],
|
prompt=sample["prompt"],
|
||||||
prompt_len=len(tokenizer(sample["prompt"]).input_ids),
|
prompt_len=len(tokenizer(sample["prompt"]).input_ids),
|
||||||
expected_output_len=len(
|
expected_output_len=len(
|
||||||
tokenizer(sample["expected_output"]).input_ids
|
tokenizer(sample["expected_output"]).input_ids),
|
||||||
),
|
))
|
||||||
)
|
|
||||||
)
|
|
||||||
if len(samples) >= num_requests:
|
if len(samples) >= num_requests:
|
||||||
break
|
break
|
||||||
self.maybe_oversample_requests(samples, num_requests)
|
self.maybe_oversample_requests(samples, num_requests)
|
||||||
@ -1007,22 +997,18 @@ class ASRDataset(HuggingFaceDataset):
|
|||||||
| AMI | Meetings | Spontaneous | ihm, sdm |
|
| AMI | Meetings | Spontaneous | ihm, sdm |
|
||||||
+----------------+----------------------------------------+--------------------------+-----------------------------+
|
+----------------+----------------------------------------+--------------------------+-----------------------------+
|
||||||
|
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
|
||||||
SUPPORTED_DATASET_PATHS = {
|
SUPPORTED_DATASET_PATHS = {
|
||||||
"openslr/librispeech_asr",
|
"openslr/librispeech_asr", "facebook/voxpopuli", "LIUM/tedlium",
|
||||||
"facebook/voxpopuli",
|
"edinburghcstr/ami", "speechcolab/gigaspeech", "kensho/spgispeech"
|
||||||
"LIUM/tedlium",
|
|
||||||
"edinburghcstr/ami",
|
|
||||||
"speechcolab/gigaspeech",
|
|
||||||
"kensho/spgispeech",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
DEFAULT_OUTPUT_LEN = 128
|
DEFAULT_OUTPUT_LEN = 128
|
||||||
IS_MULTIMODAL = True
|
IS_MULTIMODAL = True
|
||||||
|
|
||||||
# TODO Whisper-specific. Abstract interface when more models are supported.
|
# TODO Whisper-specific. Abstract interface when more models are supported.
|
||||||
TRANSCRIPTION_PREAMBLE = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>"
|
TRANSCRIPTION_PREAMBLE = "<|startoftranscript|><|en|><|transcribe|>"\
|
||||||
|
"<|notimestamps|>"
|
||||||
skip_long_audios: bool = True
|
skip_long_audios: bool = True
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
@ -1033,8 +1019,8 @@ class ASRDataset(HuggingFaceDataset):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> list:
|
) -> list:
|
||||||
import librosa
|
import librosa
|
||||||
|
output_len = (output_len
|
||||||
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
|
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
|
||||||
prompt = ASRDataset.TRANSCRIPTION_PREAMBLE
|
prompt = ASRDataset.TRANSCRIPTION_PREAMBLE
|
||||||
prompt_len = len(tokenizer(prompt).input_ids)
|
prompt_len = len(tokenizer(prompt).input_ids)
|
||||||
sampled_requests = []
|
sampled_requests = []
|
||||||
@ -1057,14 +1043,10 @@ class ASRDataset(HuggingFaceDataset):
|
|||||||
prompt_len=prompt_len,
|
prompt_len=prompt_len,
|
||||||
expected_output_len=output_len,
|
expected_output_len=output_len,
|
||||||
multi_modal_data=mm_content,
|
multi_modal_data=mm_content,
|
||||||
)
|
))
|
||||||
)
|
|
||||||
if skipped:
|
if skipped:
|
||||||
logger.warning(
|
logger.warning("%d samples discarded from dataset due to" \
|
||||||
"%d samples discarded from dataset due to"
|
" their length being greater than" \
|
||||||
" their length being greater than"
|
" what Whisper supports.", skipped)
|
||||||
" what Whisper supports.",
|
|
||||||
skipped,
|
|
||||||
)
|
|
||||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||||
return sampled_requests
|
return sampled_requests
|
||||||
|
|||||||
@ -11,9 +11,9 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.inputs import PromptType
|
from vllm.inputs import PromptType
|
||||||
@ -21,14 +21,13 @@ from vllm.sampling_params import BeamSearchParams
|
|||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
|
||||||
def save_to_pytorch_benchmark_format(
|
def save_to_pytorch_benchmark_format(args: argparse.Namespace,
|
||||||
args: argparse.Namespace, results: dict[str, Any]
|
results: dict[str, Any]) -> None:
|
||||||
) -> None:
|
|
||||||
pt_records = convert_to_pytorch_benchmark_format(
|
pt_records = convert_to_pytorch_benchmark_format(
|
||||||
args=args,
|
args=args,
|
||||||
metrics={"latency": results["latencies"]},
|
metrics={"latency": results["latencies"]},
|
||||||
extra_info={k: results[k] for k in ["avg_latency", "percentiles"]},
|
extra_info={k: results[k]
|
||||||
)
|
for k in ["avg_latency", "percentiles"]})
|
||||||
if pt_records:
|
if pt_records:
|
||||||
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
|
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
|
||||||
write_to_json(pt_file, pt_records)
|
write_to_json(pt_file, pt_records)
|
||||||
@ -43,11 +42,9 @@ def main(args: argparse.Namespace):
|
|||||||
# the engine will automatically process the request in multiple batches.
|
# the engine will automatically process the request in multiple batches.
|
||||||
llm = LLM(**dataclasses.asdict(engine_args))
|
llm = LLM(**dataclasses.asdict(engine_args))
|
||||||
assert llm.llm_engine.model_config.max_model_len >= (
|
assert llm.llm_engine.model_config.max_model_len >= (
|
||||||
args.input_len + args.output_len
|
args.input_len +
|
||||||
), (
|
args.output_len), ("Please ensure that max_model_len is greater than"
|
||||||
"Please ensure that max_model_len is greater than"
|
" the sum of input_len and output_len.")
|
||||||
" the sum of input_len and output_len."
|
|
||||||
)
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
n=args.n,
|
n=args.n,
|
||||||
@ -58,16 +55,18 @@ def main(args: argparse.Namespace):
|
|||||||
detokenize=not args.disable_detokenize,
|
detokenize=not args.disable_detokenize,
|
||||||
)
|
)
|
||||||
print(sampling_params)
|
print(sampling_params)
|
||||||
dummy_prompt_token_ids = np.random.randint(
|
dummy_prompt_token_ids = np.random.randint(10000,
|
||||||
10000, size=(args.batch_size, args.input_len)
|
size=(args.batch_size,
|
||||||
)
|
args.input_len))
|
||||||
dummy_prompts: list[PromptType] = [
|
dummy_prompts: list[PromptType] = [{
|
||||||
{"prompt_token_ids": batch} for batch in dummy_prompt_token_ids.tolist()
|
"prompt_token_ids": batch
|
||||||
]
|
} for batch in dummy_prompt_token_ids.tolist()]
|
||||||
|
|
||||||
def llm_generate():
|
def llm_generate():
|
||||||
if not args.use_beam_search:
|
if not args.use_beam_search:
|
||||||
llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False)
|
llm.generate(dummy_prompts,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
use_tqdm=False)
|
||||||
else:
|
else:
|
||||||
llm.beam_search(
|
llm.beam_search(
|
||||||
dummy_prompts,
|
dummy_prompts,
|
||||||
@ -81,13 +80,12 @@ def main(args: argparse.Namespace):
|
|||||||
def run_to_completion(profile_dir: Optional[str] = None):
|
def run_to_completion(profile_dir: Optional[str] = None):
|
||||||
if profile_dir:
|
if profile_dir:
|
||||||
with torch.profiler.profile(
|
with torch.profiler.profile(
|
||||||
activities=[
|
activities=[
|
||||||
torch.profiler.ProfilerActivity.CPU,
|
torch.profiler.ProfilerActivity.CPU,
|
||||||
torch.profiler.ProfilerActivity.CUDA,
|
torch.profiler.ProfilerActivity.CUDA,
|
||||||
],
|
],
|
||||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||||
str(profile_dir)
|
str(profile_dir)),
|
||||||
),
|
|
||||||
) as p:
|
) as p:
|
||||||
llm_generate()
|
llm_generate()
|
||||||
print(p.key_averages().table(sort_by="self_cuda_time_total"))
|
print(p.key_averages().table(sort_by="self_cuda_time_total"))
|
||||||
@ -105,9 +103,8 @@ def main(args: argparse.Namespace):
|
|||||||
if args.profile:
|
if args.profile:
|
||||||
profile_dir = args.profile_result_dir
|
profile_dir = args.profile_result_dir
|
||||||
if not profile_dir:
|
if not profile_dir:
|
||||||
profile_dir = (
|
profile_dir = (Path(".") / "vllm_benchmark_result" /
|
||||||
Path(".") / "vllm_benchmark_result" / f"latency_result_{time.time()}"
|
f"latency_result_{time.time()}")
|
||||||
)
|
|
||||||
print(f"Profiling (results will be saved to '{profile_dir}')...")
|
print(f"Profiling (results will be saved to '{profile_dir}')...")
|
||||||
run_to_completion(profile_dir=profile_dir)
|
run_to_completion(profile_dir=profile_dir)
|
||||||
return
|
return
|
||||||
@ -138,8 +135,7 @@ def main(args: argparse.Namespace):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = FlexibleArgumentParser(
|
parser = FlexibleArgumentParser(
|
||||||
description="Benchmark the latency of processing a single batch of "
|
description="Benchmark the latency of processing a single batch of "
|
||||||
"requests till completion."
|
"requests till completion.")
|
||||||
)
|
|
||||||
parser.add_argument("--input-len", type=int, default=32)
|
parser.add_argument("--input-len", type=int, default=32)
|
||||||
parser.add_argument("--output-len", type=int, default=128)
|
parser.add_argument("--output-len", type=int, default=128)
|
||||||
parser.add_argument("--batch-size", type=int, default=8)
|
parser.add_argument("--batch-size", type=int, default=8)
|
||||||
@ -156,9 +152,10 @@ if __name__ == "__main__":
|
|||||||
default=10,
|
default=10,
|
||||||
help="Number of iterations to run for warmup.",
|
help="Number of iterations to run for warmup.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument("--num-iters",
|
||||||
"--num-iters", type=int, default=30, help="Number of iterations to run."
|
type=int,
|
||||||
)
|
default=30,
|
||||||
|
help="Number of iterations to run.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--profile",
|
"--profile",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
@ -168,10 +165,8 @@ if __name__ == "__main__":
|
|||||||
"--profile-result-dir",
|
"--profile-result-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help=(
|
help=("path to save the pytorch profiler output. Can be visualized "
|
||||||
"path to save the pytorch profiler output. Can be visualized "
|
"with ui.perfetto.dev or Tensorboard."),
|
||||||
"with ui.perfetto.dev or Tensorboard."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output-json",
|
"--output-json",
|
||||||
@ -182,15 +177,10 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--disable-detokenize",
|
"--disable-detokenize",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help=(
|
help=("Do not detokenize responses (i.e. do not include "
|
||||||
"Do not detokenize responses (i.e. do not include "
|
"detokenization time in the latency measurement)"),
|
||||||
"detokenization time in the latency measurement)"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser = EngineArgs.add_cli_args(parser)
|
parser = EngineArgs.add_cli_args(parser)
|
||||||
# V1 enables prefix caching by default which skews the latency
|
|
||||||
# numbers. We need to disable prefix caching by default.
|
|
||||||
parser.set_defaults(enable_prefix_caching=False)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
|||||||
@ -76,7 +76,7 @@ def repeat_prompts(prompts, repeat_count, mode: str):
|
|||||||
- 'random': Shuffle the prompts randomly after repetition.
|
- 'random': Shuffle the prompts randomly after repetition.
|
||||||
- 'tile': Repeat the entire prompt list in sequence.
|
- 'tile': Repeat the entire prompt list in sequence.
|
||||||
Example: [1, 2, 3] -> [1, 2, 3, 1, 2, 3].
|
Example: [1, 2, 3] -> [1, 2, 3, 1, 2, 3].
|
||||||
- 'interleave': Repeat each prompt consecutively before moving to
|
- 'interleave': Repeat each prompt consecutively before moving to
|
||||||
the next. Example: [1, 2, 3] -> [1, 1, 2, 2, 3, 3].
|
the next. Example: [1, 2, 3] -> [1, 1, 2, 2, 3, 3].
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -86,21 +86,20 @@ def repeat_prompts(prompts, repeat_count, mode: str):
|
|||||||
ValueError: If an invalid mode is provided.
|
ValueError: If an invalid mode is provided.
|
||||||
"""
|
"""
|
||||||
print("Repeat mode: ", mode)
|
print("Repeat mode: ", mode)
|
||||||
if mode == "random":
|
if mode == 'random':
|
||||||
repeated_prompts = prompts * repeat_count
|
repeated_prompts = prompts * repeat_count
|
||||||
random.shuffle(repeated_prompts)
|
random.shuffle(repeated_prompts)
|
||||||
return repeated_prompts
|
return repeated_prompts
|
||||||
elif mode == "tile":
|
elif mode == 'tile':
|
||||||
return prompts * repeat_count
|
return prompts * repeat_count
|
||||||
elif mode == "interleave":
|
elif mode == 'interleave':
|
||||||
repeated_prompts = []
|
repeated_prompts = []
|
||||||
for prompt in prompts:
|
for prompt in prompts:
|
||||||
repeated_prompts.extend([prompt] * repeat_count)
|
repeated_prompts.extend([prompt] * repeat_count)
|
||||||
return repeated_prompts
|
return repeated_prompts
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(f"Invalid mode: {mode}, only support "
|
||||||
f"Invalid mode: {mode}, only support 'random', 'tile', 'interleave'"
|
"'random', 'tile', 'interleave'")
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
@ -110,16 +109,16 @@ def main(args):
|
|||||||
# we append the document id at the beginning to avoid any of the document
|
# we append the document id at the beginning to avoid any of the document
|
||||||
# being the prefix of other documents
|
# being the prefix of other documents
|
||||||
prompts = [
|
prompts = [
|
||||||
str(i) + " ".join(["hi"] * args.document_length)
|
str(i) + ' '.join(['hi'] * args.document_length)
|
||||||
for i in range(args.num_documents)
|
for i in range(args.num_documents)
|
||||||
]
|
]
|
||||||
|
|
||||||
prompts = repeat_prompts(prompts, args.repeat_count, mode=args.repeat_mode)
|
prompts = repeat_prompts(prompts, args.repeat_count, mode=args.repeat_mode)
|
||||||
|
|
||||||
warmup_prompts = [
|
warmup_prompts = [
|
||||||
"This is warm up request " + str(i) + " ".join(["hi"] * args.document_length)
|
"This is warm up request " + str(i) + \
|
||||||
for i in range(args.num_documents)
|
' '.join(['hi'] * args.document_length)
|
||||||
]
|
for i in range(args.num_documents)]
|
||||||
|
|
||||||
# Create the LLM engine
|
# Create the LLM engine
|
||||||
engine_args = EngineArgs.from_cli_args(args)
|
engine_args = EngineArgs.from_cli_args(args)
|
||||||
@ -143,52 +142,42 @@ def main(args):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = FlexibleArgumentParser(
|
parser = FlexibleArgumentParser(
|
||||||
description="Benchmark the performance with or "
|
description=
|
||||||
"without automatic prefix caching."
|
'Benchmark the performance with or without automatic prefix caching.')
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--document-length",
|
'--document-length',
|
||||||
type=int,
|
type=int,
|
||||||
# Roughly the number of tokens for a system paper,
|
# Roughly the number of tokens for a system paper,
|
||||||
# excluding images
|
# excluding images
|
||||||
default=20000,
|
default=20000,
|
||||||
help="Range of input lengths for sampling prompts, "
|
help='Range of input lengths for sampling prompts,'
|
||||||
'specified as "min:max" (e.g., "128:256").',
|
'specified as "min:max" (e.g., "128:256").')
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument('--num-documents',
|
||||||
"--num-documents",
|
type=int,
|
||||||
type=int,
|
default=8,
|
||||||
default=8,
|
help='Range of input lengths for sampling prompts,'
|
||||||
help="Range of input lengths for sampling prompts, "
|
'specified as "min:max" (e.g., "128:256").')
|
||||||
'specified as "min:max" (e.g., "128:256").',
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument("--output-len", type=int, default=10)
|
parser.add_argument('--output-len', type=int, default=10)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument('--repeat-count',
|
||||||
"--repeat-count",
|
type=int,
|
||||||
type=int,
|
default=2,
|
||||||
default=2,
|
help='Number of times to repeat each prompt')
|
||||||
help="Number of times to repeat each prompt",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument("--repeat-mode",
|
||||||
"--repeat-mode",
|
type=str,
|
||||||
type=str,
|
default='random',
|
||||||
default="random",
|
help='The mode to repeat prompts. The supported '
|
||||||
help="The mode to repeat prompts. The supported "
|
'modes are "random", "tile", and "interleave". '
|
||||||
'modes are "random", "tile", and "interleave". '
|
'See repeat_prompts() in the source code for details.')
|
||||||
"See repeat_prompts() in the source code for details.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument("--shuffle-seed",
|
||||||
"--shuffle-seed",
|
type=int,
|
||||||
type=int,
|
default=0,
|
||||||
default=0,
|
help='Random seed when the repeat mode is "random"')
|
||||||
help='Random seed when the repeat mode is "random"',
|
|
||||||
)
|
|
||||||
|
|
||||||
parser = EngineArgs.add_cli_args(parser)
|
parser = EngineArgs.add_cli_args(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|||||||
@ -63,7 +63,8 @@ class Request:
|
|||||||
output_len: int
|
output_len: int
|
||||||
|
|
||||||
|
|
||||||
def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> list[int]:
|
def sample_tokens(tokenizer: PreTrainedTokenizerBase,
|
||||||
|
length: int) -> list[int]:
|
||||||
vocab = tokenizer.get_vocab()
|
vocab = tokenizer.get_vocab()
|
||||||
all_special_ids = set(tokenizer.all_special_ids)
|
all_special_ids = set(tokenizer.all_special_ids)
|
||||||
|
|
||||||
@ -90,10 +91,8 @@ def sample_requests_from_dataset(
|
|||||||
# Filter out the conversations with less than 2 turns.
|
# Filter out the conversations with less than 2 turns.
|
||||||
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
||||||
# Only keep the first two turns of each conversation.
|
# Only keep the first two turns of each conversation.
|
||||||
dataset = [
|
dataset = [(data["conversations"][0]["value"],
|
||||||
(data["conversations"][0]["value"], data["conversations"][1]["value"])
|
data["conversations"][1]["value"]) for data in dataset]
|
||||||
for data in dataset
|
|
||||||
]
|
|
||||||
|
|
||||||
# Shuffle the dataset.
|
# Shuffle the dataset.
|
||||||
random.shuffle(dataset)
|
random.shuffle(dataset)
|
||||||
@ -114,9 +113,8 @@ def sample_requests_from_dataset(
|
|||||||
completion = dataset[i][1]
|
completion = dataset[i][1]
|
||||||
completion_token_ids = tokenizer(completion).input_ids
|
completion_token_ids = tokenizer(completion).input_ids
|
||||||
prompt_len = len(prompt_token_ids)
|
prompt_len = len(prompt_token_ids)
|
||||||
output_len = (
|
output_len = (len(completion_token_ids)
|
||||||
len(completion_token_ids) if fixed_output_len is None else fixed_output_len
|
if fixed_output_len is None else fixed_output_len)
|
||||||
)
|
|
||||||
if min_len <= prompt_len <= max_len:
|
if min_len <= prompt_len <= max_len:
|
||||||
filtered_requests.append(Request(prompt, prompt_len, output_len))
|
filtered_requests.append(Request(prompt, prompt_len, output_len))
|
||||||
|
|
||||||
@ -130,27 +128,27 @@ def sample_requests_from_random(
|
|||||||
fixed_output_len: Optional[int],
|
fixed_output_len: Optional[int],
|
||||||
prefix_len: int,
|
prefix_len: int,
|
||||||
) -> list[Request]:
|
) -> list[Request]:
|
||||||
|
|
||||||
requests = []
|
requests = []
|
||||||
prefix_token_ids = sample_tokens(tokenizer, prefix_len)
|
prefix_token_ids = sample_tokens(tokenizer, prefix_len)
|
||||||
min_len, max_len = input_length_range
|
min_len, max_len = input_length_range
|
||||||
|
|
||||||
for i in range(num_requests):
|
for i in range(num_requests):
|
||||||
unique_part_token_ids = sample_tokens(
|
unique_part_token_ids = sample_tokens(
|
||||||
tokenizer, random.randint(min_len - prefix_len, max_len - prefix_len)
|
tokenizer,
|
||||||
)
|
random.randint(min_len - prefix_len, max_len - prefix_len))
|
||||||
prompt_token_ids = prefix_token_ids + unique_part_token_ids
|
prompt_token_ids = prefix_token_ids + unique_part_token_ids
|
||||||
prompt = tokenizer.decode(prompt_token_ids)
|
prompt = tokenizer.decode(prompt_token_ids)
|
||||||
prompt_len = len(prompt_token_ids)
|
prompt_len = len(prompt_token_ids)
|
||||||
assert min_len <= prompt_len <= max_len, (
|
assert (min_len <= prompt_len <= max_len
|
||||||
f"prompt_len {prompt_len} out of range {min_len}:{max_len}"
|
), f"prompt_len {prompt_len} out of range {min_len}:{max_len}"
|
||||||
)
|
|
||||||
requests.append(Request(prompt, prompt_len, fixed_output_len))
|
requests.append(Request(prompt, prompt_len, fixed_output_len))
|
||||||
return requests
|
return requests
|
||||||
|
|
||||||
|
|
||||||
def repeat_and_sort_requests(
|
def repeat_and_sort_requests(requests: list[Request],
|
||||||
requests: list[Request], repeat_count: int, sort: bool = False
|
repeat_count: int,
|
||||||
) -> list[str]:
|
sort: bool = False) -> list[str]:
|
||||||
repeated_requests = requests * repeat_count
|
repeated_requests = requests * repeat_count
|
||||||
if sort:
|
if sort:
|
||||||
repeated_requests.sort(key=lambda x: x[1])
|
repeated_requests.sort(key=lambda x: x[1])
|
||||||
@ -161,14 +159,14 @@ def repeat_and_sort_requests(
|
|||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
tokenizer = get_tokenizer(args.model, trust_remote_code=True)
|
tokenizer = get_tokenizer(args.model, trust_remote_code=True)
|
||||||
input_length_range = tuple(map(int, args.input_length_range.split(":")))
|
input_length_range = tuple(map(int, args.input_length_range.split(':')))
|
||||||
random.seed(args.seed)
|
random.seed(args.seed)
|
||||||
if args.dataset_path is not None:
|
if args.dataset_path is not None:
|
||||||
if args.prefix_len > 0:
|
if args.prefix_len > 0:
|
||||||
raise ValueError(
|
raise ValueError("prefix-len is not supported when "
|
||||||
"prefix-len is not supported when dataset-path is provided."
|
"dataset-path is provided.")
|
||||||
)
|
print(f"Start to sample {args.num_prompts} prompts "
|
||||||
print(f"Start to sample {args.num_prompts} prompts from {args.dataset_path}")
|
f"from {args.dataset_path}")
|
||||||
filtered_requests = sample_requests_from_dataset(
|
filtered_requests = sample_requests_from_dataset(
|
||||||
dataset_path=args.dataset_path,
|
dataset_path=args.dataset_path,
|
||||||
num_requests=args.num_prompts,
|
num_requests=args.num_prompts,
|
||||||
@ -198,16 +196,14 @@ def main(args):
|
|||||||
|
|
||||||
llm = LLM(**dataclasses.asdict(engine_args))
|
llm = LLM(**dataclasses.asdict(engine_args))
|
||||||
|
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(temperature=0,
|
||||||
temperature=0,
|
max_tokens=args.output_len,
|
||||||
max_tokens=args.output_len,
|
detokenize=not args.disable_detokenize)
|
||||||
detokenize=not args.disable_detokenize,
|
|
||||||
)
|
|
||||||
|
|
||||||
print("Testing filtered requests")
|
print("Testing filtered requests")
|
||||||
prompts = repeat_and_sort_requests(
|
prompts = repeat_and_sort_requests(filtered_requests,
|
||||||
filtered_requests, repeat_count=args.repeat_count, sort=args.sort
|
repeat_count=args.repeat_count,
|
||||||
)
|
sort=args.sort)
|
||||||
|
|
||||||
print("------start generating------")
|
print("------start generating------")
|
||||||
test_prefix(
|
test_prefix(
|
||||||
@ -219,35 +215,29 @@ def main(args):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = FlexibleArgumentParser(
|
parser = FlexibleArgumentParser(
|
||||||
description="Benchmark the performance with or without "
|
description=
|
||||||
"automatic prefix caching."
|
'Benchmark the performance with or without automatic prefix caching.')
|
||||||
)
|
parser.add_argument("--dataset-path",
|
||||||
parser.add_argument(
|
type=str,
|
||||||
"--dataset-path", type=str, default=None, help="Path to the dataset."
|
default=None,
|
||||||
)
|
help="Path to the dataset.")
|
||||||
parser.add_argument("--output-len", type=int, default=10)
|
parser.add_argument('--output-len', type=int, default=10)
|
||||||
parser.add_argument(
|
parser.add_argument('--num-prompts',
|
||||||
"--num-prompts",
|
type=int,
|
||||||
type=int,
|
required=True,
|
||||||
required=True,
|
help="Number of the prompts sampled from dataset")
|
||||||
help="Number of the prompts sampled from dataset",
|
parser.add_argument('--repeat-count',
|
||||||
)
|
type=int,
|
||||||
parser.add_argument(
|
default=1,
|
||||||
"--repeat-count",
|
help='Number of times to repeat each prompt')
|
||||||
type=int,
|
parser.add_argument('--sort',
|
||||||
default=1,
|
action='store_true',
|
||||||
help="Number of times to repeat each prompt",
|
help='Sort prompts by input length')
|
||||||
)
|
parser.add_argument('--input-length-range',
|
||||||
parser.add_argument(
|
type=str,
|
||||||
"--sort", action="store_true", help="Sort prompts by input length"
|
required=True,
|
||||||
)
|
help='Range of input lengths for sampling prompts,'
|
||||||
parser.add_argument(
|
'specified as "min:max" (e.g., "128:256").')
|
||||||
"--input-length-range",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Range of input lengths for sampling prompts,"
|
|
||||||
'specified as "min:max" (e.g., "128:256").',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prefix-len",
|
"--prefix-len",
|
||||||
type=int,
|
type=int,
|
||||||
@ -258,12 +248,10 @@ if __name__ == "__main__":
|
|||||||
"when dataset-path is not provided.",
|
"when dataset-path is not provided.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--disable-detokenize",
|
'--disable-detokenize',
|
||||||
action="store_true",
|
action='store_true',
|
||||||
help=(
|
help=("Do not detokenize responses (i.e. do not include "
|
||||||
"Do not detokenize responses (i.e. do not include "
|
"detokenization time in the latency measurement)"),
|
||||||
"detokenization time in the latency measurement)"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser = EngineArgs.add_cli_args(parser)
|
parser = EngineArgs.add_cli_args(parser)
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
"""Benchmark offline prioritization."""
|
"""Benchmark offline prioritization."""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import json
|
import json
|
||||||
@ -14,7 +13,7 @@ from vllm.engine.arg_utils import EngineArgs
|
|||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
|
||||||
# Select a equi-probable random priority
|
#Select a equi-probable random priority
|
||||||
def get_random_flag():
|
def get_random_flag():
|
||||||
return 0 if random.random() < 0.5 else 1
|
return 0 if random.random() < 0.5 else 1
|
||||||
|
|
||||||
@ -34,10 +33,8 @@ def sample_requests(
|
|||||||
# Filter out the conversations with less than 2 turns.
|
# Filter out the conversations with less than 2 turns.
|
||||||
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
||||||
# Only keep the first two turns of each conversation.
|
# Only keep the first two turns of each conversation.
|
||||||
dataset = [
|
dataset = [(data["conversations"][0]["value"],
|
||||||
(data["conversations"][0]["value"], data["conversations"][1]["value"])
|
data["conversations"][1]["value"]) for data in dataset]
|
||||||
for data in dataset
|
|
||||||
]
|
|
||||||
|
|
||||||
# Shuffle the dataset.
|
# Shuffle the dataset.
|
||||||
random.shuffle(dataset)
|
random.shuffle(dataset)
|
||||||
@ -54,9 +51,8 @@ def sample_requests(
|
|||||||
completion = dataset[i][1]
|
completion = dataset[i][1]
|
||||||
completion_token_ids = tokenizer(completion).input_ids
|
completion_token_ids = tokenizer(completion).input_ids
|
||||||
prompt_len = len(prompt_token_ids)
|
prompt_len = len(prompt_token_ids)
|
||||||
output_len = (
|
output_len = len(completion_token_ids
|
||||||
len(completion_token_ids) if fixed_output_len is None else fixed_output_len
|
) if fixed_output_len is None else fixed_output_len
|
||||||
)
|
|
||||||
if prompt_len < 4 or output_len < 4:
|
if prompt_len < 4 or output_len < 4:
|
||||||
# Prune too short sequences.
|
# Prune too short sequences.
|
||||||
continue
|
continue
|
||||||
@ -78,16 +74,13 @@ def run_vllm(
|
|||||||
disable_detokenize: bool = False,
|
disable_detokenize: bool = False,
|
||||||
) -> float:
|
) -> float:
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
llm = LLM(**dataclasses.asdict(engine_args))
|
llm = LLM(**dataclasses.asdict(engine_args))
|
||||||
|
|
||||||
assert all(
|
assert all(
|
||||||
llm.llm_engine.model_config.max_model_len >= (request[1] + request[2])
|
llm.llm_engine.model_config.max_model_len >= (request[1] + request[2])
|
||||||
for request in requests
|
for request in requests), (
|
||||||
), (
|
"Please ensure that max_model_len is greater than the sum of"
|
||||||
"Please ensure that max_model_len is greater than the sum of"
|
" input_len and output_len for all requests.")
|
||||||
" input_len and output_len for all requests."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add the requests to the engine.
|
# Add the requests to the engine.
|
||||||
prompts = []
|
prompts = []
|
||||||
@ -104,8 +97,7 @@ def run_vllm(
|
|||||||
ignore_eos=True,
|
ignore_eos=True,
|
||||||
max_tokens=output_len,
|
max_tokens=output_len,
|
||||||
detokenize=not disable_detokenize,
|
detokenize=not disable_detokenize,
|
||||||
)
|
))
|
||||||
)
|
|
||||||
|
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
llm.generate(prompts, sampling_params, priority=priority, use_tqdm=True)
|
llm.generate(prompts, sampling_params, priority=priority, use_tqdm=True)
|
||||||
@ -119,33 +111,26 @@ def main(args: argparse.Namespace):
|
|||||||
|
|
||||||
# Sample the requests.
|
# Sample the requests.
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
args.tokenizer, trust_remote_code=args.trust_remote_code
|
args.tokenizer, trust_remote_code=args.trust_remote_code)
|
||||||
)
|
|
||||||
if args.dataset is None:
|
if args.dataset is None:
|
||||||
# Synthesize a prompt with the given input length.
|
# Synthesize a prompt with the given input length.
|
||||||
prompt = "hi" * (args.input_len - 1)
|
prompt = "hi" * (args.input_len - 1)
|
||||||
requests = [
|
requests = [(prompt, args.input_len, args.output_len,
|
||||||
(prompt, args.input_len, args.output_len, get_random_flag())
|
get_random_flag()) for _ in range(args.num_prompts)]
|
||||||
for _ in range(args.num_prompts)
|
|
||||||
]
|
|
||||||
else:
|
else:
|
||||||
requests = sample_requests(
|
requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
|
||||||
args.dataset, args.num_prompts, tokenizer, args.output_len
|
args.output_len)
|
||||||
)
|
|
||||||
|
|
||||||
if args.backend == "vllm":
|
if args.backend == "vllm":
|
||||||
elapsed_time = run_vllm(
|
elapsed_time = run_vllm(requests, args.n,
|
||||||
requests, args.n, EngineArgs.from_cli_args(args), args.disable_detokenize
|
EngineArgs.from_cli_args(args),
|
||||||
)
|
args.disable_detokenize)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown backend: {args.backend}")
|
raise ValueError(f"Unknown backend: {args.backend}")
|
||||||
total_num_tokens = sum(
|
total_num_tokens = sum(prompt_len + output_len
|
||||||
prompt_len + output_len for _, prompt_len, output_len, priority in requests
|
for _, prompt_len, output_len, priority in requests)
|
||||||
)
|
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
||||||
print(
|
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
|
||||||
f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
|
||||||
f"{total_num_tokens / elapsed_time:.2f} tokens/s"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Output JSON results if specified
|
# Output JSON results if specified
|
||||||
if args.output_json:
|
if args.output_json:
|
||||||
@ -162,44 +147,41 @@ def main(args: argparse.Namespace):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = FlexibleArgumentParser(description="Benchmark the throughput.")
|
parser = FlexibleArgumentParser(description="Benchmark the throughput.")
|
||||||
|
parser.add_argument("--backend",
|
||||||
|
type=str,
|
||||||
|
choices=["vllm", "hf", "mii"],
|
||||||
|
default="vllm")
|
||||||
|
parser.add_argument("--dataset",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to the dataset.")
|
||||||
|
parser.add_argument("--input-len",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Input prompt length for each request")
|
||||||
|
parser.add_argument("--output-len",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Output length for each request. Overrides the "
|
||||||
|
"output length from the dataset.")
|
||||||
|
parser.add_argument("--n",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of generated sequences per prompt.")
|
||||||
|
parser.add_argument("--num-prompts",
|
||||||
|
type=int,
|
||||||
|
default=200,
|
||||||
|
help="Number of prompts to process.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--backend", type=str, choices=["vllm", "hf", "mii"], default="vllm"
|
'--output-json',
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--dataset", type=str, default=None, help="Path to the dataset."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--input-len",
|
|
||||||
type=int,
|
|
||||||
default=None,
|
|
||||||
help="Input prompt length for each request",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--output-len",
|
|
||||||
type=int,
|
|
||||||
default=None,
|
|
||||||
help="Output length for each request. Overrides the "
|
|
||||||
"output length from the dataset.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--n", type=int, default=1, help="Number of generated sequences per prompt."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--num-prompts", type=int, default=200, help="Number of prompts to process."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--output-json",
|
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="Path to save the throughput results in JSON format.",
|
help='Path to save the throughput results in JSON format.')
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--disable-detokenize",
|
'--disable-detokenize',
|
||||||
action="store_true",
|
action='store_true',
|
||||||
help=(
|
help=("Do not detokenize responses (i.e. do not include "
|
||||||
"Do not detokenize responses (i.e. do not include "
|
"detokenization time in the latency measurement)"),
|
||||||
"detokenization time in the latency measurement)"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser = EngineArgs.add_cli_args(parser)
|
parser = EngineArgs.add_cli_args(parser)
|
||||||
|
|||||||
@ -20,7 +20,6 @@ On the client side, run:
|
|||||||
--endpoint /generate_stream
|
--endpoint /generate_stream
|
||||||
to the end of the command above.
|
to the end of the command above.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import gc
|
import gc
|
||||||
@ -35,16 +34,12 @@ from datetime import datetime
|
|||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from backend_request_func import (ASYNC_REQUEST_FUNCS,
|
||||||
|
OPENAI_COMPATIBLE_BACKENDS, RequestFuncInput,
|
||||||
|
RequestFuncOutput)
|
||||||
from tqdm.asyncio import tqdm
|
from tqdm.asyncio import tqdm
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from backend_request_func import (
|
|
||||||
ASYNC_REQUEST_FUNCS,
|
|
||||||
OPENAI_COMPATIBLE_BACKENDS,
|
|
||||||
RequestFuncInput,
|
|
||||||
RequestFuncOutput,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -55,21 +50,12 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
from argparse import ArgumentParser as FlexibleArgumentParser
|
from argparse import ArgumentParser as FlexibleArgumentParser
|
||||||
|
|
||||||
from benchmark_dataset import (
|
from benchmark_dataset import (AIMODataset, ASRDataset, BurstGPTDataset,
|
||||||
AIMODataset,
|
ConversationDataset, HuggingFaceDataset,
|
||||||
ASRDataset,
|
InstructCoderDataset, MTBenchDataset,
|
||||||
BurstGPTDataset,
|
NextEditPredictionDataset, RandomDataset,
|
||||||
ConversationDataset,
|
SampleRequest, ShareGPTDataset, SonnetDataset,
|
||||||
HuggingFaceDataset,
|
VisionArenaDataset)
|
||||||
InstructCoderDataset,
|
|
||||||
MTBenchDataset,
|
|
||||||
NextEditPredictionDataset,
|
|
||||||
RandomDataset,
|
|
||||||
SampleRequest,
|
|
||||||
ShareGPTDataset,
|
|
||||||
SonnetDataset,
|
|
||||||
VisionArenaDataset,
|
|
||||||
)
|
|
||||||
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||||
|
|
||||||
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
|
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
|
||||||
@ -132,8 +118,7 @@ async def get_request(
|
|||||||
|
|
||||||
# Calculate scale parameter theta to maintain the desired request_rate.
|
# Calculate scale parameter theta to maintain the desired request_rate.
|
||||||
assert burstiness > 0, (
|
assert burstiness > 0, (
|
||||||
f"A positive burstiness factor is expected, but given {burstiness}."
|
f"A positive burstiness factor is expected, but given {burstiness}.")
|
||||||
)
|
|
||||||
theta = 1.0 / (request_rate * burstiness)
|
theta = 1.0 / (request_rate * burstiness)
|
||||||
|
|
||||||
for request in input_requests:
|
for request in input_requests:
|
||||||
@ -179,10 +164,8 @@ def calculate_metrics(
|
|||||||
# bundled together
|
# bundled together
|
||||||
# Note : this may inflate the output token count slightly
|
# Note : this may inflate the output token count slightly
|
||||||
output_len = len(
|
output_len = len(
|
||||||
tokenizer(
|
tokenizer(outputs[i].generated_text,
|
||||||
outputs[i].generated_text, add_special_tokens=False
|
add_special_tokens=False).input_ids)
|
||||||
).input_ids
|
|
||||||
)
|
|
||||||
actual_output_lens.append(output_len)
|
actual_output_lens.append(output_len)
|
||||||
total_input += input_requests[i].prompt_len
|
total_input += input_requests[i].prompt_len
|
||||||
tpot = 0
|
tpot = 0
|
||||||
@ -205,19 +188,16 @@ def calculate_metrics(
|
|||||||
|
|
||||||
if "ttft" in goodput_config_dict:
|
if "ttft" in goodput_config_dict:
|
||||||
valid_metrics.append(ttfts)
|
valid_metrics.append(ttfts)
|
||||||
slo_values.append(
|
slo_values.append(goodput_config_dict["ttft"] /
|
||||||
goodput_config_dict["ttft"] / MILLISECONDS_TO_SECONDS_CONVERSION
|
MILLISECONDS_TO_SECONDS_CONVERSION)
|
||||||
)
|
|
||||||
if "tpot" in goodput_config_dict:
|
if "tpot" in goodput_config_dict:
|
||||||
valid_metrics.append(all_tpots)
|
valid_metrics.append(all_tpots)
|
||||||
slo_values.append(
|
slo_values.append(goodput_config_dict["tpot"] /
|
||||||
goodput_config_dict["tpot"] / MILLISECONDS_TO_SECONDS_CONVERSION
|
MILLISECONDS_TO_SECONDS_CONVERSION)
|
||||||
)
|
|
||||||
if "e2el" in goodput_config_dict:
|
if "e2el" in goodput_config_dict:
|
||||||
valid_metrics.append(e2els)
|
valid_metrics.append(e2els)
|
||||||
slo_values.append(
|
slo_values.append(goodput_config_dict["e2el"] /
|
||||||
goodput_config_dict["e2el"] / MILLISECONDS_TO_SECONDS_CONVERSION
|
MILLISECONDS_TO_SECONDS_CONVERSION)
|
||||||
)
|
|
||||||
|
|
||||||
for req_metric in zip(*valid_metrics):
|
for req_metric in zip(*valid_metrics):
|
||||||
is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)])
|
is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)])
|
||||||
@ -228,8 +208,7 @@ def calculate_metrics(
|
|||||||
warnings.warn(
|
warnings.warn(
|
||||||
"All requests failed. This is likely due to a misconfiguration "
|
"All requests failed. This is likely due to a misconfiguration "
|
||||||
"on the benchmark arguments.",
|
"on the benchmark arguments.",
|
||||||
stacklevel=2,
|
stacklevel=2)
|
||||||
)
|
|
||||||
metrics = BenchmarkMetrics(
|
metrics = BenchmarkMetrics(
|
||||||
completed=completed,
|
completed=completed,
|
||||||
total_input=total_input,
|
total_input=total_input,
|
||||||
@ -238,31 +217,27 @@ def calculate_metrics(
|
|||||||
request_goodput=good_completed / dur_s,
|
request_goodput=good_completed / dur_s,
|
||||||
output_throughput=sum(actual_output_lens) / dur_s,
|
output_throughput=sum(actual_output_lens) / dur_s,
|
||||||
total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s,
|
total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s,
|
||||||
mean_ttft_ms=np.mean(ttfts or 0)
|
mean_ttft_ms=np.mean(ttfts or 0) *
|
||||||
* 1000, # ttfts is empty if streaming is not supported by backend
|
1000, # ttfts is empty if streaming is not supported by backend
|
||||||
std_ttft_ms=np.std(ttfts or 0) * 1000,
|
std_ttft_ms=np.std(ttfts or 0) * 1000,
|
||||||
median_ttft_ms=np.median(ttfts or 0) * 1000,
|
median_ttft_ms=np.median(ttfts or 0) * 1000,
|
||||||
percentiles_ttft_ms=[
|
percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000)
|
||||||
(p, np.percentile(ttfts or 0, p) * 1000) for p in selected_percentiles
|
for p in selected_percentiles],
|
||||||
],
|
|
||||||
mean_tpot_ms=np.mean(tpots or 0) * 1000,
|
mean_tpot_ms=np.mean(tpots or 0) * 1000,
|
||||||
std_tpot_ms=np.std(tpots or 0) * 1000,
|
std_tpot_ms=np.std(tpots or 0) * 1000,
|
||||||
median_tpot_ms=np.median(tpots or 0) * 1000,
|
median_tpot_ms=np.median(tpots or 0) * 1000,
|
||||||
percentiles_tpot_ms=[
|
percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000)
|
||||||
(p, np.percentile(tpots or 0, p) * 1000) for p in selected_percentiles
|
for p in selected_percentiles],
|
||||||
],
|
|
||||||
mean_itl_ms=np.mean(itls or 0) * 1000,
|
mean_itl_ms=np.mean(itls or 0) * 1000,
|
||||||
std_itl_ms=np.std(itls or 0) * 1000,
|
std_itl_ms=np.std(itls or 0) * 1000,
|
||||||
median_itl_ms=np.median(itls or 0) * 1000,
|
median_itl_ms=np.median(itls or 0) * 1000,
|
||||||
percentiles_itl_ms=[
|
percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000)
|
||||||
(p, np.percentile(itls or 0, p) * 1000) for p in selected_percentiles
|
for p in selected_percentiles],
|
||||||
],
|
|
||||||
mean_e2el_ms=np.mean(e2els or 0) * 1000,
|
mean_e2el_ms=np.mean(e2els or 0) * 1000,
|
||||||
std_e2el_ms=np.std(e2els or 0) * 1000,
|
std_e2el_ms=np.std(e2els or 0) * 1000,
|
||||||
median_e2el_ms=np.median(e2els or 0) * 1000,
|
median_e2el_ms=np.median(e2els or 0) * 1000,
|
||||||
percentiles_e2el_ms=[
|
percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000)
|
||||||
(p, np.percentile(e2els or 0, p) * 1000) for p in selected_percentiles
|
for p in selected_percentiles],
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return metrics, actual_output_lens
|
return metrics, actual_output_lens
|
||||||
@ -295,12 +270,10 @@ async def benchmark(
|
|||||||
raise ValueError(f"Unknown backend: {backend}")
|
raise ValueError(f"Unknown backend: {backend}")
|
||||||
|
|
||||||
print("Starting initial single prompt test run...")
|
print("Starting initial single prompt test run...")
|
||||||
test_prompt, test_prompt_len, test_output_len, test_mm_content = (
|
test_prompt, test_prompt_len, test_output_len, test_mm_content = \
|
||||||
input_requests[0].prompt,
|
input_requests[0].prompt, input_requests[0].prompt_len, \
|
||||||
input_requests[0].prompt_len,
|
input_requests[0].expected_output_len, \
|
||||||
input_requests[0].expected_output_len,
|
input_requests[0].multi_modal_data
|
||||||
input_requests[0].multi_modal_data,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert test_mm_content is None or isinstance(test_mm_content, dict)
|
assert test_mm_content is None or isinstance(test_mm_content, dict)
|
||||||
test_input = RequestFuncInput(
|
test_input = RequestFuncInput(
|
||||||
@ -320,36 +293,36 @@ async def benchmark(
|
|||||||
if not test_output.success:
|
if not test_output.success:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Initial test run failed - Please make sure benchmark arguments "
|
"Initial test run failed - Please make sure benchmark arguments "
|
||||||
f"are correctly specified. Error: {test_output.error}"
|
f"are correctly specified. Error: {test_output.error}")
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
print("Initial test run completed. Starting main benchmark run...")
|
print("Initial test run completed. Starting main benchmark run...")
|
||||||
|
|
||||||
if lora_modules:
|
if lora_modules:
|
||||||
# For each input request, choose a LoRA module at random.
|
# For each input request, choose a LoRA module at random.
|
||||||
lora_modules = iter(
|
lora_modules = iter(
|
||||||
[random.choice(lora_modules) for _ in range(len(input_requests))]
|
[random.choice(lora_modules) \
|
||||||
)
|
for _ in range(len(input_requests))])
|
||||||
|
|
||||||
if profile:
|
if profile:
|
||||||
print("Starting profiler...")
|
print("Starting profiler...")
|
||||||
profile_input = RequestFuncInput(
|
profile_input = RequestFuncInput(model=model_id,
|
||||||
model=model_id,
|
model_name=model_name,
|
||||||
model_name=model_name,
|
prompt=test_prompt,
|
||||||
prompt=test_prompt,
|
api_url=base_url + "/start_profile",
|
||||||
api_url=base_url + "/start_profile",
|
prompt_len=test_prompt_len,
|
||||||
prompt_len=test_prompt_len,
|
output_len=test_output_len,
|
||||||
output_len=test_output_len,
|
logprobs=logprobs,
|
||||||
logprobs=logprobs,
|
multi_modal_content=test_mm_content,
|
||||||
multi_modal_content=test_mm_content,
|
ignore_eos=ignore_eos,
|
||||||
ignore_eos=ignore_eos,
|
extra_body=extra_body)
|
||||||
extra_body=extra_body,
|
|
||||||
)
|
|
||||||
profile_output = await request_func(request_func_input=profile_input)
|
profile_output = await request_func(request_func_input=profile_input)
|
||||||
if profile_output.success:
|
if profile_output.success:
|
||||||
print("Profiler started")
|
print("Profiler started")
|
||||||
|
|
||||||
distribution = "Poisson process" if burstiness == 1.0 else "Gamma distribution"
|
if burstiness == 1.0:
|
||||||
|
distribution = "Poisson process"
|
||||||
|
else:
|
||||||
|
distribution = "Gamma distribution"
|
||||||
|
|
||||||
print(f"Traffic request rate: {request_rate}")
|
print(f"Traffic request rate: {request_rate}")
|
||||||
print(f"Burstiness factor: {burstiness} ({distribution})")
|
print(f"Burstiness factor: {burstiness} ({distribution})")
|
||||||
@ -361,45 +334,42 @@ async def benchmark(
|
|||||||
# and it will simplify the code in limited_request_func.
|
# and it will simplify the code in limited_request_func.
|
||||||
# semaphore = (asyncio.Semaphore(max_concurrency)
|
# semaphore = (asyncio.Semaphore(max_concurrency)
|
||||||
# if max_concurrency else contextlib.nullcontext())
|
# if max_concurrency else contextlib.nullcontext())
|
||||||
semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
|
semaphore = (asyncio.Semaphore(max_concurrency)
|
||||||
|
if max_concurrency else None)
|
||||||
|
|
||||||
async def limited_request_func(request_func_input, pbar):
|
async def limited_request_func(request_func_input, pbar):
|
||||||
if semaphore is None:
|
if semaphore is None:
|
||||||
return await request_func(request_func_input=request_func_input, pbar=pbar)
|
return await request_func(request_func_input=request_func_input,
|
||||||
|
pbar=pbar)
|
||||||
async with semaphore:
|
async with semaphore:
|
||||||
return await request_func(request_func_input=request_func_input, pbar=pbar)
|
return await request_func(request_func_input=request_func_input,
|
||||||
|
pbar=pbar)
|
||||||
|
|
||||||
benchmark_start_time = time.perf_counter()
|
benchmark_start_time = time.perf_counter()
|
||||||
tasks: list[asyncio.Task] = []
|
tasks: list[asyncio.Task] = []
|
||||||
async for request in get_request(input_requests, request_rate, burstiness):
|
async for request in get_request(input_requests, request_rate, burstiness):
|
||||||
prompt, prompt_len, output_len, mm_content = (
|
prompt, prompt_len, output_len, mm_content = request.prompt, \
|
||||||
request.prompt,
|
request.prompt_len, request.expected_output_len, \
|
||||||
request.prompt_len,
|
request.multi_modal_data
|
||||||
request.expected_output_len,
|
|
||||||
request.multi_modal_data,
|
|
||||||
)
|
|
||||||
req_model_id, req_model_name = model_id, model_name
|
req_model_id, req_model_name = model_id, model_name
|
||||||
if lora_modules:
|
if lora_modules:
|
||||||
req_lora_module = next(lora_modules)
|
req_lora_module = next(lora_modules)
|
||||||
req_model_id, req_model_name = req_lora_module, req_lora_module
|
req_model_id, req_model_name = req_lora_module, req_lora_module
|
||||||
|
|
||||||
request_func_input = RequestFuncInput(
|
request_func_input = RequestFuncInput(model=req_model_id,
|
||||||
model=req_model_id,
|
model_name=req_model_name,
|
||||||
model_name=req_model_name,
|
prompt=prompt,
|
||||||
prompt=prompt,
|
api_url=api_url,
|
||||||
api_url=api_url,
|
prompt_len=prompt_len,
|
||||||
prompt_len=prompt_len,
|
output_len=output_len,
|
||||||
output_len=output_len,
|
logprobs=logprobs,
|
||||||
logprobs=logprobs,
|
multi_modal_content=mm_content,
|
||||||
multi_modal_content=mm_content,
|
ignore_eos=ignore_eos,
|
||||||
ignore_eos=ignore_eos,
|
extra_body=extra_body)
|
||||||
extra_body=extra_body,
|
|
||||||
)
|
|
||||||
tasks.append(
|
tasks.append(
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
limited_request_func(request_func_input=request_func_input, pbar=pbar)
|
limited_request_func(request_func_input=request_func_input,
|
||||||
)
|
pbar=pbar)))
|
||||||
)
|
|
||||||
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
|
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
if profile:
|
if profile:
|
||||||
@ -431,32 +401,22 @@ async def benchmark(
|
|||||||
goodput_config_dict=goodput_config_dict,
|
goodput_config_dict=goodput_config_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
print("{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
|
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
|
||||||
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
|
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
|
||||||
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
|
print("{:<40} {:<10.2f}".format("Benchmark duration (s):",
|
||||||
|
benchmark_duration))
|
||||||
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
|
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
|
||||||
print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
|
print("{:<40} {:<10}".format("Total generated tokens:",
|
||||||
print(
|
metrics.total_output))
|
||||||
"{:<40} {:<10.2f}".format(
|
print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
|
||||||
"Request throughput (req/s):", metrics.request_throughput
|
metrics.request_throughput))
|
||||||
)
|
|
||||||
)
|
|
||||||
if goodput_config_dict:
|
if goodput_config_dict:
|
||||||
print(
|
print("{:<40} {:<10.2f}".format("Request goodput (req/s):",
|
||||||
"{:<40} {:<10.2f}".format(
|
metrics.request_goodput))
|
||||||
"Request goodput (req/s):", metrics.request_goodput
|
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
|
||||||
)
|
metrics.output_throughput))
|
||||||
)
|
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
|
||||||
print(
|
metrics.total_token_throughput))
|
||||||
"{:<40} {:<10.2f}".format(
|
|
||||||
"Output token throughput (tok/s):", metrics.output_throughput
|
|
||||||
)
|
|
||||||
)
|
|
||||||
print(
|
|
||||||
"{:<40} {:<10.2f}".format(
|
|
||||||
"Total Token throughput (tok/s):", metrics.total_token_throughput
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"duration": benchmark_duration,
|
"duration": benchmark_duration,
|
||||||
@ -464,7 +424,8 @@ async def benchmark(
|
|||||||
"total_input_tokens": metrics.total_input,
|
"total_input_tokens": metrics.total_input,
|
||||||
"total_output_tokens": metrics.total_output,
|
"total_output_tokens": metrics.total_output,
|
||||||
"request_throughput": metrics.request_throughput,
|
"request_throughput": metrics.request_throughput,
|
||||||
"request_goodput:": metrics.request_goodput if goodput_config_dict else None,
|
"request_goodput:":
|
||||||
|
metrics.request_goodput if goodput_config_dict else None,
|
||||||
"output_throughput": metrics.output_throughput,
|
"output_throughput": metrics.output_throughput,
|
||||||
"total_token_throughput": metrics.total_token_throughput,
|
"total_token_throughput": metrics.total_token_throughput,
|
||||||
"input_lens": [output.prompt_len for output in outputs],
|
"input_lens": [output.prompt_len for output in outputs],
|
||||||
@ -487,35 +448,29 @@ async def benchmark(
|
|||||||
# metric.
|
# metric.
|
||||||
if metric_attribute_name not in selected_percentile_metrics:
|
if metric_attribute_name not in selected_percentile_metrics:
|
||||||
return
|
return
|
||||||
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
|
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
|
||||||
print(
|
print("{:<40} {:<10.2f}".format(
|
||||||
"{:<40} {:<10.2f}".format(
|
f"Mean {metric_name} (ms):",
|
||||||
f"Mean {metric_name} (ms):",
|
getattr(metrics, f"mean_{metric_attribute_name}_ms")))
|
||||||
getattr(metrics, f"mean_{metric_attribute_name}_ms"),
|
print("{:<40} {:<10.2f}".format(
|
||||||
)
|
f"Median {metric_name} (ms):",
|
||||||
)
|
getattr(metrics, f"median_{metric_attribute_name}_ms")))
|
||||||
print(
|
|
||||||
"{:<40} {:<10.2f}".format(
|
|
||||||
f"Median {metric_name} (ms):",
|
|
||||||
getattr(metrics, f"median_{metric_attribute_name}_ms"),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
result[f"mean_{metric_attribute_name}_ms"] = getattr(
|
result[f"mean_{metric_attribute_name}_ms"] = getattr(
|
||||||
metrics, f"mean_{metric_attribute_name}_ms"
|
metrics, f"mean_{metric_attribute_name}_ms")
|
||||||
)
|
|
||||||
result[f"median_{metric_attribute_name}_ms"] = getattr(
|
result[f"median_{metric_attribute_name}_ms"] = getattr(
|
||||||
metrics, f"median_{metric_attribute_name}_ms"
|
metrics, f"median_{metric_attribute_name}_ms")
|
||||||
)
|
|
||||||
result[f"std_{metric_attribute_name}_ms"] = getattr(
|
result[f"std_{metric_attribute_name}_ms"] = getattr(
|
||||||
metrics, f"std_{metric_attribute_name}_ms"
|
metrics, f"std_{metric_attribute_name}_ms")
|
||||||
)
|
for p, value in getattr(metrics,
|
||||||
for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}_ms"):
|
f"percentiles_{metric_attribute_name}_ms"):
|
||||||
p_word = str(int(p)) if int(p) == p else str(p)
|
p_word = str(int(p)) if int(p) == p else str(p)
|
||||||
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value))
|
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):",
|
||||||
|
value))
|
||||||
result[f"p{p_word}_{metric_attribute_name}_ms"] = value
|
result[f"p{p_word}_{metric_attribute_name}_ms"] = value
|
||||||
|
|
||||||
process_one_metric("ttft", "TTFT", "Time to First Token")
|
process_one_metric("ttft", "TTFT", "Time to First Token")
|
||||||
process_one_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)")
|
process_one_metric("tpot", "TPOT",
|
||||||
|
"Time per Output Token (excl. 1st token)")
|
||||||
process_one_metric("itl", "ITL", "Inter-token Latency")
|
process_one_metric("itl", "ITL", "Inter-token Latency")
|
||||||
process_one_metric("e2el", "E2EL", "End-to-end Latency")
|
process_one_metric("e2el", "E2EL", "End-to-end Latency")
|
||||||
|
|
||||||
@ -535,14 +490,12 @@ def check_goodput_args(args):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid metric name found, {slo_name}: {slo_val}. "
|
f"Invalid metric name found, {slo_name}: {slo_val}. "
|
||||||
"The service level objective name should be one of "
|
"The service level objective name should be one of "
|
||||||
f"{str(VALID_NAMES)}. "
|
f"{str(VALID_NAMES)}. ")
|
||||||
)
|
|
||||||
if slo_val < 0:
|
if slo_val < 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid value found, {slo_name}: {slo_val}. "
|
f"Invalid value found, {slo_name}: {slo_val}. "
|
||||||
"The service level objective value should be "
|
"The service level objective value should be "
|
||||||
"non-negative."
|
"non-negative.")
|
||||||
)
|
|
||||||
return goodput_config_dict
|
return goodput_config_dict
|
||||||
|
|
||||||
|
|
||||||
@ -555,42 +508,31 @@ def parse_goodput(slo_pairs):
|
|||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
raise argparse.ArgumentTypeError(
|
raise argparse.ArgumentTypeError(
|
||||||
"Invalid format found for service level objectives. "
|
"Invalid format found for service level objectives. "
|
||||||
'Specify service level objectives for goodput as "KEY:VALUE" '
|
"Specify service level objectives for goodput as \"KEY:VALUE\" "
|
||||||
"pairs, where the key is a metric name, and the value is a "
|
"pairs, where the key is a metric name, and the value is a "
|
||||||
"number in milliseconds."
|
"number in milliseconds.") from err
|
||||||
) from err
|
|
||||||
return goodput_config_dict
|
return goodput_config_dict
|
||||||
|
|
||||||
|
|
||||||
def save_to_pytorch_benchmark_format(
|
def save_to_pytorch_benchmark_format(args: argparse.Namespace,
|
||||||
args: argparse.Namespace, results: dict[str, Any], file_name: str
|
results: dict[str, Any],
|
||||||
) -> None:
|
file_name: str) -> None:
|
||||||
metrics = [
|
metrics = [
|
||||||
"median_ttft_ms",
|
"median_ttft_ms", "mean_ttft_ms", "std_ttft_ms", "p99_ttft_ms",
|
||||||
"mean_ttft_ms",
|
"mean_tpot_ms", "median_tpot_ms", "std_tpot_ms", "p99_tpot_ms",
|
||||||
"std_ttft_ms",
|
"median_itl_ms", "mean_itl_ms", "std_itl_ms", "p99_itl_ms"
|
||||||
"p99_ttft_ms",
|
|
||||||
"mean_tpot_ms",
|
|
||||||
"median_tpot_ms",
|
|
||||||
"std_tpot_ms",
|
|
||||||
"p99_tpot_ms",
|
|
||||||
"median_itl_ms",
|
|
||||||
"mean_itl_ms",
|
|
||||||
"std_itl_ms",
|
|
||||||
"p99_itl_ms",
|
|
||||||
]
|
]
|
||||||
# These raw data might be useful, but they are rather big. They can be added
|
# These raw data might be useful, but they are rather big. They can be added
|
||||||
# later if needed
|
# later if needed
|
||||||
ignored_metrics = ["ttfts", "itls", "generated_texts", "errors"]
|
ignored_metrics = ["ttfts", "itls", "generated_texts", "errors"]
|
||||||
pt_records = convert_to_pytorch_benchmark_format(
|
pt_records = convert_to_pytorch_benchmark_format(
|
||||||
args=args,
|
args=args,
|
||||||
metrics={k: [results[k]] for k in metrics},
|
metrics={k: [results[k]]
|
||||||
|
for k in metrics},
|
||||||
extra_info={
|
extra_info={
|
||||||
k: results[k]
|
k: results[k]
|
||||||
for k in results
|
for k in results if k not in metrics and k not in ignored_metrics
|
||||||
if k not in metrics and k not in ignored_metrics
|
})
|
||||||
},
|
|
||||||
)
|
|
||||||
if pt_records:
|
if pt_records:
|
||||||
# Don't use json suffix here as we don't want CI to pick it up
|
# Don't use json suffix here as we don't want CI to pick it up
|
||||||
pt_file = f"{os.path.splitext(file_name)[0]}.pytorch.json"
|
pt_file = f"{os.path.splitext(file_name)[0]}.pytorch.json"
|
||||||
@ -615,42 +557,34 @@ def main(args: argparse.Namespace):
|
|||||||
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
|
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
|
||||||
base_url = f"http://{args.host}:{args.port}"
|
base_url = f"http://{args.host}:{args.port}"
|
||||||
|
|
||||||
tokenizer = get_tokenizer(
|
tokenizer = get_tokenizer(tokenizer_id,
|
||||||
tokenizer_id,
|
tokenizer_mode=tokenizer_mode,
|
||||||
tokenizer_mode=tokenizer_mode,
|
trust_remote_code=args.trust_remote_code)
|
||||||
trust_remote_code=args.trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.dataset_name is None:
|
if args.dataset_name is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Please specify '--dataset-name' and the corresponding "
|
"Please specify '--dataset-name' and the corresponding "
|
||||||
"'--dataset-path' if required."
|
"'--dataset-path' if required.")
|
||||||
)
|
|
||||||
|
|
||||||
if args.dataset_name == "sonnet":
|
if args.dataset_name == "sonnet":
|
||||||
dataset = SonnetDataset(dataset_path=args.dataset_path)
|
dataset = SonnetDataset(dataset_path=args.dataset_path)
|
||||||
# For the "sonnet" dataset, formatting depends on the backend.
|
# For the "sonnet" dataset, formatting depends on the backend.
|
||||||
if args.backend == "openai-chat":
|
if args.backend == "openai-chat":
|
||||||
input_requests = dataset.sample(
|
input_requests = dataset.sample(num_requests=args.num_prompts,
|
||||||
num_requests=args.num_prompts,
|
input_len=args.sonnet_input_len,
|
||||||
input_len=args.sonnet_input_len,
|
output_len=args.sonnet_output_len,
|
||||||
output_len=args.sonnet_output_len,
|
prefix_len=args.sonnet_prefix_len,
|
||||||
prefix_len=args.sonnet_prefix_len,
|
tokenizer=tokenizer,
|
||||||
tokenizer=tokenizer,
|
return_prompt_formatted=False)
|
||||||
return_prompt_formatted=False,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
assert tokenizer.chat_template or tokenizer.default_chat_template, (
|
assert tokenizer.chat_template or tokenizer.default_chat_template, (
|
||||||
"Tokenizer/model must have chat template for sonnet dataset."
|
"Tokenizer/model must have chat template for sonnet dataset.")
|
||||||
)
|
input_requests = dataset.sample(num_requests=args.num_prompts,
|
||||||
input_requests = dataset.sample(
|
input_len=args.sonnet_input_len,
|
||||||
num_requests=args.num_prompts,
|
output_len=args.sonnet_output_len,
|
||||||
input_len=args.sonnet_input_len,
|
prefix_len=args.sonnet_prefix_len,
|
||||||
output_len=args.sonnet_output_len,
|
tokenizer=tokenizer,
|
||||||
prefix_len=args.sonnet_prefix_len,
|
return_prompt_formatted=True)
|
||||||
tokenizer=tokenizer,
|
|
||||||
return_prompt_formatted=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif args.dataset_name == "hf":
|
elif args.dataset_name == "hf":
|
||||||
# all following datasets are implemented from the
|
# all following datasets are implemented from the
|
||||||
@ -677,30 +611,23 @@ def main(args: argparse.Namespace):
|
|||||||
dataset_class = ASRDataset
|
dataset_class = ASRDataset
|
||||||
args.hf_split = "train"
|
args.hf_split = "train"
|
||||||
else:
|
else:
|
||||||
supported_datasets = set(
|
supported_datasets = set([
|
||||||
[
|
dataset_name for cls in HuggingFaceDataset.__subclasses__()
|
||||||
dataset_name
|
for dataset_name in cls.SUPPORTED_DATASET_PATHS
|
||||||
for cls in HuggingFaceDataset.__subclasses__()
|
])
|
||||||
for dataset_name in cls.SUPPORTED_DATASET_PATHS
|
|
||||||
]
|
|
||||||
)
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported dataset path: {args.dataset_path}. "
|
f"Unsupported dataset path: {args.dataset_path}. "
|
||||||
"Huggingface dataset only supports dataset_path"
|
"Huggingface dataset only supports dataset_path"
|
||||||
f" from one of following: {supported_datasets}. "
|
f" from one of following: {supported_datasets}. "
|
||||||
"Please consider contributing if you would "
|
"Please consider contributing if you would "
|
||||||
"like to add support for additional dataset formats."
|
"like to add support for additional dataset formats.")
|
||||||
)
|
|
||||||
|
|
||||||
if dataset_class.IS_MULTIMODAL and backend not in [
|
if (dataset_class.IS_MULTIMODAL and backend not in \
|
||||||
"openai-chat",
|
["openai-chat", "openai-audio"]):
|
||||||
"openai-audio",
|
|
||||||
]:
|
|
||||||
# multi-modal benchmark is only available on OpenAI Chat backend.
|
# multi-modal benchmark is only available on OpenAI Chat backend.
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Multi-modal content is only supported on 'openai-chat' and "
|
"Multi-modal content is only supported on 'openai-chat' and " \
|
||||||
"'openai-audio' backend."
|
"'openai-audio' backend.")
|
||||||
)
|
|
||||||
input_requests = dataset_class(
|
input_requests = dataset_class(
|
||||||
dataset_path=args.dataset_path,
|
dataset_path=args.dataset_path,
|
||||||
dataset_subset=args.hf_subset,
|
dataset_subset=args.hf_subset,
|
||||||
@ -715,24 +642,26 @@ def main(args: argparse.Namespace):
|
|||||||
else:
|
else:
|
||||||
# For datasets that follow a similar structure, use a mapping.
|
# For datasets that follow a similar structure, use a mapping.
|
||||||
dataset_mapping = {
|
dataset_mapping = {
|
||||||
"sharegpt": lambda: ShareGPTDataset(
|
"sharegpt":
|
||||||
random_seed=args.seed, dataset_path=args.dataset_path
|
lambda: ShareGPTDataset(random_seed=args.seed,
|
||||||
).sample(
|
dataset_path=args.dataset_path).sample(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
num_requests=args.num_prompts,
|
num_requests=args.num_prompts,
|
||||||
output_len=args.sharegpt_output_len,
|
output_len=args.sharegpt_output_len,
|
||||||
),
|
),
|
||||||
"burstgpt": lambda: BurstGPTDataset(
|
"burstgpt":
|
||||||
random_seed=args.seed, dataset_path=args.dataset_path
|
lambda: BurstGPTDataset(random_seed=args.seed,
|
||||||
).sample(tokenizer=tokenizer, num_requests=args.num_prompts),
|
dataset_path=args.dataset_path).
|
||||||
"random": lambda: RandomDataset(dataset_path=args.dataset_path).sample(
|
sample(tokenizer=tokenizer, num_requests=args.num_prompts),
|
||||||
|
"random":
|
||||||
|
lambda: RandomDataset(dataset_path=args.dataset_path).sample(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
num_requests=args.num_prompts,
|
num_requests=args.num_prompts,
|
||||||
prefix_len=args.random_prefix_len,
|
prefix_len=args.random_prefix_len,
|
||||||
input_len=args.random_input_len,
|
input_len=args.random_input_len,
|
||||||
output_len=args.random_output_len,
|
output_len=args.random_output_len,
|
||||||
range_ratio=args.random_range_ratio,
|
range_ratio=args.random_range_ratio,
|
||||||
),
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -748,16 +677,15 @@ def main(args: argparse.Namespace):
|
|||||||
"top_p": args.top_p,
|
"top_p": args.top_p,
|
||||||
"top_k": args.top_k,
|
"top_k": args.top_k,
|
||||||
"min_p": args.min_p,
|
"min_p": args.min_p,
|
||||||
"temperature": args.temperature,
|
"temperature": args.temperature
|
||||||
}.items()
|
}.items() if v is not None
|
||||||
if v is not None
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Sampling parameters are only supported by openai-compatible backend.
|
# Sampling parameters are only supported by openai-compatible backend.
|
||||||
if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS:
|
if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Sampling parameters are only supported by openai-compatible backends."
|
"Sampling parameters are only supported by openai-compatible "
|
||||||
)
|
"backends.")
|
||||||
|
|
||||||
if "temperature" not in sampling_params:
|
if "temperature" not in sampling_params:
|
||||||
sampling_params["temperature"] = 0.0 # Default to greedy decoding.
|
sampling_params["temperature"] = 0.0 # Default to greedy decoding.
|
||||||
@ -781,14 +709,15 @@ def main(args: argparse.Namespace):
|
|||||||
disable_tqdm=args.disable_tqdm,
|
disable_tqdm=args.disable_tqdm,
|
||||||
profile=args.profile,
|
profile=args.profile,
|
||||||
selected_percentile_metrics=args.percentile_metrics.split(","),
|
selected_percentile_metrics=args.percentile_metrics.split(","),
|
||||||
selected_percentiles=[float(p) for p in args.metric_percentiles.split(",")],
|
selected_percentiles=[
|
||||||
|
float(p) for p in args.metric_percentiles.split(",")
|
||||||
|
],
|
||||||
ignore_eos=args.ignore_eos,
|
ignore_eos=args.ignore_eos,
|
||||||
goodput_config_dict=goodput_config_dict,
|
goodput_config_dict=goodput_config_dict,
|
||||||
max_concurrency=args.max_concurrency,
|
max_concurrency=args.max_concurrency,
|
||||||
lora_modules=args.lora_modules,
|
lora_modules=args.lora_modules,
|
||||||
extra_body=sampling_params,
|
extra_body=sampling_params,
|
||||||
)
|
))
|
||||||
)
|
|
||||||
|
|
||||||
# Save config and results to json
|
# Save config and results to json
|
||||||
if args.save_result or args.append_result:
|
if args.save_result or args.append_result:
|
||||||
@ -813,9 +742,8 @@ def main(args: argparse.Namespace):
|
|||||||
"Invalid metadata format. Please use KEY=VALUE format."
|
"Invalid metadata format. Please use KEY=VALUE format."
|
||||||
)
|
)
|
||||||
# Traffic
|
# Traffic
|
||||||
result_json["request_rate"] = (
|
result_json["request_rate"] = (args.request_rate if args.request_rate
|
||||||
args.request_rate if args.request_rate < float("inf") else "inf"
|
< float("inf") else "inf")
|
||||||
)
|
|
||||||
result_json["burstiness"] = args.burstiness
|
result_json["burstiness"] = args.burstiness
|
||||||
result_json["max_concurrency"] = args.max_concurrency
|
result_json["max_concurrency"] = args.max_concurrency
|
||||||
|
|
||||||
@ -825,31 +753,24 @@ def main(args: argparse.Namespace):
|
|||||||
if not args.save_detailed:
|
if not args.save_detailed:
|
||||||
# Remove fields with too many data points
|
# Remove fields with too many data points
|
||||||
for field in [
|
for field in [
|
||||||
"input_lens",
|
"input_lens", "output_lens", "ttfts", "itls",
|
||||||
"output_lens",
|
"generated_texts", "errors"
|
||||||
"ttfts",
|
|
||||||
"itls",
|
|
||||||
"generated_texts",
|
|
||||||
"errors",
|
|
||||||
]:
|
]:
|
||||||
if field in result_json:
|
if field in result_json:
|
||||||
del result_json[field]
|
del result_json[field]
|
||||||
|
|
||||||
# Save to file
|
# Save to file
|
||||||
base_model_id = model_id.split("/")[-1]
|
base_model_id = model_id.split("/")[-1]
|
||||||
max_concurrency_str = (
|
max_concurrency_str = (f"-concurrency{args.max_concurrency}"
|
||||||
f"-concurrency{args.max_concurrency}"
|
if args.max_concurrency is not None else "")
|
||||||
if args.max_concurrency is not None
|
file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" #noqa
|
||||||
else ""
|
|
||||||
)
|
|
||||||
file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
|
|
||||||
if args.result_filename:
|
if args.result_filename:
|
||||||
file_name = args.result_filename
|
file_name = args.result_filename
|
||||||
if args.result_dir:
|
if args.result_dir:
|
||||||
file_name = os.path.join(args.result_dir, file_name)
|
file_name = os.path.join(args.result_dir, file_name)
|
||||||
with open(
|
with open(file_name,
|
||||||
file_name, mode="a+" if args.append_result else "w", encoding="utf-8"
|
mode="a+" if args.append_result else "w",
|
||||||
) as outfile:
|
encoding='utf-8') as outfile:
|
||||||
# Append a newline.
|
# Append a newline.
|
||||||
if args.append_result and outfile.tell() != 0:
|
if args.append_result and outfile.tell() != 0:
|
||||||
outfile.write("\n")
|
outfile.write("\n")
|
||||||
@ -859,8 +780,7 @@ def main(args: argparse.Namespace):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = FlexibleArgumentParser(
|
parser = FlexibleArgumentParser(
|
||||||
description="Benchmark the online serving throughput."
|
description="Benchmark the online serving throughput.")
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--backend",
|
"--backend",
|
||||||
type=str,
|
type=str,
|
||||||
@ -889,13 +809,11 @@ if __name__ == "__main__":
|
|||||||
choices=["sharegpt", "burstgpt", "sonnet", "random", "hf"],
|
choices=["sharegpt", "burstgpt", "sonnet", "random", "hf"],
|
||||||
help="Name of the dataset to benchmark on.",
|
help="Name of the dataset to benchmark on.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument("--dataset-path",
|
||||||
"--dataset-path",
|
type=str,
|
||||||
type=str,
|
default=None,
|
||||||
default=None,
|
help="Path to the sharegpt/sonnet dataset. "
|
||||||
help="Path to the sharegpt/sonnet dataset. "
|
"Or the huggingface dataset ID if using HF dataset.")
|
||||||
"Or the huggingface dataset ID if using HF dataset.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-concurrency",
|
"--max-concurrency",
|
||||||
type=int,
|
type=int,
|
||||||
@ -907,8 +825,7 @@ if __name__ == "__main__":
|
|||||||
"initiated, this argument will control how many are actually allowed "
|
"initiated, this argument will control how many are actually allowed "
|
||||||
"to execute at a time. This means that when used in combination, the "
|
"to execute at a time. This means that when used in combination, the "
|
||||||
"actual request rate may be lower than specified with --request-rate, "
|
"actual request rate may be lower than specified with --request-rate, "
|
||||||
"if the server is not processing requests fast enough to keep up.",
|
"if the server is not processing requests fast enough to keep up.")
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model",
|
"--model",
|
||||||
@ -919,7 +836,8 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tokenizer",
|
"--tokenizer",
|
||||||
type=str,
|
type=str,
|
||||||
help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
|
help=
|
||||||
|
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
|
||||||
)
|
)
|
||||||
parser.add_argument("--use-beam-search", action="store_true")
|
parser.add_argument("--use-beam-search", action="store_true")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -932,13 +850,11 @@ if __name__ == "__main__":
|
|||||||
"--logprobs",
|
"--logprobs",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help=(
|
help=("Number of logprobs-per-token to compute & return as part of "
|
||||||
"Number of logprobs-per-token to compute & return as part of "
|
"the request. If unspecified, then either (1) if beam search "
|
||||||
"the request. If unspecified, then either (1) if beam search "
|
"is disabled, no logprobs are computed & a single dummy "
|
||||||
"is disabled, no logprobs are computed & a single dummy "
|
"logprob is returned for each token; or (2) if beam search "
|
||||||
"logprob is returned for each token; or (2) if beam search "
|
"is enabled 1 logprob per token is computed"),
|
||||||
"is enabled 1 logprob per token is computed"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--request-rate",
|
"--request-rate",
|
||||||
@ -1022,38 +938,35 @@ if __name__ == "__main__":
|
|||||||
"--ignore-eos",
|
"--ignore-eos",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Set ignore_eos flag when sending the benchmark request."
|
help="Set ignore_eos flag when sending the benchmark request."
|
||||||
"Warning: ignore_eos is not supported in deepspeed_mii and tgi.",
|
"Warning: ignore_eos is not supported in deepspeed_mii and tgi.")
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--percentile-metrics",
|
"--percentile-metrics",
|
||||||
type=str,
|
type=str,
|
||||||
default="ttft,tpot,itl",
|
default="ttft,tpot,itl",
|
||||||
help="Comma-separated list of selected metrics to report percentils. "
|
help="Comma-separated list of selected metrics to report percentils. "
|
||||||
"This argument specifies the metrics to report percentiles. "
|
"This argument specifies the metrics to report percentiles. "
|
||||||
'Allowed metric names are "ttft", "tpot", "itl", "e2el". '
|
"Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". "
|
||||||
'Default value is "ttft,tpot,itl".',
|
"Default value is \"ttft,tpot,itl\".")
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--metric-percentiles",
|
"--metric-percentiles",
|
||||||
type=str,
|
type=str,
|
||||||
default="99",
|
default="99",
|
||||||
help="Comma-separated list of percentiles for selected metrics. "
|
help="Comma-separated list of percentiles for selected metrics. "
|
||||||
'To report 25-th, 50-th, and 75-th percentiles, use "25,50,75". '
|
"To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". "
|
||||||
'Default value is "99". '
|
"Default value is \"99\". "
|
||||||
'Use "--percentile-metrics" to select metrics.',
|
"Use \"--percentile-metrics\" to select metrics.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--goodput",
|
"--goodput",
|
||||||
nargs="+",
|
nargs="+",
|
||||||
required=False,
|
required=False,
|
||||||
help='Specify service level objectives for goodput as "KEY:VALUE" '
|
help="Specify service level objectives for goodput as \"KEY:VALUE\" "
|
||||||
"pairs, where the key is a metric name, and the value is in "
|
"pairs, where the key is a metric name, and the value is in "
|
||||||
'milliseconds. Multiple "KEY:VALUE" pairs can be provided, '
|
"milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, "
|
||||||
"separated by spaces. Allowed request level metric names are "
|
"separated by spaces. Allowed request level metric names are "
|
||||||
'"ttft", "tpot", "e2el". For more context on the definition of '
|
"\"ttft\", \"tpot\", \"e2el\". For more context on the definition of "
|
||||||
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
|
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
|
||||||
"and the blog: https://hao-ai-lab.github.io/blogs/distserve",
|
"and the blog: https://hao-ai-lab.github.io/blogs/distserve")
|
||||||
)
|
|
||||||
|
|
||||||
# group for dataset specific arguments
|
# group for dataset specific arguments
|
||||||
sonnet_group = parser.add_argument_group("sonnet dataset options")
|
sonnet_group = parser.add_argument_group("sonnet dataset options")
|
||||||
@ -1061,19 +974,22 @@ if __name__ == "__main__":
|
|||||||
"--sonnet-input-len",
|
"--sonnet-input-len",
|
||||||
type=int,
|
type=int,
|
||||||
default=550,
|
default=550,
|
||||||
help="Number of input tokens per request, used only for sonnet dataset.",
|
help=
|
||||||
|
"Number of input tokens per request, used only for sonnet dataset.",
|
||||||
)
|
)
|
||||||
sonnet_group.add_argument(
|
sonnet_group.add_argument(
|
||||||
"--sonnet-output-len",
|
"--sonnet-output-len",
|
||||||
type=int,
|
type=int,
|
||||||
default=150,
|
default=150,
|
||||||
help="Number of output tokens per request, used only for sonnet dataset.",
|
help=
|
||||||
|
"Number of output tokens per request, used only for sonnet dataset.",
|
||||||
)
|
)
|
||||||
sonnet_group.add_argument(
|
sonnet_group.add_argument(
|
||||||
"--sonnet-prefix-len",
|
"--sonnet-prefix-len",
|
||||||
type=int,
|
type=int,
|
||||||
default=200,
|
default=200,
|
||||||
help="Number of prefix tokens per request, used only for sonnet dataset.",
|
help=
|
||||||
|
"Number of prefix tokens per request, used only for sonnet dataset.",
|
||||||
)
|
)
|
||||||
|
|
||||||
sharegpt_group = parser.add_argument_group("sharegpt dataset options")
|
sharegpt_group = parser.add_argument_group("sharegpt dataset options")
|
||||||
@ -1082,21 +998,22 @@ if __name__ == "__main__":
|
|||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="Output length for each request. Overrides the output length "
|
help="Output length for each request. Overrides the output length "
|
||||||
"from the ShareGPT dataset.",
|
"from the ShareGPT dataset.")
|
||||||
)
|
|
||||||
|
|
||||||
random_group = parser.add_argument_group("random dataset options")
|
random_group = parser.add_argument_group("random dataset options")
|
||||||
random_group.add_argument(
|
random_group.add_argument(
|
||||||
"--random-input-len",
|
"--random-input-len",
|
||||||
type=int,
|
type=int,
|
||||||
default=1024,
|
default=1024,
|
||||||
help="Number of input tokens per request, used only for random sampling.",
|
help=
|
||||||
|
"Number of input tokens per request, used only for random sampling.",
|
||||||
)
|
)
|
||||||
random_group.add_argument(
|
random_group.add_argument(
|
||||||
"--random-output-len",
|
"--random-output-len",
|
||||||
type=int,
|
type=int,
|
||||||
default=128,
|
default=128,
|
||||||
help="Number of output tokens per request, used only for random sampling.",
|
help=
|
||||||
|
"Number of output tokens per request, used only for random sampling.",
|
||||||
)
|
)
|
||||||
random_group.add_argument(
|
random_group.add_argument(
|
||||||
"--random-range-ratio",
|
"--random-range-ratio",
|
||||||
@ -1111,23 +1028,23 @@ if __name__ == "__main__":
|
|||||||
"--random-prefix-len",
|
"--random-prefix-len",
|
||||||
type=int,
|
type=int,
|
||||||
default=0,
|
default=0,
|
||||||
help=(
|
help=("Number of fixed prefix tokens before the random context "
|
||||||
"Number of fixed prefix tokens before the random context "
|
"in a request. "
|
||||||
"in a request. "
|
"The total input length is the sum of `random-prefix-len` and "
|
||||||
"The total input length is the sum of `random-prefix-len` and "
|
"a random "
|
||||||
"a random "
|
"context length sampled from [input_len * (1 - range_ratio), "
|
||||||
"context length sampled from [input_len * (1 - range_ratio), "
|
"input_len * (1 + range_ratio)]."),
|
||||||
"input_len * (1 + range_ratio)]."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
hf_group = parser.add_argument_group("hf dataset options")
|
hf_group = parser.add_argument_group("hf dataset options")
|
||||||
hf_group.add_argument(
|
hf_group.add_argument("--hf-subset",
|
||||||
"--hf-subset", type=str, default=None, help="Subset of the HF dataset."
|
type=str,
|
||||||
)
|
default=None,
|
||||||
hf_group.add_argument(
|
help="Subset of the HF dataset.")
|
||||||
"--hf-split", type=str, default=None, help="Split of the HF dataset."
|
hf_group.add_argument("--hf-split",
|
||||||
)
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Split of the HF dataset.")
|
||||||
hf_group.add_argument(
|
hf_group.add_argument(
|
||||||
"--hf-output-len",
|
"--hf-output-len",
|
||||||
type=int,
|
type=int,
|
||||||
@ -1141,58 +1058,52 @@ if __name__ == "__main__":
|
|||||||
"--top-p",
|
"--top-p",
|
||||||
type=float,
|
type=float,
|
||||||
default=None,
|
default=None,
|
||||||
help="Top-p sampling parameter. Only has effect on openai-compatible backends.",
|
help="Top-p sampling parameter. Only has effect on openai-compatible "
|
||||||
)
|
"backends.")
|
||||||
sampling_group.add_argument(
|
sampling_group.add_argument(
|
||||||
"--top-k",
|
"--top-k",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="Top-k sampling parameter. Only has effect on openai-compatible backends.",
|
help="Top-k sampling parameter. Only has effect on openai-compatible "
|
||||||
)
|
"backends.")
|
||||||
sampling_group.add_argument(
|
sampling_group.add_argument(
|
||||||
"--min-p",
|
"--min-p",
|
||||||
type=float,
|
type=float,
|
||||||
default=None,
|
default=None,
|
||||||
help="Min-p sampling parameter. Only has effect on openai-compatible backends.",
|
help="Min-p sampling parameter. Only has effect on openai-compatible "
|
||||||
)
|
"backends.")
|
||||||
sampling_group.add_argument(
|
sampling_group.add_argument(
|
||||||
"--temperature",
|
"--temperature",
|
||||||
type=float,
|
type=float,
|
||||||
default=None,
|
default=None,
|
||||||
help="Temperature sampling parameter. Only has effect on "
|
help="Temperature sampling parameter. Only has effect on "
|
||||||
"openai-compatible backends. If not specified, default to greedy "
|
"openai-compatible backends. If not specified, default to greedy "
|
||||||
"decoding (i.e. temperature==0.0).",
|
"decoding (i.e. temperature==0.0).")
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tokenizer-mode",
|
'--tokenizer-mode',
|
||||||
type=str,
|
type=str,
|
||||||
default="auto",
|
default="auto",
|
||||||
choices=["auto", "slow", "mistral", "custom"],
|
choices=['auto', 'slow', 'mistral', 'custom'],
|
||||||
help='The tokenizer mode.\n\n* "auto" will use the '
|
help='The tokenizer mode.\n\n* "auto" will use the '
|
||||||
'fast tokenizer if available.\n* "slow" will '
|
'fast tokenizer if available.\n* "slow" will '
|
||||||
"always use the slow tokenizer. \n* "
|
'always use the slow tokenizer. \n* '
|
||||||
'"mistral" will always use the `mistral_common` tokenizer. \n*'
|
'"mistral" will always use the `mistral_common` tokenizer. \n*'
|
||||||
'"custom" will use --tokenizer to select the preregistered tokenizer.',
|
'"custom" will use --tokenizer to select the preregistered tokenizer.')
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument("--served-model-name",
|
||||||
"--served-model-name",
|
type=str,
|
||||||
type=str,
|
default=None,
|
||||||
default=None,
|
help="The model name used in the API. "
|
||||||
help="The model name used in the API. "
|
"If not specified, the model name will be the "
|
||||||
"If not specified, the model name will be the "
|
"same as the ``--model`` argument. ")
|
||||||
"same as the ``--model`` argument. ",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument("--lora-modules",
|
||||||
"--lora-modules",
|
nargs='+',
|
||||||
nargs="+",
|
default=None,
|
||||||
default=None,
|
help="A subset of LoRA module names passed in when "
|
||||||
help="A subset of LoRA module names passed in when "
|
"launching the server. For each request, the "
|
||||||
"launching the server. For each request, the "
|
"script chooses a LoRA module at random.")
|
||||||
"script chooses a LoRA module at random.",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|||||||
@ -19,7 +19,6 @@ On the client side, run:
|
|||||||
--endpoint /generate_stream
|
--endpoint /generate_stream
|
||||||
to the end of the command above.
|
to the end of the command above.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import copy
|
import copy
|
||||||
@ -37,15 +36,11 @@ from typing import Optional
|
|||||||
import datasets
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
|
||||||
|
RequestFuncOutput)
|
||||||
from tqdm.asyncio import tqdm
|
from tqdm.asyncio import tqdm
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from backend_request_func import (
|
|
||||||
ASYNC_REQUEST_FUNCS,
|
|
||||||
RequestFuncInput,
|
|
||||||
RequestFuncOutput,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -57,8 +52,7 @@ except ImportError:
|
|||||||
from argparse import ArgumentParser as FlexibleArgumentParser
|
from argparse import ArgumentParser as FlexibleArgumentParser
|
||||||
|
|
||||||
from vllm.v1.structured_output.backend_xgrammar import (
|
from vllm.v1.structured_output.backend_xgrammar import (
|
||||||
has_xgrammar_unsupported_json_features,
|
has_xgrammar_unsupported_json_features)
|
||||||
)
|
|
||||||
|
|
||||||
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
|
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
|
||||||
|
|
||||||
@ -104,7 +98,6 @@ class SampleRequest:
|
|||||||
prompt_len: The length of the prompt in tokens.
|
prompt_len: The length of the prompt in tokens.
|
||||||
expected_output_len: The expected length of the output in tokens.
|
expected_output_len: The expected length of the output in tokens.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
prompt: str
|
prompt: str
|
||||||
prompt_len: int
|
prompt_len: int
|
||||||
expected_output_len: int
|
expected_output_len: int
|
||||||
@ -113,28 +106,32 @@ class SampleRequest:
|
|||||||
completion: str = None
|
completion: str = None
|
||||||
|
|
||||||
|
|
||||||
def sample_requests(
|
def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
||||||
tokenizer: PreTrainedTokenizerBase, args: argparse.Namespace
|
args: argparse.Namespace) -> list[SampleRequest]:
|
||||||
) -> list[SampleRequest]:
|
if args.dataset == 'json' or args.dataset == 'json-unique':
|
||||||
if args.dataset == "json" or args.dataset == "json-unique":
|
|
||||||
if args.json_schema_path is None:
|
if args.json_schema_path is None:
|
||||||
dir_path = os.path.dirname(os.path.realpath(__file__))
|
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||||
args.json_schema_path = os.path.join(
|
args.json_schema_path = os.path.join(dir_path,
|
||||||
dir_path, "structured_schemas", "structured_schema_1.json"
|
"structured_schemas",
|
||||||
)
|
"structured_schema_1.json")
|
||||||
json_schemas = []
|
json_schemas = []
|
||||||
with open(args.json_schema_path) as f:
|
with open(args.json_schema_path) as f:
|
||||||
schema = json.load(f)
|
schema = json.load(f)
|
||||||
|
|
||||||
if args.dataset == "json-unique":
|
if args.dataset == 'json-unique':
|
||||||
json_schemas = [copy.deepcopy(schema) for _ in range(args.num_prompts)]
|
json_schemas = [
|
||||||
|
copy.deepcopy(schema) for _ in range(args.num_prompts)
|
||||||
|
]
|
||||||
for i in range(len(json_schemas)):
|
for i in range(len(json_schemas)):
|
||||||
if "properties" not in json_schemas[i]:
|
if "properties" not in json_schemas[i]:
|
||||||
json_schemas[i]["properties"] = {}
|
json_schemas[i]["properties"] = {}
|
||||||
json_schemas[i]["properties"][f"__optional_field_{uuid.uuid4()}"] = {
|
json_schemas[i]["properties"][
|
||||||
"type": "string",
|
f"__optional_field_{uuid.uuid4()}"] = {
|
||||||
"description": "An unique optional field to avoid cached schemas",
|
"type":
|
||||||
}
|
"string",
|
||||||
|
"description":
|
||||||
|
"An unique optional field to avoid cached schemas"
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
json_schemas = [schema] * args.num_prompts
|
json_schemas = [schema] * args.num_prompts
|
||||||
|
|
||||||
@ -145,13 +142,11 @@ def sample_requests(
|
|||||||
return json_schemas[index % len(json_schemas)]
|
return json_schemas[index % len(json_schemas)]
|
||||||
|
|
||||||
requests = [
|
requests = [
|
||||||
SampleRequest(
|
SampleRequest(prompt=gen_prompt(i),
|
||||||
prompt=gen_prompt(i),
|
prompt_len=len(tokenizer(gen_prompt(i)).input_ids),
|
||||||
prompt_len=len(tokenizer(gen_prompt(i)).input_ids),
|
expected_output_len=args.output_len,
|
||||||
expected_output_len=args.output_len,
|
schema=get_schema(i),
|
||||||
schema=get_schema(i),
|
structure_type=args.structure_type)
|
||||||
structure_type=args.structure_type,
|
|
||||||
)
|
|
||||||
for i in range(args.num_prompts)
|
for i in range(args.num_prompts)
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -175,13 +170,11 @@ def sample_requests(
|
|||||||
input_len = len(tokenizer(prompt).input_ids)
|
input_len = len(tokenizer(prompt).input_ids)
|
||||||
print(f"Input length of the prompt: {input_len} tokens")
|
print(f"Input length of the prompt: {input_len} tokens")
|
||||||
requests = [
|
requests = [
|
||||||
SampleRequest(
|
SampleRequest(prompt=prompt,
|
||||||
prompt=prompt,
|
prompt_len=input_len,
|
||||||
prompt_len=input_len,
|
expected_output_len=args.output_len,
|
||||||
expected_output_len=args.output_len,
|
schema=schema,
|
||||||
schema=schema,
|
structure_type=args.structure_type)
|
||||||
structure_type=args.structure_type,
|
|
||||||
)
|
|
||||||
for _ in range(args.num_prompts)
|
for _ in range(args.num_prompts)
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -195,13 +188,11 @@ def sample_requests(
|
|||||||
input_len = len(tokenizer(prompt).input_ids)
|
input_len = len(tokenizer(prompt).input_ids)
|
||||||
print(f"Input length of the prompt: {input_len} tokens")
|
print(f"Input length of the prompt: {input_len} tokens")
|
||||||
requests = [
|
requests = [
|
||||||
SampleRequest(
|
SampleRequest(prompt=prompt,
|
||||||
prompt=prompt,
|
prompt_len=input_len,
|
||||||
prompt_len=input_len,
|
expected_output_len=args.output_len,
|
||||||
expected_output_len=args.output_len,
|
schema=regex,
|
||||||
schema=regex,
|
structure_type=args.structure_type)
|
||||||
structure_type=args.structure_type,
|
|
||||||
)
|
|
||||||
for _ in range(args.num_prompts)
|
for _ in range(args.num_prompts)
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -212,55 +203,48 @@ def sample_requests(
|
|||||||
input_len = len(tokenizer(prompt).input_ids)
|
input_len = len(tokenizer(prompt).input_ids)
|
||||||
print(f"Input length of the prompt: {input_len} tokens")
|
print(f"Input length of the prompt: {input_len} tokens")
|
||||||
requests = [
|
requests = [
|
||||||
SampleRequest(
|
SampleRequest(prompt=prompt,
|
||||||
prompt=prompt,
|
prompt_len=input_len,
|
||||||
prompt_len=input_len,
|
expected_output_len=args.output_len,
|
||||||
expected_output_len=args.output_len,
|
schema=choice,
|
||||||
schema=choice,
|
structure_type=args.structure_type)
|
||||||
structure_type=args.structure_type,
|
|
||||||
)
|
|
||||||
for _ in range(args.num_prompts)
|
for _ in range(args.num_prompts)
|
||||||
]
|
]
|
||||||
|
|
||||||
elif args.dataset == "xgrammar_bench":
|
elif args.dataset == "xgrammar_bench":
|
||||||
requests: list[SampleRequest] = []
|
requests: list[SampleRequest] = []
|
||||||
dataset = datasets.load_dataset("NousResearch/json-mode-eval", split="train")
|
dataset = datasets.load_dataset("NousResearch/json-mode-eval",
|
||||||
|
split="train")
|
||||||
full_dataset_len = len(dataset)
|
full_dataset_len = len(dataset)
|
||||||
|
|
||||||
def _filter_func(item):
|
def _filter_func(item):
|
||||||
import json
|
import json
|
||||||
|
|
||||||
schema = json.loads(item["schema"])
|
schema = json.loads(item["schema"])
|
||||||
return not has_xgrammar_unsupported_json_features(schema)
|
return not has_xgrammar_unsupported_json_features(schema)
|
||||||
|
|
||||||
dataset = dataset.filter(_filter_func)
|
dataset = dataset.filter(_filter_func)
|
||||||
num_filtered_out = full_dataset_len - len(dataset)
|
num_filtered_out = full_dataset_len - len(dataset)
|
||||||
print(
|
print(f"dataset has {len(dataset)} entries after filtering "
|
||||||
f"dataset has {len(dataset)} entries after filtering "
|
f"out {num_filtered_out} entries with unsupported features")
|
||||||
f"out {num_filtered_out} entries with unsupported features"
|
|
||||||
)
|
|
||||||
len_dataset = len(dataset)
|
len_dataset = len(dataset)
|
||||||
for data_point_idx in range(args.num_prompts):
|
for data_point_idx in range(args.num_prompts):
|
||||||
idx = data_point_idx
|
idx = data_point_idx
|
||||||
while idx >= len_dataset:
|
while idx >= len_dataset:
|
||||||
idx -= len_dataset
|
idx -= len_dataset
|
||||||
schema = dataset["schema"][idx]
|
schema = dataset["schema"][idx]
|
||||||
prompt = tokenizer.apply_chat_template(
|
prompt = tokenizer.apply_chat_template(dataset["prompt"][idx],
|
||||||
dataset["prompt"][idx], tokenize=False, add_generation_prompt=True
|
tokenize=False,
|
||||||
)
|
add_generation_prompt=True)
|
||||||
input_len = len(tokenizer(prompt).input_ids)
|
input_len = len(tokenizer(prompt).input_ids)
|
||||||
completion = dataset["completion"][idx]
|
completion = dataset["completion"][idx]
|
||||||
|
|
||||||
requests.append(
|
requests.append(
|
||||||
SampleRequest(
|
SampleRequest(prompt=prompt,
|
||||||
prompt=prompt,
|
prompt_len=input_len,
|
||||||
prompt_len=input_len,
|
expected_output_len=args.output_len,
|
||||||
expected_output_len=args.output_len,
|
schema=schema,
|
||||||
schema=schema,
|
structure_type=args.structure_type,
|
||||||
structure_type=args.structure_type,
|
completion=completion))
|
||||||
completion=completion,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return requests
|
return requests
|
||||||
|
|
||||||
@ -292,8 +276,7 @@ async def get_request(
|
|||||||
|
|
||||||
# Calculate scale parameter theta to maintain the desired request_rate.
|
# Calculate scale parameter theta to maintain the desired request_rate.
|
||||||
assert burstiness > 0, (
|
assert burstiness > 0, (
|
||||||
f"A positive burstiness factor is expected, but given {burstiness}."
|
f"A positive burstiness factor is expected, but given {burstiness}.")
|
||||||
)
|
|
||||||
theta = 1.0 / (request_rate * burstiness)
|
theta = 1.0 / (request_rate * burstiness)
|
||||||
|
|
||||||
for i, request in enumerate(input_requests):
|
for i, request in enumerate(input_requests):
|
||||||
@ -335,8 +318,8 @@ def calculate_metrics(
|
|||||||
# multiple output tokens may be bundled together
|
# multiple output tokens may be bundled together
|
||||||
# Note : this may inflate the output token count slightly
|
# Note : this may inflate the output token count slightly
|
||||||
output_len = len(
|
output_len = len(
|
||||||
tokenizer(outputs[i].generated_text, add_special_tokens=False).input_ids
|
tokenizer(outputs[i].generated_text,
|
||||||
)
|
add_special_tokens=False).input_ids)
|
||||||
actual_output_lens.append(output_len)
|
actual_output_lens.append(output_len)
|
||||||
total_input += input_requests[i].prompt_len
|
total_input += input_requests[i].prompt_len
|
||||||
tpot = 0
|
tpot = 0
|
||||||
@ -360,19 +343,16 @@ def calculate_metrics(
|
|||||||
|
|
||||||
if "ttft" in goodput_config_dict:
|
if "ttft" in goodput_config_dict:
|
||||||
valid_metrics.append(ttfts)
|
valid_metrics.append(ttfts)
|
||||||
slo_values.append(
|
slo_values.append(goodput_config_dict["ttft"] /
|
||||||
goodput_config_dict["ttft"] / MILLISECONDS_TO_SECONDS_CONVERSION
|
MILLISECONDS_TO_SECONDS_CONVERSION)
|
||||||
)
|
|
||||||
if "tpot" in goodput_config_dict:
|
if "tpot" in goodput_config_dict:
|
||||||
valid_metrics.append(all_tpots)
|
valid_metrics.append(all_tpots)
|
||||||
slo_values.append(
|
slo_values.append(goodput_config_dict["tpot"] /
|
||||||
goodput_config_dict["tpot"] / MILLISECONDS_TO_SECONDS_CONVERSION
|
MILLISECONDS_TO_SECONDS_CONVERSION)
|
||||||
)
|
|
||||||
if "e2el" in goodput_config_dict:
|
if "e2el" in goodput_config_dict:
|
||||||
valid_metrics.append(e2els)
|
valid_metrics.append(e2els)
|
||||||
slo_values.append(
|
slo_values.append(goodput_config_dict["e2el"] /
|
||||||
goodput_config_dict["e2el"] / MILLISECONDS_TO_SECONDS_CONVERSION
|
MILLISECONDS_TO_SECONDS_CONVERSION)
|
||||||
)
|
|
||||||
|
|
||||||
for req_metric in zip(*valid_metrics):
|
for req_metric in zip(*valid_metrics):
|
||||||
is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)])
|
is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)])
|
||||||
@ -383,8 +363,7 @@ def calculate_metrics(
|
|||||||
warnings.warn(
|
warnings.warn(
|
||||||
"All requests failed. This is likely due to a misconfiguration "
|
"All requests failed. This is likely due to a misconfiguration "
|
||||||
"on the benchmark arguments.",
|
"on the benchmark arguments.",
|
||||||
stacklevel=2,
|
stacklevel=2)
|
||||||
)
|
|
||||||
metrics = BenchmarkMetrics(
|
metrics = BenchmarkMetrics(
|
||||||
completed=completed,
|
completed=completed,
|
||||||
total_input=total_input,
|
total_input=total_input,
|
||||||
@ -393,31 +372,27 @@ def calculate_metrics(
|
|||||||
request_goodput=good_completed / dur_s,
|
request_goodput=good_completed / dur_s,
|
||||||
output_throughput=sum(actual_output_lens) / dur_s,
|
output_throughput=sum(actual_output_lens) / dur_s,
|
||||||
total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s,
|
total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s,
|
||||||
mean_ttft_ms=np.mean(ttfts or 0)
|
mean_ttft_ms=np.mean(ttfts or 0) *
|
||||||
* 1000, # ttfts is empty if streaming is not supported by backend
|
1000, # ttfts is empty if streaming is not supported by backend
|
||||||
std_ttft_ms=np.std(ttfts or 0) * 1000,
|
std_ttft_ms=np.std(ttfts or 0) * 1000,
|
||||||
median_ttft_ms=np.median(ttfts or 0) * 1000,
|
median_ttft_ms=np.median(ttfts or 0) * 1000,
|
||||||
percentiles_ttft_ms=[
|
percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000)
|
||||||
(p, np.percentile(ttfts or 0, p) * 1000) for p in selected_percentiles
|
for p in selected_percentiles],
|
||||||
],
|
|
||||||
mean_tpot_ms=np.mean(tpots or 0) * 1000,
|
mean_tpot_ms=np.mean(tpots or 0) * 1000,
|
||||||
std_tpot_ms=np.std(tpots or 0) * 1000,
|
std_tpot_ms=np.std(tpots or 0) * 1000,
|
||||||
median_tpot_ms=np.median(tpots or 0) * 1000,
|
median_tpot_ms=np.median(tpots or 0) * 1000,
|
||||||
percentiles_tpot_ms=[
|
percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000)
|
||||||
(p, np.percentile(tpots or 0, p) * 1000) for p in selected_percentiles
|
for p in selected_percentiles],
|
||||||
],
|
|
||||||
mean_itl_ms=np.mean(itls or 0) * 1000,
|
mean_itl_ms=np.mean(itls or 0) * 1000,
|
||||||
std_itl_ms=np.std(itls or 0) * 1000,
|
std_itl_ms=np.std(itls or 0) * 1000,
|
||||||
median_itl_ms=np.median(itls or 0) * 1000,
|
median_itl_ms=np.median(itls or 0) * 1000,
|
||||||
percentiles_itl_ms=[
|
percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000)
|
||||||
(p, np.percentile(itls or 0, p) * 1000) for p in selected_percentiles
|
for p in selected_percentiles],
|
||||||
],
|
|
||||||
mean_e2el_ms=np.mean(e2els or 0) * 1000,
|
mean_e2el_ms=np.mean(e2els or 0) * 1000,
|
||||||
std_e2el_ms=np.std(e2els or 0) * 1000,
|
std_e2el_ms=np.std(e2els or 0) * 1000,
|
||||||
median_e2el_ms=np.median(e2els or 0) * 1000,
|
median_e2el_ms=np.median(e2els or 0) * 1000,
|
||||||
percentiles_e2el_ms=[
|
percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000)
|
||||||
(p, np.percentile(e2els or 0, p) * 1000) for p in selected_percentiles
|
for p in selected_percentiles],
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return metrics, actual_output_lens
|
return metrics, actual_output_lens
|
||||||
@ -454,13 +429,12 @@ async def benchmark(
|
|||||||
|
|
||||||
print("Starting initial single prompt test run...")
|
print("Starting initial single prompt test run...")
|
||||||
structured_output_req_idx = random.sample(
|
structured_output_req_idx = random.sample(
|
||||||
range(len(input_requests)), int(len(input_requests) * structured_output_ratio)
|
range(len(input_requests)),
|
||||||
)
|
int(len(input_requests) * structured_output_ratio))
|
||||||
|
|
||||||
test_request = input_requests[0]
|
test_request = input_requests[0]
|
||||||
test_req_extra_body = (
|
test_req_extra_body = (prepare_extra_body(test_request)
|
||||||
prepare_extra_body(test_request) if 0 in structured_output_req_idx else None
|
if 0 in structured_output_req_idx else None)
|
||||||
)
|
|
||||||
test_input = RequestFuncInput(
|
test_input = RequestFuncInput(
|
||||||
model=model_id,
|
model=model_id,
|
||||||
prompt=test_request.prompt,
|
prompt=test_request.prompt,
|
||||||
@ -474,8 +448,7 @@ async def benchmark(
|
|||||||
if not test_output.success:
|
if not test_output.success:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Initial test run failed - Please make sure benchmark arguments "
|
"Initial test run failed - Please make sure benchmark arguments "
|
||||||
f"are correctly specified. Error: {test_output.error}"
|
f"are correctly specified. Error: {test_output.error}")
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
print("Initial test run completed. Starting main benchmark run...")
|
print("Initial test run completed. Starting main benchmark run...")
|
||||||
|
|
||||||
@ -494,7 +467,10 @@ async def benchmark(
|
|||||||
if profile_output.success:
|
if profile_output.success:
|
||||||
print("Profiler started")
|
print("Profiler started")
|
||||||
|
|
||||||
distribution = "Poisson process" if burstiness == 1.0 else "Gamma distribution"
|
if burstiness == 1.0:
|
||||||
|
distribution = "Poisson process"
|
||||||
|
else:
|
||||||
|
distribution = "Gamma distribution"
|
||||||
|
|
||||||
print(f"Traffic request rate: {request_rate}")
|
print(f"Traffic request rate: {request_rate}")
|
||||||
print(f"Burstiness factor: {burstiness} ({distribution})")
|
print(f"Burstiness factor: {burstiness} ({distribution})")
|
||||||
@ -506,21 +482,24 @@ async def benchmark(
|
|||||||
# and it will simplify the code in limited_request_func.
|
# and it will simplify the code in limited_request_func.
|
||||||
# semaphore = (asyncio.Semaphore(max_concurrency)
|
# semaphore = (asyncio.Semaphore(max_concurrency)
|
||||||
# if max_concurrency else contextlib.nullcontext())
|
# if max_concurrency else contextlib.nullcontext())
|
||||||
semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
|
semaphore = (asyncio.Semaphore(max_concurrency)
|
||||||
|
if max_concurrency else None)
|
||||||
|
|
||||||
async def limited_request_func(request_func_input, pbar):
|
async def limited_request_func(request_func_input, pbar):
|
||||||
if semaphore is None:
|
if semaphore is None:
|
||||||
return await request_func(request_func_input=request_func_input, pbar=pbar)
|
return await request_func(request_func_input=request_func_input,
|
||||||
|
pbar=pbar)
|
||||||
async with semaphore:
|
async with semaphore:
|
||||||
return await request_func(request_func_input=request_func_input, pbar=pbar)
|
return await request_func(request_func_input=request_func_input,
|
||||||
|
pbar=pbar)
|
||||||
|
|
||||||
benchmark_start_time = time.perf_counter()
|
benchmark_start_time = time.perf_counter()
|
||||||
tasks: list[asyncio.Task] = []
|
tasks: list[asyncio.Task] = []
|
||||||
expected: list[str] = []
|
expected: list[str] = []
|
||||||
async for i, request in get_request(input_requests, request_rate, burstiness):
|
async for i, request in get_request(input_requests, request_rate,
|
||||||
extra_body = (
|
burstiness):
|
||||||
prepare_extra_body(request) if i in structured_output_req_idx else None
|
extra_body = prepare_extra_body(
|
||||||
)
|
request) if i in structured_output_req_idx else None
|
||||||
request_func_input = RequestFuncInput(
|
request_func_input = RequestFuncInput(
|
||||||
model=model_id,
|
model=model_id,
|
||||||
prompt=request.prompt,
|
prompt=request.prompt,
|
||||||
@ -533,9 +512,8 @@ async def benchmark(
|
|||||||
expected.append(request.completion)
|
expected.append(request.completion)
|
||||||
tasks.append(
|
tasks.append(
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
limited_request_func(request_func_input=request_func_input, pbar=pbar)
|
limited_request_func(request_func_input=request_func_input,
|
||||||
)
|
pbar=pbar)))
|
||||||
)
|
|
||||||
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
|
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
if profile:
|
if profile:
|
||||||
@ -567,58 +545,54 @@ async def benchmark(
|
|||||||
goodput_config_dict=goodput_config_dict,
|
goodput_config_dict=goodput_config_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
print("{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
|
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
|
||||||
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
|
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
|
||||||
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
|
print("{:<40} {:<10.2f}".format("Benchmark duration (s):",
|
||||||
|
benchmark_duration))
|
||||||
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
|
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
|
||||||
print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
|
print("{:<40} {:<10}".format("Total generated tokens:",
|
||||||
print(
|
metrics.total_output))
|
||||||
"{:<40} {:<10.2f}".format(
|
print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
|
||||||
"Request throughput (req/s):", metrics.request_throughput
|
metrics.request_throughput))
|
||||||
)
|
|
||||||
)
|
|
||||||
if goodput_config_dict:
|
if goodput_config_dict:
|
||||||
print(
|
print("{:<40} {:<10.2f}".format("Request goodput (req/s):",
|
||||||
"{:<40} {:<10.2f}".format(
|
metrics.request_goodput))
|
||||||
"Request goodput (req/s):", metrics.request_goodput
|
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
|
||||||
)
|
metrics.output_throughput))
|
||||||
)
|
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
|
||||||
print(
|
metrics.total_token_throughput))
|
||||||
"{:<40} {:<10.2f}".format(
|
|
||||||
"Output token throughput (tok/s):", metrics.output_throughput
|
|
||||||
)
|
|
||||||
)
|
|
||||||
print(
|
|
||||||
"{:<40} {:<10.2f}".format(
|
|
||||||
"Total Token throughput (tok/s):", metrics.total_token_throughput
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"duration": benchmark_duration,
|
"duration":
|
||||||
"completed": metrics.completed,
|
benchmark_duration,
|
||||||
"total_input_tokens": metrics.total_input,
|
"completed":
|
||||||
"total_output_tokens": metrics.total_output,
|
metrics.completed,
|
||||||
"request_throughput": metrics.request_throughput,
|
"total_input_tokens":
|
||||||
"output_throughput": metrics.output_throughput,
|
metrics.total_input,
|
||||||
"total_token_throughput": metrics.total_token_throughput,
|
"total_output_tokens":
|
||||||
"ttft_description": pd.Series([output.ttft for output in outputs])
|
metrics.total_output,
|
||||||
.describe()
|
"request_throughput":
|
||||||
.to_dict(),
|
metrics.request_throughput,
|
||||||
"tpot_description": pd.Series([output.tpot for output in outputs])
|
"output_throughput":
|
||||||
.describe()
|
metrics.output_throughput,
|
||||||
.to_dict(),
|
"total_token_throughput":
|
||||||
|
metrics.total_token_throughput,
|
||||||
|
"ttft_description":
|
||||||
|
pd.Series([output.ttft for output in outputs]).describe().to_dict(),
|
||||||
|
"tpot_description":
|
||||||
|
pd.Series([output.tpot for output in outputs]).describe().to_dict(),
|
||||||
"input_lens": [output.prompt_len for output in outputs],
|
"input_lens": [output.prompt_len for output in outputs],
|
||||||
"output_lens": actual_output_lens,
|
"output_lens":
|
||||||
|
actual_output_lens,
|
||||||
"ttfts": [output.ttft for output in outputs],
|
"ttfts": [output.ttft for output in outputs],
|
||||||
"itls": [output.itl for output in outputs],
|
"itls": [output.itl for output in outputs],
|
||||||
"errors": [output.error for output in outputs],
|
"errors": [output.error for output in outputs],
|
||||||
}
|
}
|
||||||
|
|
||||||
ret = [
|
ret = [{
|
||||||
{"generated": output.generated_text, "expected": gt}
|
'generated': output.generated_text,
|
||||||
for output, gt in zip(outputs, expected)
|
'expected': gt
|
||||||
]
|
} for output, gt in zip(outputs, expected)]
|
||||||
|
|
||||||
def process_one_metric(
|
def process_one_metric(
|
||||||
# E.g., "ttft"
|
# E.g., "ttft"
|
||||||
@ -632,35 +606,29 @@ async def benchmark(
|
|||||||
# metric.
|
# metric.
|
||||||
if metric_attribute_name not in selected_percentile_metrics:
|
if metric_attribute_name not in selected_percentile_metrics:
|
||||||
return
|
return
|
||||||
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
|
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
|
||||||
print(
|
print("{:<40} {:<10.2f}".format(
|
||||||
"{:<40} {:<10.2f}".format(
|
f"Mean {metric_name} (ms):",
|
||||||
f"Mean {metric_name} (ms):",
|
getattr(metrics, f"mean_{metric_attribute_name}_ms")))
|
||||||
getattr(metrics, f"mean_{metric_attribute_name}_ms"),
|
print("{:<40} {:<10.2f}".format(
|
||||||
)
|
f"Median {metric_name} (ms):",
|
||||||
)
|
getattr(metrics, f"median_{metric_attribute_name}_ms")))
|
||||||
print(
|
|
||||||
"{:<40} {:<10.2f}".format(
|
|
||||||
f"Median {metric_name} (ms):",
|
|
||||||
getattr(metrics, f"median_{metric_attribute_name}_ms"),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
result[f"mean_{metric_attribute_name}_ms"] = getattr(
|
result[f"mean_{metric_attribute_name}_ms"] = getattr(
|
||||||
metrics, f"mean_{metric_attribute_name}_ms"
|
metrics, f"mean_{metric_attribute_name}_ms")
|
||||||
)
|
|
||||||
result[f"median_{metric_attribute_name}_ms"] = getattr(
|
result[f"median_{metric_attribute_name}_ms"] = getattr(
|
||||||
metrics, f"median_{metric_attribute_name}_ms"
|
metrics, f"median_{metric_attribute_name}_ms")
|
||||||
)
|
|
||||||
result[f"std_{metric_attribute_name}_ms"] = getattr(
|
result[f"std_{metric_attribute_name}_ms"] = getattr(
|
||||||
metrics, f"std_{metric_attribute_name}_ms"
|
metrics, f"std_{metric_attribute_name}_ms")
|
||||||
)
|
for p, value in getattr(metrics,
|
||||||
for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}_ms"):
|
f"percentiles_{metric_attribute_name}_ms"):
|
||||||
p_word = str(int(p)) if int(p) == p else str(p)
|
p_word = str(int(p)) if int(p) == p else str(p)
|
||||||
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value))
|
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):",
|
||||||
|
value))
|
||||||
result[f"p{p_word}_{metric_attribute_name}_ms"] = value
|
result[f"p{p_word}_{metric_attribute_name}_ms"] = value
|
||||||
|
|
||||||
process_one_metric("ttft", "TTFT", "Time to First Token")
|
process_one_metric("ttft", "TTFT", "Time to First Token")
|
||||||
process_one_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)")
|
process_one_metric("tpot", "TPOT",
|
||||||
|
"Time per Output Token (excl. 1st token)")
|
||||||
process_one_metric("itl", "ITL", "Inter-token Latency")
|
process_one_metric("itl", "ITL", "Inter-token Latency")
|
||||||
process_one_metric("e2el", "E2EL", "End-to-end Latency")
|
process_one_metric("e2el", "E2EL", "End-to-end Latency")
|
||||||
|
|
||||||
@ -670,13 +638,13 @@ async def benchmark(
|
|||||||
|
|
||||||
|
|
||||||
def evaluate(ret, args):
|
def evaluate(ret, args):
|
||||||
|
|
||||||
def _eval_correctness_json(expected, actual):
|
def _eval_correctness_json(expected, actual):
|
||||||
# extract json string from string using regex
|
# extract json string from string using regex
|
||||||
import regex as re
|
import re
|
||||||
|
actual = actual.replace('\n', '').replace(' ', '').strip()
|
||||||
actual = actual.replace("\n", "").replace(" ", "").strip()
|
|
||||||
try:
|
try:
|
||||||
actual = re.search(r"\{.*\}", actual).group()
|
actual = re.search(r'\{.*\}', actual).group()
|
||||||
actual = json.loads(actual)
|
actual = json.loads(actual)
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
@ -687,33 +655,29 @@ def evaluate(ret, args):
|
|||||||
return actual in args.choice
|
return actual in args.choice
|
||||||
|
|
||||||
def _eval_correctness_regex(expected, actual):
|
def _eval_correctness_regex(expected, actual):
|
||||||
import regex as re
|
import re
|
||||||
|
|
||||||
return re.match(args.regex, actual) is not None
|
return re.match(args.regex, actual) is not None
|
||||||
|
|
||||||
def _eval_correctness(expected, actual):
|
def _eval_correctness(expected, actual):
|
||||||
if args.structure_type == "guided_json":
|
if args.structure_type == 'guided_json':
|
||||||
return _eval_correctness_json(expected, actual)
|
return _eval_correctness_json(expected, actual)
|
||||||
elif args.structure_type == "guided_regex":
|
elif args.structure_type == 'guided_regex':
|
||||||
return _eval_correctness_regex(expected, actual)
|
return _eval_correctness_regex(expected, actual)
|
||||||
elif args.structure_type == "guided_choice":
|
elif args.structure_type == 'guided_choice':
|
||||||
return _eval_correctness_choice(expected, actual)
|
return _eval_correctness_choice(expected, actual)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
scores = []
|
scores = []
|
||||||
for res in ret:
|
for res in ret:
|
||||||
score = _eval_correctness(res["expected"], res["generated"])
|
score = _eval_correctness(res['expected'], res['generated'])
|
||||||
res["correctness"] = score
|
res['correctness'] = score
|
||||||
scores.append(score)
|
scores.append(score)
|
||||||
|
|
||||||
not_none_scores = [score for score in scores if score is not None]
|
not_none_scores = [score for score in scores if score is not None]
|
||||||
|
|
||||||
return (
|
return (sum(not_none_scores) / len(not_none_scores) *
|
||||||
(sum(not_none_scores) / len(not_none_scores) * 100)
|
100) if len(not_none_scores) > 0 else None
|
||||||
if len(not_none_scores) > 0
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_goodput(slo_pairs):
|
def parse_goodput(slo_pairs):
|
||||||
@ -725,10 +689,9 @@ def parse_goodput(slo_pairs):
|
|||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
raise argparse.ArgumentTypeError(
|
raise argparse.ArgumentTypeError(
|
||||||
"Invalid format found for service level objectives. "
|
"Invalid format found for service level objectives. "
|
||||||
'Specify service level objectives for goodput as "KEY:VALUE" '
|
"Specify service level objectives for goodput as \"KEY:VALUE\" "
|
||||||
"pairs, where the key is a metric name, and the value is a "
|
"pairs, where the key is a metric name, and the value is a "
|
||||||
"number in milliseconds."
|
"number in milliseconds.") from err
|
||||||
) from err
|
|
||||||
return goodput_config_dict
|
return goodput_config_dict
|
||||||
|
|
||||||
|
|
||||||
@ -742,14 +705,12 @@ def check_goodput_args(args):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid metric name found, {slo_name}: {slo_val}. "
|
f"Invalid metric name found, {slo_name}: {slo_val}. "
|
||||||
"The service level objective name should be one of "
|
"The service level objective name should be one of "
|
||||||
f"{str(VALID_NAMES)}. "
|
f"{str(VALID_NAMES)}. ")
|
||||||
)
|
|
||||||
if slo_val < 0:
|
if slo_val < 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid value found, {slo_name}: {slo_val}. "
|
f"Invalid value found, {slo_name}: {slo_val}. "
|
||||||
"The service level objective value should be "
|
"The service level objective value should be "
|
||||||
"non-negative."
|
"non-negative.")
|
||||||
)
|
|
||||||
return goodput_config_dict
|
return goodput_config_dict
|
||||||
|
|
||||||
|
|
||||||
@ -775,19 +736,19 @@ def main(args: argparse.Namespace):
|
|||||||
tokenizer_mode=args.tokenizer_mode,
|
tokenizer_mode=args.tokenizer_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.dataset == "grammar":
|
if args.dataset == 'grammar':
|
||||||
args.structure_type = "guided_grammar"
|
args.structure_type = 'guided_grammar'
|
||||||
elif args.dataset == "regex":
|
elif args.dataset == 'regex':
|
||||||
args.structure_type = "guided_regex"
|
args.structure_type = 'guided_regex'
|
||||||
elif args.dataset == "choice":
|
elif args.dataset == 'choice':
|
||||||
args.structure_type = "guided_choice"
|
args.structure_type = 'guided_choice'
|
||||||
else:
|
else:
|
||||||
args.structure_type = "guided_json"
|
args.structure_type = 'guided_json'
|
||||||
|
|
||||||
if args.no_structured_output:
|
if args.no_structured_output:
|
||||||
args.structured_output_ratio = 0
|
args.structured_output_ratio = 0
|
||||||
if args.save_results:
|
if args.save_results:
|
||||||
result_file_name = f"{args.structured_output_ratio}guided"
|
result_file_name = f'{args.structured_output_ratio}guided'
|
||||||
result_file_name += f"_{backend}"
|
result_file_name += f"_{backend}"
|
||||||
result_file_name += f"_{args.request_rate}qps"
|
result_file_name += f"_{args.request_rate}qps"
|
||||||
result_file_name += f"_{args.model.split('/')[-1]}"
|
result_file_name += f"_{args.model.split('/')[-1]}"
|
||||||
@ -815,29 +776,36 @@ def main(args: argparse.Namespace):
|
|||||||
disable_tqdm=args.disable_tqdm,
|
disable_tqdm=args.disable_tqdm,
|
||||||
profile=args.profile,
|
profile=args.profile,
|
||||||
selected_percentile_metrics=args.percentile_metrics.split(","),
|
selected_percentile_metrics=args.percentile_metrics.split(","),
|
||||||
selected_percentiles=[float(p) for p in args.metric_percentiles.split(",")],
|
selected_percentiles=[
|
||||||
|
float(p) for p in args.metric_percentiles.split(",")
|
||||||
|
],
|
||||||
ignore_eos=args.ignore_eos,
|
ignore_eos=args.ignore_eos,
|
||||||
max_concurrency=args.max_concurrency,
|
max_concurrency=args.max_concurrency,
|
||||||
structured_output_ratio=args.structured_output_ratio,
|
structured_output_ratio=args.structured_output_ratio,
|
||||||
goodput_config_dict=goodput_config_dict,
|
goodput_config_dict=goodput_config_dict,
|
||||||
)
|
))
|
||||||
)
|
|
||||||
|
|
||||||
# Save config and results to json
|
# Save config and results to json
|
||||||
score = evaluate(ret, args)
|
score = evaluate(ret, args)
|
||||||
print("correct_rate(%)", score, "\n")
|
print("correct_rate(%)", score, '\n')
|
||||||
if args.save_results:
|
if args.save_results:
|
||||||
results = {
|
results = {
|
||||||
"backend": backend,
|
"backend":
|
||||||
"model_id": model_id,
|
backend,
|
||||||
"tokenizer_id": tokenizer_id,
|
"model_id":
|
||||||
"num_prompts": args.num_prompts,
|
model_id,
|
||||||
"request_rate": args.request_rate
|
"tokenizer_id":
|
||||||
if args.request_rate < float("inf")
|
tokenizer_id,
|
||||||
else "inf",
|
"num_prompts":
|
||||||
"burstiness": args.burstiness,
|
args.num_prompts,
|
||||||
"max_concurrency": args.max_concurrency,
|
"request_rate":
|
||||||
"correct_rate(%)": score,
|
args.request_rate if args.request_rate < float("inf") else "inf",
|
||||||
|
"burstiness":
|
||||||
|
args.burstiness,
|
||||||
|
"max_concurrency":
|
||||||
|
args.max_concurrency,
|
||||||
|
"correct_rate(%)":
|
||||||
|
score
|
||||||
}
|
}
|
||||||
results = {"outputs": ret, **results, **benchmark_result}
|
results = {"outputs": ret, **results, **benchmark_result}
|
||||||
|
|
||||||
@ -846,14 +814,13 @@ def main(args: argparse.Namespace):
|
|||||||
result_file_name = args.result_filename
|
result_file_name = args.result_filename
|
||||||
if args.result_dir:
|
if args.result_dir:
|
||||||
result_file_name = os.path.join(args.result_dir, result_file_name)
|
result_file_name = os.path.join(args.result_dir, result_file_name)
|
||||||
with open(result_file_name, "w", encoding="utf-8") as outfile:
|
with open(result_file_name, "w", encoding='utf-8') as outfile:
|
||||||
json.dump(results, outfile, indent=4)
|
json.dump(results, outfile, indent=4)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = FlexibleArgumentParser(
|
parser = FlexibleArgumentParser(
|
||||||
description="Benchmark the online serving throughput."
|
description="Benchmark the online serving throughput.")
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--backend",
|
"--backend",
|
||||||
type=str,
|
type=str,
|
||||||
@ -875,14 +842,16 @@ if __name__ == "__main__":
|
|||||||
default="/v1/completions",
|
default="/v1/completions",
|
||||||
help="API endpoint.",
|
help="API endpoint.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument("--dataset",
|
||||||
"--dataset",
|
default='json',
|
||||||
default="json",
|
choices=[
|
||||||
choices=["json", "json-unique", "grammar", "regex", "choice", "xgrammar_bench"],
|
'json', 'json-unique', 'grammar', 'regex',
|
||||||
)
|
'choice', 'xgrammar_bench'
|
||||||
parser.add_argument(
|
])
|
||||||
"--json-schema-path", type=str, default=None, help="Path to json schema."
|
parser.add_argument("--json-schema-path",
|
||||||
)
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to json schema.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-concurrency",
|
"--max-concurrency",
|
||||||
type=int,
|
type=int,
|
||||||
@ -894,8 +863,7 @@ if __name__ == "__main__":
|
|||||||
"initiated, this argument will control how many are actually allowed "
|
"initiated, this argument will control how many are actually allowed "
|
||||||
"to execute at a time. This means that when used in combination, the "
|
"to execute at a time. This means that when used in combination, the "
|
||||||
"actual request rate may be lower than specified with --request-rate, "
|
"actual request rate may be lower than specified with --request-rate, "
|
||||||
"if the server is not processing requests fast enough to keep up.",
|
"if the server is not processing requests fast enough to keep up.")
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model",
|
"--model",
|
||||||
type=str,
|
type=str,
|
||||||
@ -905,13 +873,15 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tokenizer",
|
"--tokenizer",
|
||||||
type=str,
|
type=str,
|
||||||
help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
|
help=
|
||||||
|
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tokenizer-mode",
|
"--tokenizer-mode",
|
||||||
type=str,
|
type=str,
|
||||||
default="auto",
|
default="auto",
|
||||||
help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
|
help=
|
||||||
|
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-prompts",
|
"--num-prompts",
|
||||||
@ -988,51 +958,44 @@ if __name__ == "__main__":
|
|||||||
"--ignore-eos",
|
"--ignore-eos",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Set ignore_eos flag when sending the benchmark request."
|
help="Set ignore_eos flag when sending the benchmark request."
|
||||||
"Warning: ignore_eos is not supported in deepspeed_mii and tgi.",
|
"Warning: ignore_eos is not supported in deepspeed_mii and tgi.")
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--percentile-metrics",
|
"--percentile-metrics",
|
||||||
type=str,
|
type=str,
|
||||||
default="ttft,tpot,itl",
|
default="ttft,tpot,itl",
|
||||||
help="Comma-separated list of selected metrics to report percentils. "
|
help="Comma-separated list of selected metrics to report percentils. "
|
||||||
"This argument specifies the metrics to report percentiles. "
|
"This argument specifies the metrics to report percentiles. "
|
||||||
'Allowed metric names are "ttft", "tpot", "itl", "e2el". '
|
"Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". "
|
||||||
'Default value is "ttft,tpot,itl".',
|
"Default value is \"ttft,tpot,itl\".")
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--metric-percentiles",
|
"--metric-percentiles",
|
||||||
type=str,
|
type=str,
|
||||||
default="99",
|
default="99",
|
||||||
help="Comma-separated list of percentiles for selected metrics. "
|
help="Comma-separated list of percentiles for selected metrics. "
|
||||||
'To report 25-th, 50-th, and 75-th percentiles, use "25,50,75". '
|
"To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". "
|
||||||
'Default value is "99". '
|
"Default value is \"99\". "
|
||||||
'Use "--percentile-metrics" to select metrics.',
|
"Use \"--percentile-metrics\" to select metrics.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--goodput",
|
"--goodput",
|
||||||
nargs="+",
|
nargs="+",
|
||||||
required=False,
|
required=False,
|
||||||
help='Specify service level objectives for goodput as "KEY:VALUE" '
|
help="Specify service level objectives for goodput as \"KEY:VALUE\" "
|
||||||
"pairs, where the key is a metric name, and the value is in "
|
"pairs, where the key is a metric name, and the value is in "
|
||||||
'milliseconds. Multiple "KEY:VALUE" pairs can be provided, '
|
"milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, "
|
||||||
"separated by spaces. Allowed request level metric names are "
|
"separated by spaces. Allowed request level metric names are "
|
||||||
'"ttft", "tpot", "e2el". For more context on the definition of '
|
"\"ttft\", \"tpot\", \"e2el\". For more context on the definition of "
|
||||||
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
|
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
|
||||||
"and the blog: https://hao-ai-lab.github.io/blogs/distserve",
|
"and the blog: https://hao-ai-lab.github.io/blogs/distserve")
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument("--no-structured-output",
|
||||||
"--no-structured-output",
|
action='store_true',
|
||||||
action="store_true",
|
default=False,
|
||||||
default=False,
|
help="Whether to disable JSON decoding or not.")
|
||||||
help="Whether to disable JSON decoding or not.",
|
parser.add_argument("--structured-output-ratio",
|
||||||
)
|
type=float,
|
||||||
parser.add_argument(
|
default=1.0,
|
||||||
"--structured-output-ratio",
|
help="Ratio of Structured Outputs requests")
|
||||||
type=float,
|
|
||||||
default=1.0,
|
|
||||||
help="Ratio of Structured Outputs requests",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
"""Benchmark offline inference throughput."""
|
"""Benchmark offline inference throughput."""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import json
|
import json
|
||||||
@ -12,25 +11,18 @@ from typing import Any, Optional, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import uvloop
|
import uvloop
|
||||||
from tqdm import tqdm
|
from benchmark_dataset import (AIMODataset, BurstGPTDataset,
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase
|
ConversationDataset, InstructCoderDataset,
|
||||||
|
RandomDataset, SampleRequest, ShareGPTDataset,
|
||||||
from benchmark_dataset import (
|
SonnetDataset, VisionArenaDataset)
|
||||||
AIMODataset,
|
|
||||||
BurstGPTDataset,
|
|
||||||
ConversationDataset,
|
|
||||||
InstructCoderDataset,
|
|
||||||
RandomDataset,
|
|
||||||
SampleRequest,
|
|
||||||
ShareGPTDataset,
|
|
||||||
SonnetDataset,
|
|
||||||
VisionArenaDataset,
|
|
||||||
)
|
|
||||||
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
||||||
|
PreTrainedTokenizerBase)
|
||||||
|
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||||
from vllm.entrypoints.openai.api_server import (
|
from vllm.entrypoints.openai.api_server import (
|
||||||
build_async_engine_client_from_engine_args,
|
build_async_engine_client_from_engine_args)
|
||||||
)
|
|
||||||
from vllm.inputs import TextPrompt, TokensPrompt
|
from vllm.inputs import TextPrompt, TokensPrompt
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
@ -45,30 +37,23 @@ def run_vllm(
|
|||||||
disable_detokenize: bool = False,
|
disable_detokenize: bool = False,
|
||||||
) -> tuple[float, Optional[list[RequestOutput]]]:
|
) -> tuple[float, Optional[list[RequestOutput]]]:
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
llm = LLM(**dataclasses.asdict(engine_args))
|
llm = LLM(**dataclasses.asdict(engine_args))
|
||||||
assert all(
|
assert all(
|
||||||
llm.llm_engine.model_config.max_model_len
|
llm.llm_engine.model_config.max_model_len >= (
|
||||||
>= (request.prompt_len + request.expected_output_len)
|
request.prompt_len + request.expected_output_len)
|
||||||
for request in requests
|
for request in requests), (
|
||||||
), (
|
"Please ensure that max_model_len is greater than the sum of"
|
||||||
"Please ensure that max_model_len is greater than the sum of"
|
" prompt_len and expected_output_len for all requests.")
|
||||||
" prompt_len and expected_output_len for all requests."
|
|
||||||
)
|
|
||||||
# Add the requests to the engine.
|
# Add the requests to the engine.
|
||||||
prompts: list[Union[TextPrompt, TokensPrompt]] = []
|
prompts: list[Union[TextPrompt, TokensPrompt]] = []
|
||||||
sampling_params: list[SamplingParams] = []
|
sampling_params: list[SamplingParams] = []
|
||||||
for request in requests:
|
for request in requests:
|
||||||
prompts.append(
|
prompts.append(
|
||||||
TokensPrompt(
|
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"],
|
||||||
prompt_token_ids=request.prompt["prompt_token_ids"],
|
multi_modal_data=request.multi_modal_data)
|
||||||
multi_modal_data=request.multi_modal_data,
|
if "prompt_token_ids" in request.prompt else \
|
||||||
)
|
TextPrompt(prompt=request.prompt,
|
||||||
if "prompt_token_ids" in request.prompt
|
multi_modal_data=request.multi_modal_data))
|
||||||
else TextPrompt(
|
|
||||||
prompt=request.prompt, multi_modal_data=request.multi_modal_data
|
|
||||||
)
|
|
||||||
)
|
|
||||||
sampling_params.append(
|
sampling_params.append(
|
||||||
SamplingParams(
|
SamplingParams(
|
||||||
n=n,
|
n=n,
|
||||||
@ -77,8 +62,7 @@ def run_vllm(
|
|||||||
ignore_eos=True,
|
ignore_eos=True,
|
||||||
max_tokens=request.expected_output_len,
|
max_tokens=request.expected_output_len,
|
||||||
detokenize=not disable_detokenize,
|
detokenize=not disable_detokenize,
|
||||||
)
|
))
|
||||||
)
|
|
||||||
lora_requests: Optional[list[LoRARequest]] = None
|
lora_requests: Optional[list[LoRARequest]] = None
|
||||||
if engine_args.enable_lora:
|
if engine_args.enable_lora:
|
||||||
lora_requests = [request.lora_request for request in requests]
|
lora_requests = [request.lora_request for request in requests]
|
||||||
@ -88,9 +72,10 @@ def run_vllm(
|
|||||||
outputs = None
|
outputs = None
|
||||||
if not use_beam_search:
|
if not use_beam_search:
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
outputs = llm.generate(
|
outputs = llm.generate(prompts,
|
||||||
prompts, sampling_params, lora_request=lora_requests, use_tqdm=True
|
sampling_params,
|
||||||
)
|
lora_request=lora_requests,
|
||||||
|
use_tqdm=True)
|
||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
else:
|
else:
|
||||||
assert lora_requests is None, "BeamSearch API does not support LoRA"
|
assert lora_requests is None, "BeamSearch API does not support LoRA"
|
||||||
@ -106,35 +91,30 @@ def run_vllm(
|
|||||||
beam_width=n,
|
beam_width=n,
|
||||||
max_tokens=output_len,
|
max_tokens=output_len,
|
||||||
ignore_eos=True,
|
ignore_eos=True,
|
||||||
),
|
))
|
||||||
)
|
|
||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
return end - start, outputs
|
return end - start, outputs
|
||||||
|
|
||||||
|
|
||||||
def run_vllm_chat(
|
def run_vllm_chat(
|
||||||
requests: list[SampleRequest],
|
requests: list[SampleRequest],
|
||||||
n: int,
|
n: int,
|
||||||
engine_args: EngineArgs,
|
engine_args: EngineArgs,
|
||||||
disable_detokenize: bool = False,
|
disable_detokenize: bool = False) -> tuple[float, list[RequestOutput]]:
|
||||||
) -> tuple[float, list[RequestOutput]]:
|
|
||||||
"""
|
"""
|
||||||
Run vLLM chat benchmark. This function is recommended ONLY for benchmarking
|
Run vLLM chat benchmark. This function is recommended ONLY for benchmarking
|
||||||
multimodal models as it properly handles multimodal inputs and chat
|
multimodal models as it properly handles multimodal inputs and chat
|
||||||
formatting. For non-multimodal models, use run_vllm() instead.
|
formatting. For non-multimodal models, use run_vllm() instead.
|
||||||
"""
|
"""
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
llm = LLM(**dataclasses.asdict(engine_args))
|
llm = LLM(**dataclasses.asdict(engine_args))
|
||||||
|
|
||||||
assert all(
|
assert all(
|
||||||
llm.llm_engine.model_config.max_model_len
|
llm.llm_engine.model_config.max_model_len >= (
|
||||||
>= (request.prompt_len + request.expected_output_len)
|
request.prompt_len + request.expected_output_len)
|
||||||
for request in requests
|
for request in requests), (
|
||||||
), (
|
"Please ensure that max_model_len is greater than the sum of "
|
||||||
"Please ensure that max_model_len is greater than the sum of "
|
"prompt_len and expected_output_len for all requests.")
|
||||||
"prompt_len and expected_output_len for all requests."
|
|
||||||
)
|
|
||||||
|
|
||||||
prompts = []
|
prompts = []
|
||||||
sampling_params: list[SamplingParams] = []
|
sampling_params: list[SamplingParams] = []
|
||||||
@ -148,8 +128,7 @@ def run_vllm_chat(
|
|||||||
ignore_eos=True,
|
ignore_eos=True,
|
||||||
max_tokens=request.expected_output_len,
|
max_tokens=request.expected_output_len,
|
||||||
detokenize=not disable_detokenize,
|
detokenize=not disable_detokenize,
|
||||||
)
|
))
|
||||||
)
|
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
outputs = llm.chat(prompts, sampling_params, use_tqdm=True)
|
outputs = llm.chat(prompts, sampling_params, use_tqdm=True)
|
||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
@ -166,17 +145,13 @@ async def run_vllm_async(
|
|||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
|
|
||||||
async with build_async_engine_client_from_engine_args(
|
async with build_async_engine_client_from_engine_args(
|
||||||
engine_args, disable_frontend_multiprocessing
|
engine_args, disable_frontend_multiprocessing) as llm:
|
||||||
) as llm:
|
|
||||||
model_config = await llm.get_model_config()
|
|
||||||
assert all(
|
assert all(
|
||||||
model_config.max_model_len
|
llm.model_config.max_model_len >= (request.prompt_len +
|
||||||
>= (request.prompt_len + request.expected_output_len)
|
request.expected_output_len)
|
||||||
for request in requests
|
for request in requests), (
|
||||||
), (
|
"Please ensure that max_model_len is greater than the sum of"
|
||||||
"Please ensure that max_model_len is greater than the sum of"
|
" prompt_len and expected_output_len for all requests.")
|
||||||
" prompt_len and expected_output_len for all requests."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add the requests to the engine.
|
# Add the requests to the engine.
|
||||||
prompts: list[Union[TextPrompt, TokensPrompt]] = []
|
prompts: list[Union[TextPrompt, TokensPrompt]] = []
|
||||||
@ -184,15 +159,11 @@ async def run_vllm_async(
|
|||||||
lora_requests: list[Optional[LoRARequest]] = []
|
lora_requests: list[Optional[LoRARequest]] = []
|
||||||
for request in requests:
|
for request in requests:
|
||||||
prompts.append(
|
prompts.append(
|
||||||
TokensPrompt(
|
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"],
|
||||||
prompt_token_ids=request.prompt["prompt_token_ids"],
|
multi_modal_data=request.multi_modal_data)
|
||||||
multi_modal_data=request.multi_modal_data,
|
if "prompt_token_ids" in request.prompt else \
|
||||||
)
|
TextPrompt(prompt=request.prompt,
|
||||||
if "prompt_token_ids" in request.prompt
|
multi_modal_data=request.multi_modal_data))
|
||||||
else TextPrompt(
|
|
||||||
prompt=request.prompt, multi_modal_data=request.multi_modal_data
|
|
||||||
)
|
|
||||||
)
|
|
||||||
sampling_params.append(
|
sampling_params.append(
|
||||||
SamplingParams(
|
SamplingParams(
|
||||||
n=n,
|
n=n,
|
||||||
@ -201,16 +172,17 @@ async def run_vllm_async(
|
|||||||
ignore_eos=True,
|
ignore_eos=True,
|
||||||
max_tokens=request.expected_output_len,
|
max_tokens=request.expected_output_len,
|
||||||
detokenize=not disable_detokenize,
|
detokenize=not disable_detokenize,
|
||||||
)
|
))
|
||||||
)
|
|
||||||
lora_requests.append(request.lora_request)
|
lora_requests.append(request.lora_request)
|
||||||
|
|
||||||
generators = []
|
generators = []
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
for i, (prompt, sp, lr) in enumerate(
|
for i, (prompt, sp,
|
||||||
zip(prompts, sampling_params, lora_requests)
|
lr) in enumerate(zip(prompts, sampling_params, lora_requests)):
|
||||||
):
|
generator = llm.generate(prompt,
|
||||||
generator = llm.generate(prompt, sp, lora_request=lr, request_id=f"test{i}")
|
sp,
|
||||||
|
lora_request=lr,
|
||||||
|
request_id=f"test{i}")
|
||||||
generators.append(generator)
|
generators.append(generator)
|
||||||
all_gens = merge_async_iterators(*generators)
|
all_gens = merge_async_iterators(*generators)
|
||||||
async for i, res in all_gens:
|
async for i, res in all_gens:
|
||||||
@ -229,8 +201,7 @@ def run_hf(
|
|||||||
disable_detokenize: bool = False,
|
disable_detokenize: bool = False,
|
||||||
) -> float:
|
) -> float:
|
||||||
llm = AutoModelForCausalLM.from_pretrained(
|
llm = AutoModelForCausalLM.from_pretrained(
|
||||||
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code
|
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
|
||||||
)
|
|
||||||
if llm.config.model_type == "llama":
|
if llm.config.model_type == "llama":
|
||||||
# To enable padding in the HF backend.
|
# To enable padding in the HF backend.
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
@ -253,15 +224,14 @@ def run_hf(
|
|||||||
# Check if we can add more requests to the batch.
|
# Check if we can add more requests to the batch.
|
||||||
next_prompt_len = requests[i + 1].prompt_len
|
next_prompt_len = requests[i + 1].prompt_len
|
||||||
next_output_len = requests[i + 1].expected_output_len
|
next_output_len = requests[i + 1].expected_output_len
|
||||||
if (
|
if (max(max_prompt_len, next_prompt_len) +
|
||||||
max(max_prompt_len, next_prompt_len)
|
max(max_output_len, next_output_len)) <= 2048:
|
||||||
+ max(max_output_len, next_output_len)
|
|
||||||
) <= 2048:
|
|
||||||
# We can add more requests to the batch.
|
# We can add more requests to the batch.
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Generate the sequences.
|
# Generate the sequences.
|
||||||
input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids
|
input_ids = tokenizer(batch, return_tensors="pt",
|
||||||
|
padding=True).input_ids
|
||||||
llm_outputs = llm.generate(
|
llm_outputs = llm.generate(
|
||||||
input_ids=input_ids.cuda(),
|
input_ids=input_ids.cuda(),
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
@ -291,7 +261,6 @@ def run_mii(
|
|||||||
output_len: int,
|
output_len: int,
|
||||||
) -> float:
|
) -> float:
|
||||||
from mii import client, serve
|
from mii import client, serve
|
||||||
|
|
||||||
llm = serve(model, tensor_parallel=tensor_parallel_size)
|
llm = serve(model, tensor_parallel=tensor_parallel_size)
|
||||||
prompts = [request.prompt for request in requests]
|
prompts = [request.prompt for request in requests]
|
||||||
|
|
||||||
@ -303,9 +272,8 @@ def run_mii(
|
|||||||
return end - start
|
return end - start
|
||||||
|
|
||||||
|
|
||||||
def save_to_pytorch_benchmark_format(
|
def save_to_pytorch_benchmark_format(args: argparse.Namespace,
|
||||||
args: argparse.Namespace, results: dict[str, Any]
|
results: dict[str, Any]) -> None:
|
||||||
) -> None:
|
|
||||||
pt_records = convert_to_pytorch_benchmark_format(
|
pt_records = convert_to_pytorch_benchmark_format(
|
||||||
args=args,
|
args=args,
|
||||||
metrics={
|
metrics={
|
||||||
@ -313,9 +281,9 @@ def save_to_pytorch_benchmark_format(
|
|||||||
"tokens_per_second": [results["tokens_per_second"]],
|
"tokens_per_second": [results["tokens_per_second"]],
|
||||||
},
|
},
|
||||||
extra_info={
|
extra_info={
|
||||||
k: results[k] for k in ["elapsed_time", "num_requests", "total_num_tokens"]
|
k: results[k]
|
||||||
},
|
for k in ["elapsed_time", "num_requests", "total_num_tokens"]
|
||||||
)
|
})
|
||||||
if pt_records:
|
if pt_records:
|
||||||
# Don't use json suffix here as we don't want CI to pick it up
|
# Don't use json suffix here as we don't want CI to pick it up
|
||||||
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
|
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
|
||||||
@ -347,8 +315,7 @@ def get_requests(args, tokenizer):
|
|||||||
sample_kwargs["enable_multimodal_chat"] = True
|
sample_kwargs["enable_multimodal_chat"] = True
|
||||||
elif args.dataset_name == "sonnet":
|
elif args.dataset_name == "sonnet":
|
||||||
assert tokenizer.chat_template or tokenizer.default_chat_template, (
|
assert tokenizer.chat_template or tokenizer.default_chat_template, (
|
||||||
"Tokenizer/model must have chat template for sonnet dataset."
|
"Tokenizer/model must have chat template for sonnet dataset.")
|
||||||
)
|
|
||||||
dataset_cls = SonnetDataset
|
dataset_cls = SonnetDataset
|
||||||
sample_kwargs["prefix_len"] = args.prefix_len
|
sample_kwargs["prefix_len"] = args.prefix_len
|
||||||
sample_kwargs["return_prompt_formatted"] = True
|
sample_kwargs["return_prompt_formatted"] = True
|
||||||
@ -357,21 +324,21 @@ def get_requests(args, tokenizer):
|
|||||||
elif args.dataset_name == "hf":
|
elif args.dataset_name == "hf":
|
||||||
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
|
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
|
||||||
dataset_cls = VisionArenaDataset
|
dataset_cls = VisionArenaDataset
|
||||||
common_kwargs["dataset_subset"] = None
|
common_kwargs['dataset_subset'] = None
|
||||||
common_kwargs["dataset_split"] = "train"
|
common_kwargs['dataset_split'] = "train"
|
||||||
sample_kwargs["enable_multimodal_chat"] = True
|
sample_kwargs["enable_multimodal_chat"] = True
|
||||||
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
|
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
|
||||||
dataset_cls = InstructCoderDataset
|
dataset_cls = InstructCoderDataset
|
||||||
common_kwargs["dataset_split"] = "train"
|
common_kwargs['dataset_split'] = "train"
|
||||||
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
|
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
|
||||||
dataset_cls = ConversationDataset
|
dataset_cls = ConversationDataset
|
||||||
common_kwargs["dataset_subset"] = args.hf_subset
|
common_kwargs['dataset_subset'] = args.hf_subset
|
||||||
common_kwargs["dataset_split"] = args.hf_split
|
common_kwargs['dataset_split'] = args.hf_split
|
||||||
sample_kwargs["enable_multimodal_chat"] = True
|
sample_kwargs["enable_multimodal_chat"] = True
|
||||||
elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
|
elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
|
||||||
dataset_cls = AIMODataset
|
dataset_cls = AIMODataset
|
||||||
common_kwargs["dataset_subset"] = None
|
common_kwargs['dataset_subset'] = None
|
||||||
common_kwargs["dataset_split"] = "train"
|
common_kwargs['dataset_split'] = "train"
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown dataset name: {args.dataset_name}")
|
raise ValueError(f"Unknown dataset name: {args.dataset_name}")
|
||||||
# Remove None values
|
# Remove None values
|
||||||
@ -386,10 +353,10 @@ def main(args: argparse.Namespace):
|
|||||||
random.seed(args.seed)
|
random.seed(args.seed)
|
||||||
# Sample the requests.
|
# Sample the requests.
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
args.tokenizer, trust_remote_code=args.trust_remote_code
|
args.tokenizer, trust_remote_code=args.trust_remote_code)
|
||||||
)
|
|
||||||
requests = get_requests(args, tokenizer)
|
requests = get_requests(args, tokenizer)
|
||||||
is_multi_modal = any(request.multi_modal_data is not None for request in requests)
|
is_multi_modal = any(request.multi_modal_data is not None
|
||||||
|
for request in requests)
|
||||||
request_outputs: Optional[list[RequestOutput]] = None
|
request_outputs: Optional[list[RequestOutput]] = None
|
||||||
if args.backend == "vllm":
|
if args.backend == "vllm":
|
||||||
if args.async_engine:
|
if args.async_engine:
|
||||||
@ -400,34 +367,23 @@ def main(args: argparse.Namespace):
|
|||||||
AsyncEngineArgs.from_cli_args(args),
|
AsyncEngineArgs.from_cli_args(args),
|
||||||
args.disable_frontend_multiprocessing,
|
args.disable_frontend_multiprocessing,
|
||||||
args.disable_detokenize,
|
args.disable_detokenize,
|
||||||
)
|
))
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
elapsed_time, request_outputs = run_vllm(
|
elapsed_time, request_outputs = run_vllm(
|
||||||
requests,
|
requests, args.n, EngineArgs.from_cli_args(args),
|
||||||
args.n,
|
args.disable_detokenize)
|
||||||
EngineArgs.from_cli_args(args),
|
|
||||||
args.disable_detokenize,
|
|
||||||
)
|
|
||||||
elif args.backend == "hf":
|
elif args.backend == "hf":
|
||||||
assert args.tensor_parallel_size == 1
|
assert args.tensor_parallel_size == 1
|
||||||
elapsed_time = run_hf(
|
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||||
requests,
|
args.hf_max_batch_size, args.trust_remote_code,
|
||||||
args.model,
|
args.disable_detokenize)
|
||||||
tokenizer,
|
|
||||||
args.n,
|
|
||||||
args.hf_max_batch_size,
|
|
||||||
args.trust_remote_code,
|
|
||||||
args.disable_detokenize,
|
|
||||||
)
|
|
||||||
elif args.backend == "mii":
|
elif args.backend == "mii":
|
||||||
elapsed_time = run_mii(
|
elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
|
||||||
requests, args.model, args.tensor_parallel_size, args.output_len
|
args.output_len)
|
||||||
)
|
|
||||||
elif args.backend == "vllm-chat":
|
elif args.backend == "vllm-chat":
|
||||||
elapsed_time, request_outputs = run_vllm_chat(
|
elapsed_time, request_outputs = run_vllm_chat(
|
||||||
requests, args.n, EngineArgs.from_cli_args(args), args.disable_detokenize
|
requests, args.n, EngineArgs.from_cli_args(args),
|
||||||
)
|
args.disable_detokenize)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown backend: {args.backend}")
|
raise ValueError(f"Unknown backend: {args.backend}")
|
||||||
|
|
||||||
@ -439,31 +395,28 @@ def main(args: argparse.Namespace):
|
|||||||
for ro in request_outputs:
|
for ro in request_outputs:
|
||||||
if not isinstance(ro, RequestOutput):
|
if not isinstance(ro, RequestOutput):
|
||||||
continue
|
continue
|
||||||
total_prompt_tokens += (
|
total_prompt_tokens += len(
|
||||||
len(ro.prompt_token_ids) if ro.prompt_token_ids else 0
|
ro.prompt_token_ids) if ro.prompt_token_ids else 0
|
||||||
)
|
total_output_tokens += sum(
|
||||||
total_output_tokens += sum(len(o.token_ids) for o in ro.outputs if o)
|
len(o.token_ids) for o in ro.outputs if o)
|
||||||
total_num_tokens = total_prompt_tokens + total_output_tokens
|
total_num_tokens = total_prompt_tokens + total_output_tokens
|
||||||
else:
|
else:
|
||||||
total_num_tokens = sum(r.prompt_len + r.expected_output_len for r in requests)
|
total_num_tokens = sum(r.prompt_len + r.expected_output_len
|
||||||
|
for r in requests)
|
||||||
total_output_tokens = sum(r.expected_output_len for r in requests)
|
total_output_tokens = sum(r.expected_output_len for r in requests)
|
||||||
total_prompt_tokens = total_num_tokens - total_output_tokens
|
total_prompt_tokens = total_num_tokens - total_output_tokens
|
||||||
|
|
||||||
if is_multi_modal and args.backend != "vllm-chat":
|
if is_multi_modal and args.backend != "vllm-chat":
|
||||||
print(
|
print("\033[91mWARNING\033[0m: Multi-modal request with "
|
||||||
"\033[91mWARNING\033[0m: Multi-modal request with "
|
f"{args.backend} backend detected. The "
|
||||||
f"{args.backend} backend detected. The "
|
"following metrics are not accurate because image tokens are not"
|
||||||
"following metrics are not accurate because image tokens are not"
|
" counted. See vllm-project/vllm/issues/9778 for details.")
|
||||||
" counted. See vllm-project/vllm/issues/9778 for details."
|
|
||||||
)
|
|
||||||
# TODO(vllm-project/vllm/issues/9778): Count multi-modal token length.
|
# TODO(vllm-project/vllm/issues/9778): Count multi-modal token length.
|
||||||
# vllm-chat backend counts the image tokens now
|
# vllm-chat backend counts the image tokens now
|
||||||
|
|
||||||
print(
|
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
||||||
f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
|
||||||
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
|
f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
|
||||||
f"{total_output_tokens / elapsed_time:.2f} output tokens/s"
|
|
||||||
)
|
|
||||||
print(f"Total num prompt tokens: {total_prompt_tokens}")
|
print(f"Total num prompt tokens: {total_prompt_tokens}")
|
||||||
print(f"Total num output tokens: {total_output_tokens}")
|
print(f"Total num output tokens: {total_output_tokens}")
|
||||||
|
|
||||||
@ -491,8 +444,7 @@ def validate_args(args):
|
|||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The '--dataset' argument will be deprecated in the next release. "
|
"The '--dataset' argument will be deprecated in the next release. "
|
||||||
"Please use '--dataset-name' and '--dataset-path' instead.",
|
"Please use '--dataset-name' and '--dataset-path' instead.",
|
||||||
stacklevel=2,
|
stacklevel=2)
|
||||||
)
|
|
||||||
args.dataset_path = args.dataset
|
args.dataset_path = args.dataset
|
||||||
|
|
||||||
if not getattr(args, "tokenizer", None):
|
if not getattr(args, "tokenizer", None):
|
||||||
@ -505,8 +457,9 @@ def validate_args(args):
|
|||||||
|
|
||||||
# === Dataset Configuration ===
|
# === Dataset Configuration ===
|
||||||
if not args.dataset and not args.dataset_path:
|
if not args.dataset and not args.dataset_path:
|
||||||
print("When dataset path is not set, it will default to random dataset")
|
print(
|
||||||
args.dataset_name = "random"
|
"When dataset path is not set, it will default to random dataset")
|
||||||
|
args.dataset_name = 'random'
|
||||||
if args.input_len is None:
|
if args.input_len is None:
|
||||||
raise ValueError("input_len must be provided for a random dataset")
|
raise ValueError("input_len must be provided for a random dataset")
|
||||||
|
|
||||||
@ -514,55 +467,41 @@ def validate_args(args):
|
|||||||
# --hf-subset and --hf-split: only used
|
# --hf-subset and --hf-split: only used
|
||||||
# when dataset_name is 'hf'
|
# when dataset_name is 'hf'
|
||||||
if args.dataset_name != "hf" and (
|
if args.dataset_name != "hf" and (
|
||||||
getattr(args, "hf_subset", None) is not None
|
getattr(args, "hf_subset", None) is not None
|
||||||
or getattr(args, "hf_split", None) is not None
|
or getattr(args, "hf_split", None) is not None):
|
||||||
):
|
warnings.warn("--hf-subset and --hf-split will be ignored \
|
||||||
warnings.warn(
|
|
||||||
"--hf-subset and --hf-split will be ignored \
|
|
||||||
since --dataset-name is not 'hf'.",
|
since --dataset-name is not 'hf'.",
|
||||||
stacklevel=2,
|
stacklevel=2)
|
||||||
)
|
|
||||||
elif args.dataset_name == "hf":
|
elif args.dataset_name == "hf":
|
||||||
if args.dataset_path in (
|
if args.dataset_path in (
|
||||||
VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys()
|
VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys()
|
||||||
| ConversationDataset.SUPPORTED_DATASET_PATHS
|
| ConversationDataset.SUPPORTED_DATASET_PATHS):
|
||||||
):
|
assert args.backend == "vllm-chat", f"{args.dataset_path} needs to use vllm-chat as the backend." #noqa: E501
|
||||||
assert args.backend == "vllm-chat", (
|
elif args.dataset_path in (InstructCoderDataset.SUPPORTED_DATASET_PATHS
|
||||||
f"{args.dataset_path} needs to use vllm-chat as the backend."
|
| AIMODataset.SUPPORTED_DATASET_PATHS):
|
||||||
) # noqa: E501
|
assert args.backend == "vllm", f"{args.dataset_path} needs to use vllm as the backend." #noqa: E501
|
||||||
elif args.dataset_path in (
|
|
||||||
InstructCoderDataset.SUPPORTED_DATASET_PATHS
|
|
||||||
| AIMODataset.SUPPORTED_DATASET_PATHS
|
|
||||||
):
|
|
||||||
assert args.backend == "vllm", (
|
|
||||||
f"{args.dataset_path} needs to use vllm as the backend."
|
|
||||||
) # noqa: E501
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"{args.dataset_path} is not supported by hf dataset.")
|
raise ValueError(
|
||||||
|
f"{args.dataset_path} is not supported by hf dataset.")
|
||||||
|
|
||||||
# --random-range-ratio: only used when dataset_name is 'random'
|
# --random-range-ratio: only used when dataset_name is 'random'
|
||||||
if args.dataset_name != "random" and args.random_range_ratio is not None:
|
if args.dataset_name != 'random' and args.random_range_ratio is not None:
|
||||||
warnings.warn(
|
warnings.warn("--random-range-ratio will be ignored since \
|
||||||
"--random-range-ratio will be ignored since \
|
|
||||||
--dataset-name is not 'random'.",
|
--dataset-name is not 'random'.",
|
||||||
stacklevel=2,
|
stacklevel=2)
|
||||||
)
|
|
||||||
|
|
||||||
# --prefix-len: only used when dataset_name is 'random', 'sonnet', or not
|
# --prefix-len: only used when dataset_name is 'random', 'sonnet', or not
|
||||||
# set.
|
# set.
|
||||||
if (
|
if args.dataset_name not in {"random", "sonnet", None
|
||||||
args.dataset_name not in {"random", "sonnet", None}
|
} and args.prefix_len is not None:
|
||||||
and args.prefix_len is not None
|
warnings.warn("--prefix-len will be ignored since --dataset-name\
|
||||||
):
|
|
||||||
warnings.warn(
|
|
||||||
"--prefix-len will be ignored since --dataset-name\
|
|
||||||
is not 'random', 'sonnet', or not set.",
|
is not 'random', 'sonnet', or not set.",
|
||||||
stacklevel=2,
|
stacklevel=2)
|
||||||
)
|
|
||||||
|
|
||||||
# === LoRA Settings ===
|
# === LoRA Settings ===
|
||||||
if getattr(args, "enable_lora", False) and args.backend != "vllm":
|
if getattr(args, "enable_lora", False) and args.backend != "vllm":
|
||||||
raise ValueError("LoRA benchmarking is only supported for vLLM backend")
|
raise ValueError(
|
||||||
|
"LoRA benchmarking is only supported for vLLM backend")
|
||||||
if getattr(args, "enable_lora", False) and args.lora_path is None:
|
if getattr(args, "enable_lora", False) and args.lora_path is None:
|
||||||
raise ValueError("LoRA path must be provided when enable_lora is True")
|
raise ValueError("LoRA path must be provided when enable_lora is True")
|
||||||
|
|
||||||
@ -572,10 +511,8 @@ def validate_args(args):
|
|||||||
if args.backend != "hf" and args.hf_max_batch_size is not None:
|
if args.backend != "hf" and args.hf_max_batch_size is not None:
|
||||||
raise ValueError("HF max batch size is only for HF backend.")
|
raise ValueError("HF max batch size is only for HF backend.")
|
||||||
|
|
||||||
if (
|
if args.backend in {"hf", "mii"} and getattr(args, "quantization",
|
||||||
args.backend in {"hf", "mii"}
|
None) is not None:
|
||||||
and getattr(args, "quantization", None) is not None
|
|
||||||
):
|
|
||||||
raise ValueError("Quantization is only for vLLM backend.")
|
raise ValueError("Quantization is only for vLLM backend.")
|
||||||
|
|
||||||
if args.backend == "mii" and args.dtype != "auto":
|
if args.backend == "mii" and args.dtype != "auto":
|
||||||
@ -583,32 +520,29 @@ def validate_args(args):
|
|||||||
if args.backend == "mii" and args.n != 1:
|
if args.backend == "mii" and args.n != 1:
|
||||||
raise ValueError("n must be 1 for MII backend.")
|
raise ValueError("n must be 1 for MII backend.")
|
||||||
if args.backend == "mii" and args.tokenizer != args.model:
|
if args.backend == "mii" and args.tokenizer != args.model:
|
||||||
raise ValueError("Tokenizer must be the same as the model for MII backend.")
|
raise ValueError(
|
||||||
|
"Tokenizer must be the same as the model for MII backend.")
|
||||||
|
|
||||||
# --data-parallel is not supported currently.
|
# --data-parallel is not supported currently.
|
||||||
# https://github.com/vllm-project/vllm/issues/16222
|
# https://github.com/vllm-project/vllm/issues/16222
|
||||||
if args.data_parallel_size > 1:
|
if args.data_parallel_size > 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Data parallel is not supported in offline benchmark, \
|
"Data parallel is not supported in offline benchmark, \
|
||||||
please use benchmark serving instead"
|
please use benchmark serving instead")
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = FlexibleArgumentParser(description="Benchmark the throughput.")
|
parser = FlexibleArgumentParser(description="Benchmark the throughput.")
|
||||||
parser.add_argument(
|
parser.add_argument("--backend",
|
||||||
"--backend",
|
type=str,
|
||||||
type=str,
|
choices=["vllm", "hf", "mii", "vllm-chat"],
|
||||||
choices=["vllm", "hf", "mii", "vllm-chat"],
|
default="vllm")
|
||||||
default="vllm",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dataset-name",
|
"--dataset-name",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"],
|
choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"],
|
||||||
help="Name of the dataset to benchmark on.",
|
help="Name of the dataset to benchmark on.",
|
||||||
default="sharegpt",
|
default="sharegpt")
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dataset",
|
"--dataset",
|
||||||
type=str,
|
type=str,
|
||||||
@ -616,70 +550,57 @@ if __name__ == "__main__":
|
|||||||
help="Path to the ShareGPT dataset, will be deprecated in\
|
help="Path to the ShareGPT dataset, will be deprecated in\
|
||||||
the next release. The dataset is expected to "
|
the next release. The dataset is expected to "
|
||||||
"be a json in form of list[dict[..., conversations: "
|
"be a json in form of list[dict[..., conversations: "
|
||||||
"list[dict[..., value: <prompt_or_response>]]]]",
|
"list[dict[..., value: <prompt_or_response>]]]]")
|
||||||
)
|
parser.add_argument("--dataset-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to the dataset")
|
||||||
|
parser.add_argument("--input-len",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Input prompt length for each request")
|
||||||
|
parser.add_argument("--output-len",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Output length for each request. Overrides the "
|
||||||
|
"output length from the dataset.")
|
||||||
|
parser.add_argument("--n",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of generated sequences per prompt.")
|
||||||
|
parser.add_argument("--num-prompts",
|
||||||
|
type=int,
|
||||||
|
default=1000,
|
||||||
|
help="Number of prompts to process.")
|
||||||
|
parser.add_argument("--hf-max-batch-size",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Maximum batch size for HF backend.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dataset-path", type=str, default=None, help="Path to the dataset"
|
'--output-json',
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--input-len",
|
|
||||||
type=int,
|
|
||||||
default=None,
|
|
||||||
help="Input prompt length for each request",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--output-len",
|
|
||||||
type=int,
|
|
||||||
default=None,
|
|
||||||
help="Output length for each request. Overrides the "
|
|
||||||
"output length from the dataset.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--n", type=int, default=1, help="Number of generated sequences per prompt."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--num-prompts", type=int, default=1000, help="Number of prompts to process."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--hf-max-batch-size",
|
|
||||||
type=int,
|
|
||||||
default=None,
|
|
||||||
help="Maximum batch size for HF backend.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--output-json",
|
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="Path to save the throughput results in JSON format.",
|
help='Path to save the throughput results in JSON format.')
|
||||||
)
|
parser.add_argument("--async-engine",
|
||||||
parser.add_argument(
|
action='store_true',
|
||||||
"--async-engine",
|
default=False,
|
||||||
action="store_true",
|
help="Use vLLM async engine rather than LLM class.")
|
||||||
default=False,
|
parser.add_argument("--disable-frontend-multiprocessing",
|
||||||
help="Use vLLM async engine rather than LLM class.",
|
action='store_true',
|
||||||
)
|
default=False,
|
||||||
parser.add_argument(
|
help="Disable decoupled async engine frontend.")
|
||||||
"--disable-frontend-multiprocessing",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
help="Disable decoupled async engine frontend.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--disable-detokenize",
|
"--disable-detokenize",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help=(
|
help=("Do not detokenize the response (i.e. do not include "
|
||||||
"Do not detokenize the response (i.e. do not include "
|
"detokenization time in the measurement)"))
|
||||||
"detokenization time in the measurement)"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
# LoRA
|
# LoRA
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lora-path",
|
"--lora-path",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="Path to the LoRA adapters to use. This can be an absolute path, "
|
help="Path to the lora adapters to use. This can be an absolute path, "
|
||||||
"a relative path, or a Hugging Face model identifier.",
|
"a relative path, or a Hugging Face model identifier.")
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prefix-len",
|
"--prefix-len",
|
||||||
type=int,
|
type=int,
|
||||||
@ -693,8 +614,7 @@ if __name__ == "__main__":
|
|||||||
f"prefix_len (default: {SonnetDataset.DEFAULT_PREFIX_LEN}) "
|
f"prefix_len (default: {SonnetDataset.DEFAULT_PREFIX_LEN}) "
|
||||||
"controls how much of the input is fixed lines versus "
|
"controls how much of the input is fixed lines versus "
|
||||||
"random lines, but the total input length remains approximately "
|
"random lines, but the total input length remains approximately "
|
||||||
"input_len tokens.",
|
"input_len tokens.")
|
||||||
)
|
|
||||||
# random dataset
|
# random dataset
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--random-range-ratio",
|
"--random-range-ratio",
|
||||||
@ -708,12 +628,14 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# hf dtaset
|
# hf dtaset
|
||||||
parser.add_argument(
|
parser.add_argument("--hf-subset",
|
||||||
"--hf-subset", type=str, default=None, help="Subset of the HF dataset."
|
type=str,
|
||||||
)
|
default=None,
|
||||||
parser.add_argument(
|
help="Subset of the HF dataset.")
|
||||||
"--hf-split", type=str, default=None, help="Split of the HF dataset."
|
parser.add_argument("--hf-split",
|
||||||
)
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Split of the HF dataset.")
|
||||||
|
|
||||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|||||||
@ -7,9 +7,9 @@ import os
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
def convert_to_pytorch_benchmark_format(
|
def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
|
||||||
args: argparse.Namespace, metrics: dict[str, list], extra_info: dict[str, Any]
|
metrics: dict[str, list],
|
||||||
) -> list:
|
extra_info: dict[str, Any]) -> list:
|
||||||
"""
|
"""
|
||||||
Save the benchmark results in the format used by PyTorch OSS benchmark with
|
Save the benchmark results in the format used by PyTorch OSS benchmark with
|
||||||
on metric per record
|
on metric per record
|
||||||
@ -37,12 +37,12 @@ def convert_to_pytorch_benchmark_format(
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
tp = record["benchmark"]["extra_info"]["args"].get("tensor_parallel_size")
|
tp = record["benchmark"]["extra_info"]["args"].get(
|
||||||
|
"tensor_parallel_size")
|
||||||
# Save tensor_parallel_size parameter if it's part of the metadata
|
# Save tensor_parallel_size parameter if it's part of the metadata
|
||||||
if not tp and "tensor_parallel_size" in extra_info:
|
if not tp and "tensor_parallel_size" in extra_info:
|
||||||
record["benchmark"]["extra_info"]["args"]["tensor_parallel_size"] = (
|
record["benchmark"]["extra_info"]["args"][
|
||||||
extra_info["tensor_parallel_size"]
|
"tensor_parallel_size"] = extra_info["tensor_parallel_size"]
|
||||||
)
|
|
||||||
|
|
||||||
records.append(record)
|
records.append(record)
|
||||||
|
|
||||||
@ -50,6 +50,7 @@ def convert_to_pytorch_benchmark_format(
|
|||||||
|
|
||||||
|
|
||||||
class InfEncoder(json.JSONEncoder):
|
class InfEncoder(json.JSONEncoder):
|
||||||
|
|
||||||
def clear_inf(self, o: Any):
|
def clear_inf(self, o: Any):
|
||||||
if isinstance(o, dict):
|
if isinstance(o, dict):
|
||||||
return {k: self.clear_inf(v) for k, v in o.items()}
|
return {k: self.clear_inf(v) for k, v in o.items()}
|
||||||
|
|||||||
@ -23,9 +23,8 @@ DEFAULT_TP_SIZES = [1]
|
|||||||
|
|
||||||
|
|
||||||
# bench
|
# bench
|
||||||
def bench_fn(
|
def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
|
||||||
label: str, sub_label: str, description: str, fn: Callable, *args, **kwargs
|
**kwargs) -> TMeasurement:
|
||||||
) -> TMeasurement:
|
|
||||||
min_run_time = 1
|
min_run_time = 1
|
||||||
|
|
||||||
globals = {
|
globals = {
|
||||||
@ -42,18 +41,16 @@ def bench_fn(
|
|||||||
).blocked_autorange(min_run_time=min_run_time)
|
).blocked_autorange(min_run_time=min_run_time)
|
||||||
|
|
||||||
|
|
||||||
def bench_int8(
|
def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
||||||
dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str
|
sub_label: str) -> Iterable[TMeasurement]:
|
||||||
) -> Iterable[TMeasurement]:
|
|
||||||
assert dtype == torch.int8
|
assert dtype == torch.int8
|
||||||
b_compressed, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k)
|
b_compressed, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k)
|
||||||
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||||
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||||
bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
|
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
out = ops.cutlass_scaled_sparse_mm(
|
out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b,
|
||||||
a, b_compressed, e, scale_a, scale_b, torch.bfloat16
|
torch.bfloat16)
|
||||||
)
|
|
||||||
out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)
|
out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)
|
||||||
|
|
||||||
if not torch.allclose(out, out_ref):
|
if not torch.allclose(out, out_ref):
|
||||||
@ -66,107 +63,54 @@ def bench_int8(
|
|||||||
timers = []
|
timers = []
|
||||||
# pytorch impl - bfloat16
|
# pytorch impl - bfloat16
|
||||||
timers.append(
|
timers.append(
|
||||||
bench_fn(
|
bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
|
||||||
label,
|
torch.mm, a.to(dtype=torch.bfloat16),
|
||||||
sub_label,
|
b.to(dtype=torch.bfloat16)))
|
||||||
"pytorch_bf16_bf16_bf16_matmul-no-scales",
|
|
||||||
torch.mm,
|
|
||||||
a.to(dtype=torch.bfloat16),
|
|
||||||
b.to(dtype=torch.bfloat16),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# pytorch impl - float16
|
# pytorch impl - float16
|
||||||
timers.append(
|
timers.append(
|
||||||
bench_fn(
|
bench_fn(label, sub_label,
|
||||||
label,
|
"pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm,
|
||||||
sub_label,
|
a.to(dtype=torch.float16), b.to(dtype=torch.float16)))
|
||||||
"pytorch_fp16_fp16_fp16_matmul-no-scales",
|
|
||||||
torch.mm,
|
|
||||||
a.to(dtype=torch.float16),
|
|
||||||
b.to(dtype=torch.float16),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# cutlass impl
|
# cutlass impl
|
||||||
timers.append(
|
timers.append(
|
||||||
bench_fn(
|
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm",
|
||||||
label,
|
ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
|
||||||
sub_label,
|
torch.bfloat16))
|
||||||
"cutlass_i8_i8_bf16_scaled_mm",
|
|
||||||
ops.cutlass_scaled_mm,
|
|
||||||
a,
|
|
||||||
b,
|
|
||||||
scale_a,
|
|
||||||
scale_b,
|
|
||||||
torch.bfloat16,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# cutlass with bias
|
# cutlass with bias
|
||||||
timers.append(
|
timers.append(
|
||||||
bench_fn(
|
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias",
|
||||||
label,
|
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16,
|
||||||
sub_label,
|
bias))
|
||||||
"cutlass_i8_i8_bf16_scaled_mm_bias",
|
|
||||||
ops.cutlass_scaled_mm,
|
|
||||||
a,
|
|
||||||
b,
|
|
||||||
scale_a,
|
|
||||||
scale_b,
|
|
||||||
torch.bfloat16,
|
|
||||||
bias,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# cutlass sparse impl
|
# cutlass sparse impl
|
||||||
timers.append(
|
timers.append(
|
||||||
bench_fn(
|
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm",
|
||||||
label,
|
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
|
||||||
sub_label,
|
scale_b, torch.bfloat16))
|
||||||
"cutlass_i8_i8_bf16_scaled_sparse_mm",
|
|
||||||
ops.cutlass_scaled_sparse_mm,
|
|
||||||
a,
|
|
||||||
b_compressed,
|
|
||||||
e,
|
|
||||||
scale_a,
|
|
||||||
scale_b,
|
|
||||||
torch.bfloat16,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# cutlass sparse with bias
|
# cutlass sparse with bias
|
||||||
timers.append(
|
timers.append(
|
||||||
bench_fn(
|
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm_bias",
|
||||||
label,
|
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
|
||||||
sub_label,
|
scale_b, torch.bfloat16, bias))
|
||||||
"cutlass_i8_i8_bf16_scaled_sparse_mm_bias",
|
|
||||||
ops.cutlass_scaled_sparse_mm,
|
|
||||||
a,
|
|
||||||
b_compressed,
|
|
||||||
e,
|
|
||||||
scale_a,
|
|
||||||
scale_b,
|
|
||||||
torch.bfloat16,
|
|
||||||
bias,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return timers
|
return timers
|
||||||
|
|
||||||
|
|
||||||
def bench_fp8(
|
def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
||||||
dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str
|
sub_label: str) -> Iterable[TMeasurement]:
|
||||||
) -> Iterable[TMeasurement]:
|
|
||||||
assert dtype == torch.float8_e4m3fn
|
assert dtype == torch.float8_e4m3fn
|
||||||
b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k)
|
b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n,
|
||||||
|
k)
|
||||||
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||||
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||||
bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
|
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
out = ops.cutlass_scaled_sparse_mm(
|
out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b,
|
||||||
a, b_compressed, e, scale_a, scale_b, torch.bfloat16
|
torch.bfloat16)
|
||||||
)
|
|
||||||
out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)
|
out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)
|
||||||
|
|
||||||
if not torch.allclose(out, out_ref):
|
if not torch.allclose(out, out_ref):
|
||||||
@ -180,165 +124,97 @@ def bench_fp8(
|
|||||||
|
|
||||||
# pytorch impl w. bf16
|
# pytorch impl w. bf16
|
||||||
timers.append(
|
timers.append(
|
||||||
bench_fn(
|
bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
|
||||||
label,
|
torch.mm, a.to(dtype=torch.bfloat16, device="cuda"),
|
||||||
sub_label,
|
b.to(dtype=torch.bfloat16, device="cuda")))
|
||||||
"pytorch_bf16_bf16_bf16_matmul-no-scales",
|
|
||||||
torch.mm,
|
|
||||||
a.to(dtype=torch.bfloat16, device="cuda"),
|
|
||||||
b.to(dtype=torch.bfloat16, device="cuda"),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# pytorch impl: bf16 output, without fp8 fast accum
|
# pytorch impl: bf16 output, without fp8 fast accum
|
||||||
timers.append(
|
timers.append(
|
||||||
bench_fn(
|
bench_fn(label,
|
||||||
label,
|
sub_label,
|
||||||
sub_label,
|
"pytorch_fp8_fp8_bf16_scaled_mm",
|
||||||
"pytorch_fp8_fp8_bf16_scaled_mm",
|
torch._scaled_mm,
|
||||||
torch._scaled_mm,
|
a,
|
||||||
a,
|
b,
|
||||||
b,
|
scale_a=scale_a,
|
||||||
scale_a=scale_a,
|
scale_b=scale_b,
|
||||||
scale_b=scale_b,
|
out_dtype=torch.bfloat16))
|
||||||
out_dtype=torch.bfloat16,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# pytorch impl: bf16 output, with fp8 fast accum
|
# pytorch impl: bf16 output, with fp8 fast accum
|
||||||
timers.append(
|
timers.append(
|
||||||
bench_fn(
|
bench_fn(label,
|
||||||
label,
|
sub_label,
|
||||||
sub_label,
|
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum",
|
||||||
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum",
|
torch._scaled_mm,
|
||||||
torch._scaled_mm,
|
a,
|
||||||
a,
|
b,
|
||||||
b,
|
scale_a=scale_a,
|
||||||
scale_a=scale_a,
|
scale_b=scale_b,
|
||||||
scale_b=scale_b,
|
out_dtype=torch.bfloat16,
|
||||||
out_dtype=torch.bfloat16,
|
use_fast_accum=True))
|
||||||
use_fast_accum=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# pytorch impl: fp16 output, without fp8 fast accum
|
# pytorch impl: fp16 output, without fp8 fast accum
|
||||||
timers.append(
|
timers.append(
|
||||||
bench_fn(
|
bench_fn(label,
|
||||||
label,
|
sub_label,
|
||||||
sub_label,
|
"pytorch_fp8_fp8_fp16_scaled_mm",
|
||||||
"pytorch_fp8_fp8_fp16_scaled_mm",
|
torch._scaled_mm,
|
||||||
torch._scaled_mm,
|
a,
|
||||||
a,
|
b,
|
||||||
b,
|
scale_a=scale_a,
|
||||||
scale_a=scale_a,
|
scale_b=scale_b,
|
||||||
scale_b=scale_b,
|
out_dtype=torch.float16))
|
||||||
out_dtype=torch.float16,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# pytorch impl: fp16 output, with fp8 fast accum
|
# pytorch impl: fp16 output, with fp8 fast accum
|
||||||
timers.append(
|
timers.append(
|
||||||
bench_fn(
|
bench_fn(label,
|
||||||
label,
|
sub_label,
|
||||||
sub_label,
|
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum",
|
||||||
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum",
|
torch._scaled_mm,
|
||||||
torch._scaled_mm,
|
a,
|
||||||
a,
|
b,
|
||||||
b,
|
scale_a=scale_a,
|
||||||
scale_a=scale_a,
|
scale_b=scale_b,
|
||||||
scale_b=scale_b,
|
out_dtype=torch.float16,
|
||||||
out_dtype=torch.float16,
|
use_fast_accum=True))
|
||||||
use_fast_accum=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# cutlass impl: bf16 output
|
# cutlass impl: bf16 output
|
||||||
timers.append(
|
timers.append(
|
||||||
bench_fn(
|
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm",
|
||||||
label,
|
ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
|
||||||
sub_label,
|
torch.bfloat16))
|
||||||
"cutlass_fp8_fp8_bf16_scaled_mm",
|
|
||||||
ops.cutlass_scaled_mm,
|
|
||||||
a,
|
|
||||||
b,
|
|
||||||
scale_a,
|
|
||||||
scale_b,
|
|
||||||
torch.bfloat16,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# cutlass impl: bf16 output
|
# cutlass impl: bf16 output
|
||||||
timers.append(
|
timers.append(
|
||||||
bench_fn(
|
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_sparse_mm",
|
||||||
label,
|
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
|
||||||
sub_label,
|
scale_b, torch.bfloat16))
|
||||||
"cutlass_fp8_fp8_bf16_scaled_sparse_mm",
|
|
||||||
ops.cutlass_scaled_sparse_mm,
|
|
||||||
a,
|
|
||||||
b_compressed,
|
|
||||||
e,
|
|
||||||
scale_a,
|
|
||||||
scale_b,
|
|
||||||
torch.bfloat16,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# cutlass impl: fp16 output
|
# cutlass impl: fp16 output
|
||||||
timers.append(
|
timers.append(
|
||||||
bench_fn(
|
bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_sparse_mm",
|
||||||
label,
|
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
|
||||||
sub_label,
|
scale_b, torch.float16))
|
||||||
"cutlass_fp8_fp8_fp16_scaled_sparse_mm",
|
|
||||||
ops.cutlass_scaled_sparse_mm,
|
|
||||||
a,
|
|
||||||
b_compressed,
|
|
||||||
e,
|
|
||||||
scale_a,
|
|
||||||
scale_b,
|
|
||||||
torch.float16,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# cutlass impl: bf16 output, with bias
|
# cutlass impl: bf16 output, with bias
|
||||||
timers.append(
|
timers.append(
|
||||||
bench_fn(
|
bench_fn(label, sub_label,
|
||||||
label,
|
"cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias",
|
||||||
sub_label,
|
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
|
||||||
"cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias",
|
scale_b, torch.bfloat16, bias))
|
||||||
ops.cutlass_scaled_sparse_mm,
|
|
||||||
a,
|
|
||||||
b_compressed,
|
|
||||||
e,
|
|
||||||
scale_a,
|
|
||||||
scale_b,
|
|
||||||
torch.bfloat16,
|
|
||||||
bias,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# cutlass impl: fp16 output, with bias
|
# cutlass impl: fp16 output, with bias
|
||||||
timers.append(
|
timers.append(
|
||||||
bench_fn(
|
bench_fn(label, sub_label,
|
||||||
label,
|
"cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias",
|
||||||
sub_label,
|
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
|
||||||
"cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias",
|
scale_b, torch.float16, bias.to(dtype=torch.float16)))
|
||||||
ops.cutlass_scaled_sparse_mm,
|
|
||||||
a,
|
|
||||||
b_compressed,
|
|
||||||
e,
|
|
||||||
scale_a,
|
|
||||||
scale_b,
|
|
||||||
torch.float16,
|
|
||||||
bias.to(dtype=torch.float16),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return timers
|
return timers
|
||||||
|
|
||||||
|
|
||||||
def bench(
|
def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
||||||
dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str
|
sub_label: str) -> Iterable[TMeasurement]:
|
||||||
) -> Iterable[TMeasurement]:
|
|
||||||
if dtype == torch.int8:
|
if dtype == torch.int8:
|
||||||
return bench_int8(dtype, m, k, n, label, sub_label)
|
return bench_int8(dtype, m, k, n, label, sub_label)
|
||||||
if dtype == torch.float8_e4m3fn:
|
if dtype == torch.float8_e4m3fn:
|
||||||
@ -352,12 +228,12 @@ def print_timers(timers: Iterable[TMeasurement]):
|
|||||||
compare.print()
|
compare.print()
|
||||||
|
|
||||||
|
|
||||||
def run(
|
def run(dtype: torch.dtype,
|
||||||
dtype: torch.dtype, MKNs: Iterable[tuple[int, int, int]]
|
MKNs: Iterable[tuple[int, int, int]]) -> Iterable[TMeasurement]:
|
||||||
) -> Iterable[TMeasurement]:
|
|
||||||
results = []
|
results = []
|
||||||
for m, k, n in MKNs:
|
for m, k, n in MKNs:
|
||||||
timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", f"MKN=({m}x{k}x{n})")
|
timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm",
|
||||||
|
f"MKN=({m}x{k}x{n})")
|
||||||
print_timers(timers)
|
print_timers(timers)
|
||||||
results.extend(timers)
|
results.extend(timers)
|
||||||
|
|
||||||
@ -365,12 +241,10 @@ def run(
|
|||||||
|
|
||||||
|
|
||||||
# output makers
|
# output makers
|
||||||
def make_output(
|
def make_output(data: Iterable[TMeasurement],
|
||||||
data: Iterable[TMeasurement],
|
MKNs: Iterable[tuple[int, int, int]],
|
||||||
MKNs: Iterable[tuple[int, int, int]],
|
base_description: str,
|
||||||
base_description: str,
|
timestamp=None):
|
||||||
timestamp=None,
|
|
||||||
):
|
|
||||||
print(f"== All Results {base_description} ====")
|
print(f"== All Results {base_description} ====")
|
||||||
print_timers(data)
|
print_timers(data)
|
||||||
|
|
||||||
@ -384,7 +258,8 @@ def make_output(
|
|||||||
|
|
||||||
|
|
||||||
def run_square_bench(args):
|
def run_square_bench(args):
|
||||||
dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment))
|
dim_sizes = list(
|
||||||
|
range(args.dim_start, args.dim_end + 1, args.dim_increment))
|
||||||
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
|
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
|
||||||
data = run(args.dtype, MKNs)
|
data = run(args.dtype, MKNs)
|
||||||
|
|
||||||
@ -444,7 +319,7 @@ def run_model_bench(args):
|
|||||||
pkl.dump(all_data, f)
|
pkl.dump(all_data, f)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
|
|
||||||
def to_torch_dtype(dt):
|
def to_torch_dtype(dt):
|
||||||
if dt == "int8":
|
if dt == "int8":
|
||||||
@ -469,15 +344,12 @@ Benchmark Cutlass GEMM.
|
|||||||
Output:
|
Output:
|
||||||
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
|
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
|
||||||
""", # noqa: E501
|
""", # noqa: E501
|
||||||
formatter_class=argparse.RawTextHelpFormatter,
|
formatter_class=argparse.RawTextHelpFormatter)
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument("--dtype",
|
||||||
"--dtype",
|
type=to_torch_dtype,
|
||||||
type=to_torch_dtype,
|
required=True,
|
||||||
required=True,
|
help="Available options are ['int8', 'fp8']")
|
||||||
help="Available options are ['int8', 'fp8']",
|
|
||||||
)
|
|
||||||
subparsers = parser.add_subparsers(dest="cmd")
|
subparsers = parser.add_subparsers(dest="cmd")
|
||||||
|
|
||||||
square_parser = subparsers.add_parser("square_bench")
|
square_parser = subparsers.add_parser("square_bench")
|
||||||
@ -496,19 +368,19 @@ Benchmark Cutlass GEMM.
|
|||||||
range_parser.set_defaults(func=run_range_bench)
|
range_parser.set_defaults(func=run_range_bench)
|
||||||
|
|
||||||
model_parser = subparsers.add_parser("model_bench")
|
model_parser = subparsers.add_parser("model_bench")
|
||||||
model_parser.add_argument(
|
model_parser.add_argument("--models",
|
||||||
"--models",
|
nargs="+",
|
||||||
nargs="+",
|
type=str,
|
||||||
type=str,
|
default=DEFAULT_MODELS,
|
||||||
default=DEFAULT_MODELS,
|
choices=WEIGHT_SHAPES.keys())
|
||||||
choices=WEIGHT_SHAPES.keys(),
|
model_parser.add_argument("--tp-sizes",
|
||||||
)
|
nargs="+",
|
||||||
model_parser.add_argument(
|
type=int,
|
||||||
"--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES
|
default=DEFAULT_TP_SIZES)
|
||||||
)
|
model_parser.add_argument("--batch-sizes",
|
||||||
model_parser.add_argument(
|
nargs="+",
|
||||||
"--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
|
type=int,
|
||||||
)
|
default=DEFAULT_BATCH_SIZES)
|
||||||
model_parser.set_defaults(func=run_model_bench)
|
model_parser.set_defaults(func=run_model_bench)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|||||||
@ -10,9 +10,8 @@ import vllm._custom_ops as ops
|
|||||||
|
|
||||||
def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
|
def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
|
||||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||||
return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to(
|
return torch.round(tensor.clamp(
|
||||||
dtype=torch.float8_e4m3fn
|
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
|
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
|
||||||
@ -27,11 +26,10 @@ def to_fp16(tensor: torch.Tensor) -> torch.Tensor:
|
|||||||
return tensor.to(dtype=torch.float16)
|
return tensor.to(dtype=torch.float16)
|
||||||
|
|
||||||
|
|
||||||
def make_rand_tensors(
|
def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
|
||||||
dtype: torch.dtype, m: int, n: int, k: int
|
k: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
a = torch.randn((m, k), device='cuda') * 5
|
||||||
a = torch.randn((m, k), device="cuda") * 5
|
b = torch.randn((n, k), device='cuda').t() * 5
|
||||||
b = torch.randn((n, k), device="cuda").t() * 5
|
|
||||||
|
|
||||||
if dtype == torch.int8:
|
if dtype == torch.int8:
|
||||||
return to_int8(a), to_int8(b)
|
return to_int8(a), to_int8(b)
|
||||||
@ -51,7 +49,9 @@ def prune_to_2_4(tensor):
|
|||||||
|
|
||||||
# Create binary mask
|
# Create binary mask
|
||||||
mask = torch.zeros_like(reshaped)
|
mask = torch.zeros_like(reshaped)
|
||||||
mask.scatter_(dim=1, index=indices, src=torch.ones_like(indices, dtype=mask.dtype))
|
mask.scatter_(dim=1,
|
||||||
|
index=indices,
|
||||||
|
src=torch.ones_like(indices, dtype=mask.dtype))
|
||||||
|
|
||||||
# Apply mask and reshape back
|
# Apply mask and reshape back
|
||||||
pruned = reshaped * mask
|
pruned = reshaped * mask
|
||||||
@ -62,11 +62,10 @@ def prune_to_2_4(tensor):
|
|||||||
return pruned.reshape(original_shape)
|
return pruned.reshape(original_shape)
|
||||||
|
|
||||||
|
|
||||||
def make_rand_sparse_tensors(
|
def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int,
|
||||||
dtype: torch.dtype, m: int, n: int, k: int
|
k: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
a = torch.randn((m, k), device='cuda') * 5
|
||||||
a = torch.randn((m, k), device="cuda") * 5
|
b = torch.randn((n, k), device='cuda').t() * 5
|
||||||
b = torch.randn((n, k), device="cuda").t() * 5
|
|
||||||
|
|
||||||
b = prune_to_2_4(b.t()).t()
|
b = prune_to_2_4(b.t()).t()
|
||||||
|
|
||||||
@ -87,9 +86,9 @@ def make_rand_sparse_tensors(
|
|||||||
return b_compressed, e, a, b
|
return b_compressed, e, a, b
|
||||||
|
|
||||||
|
|
||||||
def make_n_rand_sparse_tensors(
|
def make_n_rand_sparse_tensors(num_tensors: int, dtype: torch.dtype,
|
||||||
num_tensors: int, dtype: torch.dtype, m: int, n: int, k: int
|
m: int, n: int, k: int) -> \
|
||||||
) -> tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]:
|
tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]:
|
||||||
ABs = []
|
ABs = []
|
||||||
for _ in range(num_tensors):
|
for _ in range(num_tensors):
|
||||||
b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k)
|
b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k)
|
||||||
|
|||||||
@ -16,8 +16,7 @@ from weight_shapes import WEIGHT_SHAPES
|
|||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
w8a8_block_fp8_matmul,
|
w8a8_block_fp8_matmul)
|
||||||
)
|
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
|
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
|
||||||
@ -26,9 +25,8 @@ DEFAULT_TP_SIZES = [1]
|
|||||||
|
|
||||||
|
|
||||||
# bench
|
# bench
|
||||||
def bench_fn(
|
def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
|
||||||
label: str, sub_label: str, description: str, fn: Callable, *args, **kwargs
|
**kwargs) -> TMeasurement:
|
||||||
) -> TMeasurement:
|
|
||||||
min_run_time = 1
|
min_run_time = 1
|
||||||
|
|
||||||
globals = {
|
globals = {
|
||||||
@ -46,48 +44,45 @@ def bench_fn(
|
|||||||
|
|
||||||
|
|
||||||
def bench_int8(
|
def bench_int8(
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
m: int,
|
m: int,
|
||||||
k: int,
|
k: int,
|
||||||
n: int,
|
n: int,
|
||||||
label: str,
|
label: str,
|
||||||
sub_label: str,
|
sub_label: str,
|
||||||
bench_kernels: Optional[list[str]] = None,
|
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]:
|
||||||
) -> Iterable[TMeasurement]:
|
|
||||||
"""Benchmark INT8-based kernels."""
|
"""Benchmark INT8-based kernels."""
|
||||||
assert dtype == torch.int8
|
assert dtype == torch.int8
|
||||||
a, b = make_rand_tensors(torch.int8, m, n, k)
|
a, b = make_rand_tensors(torch.int8, m, n, k)
|
||||||
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||||
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||||
bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
|
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
|
||||||
azp = torch.zeros((m,), device="cuda", dtype=torch.int32)
|
azp = torch.zeros((m, ), device="cuda", dtype=torch.int32)
|
||||||
azp_adj = torch.zeros((n,), device="cuda", dtype=torch.int32)
|
azp_adj = torch.zeros((n, ), device="cuda", dtype=torch.int32)
|
||||||
|
|
||||||
bench_fns = {
|
bench_fns = {
|
||||||
"pytorch_bf16_bf16_bf16_matmul-no-scales": lambda: torch.mm(
|
"pytorch_bf16_bf16_bf16_matmul-no-scales":
|
||||||
a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
|
lambda: torch.mm(a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
|
||||||
),
|
),
|
||||||
"pytorch_fp16_fp16_fp16_matmul-no-scales": lambda: torch.mm(
|
"pytorch_fp16_fp16_fp16_matmul-no-scales":
|
||||||
a.to(dtype=torch.float16), b.to(dtype=torch.float16)
|
lambda: torch.mm(a.to(dtype=torch.float16), b.to(dtype=torch.float16)),
|
||||||
),
|
"cutlass_i8_i8_bf16_scaled_mm":
|
||||||
"cutlass_i8_i8_bf16_scaled_mm": lambda: ops.cutlass_scaled_mm(
|
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16),
|
||||||
a, b, scale_a, scale_b, torch.bfloat16
|
"cutlass_i8_i8_bf16_scaled_mm_bias":
|
||||||
),
|
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16,
|
||||||
"cutlass_i8_i8_bf16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
|
bias),
|
||||||
a, b, scale_a, scale_b, torch.bfloat16, bias
|
"cutlass_i8_i8_bf16_scaled_mm_azp":
|
||||||
),
|
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
|
||||||
"cutlass_i8_i8_bf16_scaled_mm_azp": lambda: ops.cutlass_scaled_mm_azp(
|
bfloat16, azp_adj),
|
||||||
a, b, scale_a, scale_b, torch.bfloat16, azp_adj
|
"cutlass_i8_i8_bf16_scaled_mm_azp_bias":
|
||||||
),
|
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
|
||||||
"cutlass_i8_i8_bf16_scaled_mm_azp_bias": lambda: ops.cutlass_scaled_mm_azp(
|
bfloat16, azp_adj, None, bias),
|
||||||
a, b, scale_a, scale_b, torch.bfloat16, azp_adj, None, bias
|
"cutlass_i8_i8_bf16_scaled_mm_azp_pt":
|
||||||
),
|
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
|
||||||
"cutlass_i8_i8_bf16_scaled_mm_azp_pt": lambda: ops.cutlass_scaled_mm_azp(
|
bfloat16, azp_adj, azp),
|
||||||
a, b, scale_a, scale_b, torch.bfloat16, azp_adj, azp
|
"cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias":
|
||||||
),
|
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
|
||||||
"cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias": lambda: ops.cutlass_scaled_mm_azp(
|
bfloat16, azp_adj, azp, bias),
|
||||||
a, b, scale_a, scale_b, torch.bfloat16, azp_adj, azp, bias
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
timers = []
|
timers = []
|
||||||
@ -101,73 +96,73 @@ def bench_int8(
|
|||||||
|
|
||||||
|
|
||||||
def bench_fp8(
|
def bench_fp8(
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
m: int,
|
m: int,
|
||||||
k: int,
|
k: int,
|
||||||
n: int,
|
n: int,
|
||||||
label: str,
|
label: str,
|
||||||
sub_label: str,
|
sub_label: str,
|
||||||
bench_kernels: Optional[list[str]] = None,
|
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]:
|
||||||
) -> Iterable[TMeasurement]:
|
|
||||||
"""Benchmark FP8-based kernels."""
|
"""Benchmark FP8-based kernels."""
|
||||||
assert dtype == torch.float8_e4m3fn
|
assert dtype == torch.float8_e4m3fn
|
||||||
a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k)
|
a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k)
|
||||||
a_cont = a.contiguous()
|
a_cont = a.contiguous()
|
||||||
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||||
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||||
|
block_scale_a = torch.rand((m, k // 128),
|
||||||
def ceil_div(x: int, y: int) -> int:
|
device="cuda",
|
||||||
return (x + y - 1) // y
|
dtype=torch.float32)
|
||||||
|
block_scale_b = torch.rand((k // 128, n // 128),
|
||||||
block_scale_a = torch.rand(
|
device="cuda",
|
||||||
(m, ceil_div(k, 128)), device="cuda", dtype=torch.float32
|
dtype=torch.float32)
|
||||||
)
|
|
||||||
block_scale_b = torch.rand(
|
|
||||||
ceil_div(k, 128), ceil_div(n, 128), device="cuda", dtype=torch.float32
|
|
||||||
)
|
|
||||||
block_scale_a_M_major = block_scale_a.t().contiguous().t()
|
block_scale_a_M_major = block_scale_a.t().contiguous().t()
|
||||||
block_scale_b_K_major = block_scale_b.t().contiguous().t()
|
block_scale_b_K_major = block_scale_b.t().contiguous().t()
|
||||||
bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
|
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
print(m, k, n)
|
print(m, k, n)
|
||||||
|
|
||||||
bench_fns = {
|
bench_fns = {
|
||||||
"pytorch_bf16_bf16_bf16_matmul-no-scales": lambda: torch.mm(
|
"pytorch_bf16_bf16_bf16_matmul-no-scales":
|
||||||
a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
|
lambda: torch.mm(a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
|
||||||
),
|
),
|
||||||
"pytorch_fp16_fp16_fp16_matmul-no-scales": lambda: torch.mm(
|
"pytorch_fp16_fp16_fp16_matmul-no-scales":
|
||||||
a.to(dtype=torch.float16), b.to(dtype=torch.float16)
|
lambda: torch.mm(a.to(dtype=torch.float16), b.to(dtype=torch.float16)),
|
||||||
),
|
"pytorch_fp8_fp8_fp16_scaled_mm":
|
||||||
"pytorch_fp8_fp8_fp16_scaled_mm": lambda: torch._scaled_mm(
|
lambda: torch._scaled_mm(
|
||||||
a, b, scale_a, scale_b, out_dtype=torch.float16
|
a, b, scale_a, scale_b, out_dtype=torch.float16),
|
||||||
),
|
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum":
|
||||||
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum": lambda: torch._scaled_mm(
|
lambda: torch._scaled_mm(a,
|
||||||
a, b, scale_a, scale_b, out_dtype=torch.float16, use_fast_accum=True
|
b,
|
||||||
),
|
scale_a,
|
||||||
"pytorch_fp8_fp8_bf16_scaled_mm": lambda: torch._scaled_mm(
|
scale_b,
|
||||||
a, b, scale_a, scale_b, out_dtype=torch.bfloat16
|
out_dtype=torch.float16,
|
||||||
),
|
use_fast_accum=True),
|
||||||
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum": lambda: torch._scaled_mm(
|
"pytorch_fp8_fp8_bf16_scaled_mm":
|
||||||
a, b, scale_a, scale_b, out_dtype=torch.bfloat16, use_fast_accum=True
|
lambda: torch._scaled_mm(
|
||||||
),
|
a, b, scale_a, scale_b, out_dtype=torch.bfloat16),
|
||||||
"cutlass_fp8_fp8_bf16_scaled_mm": lambda: ops.cutlass_scaled_mm(
|
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum":
|
||||||
a, b, scale_a, scale_b, torch.bfloat16
|
lambda: torch._scaled_mm(a,
|
||||||
),
|
b,
|
||||||
"cutlass_fp8_fp8_fp16_scaled_mm": lambda: ops.cutlass_scaled_mm(
|
scale_a,
|
||||||
a, b, scale_a, scale_b, torch.float16
|
scale_b,
|
||||||
),
|
out_dtype=torch.bfloat16,
|
||||||
"cutlass_fp8_fp8_bf16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
|
use_fast_accum=True),
|
||||||
a, b, scale_a, scale_b, torch.bfloat16, bias
|
"cutlass_fp8_fp8_bf16_scaled_mm":
|
||||||
),
|
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16),
|
||||||
"cutlass_fp8_fp8_fp16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
|
"cutlass_fp8_fp8_fp16_scaled_mm":
|
||||||
a, b, scale_a, scale_b, torch.float16, bias.to(dtype=torch.float16)
|
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.float16),
|
||||||
),
|
"cutlass_fp8_fp8_bf16_scaled_mm_bias":
|
||||||
"triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_block_fp8_matmul(
|
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16,
|
||||||
a_cont, b.t(), block_scale_a, block_scale_b.t(), (128, 128)
|
bias),
|
||||||
),
|
"cutlass_fp8_fp8_fp16_scaled_mm_bias":
|
||||||
"cutlass_fp8_fp8_fp16_scaled_mm_blockwise": lambda: ops.cutlass_scaled_mm(
|
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.float16,
|
||||||
a, b, block_scale_a_M_major, block_scale_b_K_major, torch.float16
|
bias.to(dtype=torch.float16)),
|
||||||
),
|
"triton_fp8_fp8_fp16_scaled_mm_blockwise":
|
||||||
|
lambda: w8a8_block_fp8_matmul(a_cont, b.t(), block_scale_a,
|
||||||
|
block_scale_b.t(), (128, 128)),
|
||||||
|
"cutlass_fp8_fp8_fp16_scaled_mm_blockwise":
|
||||||
|
lambda: ops.cutlass_scaled_mm(a, b, block_scale_a_M_major,
|
||||||
|
block_scale_b_K_major, torch.float16),
|
||||||
}
|
}
|
||||||
|
|
||||||
timers = []
|
timers = []
|
||||||
@ -180,15 +175,13 @@ def bench_fp8(
|
|||||||
return timers
|
return timers
|
||||||
|
|
||||||
|
|
||||||
def bench(
|
def bench(dtype: torch.dtype,
|
||||||
dtype: torch.dtype,
|
m: int,
|
||||||
m: int,
|
k: int,
|
||||||
k: int,
|
n: int,
|
||||||
n: int,
|
label: str,
|
||||||
label: str,
|
sub_label: str,
|
||||||
sub_label: str,
|
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]:
|
||||||
bench_kernels: Optional[list[str]] = None,
|
|
||||||
) -> Iterable[TMeasurement]:
|
|
||||||
if dtype == torch.int8:
|
if dtype == torch.int8:
|
||||||
return bench_int8(dtype, m, k, n, label, sub_label, bench_kernels)
|
return bench_int8(dtype, m, k, n, label, sub_label, bench_kernels)
|
||||||
if dtype == torch.float8_e4m3fn:
|
if dtype == torch.float8_e4m3fn:
|
||||||
@ -202,33 +195,27 @@ def print_timers(timers: Iterable[TMeasurement]):
|
|||||||
compare.print()
|
compare.print()
|
||||||
|
|
||||||
|
|
||||||
def run(
|
def run(dtype: torch.dtype,
|
||||||
dtype: torch.dtype,
|
MKNs: Iterable[tuple[int, int, int]],
|
||||||
MKNs: Iterable[tuple[int, int, int]],
|
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]:
|
||||||
bench_kernels: Optional[list[str]] = None,
|
|
||||||
) -> Iterable[TMeasurement]:
|
|
||||||
results = []
|
results = []
|
||||||
for m, k, n in MKNs:
|
for m, k, n in MKNs:
|
||||||
timers = bench(
|
timers = bench(dtype,
|
||||||
dtype,
|
m,
|
||||||
m,
|
k,
|
||||||
k,
|
n,
|
||||||
n,
|
f"scaled-{dtype}-gemm",
|
||||||
f"scaled-{dtype}-gemm",
|
f"MKN=({m}x{k}x{n})",
|
||||||
f"MKN=({m}x{k}x{n})",
|
bench_kernels=bench_kernels)
|
||||||
bench_kernels=bench_kernels,
|
|
||||||
)
|
|
||||||
print_timers(timers)
|
print_timers(timers)
|
||||||
results.extend(timers)
|
results.extend(timers)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def make_output(
|
def make_output(data: Iterable[TMeasurement],
|
||||||
data: Iterable[TMeasurement],
|
MKNs: Iterable[tuple[int, int, int]],
|
||||||
MKNs: Iterable[tuple[int, int, int]],
|
base_description: str,
|
||||||
base_description: str,
|
timestamp=None):
|
||||||
timestamp=None,
|
|
||||||
):
|
|
||||||
print(f"== All Results {base_description} ====")
|
print(f"== All Results {base_description} ====")
|
||||||
print_timers(data)
|
print_timers(data)
|
||||||
|
|
||||||
@ -239,7 +226,8 @@ def make_output(
|
|||||||
|
|
||||||
|
|
||||||
def run_square_bench(args):
|
def run_square_bench(args):
|
||||||
dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment))
|
dim_sizes = list(
|
||||||
|
range(args.dim_start, args.dim_end + 1, args.dim_increment))
|
||||||
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
|
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
|
||||||
data = run(args.dtype, MKNs, bench_kernels=args.kernels)
|
data = run(args.dtype, MKNs, bench_kernels=args.kernels)
|
||||||
make_output(data, MKNs, f"square_bench-{args.dtype}")
|
make_output(data, MKNs, f"square_bench-{args.dtype}")
|
||||||
@ -297,7 +285,7 @@ def run_model_bench(args):
|
|||||||
pkl.dump(all_data, f)
|
pkl.dump(all_data, f)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
|
|
||||||
def to_torch_dtype(dt):
|
def to_torch_dtype(dt):
|
||||||
if dt == "int8":
|
if dt == "int8":
|
||||||
@ -322,21 +310,19 @@ Benchmark Cutlass GEMM.
|
|||||||
Output:
|
Output:
|
||||||
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
|
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
|
||||||
""", # noqa: E501
|
""", # noqa: E501
|
||||||
formatter_class=argparse.RawTextHelpFormatter,
|
formatter_class=argparse.RawTextHelpFormatter)
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument("--dtype",
|
||||||
"--dtype",
|
type=to_torch_dtype,
|
||||||
type=to_torch_dtype,
|
required=True,
|
||||||
required=True,
|
help="Available options are ['int8', 'fp8']")
|
||||||
help="Available options are ['int8', 'fp8']",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--kernels",
|
"--kernels",
|
||||||
nargs="+",
|
nargs="+",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="Exact names of the kernels to benchmark. If not set, runs all kernels.",
|
help=
|
||||||
|
"Exact names of the kernels to benchmark. If not set, runs all kernels."
|
||||||
)
|
)
|
||||||
|
|
||||||
subparsers = parser.add_subparsers(dest="cmd")
|
subparsers = parser.add_subparsers(dest="cmd")
|
||||||
@ -357,19 +343,19 @@ Benchmark Cutlass GEMM.
|
|||||||
range_parser.set_defaults(func=run_range_bench)
|
range_parser.set_defaults(func=run_range_bench)
|
||||||
|
|
||||||
model_parser = subparsers.add_parser("model_bench")
|
model_parser = subparsers.add_parser("model_bench")
|
||||||
model_parser.add_argument(
|
model_parser.add_argument("--models",
|
||||||
"--models",
|
nargs="+",
|
||||||
nargs="+",
|
type=str,
|
||||||
type=str,
|
default=DEFAULT_MODELS,
|
||||||
default=DEFAULT_MODELS,
|
choices=WEIGHT_SHAPES.keys())
|
||||||
choices=WEIGHT_SHAPES.keys(),
|
model_parser.add_argument("--tp-sizes",
|
||||||
)
|
nargs="+",
|
||||||
model_parser.add_argument(
|
type=int,
|
||||||
"--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES
|
default=DEFAULT_TP_SIZES)
|
||||||
)
|
model_parser.add_argument("--batch-sizes",
|
||||||
model_parser.add_argument(
|
nargs="+",
|
||||||
"--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
|
type=int,
|
||||||
)
|
default=DEFAULT_BATCH_SIZES)
|
||||||
model_parser.set_defaults(func=run_model_bench)
|
model_parser.set_defaults(func=run_model_bench)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|||||||
@ -42,4 +42,4 @@ WEIGHT_SHAPES = {
|
|||||||
([8192, 57344], 1),
|
([8192, 57344], 1),
|
||||||
([28672, 8192], 0),
|
([28672, 8192], 0),
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
@ -12,37 +12,39 @@ app = Quart(__name__)
|
|||||||
|
|
||||||
async def forward_request(url, data):
|
async def forward_request(url, data):
|
||||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||||
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
headers = {
|
||||||
async with session.post(url=url, json=data, headers=headers) as response:
|
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
|
||||||
|
}
|
||||||
|
async with session.post(url=url, json=data,
|
||||||
|
headers=headers) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
# if response.headers.get('Transfer-Encoding') == 'chunked':
|
# if response.headers.get('Transfer-Encoding') == 'chunked':
|
||||||
if True:
|
if True:
|
||||||
async for chunk_bytes in response.content.iter_chunked(1024):
|
async for chunk_bytes in response.content.iter_chunked(
|
||||||
|
1024):
|
||||||
yield chunk_bytes
|
yield chunk_bytes
|
||||||
else:
|
else:
|
||||||
content = await response.read()
|
content = await response.read()
|
||||||
yield content
|
yield content
|
||||||
|
|
||||||
|
|
||||||
@app.route("/v1/completions", methods=["POST"])
|
@app.route('/v1/completions', methods=['POST'])
|
||||||
async def handle_request():
|
async def handle_request():
|
||||||
try:
|
try:
|
||||||
original_request_data = await request.get_json()
|
original_request_data = await request.get_json()
|
||||||
|
|
||||||
prefill_request = original_request_data.copy()
|
prefill_request = original_request_data.copy()
|
||||||
# change max_tokens = 1 to let it only do prefill
|
# change max_tokens = 1 to let it only do prefill
|
||||||
prefill_request["max_tokens"] = 1
|
prefill_request['max_tokens'] = 1
|
||||||
|
|
||||||
# finish prefill
|
# finish prefill
|
||||||
async for _ in forward_request(
|
async for _ in forward_request('http://localhost:8100/v1/completions',
|
||||||
"http://localhost:8100/v1/completions", prefill_request
|
prefill_request):
|
||||||
):
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# return decode
|
# return decode
|
||||||
generator = forward_request(
|
generator = forward_request('http://localhost:8200/v1/completions',
|
||||||
"http://localhost:8200/v1/completions", original_request_data
|
original_request_data)
|
||||||
)
|
|
||||||
response = await make_response(generator)
|
response = await make_response(generator)
|
||||||
response.timeout = None
|
response.timeout = None
|
||||||
|
|
||||||
@ -51,12 +53,11 @@ async def handle_request():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
exc_info = sys.exc_info()
|
exc_info = sys.exc_info()
|
||||||
print("Error occurred in disagg prefill proxy server")
|
print("Error occurred in disagg prefill proxy server")
|
||||||
print(e)
|
print(e)
|
||||||
print("".join(traceback.format_exception(*exc_info)))
|
print("".join(traceback.format_exception(*exc_info)))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
app.run(port=8000)
|
app.run(port=8000)
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from aiohttp import web
|
|||||||
|
|
||||||
|
|
||||||
class RoundRobinProxy:
|
class RoundRobinProxy:
|
||||||
|
|
||||||
def __init__(self, target_ports):
|
def __init__(self, target_ports):
|
||||||
self.target_ports = target_ports
|
self.target_ports = target_ports
|
||||||
self.port_cycle = itertools.cycle(self.target_ports)
|
self.port_cycle = itertools.cycle(self.target_ports)
|
||||||
@ -20,15 +21,14 @@ class RoundRobinProxy:
|
|||||||
try:
|
try:
|
||||||
# Forward the request
|
# Forward the request
|
||||||
async with session.request(
|
async with session.request(
|
||||||
method=request.method,
|
method=request.method,
|
||||||
url=target_url,
|
url=target_url,
|
||||||
headers=request.headers,
|
headers=request.headers,
|
||||||
data=request.content,
|
data=request.content,
|
||||||
) as response:
|
) as response:
|
||||||
# Start sending the response
|
# Start sending the response
|
||||||
resp = web.StreamResponse(
|
resp = web.StreamResponse(status=response.status,
|
||||||
status=response.status, headers=response.headers
|
headers=response.headers)
|
||||||
)
|
|
||||||
await resp.prepare(request)
|
await resp.prepare(request)
|
||||||
|
|
||||||
# Stream the response content
|
# Stream the response content
|
||||||
@ -45,11 +45,11 @@ class RoundRobinProxy:
|
|||||||
async def main():
|
async def main():
|
||||||
proxy = RoundRobinProxy([8100, 8200])
|
proxy = RoundRobinProxy([8100, 8200])
|
||||||
app = web.Application()
|
app = web.Application()
|
||||||
app.router.add_route("*", "/{path:.*}", proxy.handle_request)
|
app.router.add_route('*', '/{path:.*}', proxy.handle_request)
|
||||||
|
|
||||||
runner = web.AppRunner(app)
|
runner = web.AppRunner(app)
|
||||||
await runner.setup()
|
await runner.setup()
|
||||||
site = web.TCPSite(runner, "localhost", 8000)
|
site = web.TCPSite(runner, 'localhost', 8000)
|
||||||
await site.start()
|
await site.start()
|
||||||
|
|
||||||
print("Proxy server started on http://localhost:8000")
|
print("Proxy server started on http://localhost:8000")
|
||||||
@ -58,5 +58,5 @@ async def main():
|
|||||||
await asyncio.Event().wait()
|
await asyncio.Event().wait()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|||||||
@ -6,41 +6,43 @@ import matplotlib.pyplot as plt
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
data = []
|
data = []
|
||||||
for name in ["disagg_prefill", "chunked_prefill"]:
|
for name in ['disagg_prefill', 'chunked_prefill']:
|
||||||
for qps in [2, 4, 6, 8]:
|
for qps in [2, 4, 6, 8]:
|
||||||
with open(f"results/{name}-qps-{qps}.json") as f:
|
with open(f"results/{name}-qps-{qps}.json") as f:
|
||||||
x = json.load(f)
|
x = json.load(f)
|
||||||
x["name"] = name
|
x['name'] = name
|
||||||
x["qps"] = qps
|
x['qps'] = qps
|
||||||
data.append(x)
|
data.append(x)
|
||||||
|
|
||||||
df = pd.DataFrame.from_dict(data)
|
df = pd.DataFrame.from_dict(data)
|
||||||
dis_df = df[df["name"] == "disagg_prefill"]
|
dis_df = df[df['name'] == 'disagg_prefill']
|
||||||
chu_df = df[df["name"] == "chunked_prefill"]
|
chu_df = df[df['name'] == 'chunked_prefill']
|
||||||
|
|
||||||
plt.style.use("bmh")
|
plt.style.use('bmh')
|
||||||
plt.rcParams["font.size"] = 20
|
plt.rcParams['font.size'] = 20
|
||||||
|
|
||||||
for key in [
|
for key in [
|
||||||
"mean_ttft_ms",
|
'mean_ttft_ms', 'median_ttft_ms', 'p99_ttft_ms', 'mean_itl_ms',
|
||||||
"median_ttft_ms",
|
'median_itl_ms', 'p99_itl_ms'
|
||||||
"p99_ttft_ms",
|
|
||||||
"mean_itl_ms",
|
|
||||||
"median_itl_ms",
|
|
||||||
"p99_itl_ms",
|
|
||||||
]:
|
]:
|
||||||
|
|
||||||
fig, ax = plt.subplots(figsize=(11, 7))
|
fig, ax = plt.subplots(figsize=(11, 7))
|
||||||
plt.plot(
|
plt.plot(dis_df['qps'],
|
||||||
dis_df["qps"], dis_df[key], label="disagg_prefill", marker="o", linewidth=4
|
dis_df[key],
|
||||||
)
|
label='disagg_prefill',
|
||||||
plt.plot(
|
marker='o',
|
||||||
chu_df["qps"], chu_df[key], label="chunked_prefill", marker="o", linewidth=4
|
linewidth=4)
|
||||||
)
|
plt.plot(chu_df['qps'],
|
||||||
|
chu_df[key],
|
||||||
|
label='chunked_prefill',
|
||||||
|
marker='o',
|
||||||
|
linewidth=4)
|
||||||
ax.legend()
|
ax.legend()
|
||||||
|
|
||||||
ax.set_xlabel("QPS")
|
ax.set_xlabel('QPS')
|
||||||
ax.set_ylabel(key)
|
ax.set_ylabel(key)
|
||||||
ax.set_ylim(bottom=0)
|
ax.set_ylim(bottom=0)
|
||||||
fig.savefig(f"results/{key}.png")
|
fig.savefig(f'results/{key}.png')
|
||||||
plt.close(fig)
|
plt.close(fig)
|
||||||
|
|||||||
@ -24,12 +24,10 @@ class bench_params_t:
|
|||||||
dtype: torch.dtype
|
dtype: torch.dtype
|
||||||
|
|
||||||
def description(self):
|
def description(self):
|
||||||
return (
|
return (f'N {self.num_tokens} '
|
||||||
f"N {self.num_tokens} "
|
f'x D {self.hidden_size} '
|
||||||
f"x D {self.hidden_size} "
|
f'x R {self.add_residual} '
|
||||||
f"x R {self.add_residual} "
|
f'x DT {self.dtype}')
|
||||||
f"x DT {self.dtype}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_bench_params() -> list[bench_params_t]:
|
def get_bench_params() -> list[bench_params_t]:
|
||||||
@ -40,19 +38,15 @@ def get_bench_params() -> list[bench_params_t]:
|
|||||||
DTYPES = [torch.bfloat16, torch.float]
|
DTYPES = [torch.bfloat16, torch.float]
|
||||||
|
|
||||||
combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES)
|
combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES)
|
||||||
bench_params = list(
|
bench_params = list(map(lambda x: \
|
||||||
map(lambda x: bench_params_t(x[0], x[1], x[2], x[3]), combinations)
|
bench_params_t(x[0], x[1], x[2], x[3]), combinations))
|
||||||
)
|
|
||||||
return bench_params
|
return bench_params
|
||||||
|
|
||||||
|
|
||||||
# Reference impls
|
# Reference impls
|
||||||
def unfused_int8_impl(
|
def unfused_int8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor,
|
||||||
rms_norm_layer: RMSNorm,
|
residual: Optional[torch.Tensor],
|
||||||
x: torch.Tensor,
|
quant_dtype: torch.dtype):
|
||||||
residual: Optional[torch.Tensor],
|
|
||||||
quant_dtype: torch.dtype,
|
|
||||||
):
|
|
||||||
# Norm
|
# Norm
|
||||||
torch_out = None
|
torch_out = None
|
||||||
if residual is None:
|
if residual is None:
|
||||||
@ -64,12 +58,9 @@ def unfused_int8_impl(
|
|||||||
torch_out, _, _ = ops.scaled_int8_quant(torch_out)
|
torch_out, _, _ = ops.scaled_int8_quant(torch_out)
|
||||||
|
|
||||||
|
|
||||||
def unfused_fp8_impl(
|
def unfused_fp8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor,
|
||||||
rms_norm_layer: RMSNorm,
|
residual: Optional[torch.Tensor],
|
||||||
x: torch.Tensor,
|
quant_dtype: torch.dtype):
|
||||||
residual: Optional[torch.Tensor],
|
|
||||||
quant_dtype: torch.dtype,
|
|
||||||
):
|
|
||||||
# Norm
|
# Norm
|
||||||
torch_out = None
|
torch_out = None
|
||||||
if residual is None:
|
if residual is None:
|
||||||
@ -82,27 +73,22 @@ def unfused_fp8_impl(
|
|||||||
|
|
||||||
|
|
||||||
def fused_impl(
|
def fused_impl(
|
||||||
rms_norm_layer: RMSNorm, # this stores the weights
|
rms_norm_layer: RMSNorm, # this stores the weights
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
quant_dtype: torch.dtype,
|
quant_dtype: torch.dtype):
|
||||||
):
|
out, _ = ops.rms_norm_dynamic_per_token_quant(x,
|
||||||
out, _ = ops.rms_norm_dynamic_per_token_quant(
|
rms_norm_layer.weight,
|
||||||
x, rms_norm_layer.weight, 1e-6, quant_dtype, residual=residual
|
1e-6,
|
||||||
)
|
quant_dtype,
|
||||||
|
residual=residual)
|
||||||
|
|
||||||
|
|
||||||
# Bench functions
|
# Bench functions
|
||||||
def bench_fn(
|
def bench_fn(rms_norm_layer: RMSNorm, x: torch.Tensor, residual: torch.Tensor,
|
||||||
rms_norm_layer: RMSNorm,
|
quant_dtype: torch.dtype, label: str, sub_label: str,
|
||||||
x: torch.Tensor,
|
fn: Callable, description: str) -> TMeasurement:
|
||||||
residual: torch.Tensor,
|
|
||||||
quant_dtype: torch.dtype,
|
|
||||||
label: str,
|
|
||||||
sub_label: str,
|
|
||||||
fn: Callable,
|
|
||||||
description: str,
|
|
||||||
) -> TMeasurement:
|
|
||||||
min_run_time = 1
|
min_run_time = 1
|
||||||
|
|
||||||
globals = {
|
globals = {
|
||||||
@ -120,81 +106,43 @@ def bench_fn(
|
|||||||
description=description,
|
description=description,
|
||||||
).blocked_autorange(min_run_time=min_run_time)
|
).blocked_autorange(min_run_time=min_run_time)
|
||||||
|
|
||||||
|
def bench(params: bench_params_t, label: str, sub_label: str) \
|
||||||
|
-> Iterable[TMeasurement]:
|
||||||
|
|
||||||
def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasurement]:
|
|
||||||
# Make inputs
|
# Make inputs
|
||||||
layer = RMSNorm(params.hidden_size, 1e-6).to(dtype=params.dtype)
|
layer = RMSNorm(params.hidden_size, 1e-6).to(dtype=params.dtype)
|
||||||
# Make weights
|
# Make weights
|
||||||
layer.weight.data.normal_(mean=1.0, std=0.1)
|
layer.weight.data.normal_(mean=1.0, std=0.1)
|
||||||
# Make inputs
|
# Make inputs
|
||||||
scale = 1 / params.hidden_size
|
scale = 1 / params.hidden_size
|
||||||
x = (
|
x = torch.randn(params.num_tokens,
|
||||||
torch.randn(
|
params.hidden_size,
|
||||||
params.num_tokens, params.hidden_size, dtype=params.dtype, device="cuda"
|
dtype=params.dtype,
|
||||||
)
|
device='cuda') * scale
|
||||||
* scale
|
residual = (torch.randn_like(x) * scale).to(device='cuda') \
|
||||||
)
|
if params.add_residual else None
|
||||||
residual = (
|
|
||||||
(torch.randn_like(x) * scale).to(device="cuda") if params.add_residual else None
|
|
||||||
)
|
|
||||||
|
|
||||||
timers = []
|
timers = []
|
||||||
|
|
||||||
# unfused int8 impl.
|
# unfused int8 impl.
|
||||||
timers.append(
|
timers.append(
|
||||||
bench_fn(
|
bench_fn(layer, x, residual, torch.int8, label, sub_label,
|
||||||
layer,
|
unfused_int8_impl, "unfused_int8_impl"))
|
||||||
x,
|
|
||||||
residual,
|
|
||||||
torch.int8,
|
|
||||||
label,
|
|
||||||
sub_label,
|
|
||||||
unfused_int8_impl,
|
|
||||||
"unfused_int8_impl",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# unfused fp8 impl.
|
# unfused fp8 impl.
|
||||||
timers.append(
|
timers.append(
|
||||||
bench_fn(
|
bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label,
|
||||||
layer,
|
unfused_fp8_impl, "unfused_fp8_impl"))
|
||||||
x,
|
|
||||||
residual,
|
|
||||||
torch.float8_e4m3fn,
|
|
||||||
label,
|
|
||||||
sub_label,
|
|
||||||
unfused_fp8_impl,
|
|
||||||
"unfused_fp8_impl",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# fused int8 impl.
|
# fused int8 impl.
|
||||||
timers.append(
|
timers.append(
|
||||||
bench_fn(
|
bench_fn(layer, x, residual, torch.int8, label, sub_label, fused_impl,
|
||||||
layer,
|
"fused_int8_impl"))
|
||||||
x,
|
|
||||||
residual,
|
|
||||||
torch.int8,
|
|
||||||
label,
|
|
||||||
sub_label,
|
|
||||||
fused_impl,
|
|
||||||
"fused_int8_impl",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# fused fp8 impl.
|
# fused fp8 impl.
|
||||||
timers.append(
|
timers.append(
|
||||||
bench_fn(
|
bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label,
|
||||||
layer,
|
fused_impl, "fused_fp8_impl"))
|
||||||
x,
|
|
||||||
residual,
|
|
||||||
torch.float8_e4m3fn,
|
|
||||||
label,
|
|
||||||
sub_label,
|
|
||||||
fused_impl,
|
|
||||||
"fused_fp8_impl",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
print_timers(timers)
|
print_timers(timers)
|
||||||
|
|
||||||
@ -209,12 +157,13 @@ def print_timers(timers: Iterable[TMeasurement]):
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device('cuda')
|
||||||
bench_params = get_bench_params()
|
bench_params = get_bench_params()
|
||||||
|
|
||||||
timers = []
|
timers = []
|
||||||
for bp in tqdm(bench_params):
|
for bp in tqdm(bench_params):
|
||||||
timers.extend(bench(bp, "rms-norm-dynamic-per-token-quant", bp.description()))
|
timers.extend(
|
||||||
|
bench(bp, "rms-norm-dynamic-per-token-quant", bp.description()))
|
||||||
print_timers(timers)
|
print_timers(timers)
|
||||||
|
|
||||||
# pickle all the results
|
# pickle all the results
|
||||||
@ -223,5 +172,5 @@ def main():
|
|||||||
pkl.dump(timers, f)
|
pkl.dump(timers, f)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|||||||
@ -9,39 +9,32 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.quantization.aqlm import (
|
from vllm.model_executor.layers.quantization.aqlm import (
|
||||||
dequantize_weight,
|
dequantize_weight, generic_dequantize_gemm, get_int_dtype,
|
||||||
generic_dequantize_gemm,
|
optimized_dequantize_gemm)
|
||||||
get_int_dtype,
|
|
||||||
optimized_dequantize_gemm,
|
|
||||||
)
|
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||||
|
|
||||||
|
|
||||||
def torch_mult(
|
def torch_mult(
|
||||||
# [..., in_features]
|
input: torch.Tensor, # [..., in_features]
|
||||||
input: torch.Tensor,
|
weights: torch.Tensor,
|
||||||
weights: torch.Tensor,
|
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
|
||||||
# [num_out_groups, 1, 1, 1]
|
|
||||||
scales: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
output = F.linear(input, weights)
|
output = F.linear(input, weights)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def dequant_out_scale(
|
def dequant_out_scale(
|
||||||
# [..., in_features]
|
input: torch.Tensor, # [..., in_features]
|
||||||
input: torch.Tensor,
|
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
|
||||||
# [num_out_groups, num_in_groups, num_codebooks]
|
codebooks: torch.
|
||||||
codes: torch.IntTensor,
|
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
|
||||||
# [num_codebooks, codebook_size, out_group_size, in_group_size]
|
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
|
||||||
codebooks: torch.Tensor,
|
|
||||||
# [num_out_groups, 1, 1, 1]
|
|
||||||
scales: torch.Tensor,
|
|
||||||
output_partition_sizes: torch.IntTensor,
|
output_partition_sizes: torch.IntTensor,
|
||||||
bias: Optional[torch.Tensor],
|
bias: Optional[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
|
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
|
||||||
|
|
||||||
if bias is None:
|
if bias is None:
|
||||||
@ -53,42 +46,40 @@ def dequant_out_scale(
|
|||||||
flattened_output *= b_scales
|
flattened_output *= b_scales
|
||||||
return flattened_output.view(orig_shape)
|
return flattened_output.view(orig_shape)
|
||||||
else:
|
else:
|
||||||
b_scales = scales.view(scales.shape[:-3] + (-1,)).expand(-1, weights.shape[1])
|
b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
|
||||||
|
-1, weights.shape[1])
|
||||||
weights *= b_scales
|
weights *= b_scales
|
||||||
return F.linear(input, weights, bias)
|
return F.linear(input, weights, bias)
|
||||||
|
|
||||||
|
|
||||||
def dequant_weight_scale(
|
def dequant_weight_scale(
|
||||||
# [..., in_features]
|
input: torch.Tensor, # [..., in_features]
|
||||||
input: torch.Tensor,
|
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
|
||||||
# [num_out_groups, num_in_groups, num_codebooks]
|
codebooks: torch.
|
||||||
codes: torch.IntTensor,
|
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
|
||||||
# [num_codebooks, codebook_size, out_group_size, in_group_size]
|
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
|
||||||
codebooks: torch.Tensor,
|
|
||||||
# [num_out_groups, 1, 1, 1]
|
|
||||||
scales: torch.Tensor,
|
|
||||||
output_partition_sizes: torch.IntTensor,
|
output_partition_sizes: torch.IntTensor,
|
||||||
bias: Optional[torch.Tensor],
|
bias: Optional[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
|
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
|
||||||
|
|
||||||
b_scales = scales.view(scales.shape[:-3] + (-1,)).expand(-1, weights.shape[1])
|
b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
|
||||||
|
-1, weights.shape[1])
|
||||||
weights *= b_scales
|
weights *= b_scales
|
||||||
return F.linear(input, weights, bias)
|
return F.linear(input, weights, bias)
|
||||||
|
|
||||||
|
|
||||||
def dequant_no_scale(
|
def dequant_no_scale(
|
||||||
# [..., in_features]
|
input: torch.Tensor, # [..., in_features]
|
||||||
input: torch.Tensor,
|
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
|
||||||
# [num_out_groups, num_in_groups, num_codebooks]
|
codebooks: torch.
|
||||||
codes: torch.IntTensor,
|
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
|
||||||
# [num_codebooks, codebook_size, out_group_size, in_group_size]
|
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
|
||||||
codebooks: torch.Tensor,
|
|
||||||
# [num_out_groups, 1, 1, 1]
|
|
||||||
scales: torch.Tensor,
|
|
||||||
output_partition_sizes: torch.IntTensor,
|
output_partition_sizes: torch.IntTensor,
|
||||||
bias: Optional[torch.Tensor],
|
bias: Optional[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
|
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
|
||||||
|
|
||||||
return F.linear(input, weights, bias)
|
return F.linear(input, weights, bias)
|
||||||
@ -98,26 +89,23 @@ def dequant_no_scale(
|
|||||||
# the generic pytorch version.
|
# the generic pytorch version.
|
||||||
# Just visual comparison.
|
# Just visual comparison.
|
||||||
def dequant_test(k: int, parts: torch.Tensor, nbooks: int, bits: int) -> None:
|
def dequant_test(k: int, parts: torch.Tensor, nbooks: int, bits: int) -> None:
|
||||||
|
|
||||||
n = int(parts.sum().item())
|
n = int(parts.sum().item())
|
||||||
|
|
||||||
device = torch.device("cuda:0")
|
device = torch.device('cuda:0')
|
||||||
|
|
||||||
code_range = (1 << bits) // 2
|
code_range = (1 << bits) // 2
|
||||||
ingroups = 8
|
ingroups = 8
|
||||||
|
|
||||||
codes = torch.randint(
|
codes = torch.randint(-code_range,
|
||||||
-code_range,
|
code_range,
|
||||||
code_range,
|
size=(n, k // ingroups, nbooks),
|
||||||
size=(n, k // ingroups, nbooks),
|
dtype=get_int_dtype(bits),
|
||||||
dtype=get_int_dtype(bits),
|
device=device)
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
codebooks = torch.randn(
|
codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
|
||||||
size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
|
dtype=torch.float16,
|
||||||
dtype=torch.float16,
|
device=device)
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
count = 0
|
count = 0
|
||||||
for index in range(16):
|
for index in range(16):
|
||||||
@ -150,25 +138,24 @@ def dequant_test(k: int, parts: torch.Tensor, nbooks: int, bits: int) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
||||||
parser = FlexibleArgumentParser(description="Benchmark aqlm performance.")
|
parser = FlexibleArgumentParser(description="Benchmark aqlm performance.")
|
||||||
|
|
||||||
# Add arguments
|
# Add arguments
|
||||||
parser.add_argument(
|
parser.add_argument("--nbooks",
|
||||||
"--nbooks", type=int, default=1, help="Number of codebooks (default: 1)"
|
type=int,
|
||||||
)
|
default=1,
|
||||||
parser.add_argument(
|
help="Number of codebooks (default: 1)")
|
||||||
"--bits",
|
parser.add_argument("--bits",
|
||||||
type=int,
|
type=int,
|
||||||
default=16,
|
default=16,
|
||||||
help="Number of bits per code element (default: 16)",
|
help="Number of bits per code element (default: 16)")
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--test",
|
"--test",
|
||||||
type=bool,
|
type=bool,
|
||||||
default=False,
|
default=False,
|
||||||
help="Run the decompression/dequant tester rather than benchmarking "
|
help="Run the decompression/dequant tester rather than benchmarking "
|
||||||
"(default: False)",
|
"(default: False)")
|
||||||
)
|
|
||||||
|
|
||||||
# Parse the arguments
|
# Parse the arguments
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@ -178,7 +165,7 @@ def main():
|
|||||||
bits = args.bits
|
bits = args.bits
|
||||||
|
|
||||||
if args.test:
|
if args.test:
|
||||||
dequant_test(4096, torch.tensor((4096,)), nbooks, bits)
|
dequant_test(4096, torch.tensor((4096, )), nbooks, bits)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Otherwise, benchmark.
|
# Otherwise, benchmark.
|
||||||
@ -197,54 +184,31 @@ def main():
|
|||||||
with open(filename, "w") as f:
|
with open(filename, "w") as f:
|
||||||
sys.stdout = f
|
sys.stdout = f
|
||||||
|
|
||||||
print("m | k | n | n parts", end="")
|
print('m | k | n | n parts', end='')
|
||||||
for method in methods:
|
for method in methods:
|
||||||
print(f" | {method.__name__.replace('_', ' ')} (µs)", end="")
|
print(f" | {method.__name__.replace('_', ' ')} (µs)", end='')
|
||||||
print("")
|
print('')
|
||||||
|
|
||||||
# These are reasonable prefill sizes.
|
# These are reasonable prefill sizes.
|
||||||
ksandpartions = (
|
ksandpartions = ((4096, (4096, 4096, 4096)), (4096, (4096, )),
|
||||||
(4096, (4096, 4096, 4096)),
|
(4096, (11008, 11008)), (11008, (4096, )))
|
||||||
(4096, (4096,)),
|
|
||||||
(4096, (11008, 11008)),
|
|
||||||
(11008, (4096,)),
|
|
||||||
)
|
|
||||||
|
|
||||||
# reasonable ranges for m.
|
# reasonable ranges for m.
|
||||||
for m in [
|
for m in [
|
||||||
1,
|
1, 2, 4, 8, 10, 12, 14, 16, 24, 32, 48, 52, 56, 64, 96, 112,
|
||||||
2,
|
128, 256, 512, 1024, 1536, 2048, 3072, 4096
|
||||||
4,
|
|
||||||
8,
|
|
||||||
10,
|
|
||||||
12,
|
|
||||||
14,
|
|
||||||
16,
|
|
||||||
24,
|
|
||||||
32,
|
|
||||||
48,
|
|
||||||
52,
|
|
||||||
56,
|
|
||||||
64,
|
|
||||||
96,
|
|
||||||
112,
|
|
||||||
128,
|
|
||||||
256,
|
|
||||||
512,
|
|
||||||
1024,
|
|
||||||
1536,
|
|
||||||
2048,
|
|
||||||
3072,
|
|
||||||
4096,
|
|
||||||
]:
|
]:
|
||||||
print(f"{m}", file=sys.__stdout__)
|
print(f'{m}', file=sys.__stdout__)
|
||||||
for ksp in ksandpartions:
|
for ksp in ksandpartions:
|
||||||
run_grid(m, ksp[0], torch.tensor(ksp[1]), nbooks, bits, methods)
|
run_grid(m, ksp[0], torch.tensor(ksp[1]), nbooks, bits,
|
||||||
|
methods)
|
||||||
|
|
||||||
sys.stdout = sys.__stdout__
|
sys.stdout = sys.__stdout__
|
||||||
|
|
||||||
|
|
||||||
def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, methods):
|
def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int,
|
||||||
|
methods):
|
||||||
|
|
||||||
# I didn't see visible improvements from increasing these, but feel free :)
|
# I didn't see visible improvements from increasing these, but feel free :)
|
||||||
num_warmup_trials = 1
|
num_warmup_trials = 1
|
||||||
num_trials = 1
|
num_trials = 1
|
||||||
@ -265,7 +229,7 @@ def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, method
|
|||||||
)
|
)
|
||||||
|
|
||||||
n = parts.sum().item()
|
n = parts.sum().item()
|
||||||
print(f"{m} | {k} | {n} | {parts.tolist()}", end="")
|
print(f'{m} | {k} | {n} | {parts.tolist()}', end='')
|
||||||
|
|
||||||
for method in methods:
|
for method in methods:
|
||||||
best_time_us = 1e20
|
best_time_us = 1e20
|
||||||
@ -285,36 +249,32 @@ def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, method
|
|||||||
if kernel_dur_us < best_time_us:
|
if kernel_dur_us < best_time_us:
|
||||||
best_time_us = kernel_dur_us
|
best_time_us = kernel_dur_us
|
||||||
|
|
||||||
print(f" | {kernel_dur_us:.0f}", end="")
|
print(f' | {kernel_dur_us:.0f}', end='')
|
||||||
|
|
||||||
print("")
|
print('')
|
||||||
|
|
||||||
|
|
||||||
def run_timing(
|
def run_timing(num_calls: int, m: int, k: int, parts: torch.Tensor,
|
||||||
num_calls: int, m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, method
|
nbooks: int, bits: int, method) -> float:
|
||||||
) -> float:
|
|
||||||
n = int(parts.sum().item())
|
n = int(parts.sum().item())
|
||||||
|
|
||||||
device = torch.device("cuda:0")
|
device = torch.device('cuda:0')
|
||||||
|
|
||||||
input = torch.randn((1, m, k), dtype=torch.float16, device=device)
|
input = torch.randn((1, m, k), dtype=torch.float16, device=device)
|
||||||
|
|
||||||
code_range = (1 << bits) // 2
|
code_range = (1 << bits) // 2
|
||||||
ingroups = 8
|
ingroups = 8
|
||||||
|
|
||||||
codes = torch.randint(
|
codes = torch.randint(-code_range,
|
||||||
-code_range,
|
code_range,
|
||||||
code_range,
|
size=(n, k // ingroups, nbooks),
|
||||||
size=(n, k // ingroups, nbooks),
|
dtype=get_int_dtype(bits),
|
||||||
dtype=get_int_dtype(bits),
|
device=device)
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
codebooks = torch.randn(
|
codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
|
||||||
size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
|
dtype=torch.float16,
|
||||||
dtype=torch.float16,
|
device=device)
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
scales = torch.randn(size=(n, 1, 1, 1), dtype=torch.float16, device=device)
|
scales = torch.randn(size=(n, 1, 1, 1), dtype=torch.float16, device=device)
|
||||||
|
|
||||||
|
|||||||
@ -3,33 +3,27 @@
|
|||||||
# Licensed under the MIT License.
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
|
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
|
||||||
MINIMUM_BITBLAS_VERSION,
|
MINIMUM_BITBLAS_VERSION)
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import bitblas
|
import bitblas
|
||||||
|
|
||||||
if bitblas.__version__ < MINIMUM_BITBLAS_VERSION:
|
if bitblas.__version__ < MINIMUM_BITBLAS_VERSION:
|
||||||
raise ImportError(
|
raise ImportError("bitblas version is wrong. Please "
|
||||||
"bitblas version is wrong. Please "
|
f"install bitblas>={MINIMUM_BITBLAS_VERSION}")
|
||||||
f"install bitblas>={MINIMUM_BITBLAS_VERSION}"
|
|
||||||
)
|
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
bitblas_import_exception = e
|
bitblas_import_exception = e
|
||||||
raise ValueError(
|
raise ValueError("Trying to use the bitblas backend, but could not import"
|
||||||
"Trying to use the bitblas backend, but could not import"
|
f"with the following error: {bitblas_import_exception}. "
|
||||||
f"with the following error: {bitblas_import_exception}. "
|
"Please install bitblas through the following command: "
|
||||||
"Please install bitblas through the following command: "
|
f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`"
|
||||||
f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`"
|
) from bitblas_import_exception
|
||||||
) from bitblas_import_exception
|
|
||||||
|
|
||||||
from bitblas import Matmul, MatmulConfig, auto_detect_nvidia_target
|
from bitblas import Matmul, MatmulConfig, auto_detect_nvidia_target
|
||||||
|
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
parser = FlexibleArgumentParser(
|
parser = FlexibleArgumentParser(
|
||||||
description="Benchmark BitBLAS int4 on a specific target."
|
description="Benchmark BitBLAS int4 on a specific target.")
|
||||||
)
|
|
||||||
|
|
||||||
# Add arguments to the parser
|
# Add arguments to the parser
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -38,9 +32,10 @@ parser.add_argument(
|
|||||||
default=auto_detect_nvidia_target(),
|
default=auto_detect_nvidia_target(),
|
||||||
help="Specify the target device for benchmarking.",
|
help="Specify the target device for benchmarking.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument("--group_size",
|
||||||
"--group_size", type=int, default=None, help="Group size for grouped quantization."
|
type=int,
|
||||||
)
|
default=None,
|
||||||
|
help="Group size for grouped quantization.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--A_dtype",
|
"--A_dtype",
|
||||||
type=str,
|
type=str,
|
||||||
@ -87,17 +82,17 @@ parser.add_argument(
|
|||||||
choices=["nt", "nn"],
|
choices=["nt", "nn"],
|
||||||
help="Matrix layout, 'nt' for non-transpose A and transpose W.",
|
help="Matrix layout, 'nt' for non-transpose A and transpose W.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument("--with_bias",
|
||||||
"--with_bias", action="store_true", help="Include bias in the benchmark."
|
action="store_true",
|
||||||
)
|
help="Include bias in the benchmark.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--with_scaling",
|
"--with_scaling",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Include scaling factor in the quantization.",
|
help="Include scaling factor in the quantization.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument("--with_zeros",
|
||||||
"--with_zeros", action="store_true", help="Include zeros in the quantization."
|
action="store_true",
|
||||||
)
|
help="Include zeros in the quantization.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--zeros_mode",
|
"--zeros_mode",
|
||||||
type=str,
|
type=str,
|
||||||
@ -175,7 +170,8 @@ shapes = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Build test shapes with all the shared arguments
|
# Build test shapes with all the shared arguments
|
||||||
test_shapes = [(MatmulConfig, Matmul, (*shape, *shared_args)) for shape in shapes]
|
test_shapes = [(MatmulConfig, Matmul, (*shape, *shared_args))
|
||||||
|
for shape in shapes]
|
||||||
|
|
||||||
benchmark_sets = []
|
benchmark_sets = []
|
||||||
benchmark_sets.extend(test_shapes)
|
benchmark_sets.extend(test_shapes)
|
||||||
@ -210,12 +206,12 @@ for config_key, values in benchmark_results.items():
|
|||||||
func_name = args_split[0]
|
func_name = args_split[0]
|
||||||
input_args_str = "-".join(args_split[1:])
|
input_args_str = "-".join(args_split[1:])
|
||||||
col_widths[0] = max(col_widths[0], len(func_name) + 2, len(headers[0]) + 2)
|
col_widths[0] = max(col_widths[0], len(func_name) + 2, len(headers[0]) + 2)
|
||||||
col_widths[1] = max(col_widths[1], len(input_args_str) + 2, len(headers[1]) + 2)
|
col_widths[1] = max(col_widths[1],
|
||||||
col_widths[2] = max(
|
len(input_args_str) + 2,
|
||||||
col_widths[2],
|
len(headers[1]) + 2)
|
||||||
len(f"{values['BitBLAS_top20_latency']:.3f} ms") + 2,
|
col_widths[2] = max(col_widths[2],
|
||||||
len(headers[2]) + 2,
|
len(f"{values['BitBLAS_top20_latency']:.3f} ms") + 2,
|
||||||
)
|
len(headers[2]) + 2)
|
||||||
# break only if you want to measure widths from a single example;
|
# break only if you want to measure widths from a single example;
|
||||||
# otherwise, let it loop over all items.
|
# otherwise, let it loop over all items.
|
||||||
|
|
||||||
@ -236,6 +232,5 @@ for config_key, values in benchmark_results.items():
|
|||||||
f"{values['BitBLAS_top20_latency']:.3f} ms",
|
f"{values['BitBLAS_top20_latency']:.3f} ms",
|
||||||
]
|
]
|
||||||
row_str = "".join(
|
row_str = "".join(
|
||||||
[str(cell).ljust(col_widths[idx]) for idx, cell in enumerate(row)]
|
[str(cell).ljust(col_widths[idx]) for idx, cell in enumerate(row)])
|
||||||
)
|
|
||||||
print(row_str)
|
print(row_str)
|
||||||
|
|||||||
@ -1,489 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
"""
|
|
||||||
Benchmark the performance of the cutlass_moe_fp4 kernel vs the triton_moe
|
|
||||||
kernel. The cutlass_moe_fp4 kernel takes in fp4 quantized weights and 16-bit
|
|
||||||
activations. The triton_moe kernel takes in fp8 weights(tensor scaled to fp8)
|
|
||||||
and 16-bit activations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import nvtx
|
|
||||||
import torch
|
|
||||||
import torch.utils.benchmark as benchmark
|
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
|
||||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
|
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
|
|
||||||
from vllm.scalar_type import scalar_types
|
|
||||||
from vllm.utils import FlexibleArgumentParser
|
|
||||||
|
|
||||||
WEIGHT_SHAPES_MOE = {
|
|
||||||
"nvidia/DeepSeek-R1-FP4": [
|
|
||||||
[256, 8, 2048, 7168],
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
DEFAULT_MODELS = [
|
|
||||||
"nvidia/DeepSeek-R1-FP4",
|
|
||||||
]
|
|
||||||
|
|
||||||
DEFAULT_BATCH_SIZES = [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
|
||||||
DEFAULT_TP_SIZES = [1]
|
|
||||||
|
|
||||||
PER_ACT_TOKEN_OPTS = [False]
|
|
||||||
PER_OUT_CH_OPTS = [False]
|
|
||||||
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
|
||||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
|
||||||
|
|
||||||
|
|
||||||
def to_fp8(tensor: torch.Tensor):
|
|
||||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
|
||||||
return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to(
|
|
||||||
dtype=torch.float8_e4m3fn
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def bench_run(
|
|
||||||
results: list[benchmark.Measurement],
|
|
||||||
model: str,
|
|
||||||
num_experts: int,
|
|
||||||
topk: int,
|
|
||||||
per_act_token: bool,
|
|
||||||
per_out_ch: bool,
|
|
||||||
mkn: tuple[int, int, int],
|
|
||||||
):
|
|
||||||
label = "NVFP4 Blockscaled CUTLASS MOE vs FP8 Tensor Scaled Triton"
|
|
||||||
|
|
||||||
sub_label = (
|
|
||||||
"{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, MKN=({})".format(
|
|
||||||
model, num_experts, topk, per_act_token, per_out_ch, mkn
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Testing: {sub_label}")
|
|
||||||
|
|
||||||
(m, k, n) = mkn
|
|
||||||
|
|
||||||
dtype = torch.half
|
|
||||||
device = "cuda"
|
|
||||||
a = torch.randn((m, k), device=device, dtype=dtype) / 10
|
|
||||||
w1 = torch.randn((num_experts, 2 * n, k), device=device, dtype=dtype) / 10
|
|
||||||
w2 = torch.randn((num_experts, k, n), device=device, dtype=dtype) / 10
|
|
||||||
|
|
||||||
_, a_fp8_scale = ops.scaled_fp8_quant(a)
|
|
||||||
|
|
||||||
w1_fp8q = torch.empty(
|
|
||||||
(num_experts, 2 * n, k), device=device, dtype=torch.float8_e4m3fn
|
|
||||||
)
|
|
||||||
w2_fp8q = torch.empty((num_experts, k, n), device=device, dtype=torch.float8_e4m3fn)
|
|
||||||
w1_fp8scale = torch.empty((num_experts, 1, 1), device=device, dtype=torch.float32)
|
|
||||||
w2_fp8scale = torch.empty((num_experts, 1, 1), device=device, dtype=torch.float32)
|
|
||||||
|
|
||||||
for expert in range(num_experts):
|
|
||||||
w1_fp8q[expert], w1_fp8scale[expert] = ops.scaled_fp8_quant(w1[expert])
|
|
||||||
w2_fp8q[expert], w2_fp8scale[expert] = ops.scaled_fp8_quant(w2[expert])
|
|
||||||
|
|
||||||
w1_fp8q_notransp = w1_fp8q.clone()
|
|
||||||
w2_fp8q_notransp = w2_fp8q.clone()
|
|
||||||
w1_fp8q = w1_fp8q.transpose(1, 2)
|
|
||||||
w2_fp8q = w2_fp8q.transpose(1, 2)
|
|
||||||
|
|
||||||
score = torch.randn((m, num_experts), device=device, dtype=dtype)
|
|
||||||
|
|
||||||
topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False)
|
|
||||||
|
|
||||||
quant_blocksize = 16
|
|
||||||
w1_blockscale = torch.empty(
|
|
||||||
(num_experts, 2 * n, k // quant_blocksize),
|
|
||||||
device=device,
|
|
||||||
dtype=torch.float8_e4m3fn,
|
|
||||||
)
|
|
||||||
w2_blockscale = torch.empty(
|
|
||||||
(num_experts, k, n // quant_blocksize), device=device, dtype=torch.float8_e4m3fn
|
|
||||||
)
|
|
||||||
|
|
||||||
# n_b_scales = 2 * n if per_out_ch else 1
|
|
||||||
# k_b_scales = k if per_out_ch else 1
|
|
||||||
w1_fp4 = torch.empty((num_experts, 2 * n, k // 2), device=device, dtype=torch.uint8)
|
|
||||||
w2_fp4 = torch.empty((num_experts, k, n // 2), device=device, dtype=torch.uint8)
|
|
||||||
|
|
||||||
w1_gs = torch.empty((num_experts,), device=device, dtype=torch.float32)
|
|
||||||
w2_gs = torch.empty((num_experts,), device=device, dtype=torch.float32)
|
|
||||||
a1_gs = torch.ones((num_experts,), device=device, dtype=torch.float32)
|
|
||||||
a2_gs = torch.ones((num_experts,), device=device, dtype=torch.float32)
|
|
||||||
|
|
||||||
for expert in range(num_experts):
|
|
||||||
w1_e = w1[expert]
|
|
||||||
w2_e = w2[expert]
|
|
||||||
w1_amax = torch.abs(w1_e).max().to(torch.float32)
|
|
||||||
w2_amax = torch.abs(w2_e).max().to(torch.float32)
|
|
||||||
w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax
|
|
||||||
w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax
|
|
||||||
|
|
||||||
w1_fp4[expert], w1_blockscale[expert] = ops.scaled_fp4_quant(
|
|
||||||
w1_e, w1_gs[expert]
|
|
||||||
)
|
|
||||||
|
|
||||||
w2_fp4[expert], w2_blockscale[expert] = ops.scaled_fp4_quant(
|
|
||||||
w2_e, w2_gs[expert]
|
|
||||||
)
|
|
||||||
|
|
||||||
def run_triton_moe(
|
|
||||||
a: torch.Tensor,
|
|
||||||
w1: torch.Tensor,
|
|
||||||
w2: torch.Tensor,
|
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
w1_scale: torch.Tensor,
|
|
||||||
w2_scale: torch.Tensor,
|
|
||||||
a_fp8_scale: torch.Tensor,
|
|
||||||
num_repeats: int,
|
|
||||||
):
|
|
||||||
for _ in range(num_repeats):
|
|
||||||
fused_experts(
|
|
||||||
a,
|
|
||||||
w1,
|
|
||||||
w2,
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
use_fp8_w8a8=True,
|
|
||||||
w1_scale=w1_scale,
|
|
||||||
w2_scale=w2_scale,
|
|
||||||
a1_scale=a_fp8_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
def run_cutlass_moe_fp4(
|
|
||||||
a: torch.Tensor,
|
|
||||||
w1_fp4: torch.Tensor,
|
|
||||||
w2_fp4: torch.Tensor,
|
|
||||||
w1_blockscale: torch.Tensor,
|
|
||||||
w2_blockscale: torch.Tensor,
|
|
||||||
w1_gs: torch.Tensor,
|
|
||||||
w2_gs: torch.Tensor,
|
|
||||||
a1_gs: torch.Tensor,
|
|
||||||
a2_gs: torch.Tensor,
|
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
m: int,
|
|
||||||
n: int,
|
|
||||||
k: int,
|
|
||||||
e: int,
|
|
||||||
device: torch.device,
|
|
||||||
num_repeats: int,
|
|
||||||
):
|
|
||||||
for _ in range(num_repeats):
|
|
||||||
with nvtx.annotate("cutlass_moe_fp4", color="green"):
|
|
||||||
cutlass_moe_fp4(
|
|
||||||
a=a,
|
|
||||||
a1_gscale=a1_gs,
|
|
||||||
a2_gscale=a2_gs,
|
|
||||||
w1_fp4=w1_fp4,
|
|
||||||
w1_blockscale=w1_blockscale,
|
|
||||||
w1_alphas=w1_gs,
|
|
||||||
w2_fp4=w2_fp4,
|
|
||||||
w2_blockscale=w2_blockscale,
|
|
||||||
w2_alphas=w2_gs,
|
|
||||||
topk_weights=topk_weights,
|
|
||||||
topk_ids=topk_ids,
|
|
||||||
m=m,
|
|
||||||
n=n,
|
|
||||||
k=k,
|
|
||||||
e=num_experts,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
def run_cutlass_from_graph(
|
|
||||||
a: torch.Tensor,
|
|
||||||
a1_gscale: torch.Tensor,
|
|
||||||
w1_fp4: torch.Tensor,
|
|
||||||
w1_blockscale: torch.Tensor,
|
|
||||||
w1_alphas: torch.Tensor,
|
|
||||||
a2_gscale: torch.Tensor,
|
|
||||||
w2_fp4: torch.Tensor,
|
|
||||||
w2_blockscale: torch.Tensor,
|
|
||||||
w2_alphas: torch.Tensor,
|
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
m: int,
|
|
||||||
n: int,
|
|
||||||
k: int,
|
|
||||||
e: int,
|
|
||||||
device: torch.device,
|
|
||||||
):
|
|
||||||
with set_current_vllm_config(
|
|
||||||
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
|
||||||
):
|
|
||||||
return cutlass_moe_fp4(
|
|
||||||
a=a,
|
|
||||||
a1_gscale=a1_gs,
|
|
||||||
w1_fp4=w1_fp4,
|
|
||||||
w1_blockscale=w1_blockscale,
|
|
||||||
w1_alphas=w1_alphas,
|
|
||||||
a2_gscale=a2_gs,
|
|
||||||
w2_fp4=w2_fp4,
|
|
||||||
w2_blockscale=w2_blockscale,
|
|
||||||
w2_alphas=w2_alphas,
|
|
||||||
topk_weights=topk_weights,
|
|
||||||
topk_ids=topk_ids,
|
|
||||||
m=m,
|
|
||||||
n=n,
|
|
||||||
k=k,
|
|
||||||
e=num_experts,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
def run_triton_from_graph(
|
|
||||||
a: torch.Tensor,
|
|
||||||
w1: torch.Tensor,
|
|
||||||
w2: torch.Tensor,
|
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
w1_scale: torch.Tensor,
|
|
||||||
w2_scale: torch.Tensor,
|
|
||||||
a_fp8_scale: torch.Tensor,
|
|
||||||
):
|
|
||||||
with set_current_vllm_config(
|
|
||||||
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
|
||||||
):
|
|
||||||
return fused_experts(
|
|
||||||
a,
|
|
||||||
w1,
|
|
||||||
w2,
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
use_fp8_w8a8=True,
|
|
||||||
w1_scale=w1_scale,
|
|
||||||
w2_scale=w2_scale,
|
|
||||||
a1_scale=a_fp8_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
def replay_graph(graph, num_repeats):
|
|
||||||
for _ in range(num_repeats):
|
|
||||||
graph.replay()
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
cutlass_stream = torch.cuda.Stream()
|
|
||||||
cutlass_graph = torch.cuda.CUDAGraph()
|
|
||||||
with torch.cuda.graph(cutlass_graph, stream=cutlass_stream):
|
|
||||||
run_cutlass_from_graph(
|
|
||||||
a=a,
|
|
||||||
a1_gscale=a1_gs,
|
|
||||||
w1_fp4=w1_fp4,
|
|
||||||
w1_blockscale=w1_blockscale,
|
|
||||||
w1_alphas=w1_gs,
|
|
||||||
a2_gscale=a2_gs,
|
|
||||||
w2_fp4=w2_fp4,
|
|
||||||
w2_blockscale=w2_blockscale,
|
|
||||||
w2_alphas=w2_gs,
|
|
||||||
topk_weights=topk_weights,
|
|
||||||
topk_ids=topk_ids,
|
|
||||||
m=m,
|
|
||||||
n=n,
|
|
||||||
k=k,
|
|
||||||
e=num_experts,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
triton_stream = torch.cuda.Stream()
|
|
||||||
triton_graph = torch.cuda.CUDAGraph()
|
|
||||||
with torch.cuda.graph(triton_graph, stream=triton_stream):
|
|
||||||
run_triton_from_graph(
|
|
||||||
a,
|
|
||||||
w1_fp8q_notransp,
|
|
||||||
w2_fp8q_notransp,
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
w1_fp8scale,
|
|
||||||
w2_fp8scale,
|
|
||||||
a_fp8_scale,
|
|
||||||
)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
min_run_time = 5
|
|
||||||
num_warmup = 5
|
|
||||||
num_runs = 25
|
|
||||||
|
|
||||||
globals = {
|
|
||||||
# Baseline params
|
|
||||||
"w1": w1,
|
|
||||||
"w2": w2,
|
|
||||||
"score": score,
|
|
||||||
"topk": topk,
|
|
||||||
"w1_fp8q_notransp": w1_fp8q_notransp,
|
|
||||||
"w2_fp8q_notransp": w2_fp8q_notransp,
|
|
||||||
"w1_fp8scale": w1_fp8scale,
|
|
||||||
"w2_fp8scale": w2_fp8scale,
|
|
||||||
"a_fp8_scale": a_fp8_scale,
|
|
||||||
# Cutlass params
|
|
||||||
"a": a,
|
|
||||||
"a1_gscale": a1_gs,
|
|
||||||
"w1_fp4": w1_fp4,
|
|
||||||
"w1_blockscale": w1_blockscale,
|
|
||||||
"w1_alphas": w1_gs,
|
|
||||||
"a2_gscale": a2_gs,
|
|
||||||
"w2_fp4": w2_fp4,
|
|
||||||
"w2_blockscale": w2_blockscale,
|
|
||||||
"w2_alphas": w2_gs,
|
|
||||||
"topk_weights": topk_weights,
|
|
||||||
"topk_ids": topk_ids,
|
|
||||||
"m": m,
|
|
||||||
"n": n,
|
|
||||||
"k": k,
|
|
||||||
"e": num_experts,
|
|
||||||
"device": device,
|
|
||||||
# cuda graph params
|
|
||||||
"cutlass_graph": cutlass_graph,
|
|
||||||
"triton_graph": triton_graph,
|
|
||||||
# Gen params
|
|
||||||
"num_runs": num_runs,
|
|
||||||
# Kernels
|
|
||||||
"run_triton_moe": run_triton_moe,
|
|
||||||
"run_cutlass_moe_fp4": run_cutlass_moe_fp4,
|
|
||||||
"replay_graph": replay_graph,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Warmup
|
|
||||||
run_triton_moe(
|
|
||||||
a,
|
|
||||||
w1_fp8q_notransp,
|
|
||||||
w2_fp8q_notransp,
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
w1_fp8scale,
|
|
||||||
w2_fp8scale,
|
|
||||||
a_fp8_scale,
|
|
||||||
num_warmup,
|
|
||||||
)
|
|
||||||
|
|
||||||
results.append(
|
|
||||||
benchmark.Timer(
|
|
||||||
stmt="run_triton_moe(a, w1_fp8q_notransp, w2_fp8q_notransp, topk_weights, topk_ids, w1_fp8scale, w2_fp8scale, a_fp8_scale, num_runs)", # noqa: E501
|
|
||||||
globals=globals,
|
|
||||||
label=label,
|
|
||||||
sub_label=sub_label,
|
|
||||||
description="triton_moe",
|
|
||||||
).blocked_autorange(min_run_time=min_run_time)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Warmup
|
|
||||||
replay_graph(triton_graph, num_warmup)
|
|
||||||
|
|
||||||
results.append(
|
|
||||||
benchmark.Timer(
|
|
||||||
stmt="replay_graph(triton_graph, num_runs)",
|
|
||||||
globals=globals,
|
|
||||||
label=label,
|
|
||||||
sub_label=sub_label,
|
|
||||||
description="triton_moe_cuda_graphs",
|
|
||||||
).blocked_autorange(min_run_time=min_run_time)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Warmup
|
|
||||||
|
|
||||||
run_cutlass_moe_fp4(
|
|
||||||
a,
|
|
||||||
w1_fp4,
|
|
||||||
w2_fp4,
|
|
||||||
w1_blockscale,
|
|
||||||
w2_blockscale,
|
|
||||||
w1_gs,
|
|
||||||
w2_gs,
|
|
||||||
a1_gs,
|
|
||||||
a2_gs,
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
m,
|
|
||||||
n,
|
|
||||||
k,
|
|
||||||
num_experts,
|
|
||||||
device,
|
|
||||||
num_warmup,
|
|
||||||
)
|
|
||||||
|
|
||||||
results.append(
|
|
||||||
benchmark.Timer(
|
|
||||||
stmt="run_cutlass_moe_fp4(a, w1_fp4, w2_fp4, w1_blockscale, w2_blockscale, w1_alphas, w2_alphas, a1_gscale, a2_gscale, topk_weights, topk_ids, m, n, k, e, device, num_runs)", # noqa: E501
|
|
||||||
globals=globals,
|
|
||||||
label=label,
|
|
||||||
sub_label=sub_label,
|
|
||||||
description="cutlass_moe_fp4",
|
|
||||||
).blocked_autorange(min_run_time=min_run_time)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Warmup
|
|
||||||
replay_graph(cutlass_graph, num_warmup)
|
|
||||||
|
|
||||||
results.append(
|
|
||||||
benchmark.Timer(
|
|
||||||
stmt="replay_graph(cutlass_graph, num_runs)",
|
|
||||||
globals=globals,
|
|
||||||
label=label,
|
|
||||||
sub_label=sub_label,
|
|
||||||
description="cutlass_moe_fp4_cuda_graphs",
|
|
||||||
).blocked_autorange(min_run_time=min_run_time)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
|
||||||
print("Benchmarking models:")
|
|
||||||
for i, model in enumerate(args.models):
|
|
||||||
print(f"[{i}] {model}")
|
|
||||||
|
|
||||||
results: list[benchmark.Measurement] = []
|
|
||||||
|
|
||||||
for model in args.models:
|
|
||||||
for tp in args.tp_sizes:
|
|
||||||
for layer in WEIGHT_SHAPES_MOE[model]:
|
|
||||||
num_experts = layer[0]
|
|
||||||
topk = layer[1]
|
|
||||||
size_k = layer[2]
|
|
||||||
size_n = layer[3] // tp
|
|
||||||
|
|
||||||
if len(args.limit_k) > 0 and size_k not in args.limit_k:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if len(args.limit_n) > 0 and size_n not in args.limit_n:
|
|
||||||
continue
|
|
||||||
|
|
||||||
for per_act_token in PER_ACT_TOKEN_OPTS:
|
|
||||||
for per_out_ch in PER_OUT_CH_OPTS:
|
|
||||||
for size_m in args.batch_sizes:
|
|
||||||
mkn = (size_m, size_k, size_n)
|
|
||||||
bench_run(
|
|
||||||
results,
|
|
||||||
model,
|
|
||||||
num_experts,
|
|
||||||
topk,
|
|
||||||
per_act_token,
|
|
||||||
per_out_ch,
|
|
||||||
mkn,
|
|
||||||
)
|
|
||||||
|
|
||||||
compare = benchmark.Compare(results)
|
|
||||||
compare.print()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = FlexibleArgumentParser(
|
|
||||||
description="Benchmark NVFP4 CUTLASS MOE across specified models/shapes/batches"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--models",
|
|
||||||
nargs="+",
|
|
||||||
type=str,
|
|
||||||
default=DEFAULT_MODELS,
|
|
||||||
choices=WEIGHT_SHAPES_MOE.keys(),
|
|
||||||
)
|
|
||||||
parser.add_argument("--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES)
|
|
||||||
parser.add_argument(
|
|
||||||
"--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
|
|
||||||
)
|
|
||||||
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
|
|
||||||
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
|
|
||||||
parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[])
|
|
||||||
parser.add_argument("--limit-per-act-token", nargs="+", type=int, default=[])
|
|
||||||
parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[])
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
main(args)
|
|
||||||
@ -6,18 +6,14 @@ from benchmark_shapes import WEIGHT_SHAPES_MOE
|
|||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe_fp8,
|
||||||
cutlass_moe_fp8,
|
fused_experts,
|
||||||
fused_experts,
|
fused_topk)
|
||||||
fused_topk,
|
|
||||||
)
|
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
DEFAULT_MODELS = [
|
DEFAULT_MODELS = [
|
||||||
"nm-testing/Mixtral-8x7B-Instruct-v0.1",
|
"nm-testing/Mixtral-8x7B-Instruct-v0.1", "nm-testing/deepseekv2-lite",
|
||||||
"nm-testing/deepseekv2-lite",
|
"ibm-granite/granite-3.0-1b-a400m", "ibm-granite/granite-3.0-3b-a800m"
|
||||||
"ibm-granite/granite-3.0-1b-a400m",
|
|
||||||
"ibm-granite/granite-3.0-3b-a800m",
|
|
||||||
]
|
]
|
||||||
DEFAULT_BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128, 256, 512]
|
DEFAULT_BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128, 256, 512]
|
||||||
DEFAULT_TP_SIZES = [1]
|
DEFAULT_TP_SIZES = [1]
|
||||||
@ -28,27 +24,19 @@ PER_OUT_CH_OPTS = [False]
|
|||||||
|
|
||||||
def to_fp8(tensor: torch.Tensor):
|
def to_fp8(tensor: torch.Tensor):
|
||||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||||
return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to(
|
return torch.round(tensor.clamp(
|
||||||
dtype=torch.float8_e4m3fn
|
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def bench_run(
|
def bench_run(results: list[benchmark.Measurement], model: str,
|
||||||
results: list[benchmark.Measurement],
|
num_experts: int, topk: int, per_act_token: bool,
|
||||||
model: str,
|
per_out_ch: bool, mkn: tuple[int, int, int]):
|
||||||
num_experts: int,
|
|
||||||
topk: int,
|
|
||||||
per_act_token: bool,
|
|
||||||
per_out_ch: bool,
|
|
||||||
mkn: tuple[int, int, int],
|
|
||||||
):
|
|
||||||
label = "Quant Matmul"
|
label = "Quant Matmul"
|
||||||
|
|
||||||
sub_label = (
|
sub_label = (
|
||||||
"{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, MKN=({})".format(
|
"{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, "
|
||||||
model, num_experts, topk, per_act_token, per_out_ch, mkn
|
"MKN=({})".format(model, num_experts, topk, per_act_token, per_out_ch,
|
||||||
)
|
mkn))
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Testing: {sub_label}")
|
print(f"Testing: {sub_label}")
|
||||||
|
|
||||||
@ -62,17 +50,35 @@ def bench_run(
|
|||||||
|
|
||||||
_, a_scale = ops.scaled_fp8_quant(a)
|
_, a_scale = ops.scaled_fp8_quant(a)
|
||||||
|
|
||||||
w1_q = torch.empty(
|
w1_q = torch.empty((num_experts, 2 * n, k),
|
||||||
(num_experts, 2 * n, k), device="cuda", dtype=torch.float8_e4m3fn
|
device="cuda",
|
||||||
)
|
dtype=torch.float8_e4m3fn)
|
||||||
w2_q = torch.empty((num_experts, k, n), device="cuda", dtype=torch.float8_e4m3fn)
|
w2_q = torch.empty((num_experts, k, n),
|
||||||
w1_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32)
|
device="cuda",
|
||||||
w2_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32)
|
dtype=torch.float8_e4m3fn)
|
||||||
|
w1_scale = torch.empty((num_experts, 1, 1),
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.float32)
|
||||||
|
w2_scale = torch.empty((num_experts, 1, 1),
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.float32)
|
||||||
|
|
||||||
ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
|
ab_strides1 = torch.full((num_experts, ),
|
||||||
c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
|
k,
|
||||||
ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
|
device="cuda",
|
||||||
c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
|
dtype=torch.int64)
|
||||||
|
c_strides1 = torch.full((num_experts, ),
|
||||||
|
2 * n,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.int64)
|
||||||
|
ab_strides2 = torch.full((num_experts, ),
|
||||||
|
n,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.int64)
|
||||||
|
c_strides2 = torch.full((num_experts, ),
|
||||||
|
k,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.int64)
|
||||||
|
|
||||||
for expert in range(num_experts):
|
for expert in range(num_experts):
|
||||||
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert])
|
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert])
|
||||||
@ -85,120 +91,82 @@ def bench_run(
|
|||||||
score = torch.randn((m, num_experts), device="cuda", dtype=dtype)
|
score = torch.randn((m, num_experts), device="cuda", dtype=dtype)
|
||||||
|
|
||||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||||
a, score, topk, renormalize=False
|
a, score, topk, renormalize=False)
|
||||||
)
|
|
||||||
|
|
||||||
def run_triton_moe(
|
def run_triton_moe(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
|
||||||
a: torch.Tensor,
|
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1_scale: torch.Tensor, w2_scale: torch.Tensor,
|
||||||
w2: torch.Tensor,
|
a_scale: torch.Tensor, num_repeats: int):
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
w1_scale: torch.Tensor,
|
|
||||||
w2_scale: torch.Tensor,
|
|
||||||
a_scale: torch.Tensor,
|
|
||||||
num_repeats: int,
|
|
||||||
):
|
|
||||||
for _ in range(num_repeats):
|
for _ in range(num_repeats):
|
||||||
fused_experts(
|
fused_experts(a,
|
||||||
a,
|
w1,
|
||||||
w1,
|
w2,
|
||||||
w2,
|
topk_weights,
|
||||||
topk_weights,
|
topk_ids,
|
||||||
topk_ids,
|
use_fp8_w8a8=True,
|
||||||
use_fp8_w8a8=True,
|
w1_scale=w1_scale,
|
||||||
w1_scale=w1_scale,
|
w2_scale=w2_scale,
|
||||||
w2_scale=w2_scale,
|
a1_scale=a_scale)
|
||||||
a1_scale=a_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
def run_cutlass_moe(
|
def run_cutlass_moe(a: torch.Tensor, a_scale: torch.Tensor,
|
||||||
a: torch.Tensor,
|
w1: torch.Tensor, w2: torch.Tensor,
|
||||||
a_scale: torch.Tensor,
|
w1_scale: torch.Tensor, w2_scale: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||||
w2: torch.Tensor,
|
ab_strides1: torch.Tensor, c_strides1: torch.Tensor,
|
||||||
w1_scale: torch.Tensor,
|
ab_strides2: torch.Tensor, c_strides2: torch.Tensor,
|
||||||
w2_scale: torch.Tensor,
|
num_repeats: int):
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
ab_strides1: torch.Tensor,
|
|
||||||
c_strides1: torch.Tensor,
|
|
||||||
ab_strides2: torch.Tensor,
|
|
||||||
c_strides2: torch.Tensor,
|
|
||||||
num_repeats: int,
|
|
||||||
):
|
|
||||||
for _ in range(num_repeats):
|
for _ in range(num_repeats):
|
||||||
cutlass_moe_fp8(
|
cutlass_moe_fp8(a,
|
||||||
a,
|
w1,
|
||||||
w1,
|
w2,
|
||||||
w2,
|
w1_scale,
|
||||||
w1_scale,
|
w2_scale,
|
||||||
w2_scale,
|
topk_weights,
|
||||||
topk_weights,
|
topk_ids,
|
||||||
topk_ids,
|
ab_strides1,
|
||||||
ab_strides1,
|
c_strides1,
|
||||||
c_strides1,
|
ab_strides2,
|
||||||
ab_strides2,
|
c_strides2,
|
||||||
c_strides2,
|
a1_scale=a_scale)
|
||||||
a1_scale=a_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
def run_cutlass_from_graph(
|
def run_cutlass_from_graph(
|
||||||
a: torch.Tensor,
|
a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor,
|
||||||
a_scale: torch.Tensor,
|
w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor,
|
||||||
w1_q: torch.Tensor,
|
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||||
w2_q: torch.Tensor,
|
ab_strides1: torch.Tensor, c_strides1: torch.Tensor,
|
||||||
w1_scale: torch.Tensor,
|
ab_strides2: torch.Tensor, c_strides2: torch.Tensor):
|
||||||
w2_scale: torch.Tensor,
|
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
ab_strides1: torch.Tensor,
|
|
||||||
c_strides1: torch.Tensor,
|
|
||||||
ab_strides2: torch.Tensor,
|
|
||||||
c_strides2: torch.Tensor,
|
|
||||||
):
|
|
||||||
with set_current_vllm_config(
|
with set_current_vllm_config(
|
||||||
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
VllmConfig(parallel_config=ParallelConfig(
|
||||||
):
|
pipeline_parallel_size=1))):
|
||||||
return cutlass_moe_fp8(
|
return cutlass_moe_fp8(a,
|
||||||
a,
|
w1_q,
|
||||||
w1_q,
|
w2_q,
|
||||||
w2_q,
|
w1_scale,
|
||||||
w1_scale,
|
w2_scale,
|
||||||
w2_scale,
|
topk_weights,
|
||||||
topk_weights,
|
topk_ids,
|
||||||
topk_ids,
|
ab_strides1,
|
||||||
ab_strides1,
|
c_strides1,
|
||||||
c_strides1,
|
ab_strides2,
|
||||||
ab_strides2,
|
c_strides2,
|
||||||
c_strides2,
|
a1_scale=a_scale)
|
||||||
a1_scale=a_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
def run_triton_from_graph(
|
def run_triton_from_graph(a: torch.Tensor, w1: torch.Tensor,
|
||||||
a: torch.Tensor,
|
w2: torch.Tensor, topk_weights: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
topk_ids: torch.Tensor, w1_scale: torch.Tensor,
|
||||||
w2: torch.Tensor,
|
w2_scale: torch.Tensor, a_scale: torch.Tensor):
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
w1_scale: torch.Tensor,
|
|
||||||
w2_scale: torch.Tensor,
|
|
||||||
a_scale: torch.Tensor,
|
|
||||||
):
|
|
||||||
with set_current_vllm_config(
|
with set_current_vllm_config(
|
||||||
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
VllmConfig(parallel_config=ParallelConfig(
|
||||||
):
|
pipeline_parallel_size=1))):
|
||||||
return fused_experts(
|
return fused_experts(a,
|
||||||
a,
|
w1,
|
||||||
w1,
|
w2,
|
||||||
w2,
|
topk_weights,
|
||||||
topk_weights,
|
topk_ids,
|
||||||
topk_ids,
|
use_fp8_w8a8=True,
|
||||||
use_fp8_w8a8=True,
|
w1_scale=w1_scale,
|
||||||
w1_scale=w1_scale,
|
w2_scale=w2_scale,
|
||||||
w2_scale=w2_scale,
|
a1_scale=a_scale)
|
||||||
a1_scale=a_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
def replay_graph(graph, num_repeats):
|
def replay_graph(graph, num_repeats):
|
||||||
for _ in range(num_repeats):
|
for _ in range(num_repeats):
|
||||||
@ -208,35 +176,16 @@ def bench_run(
|
|||||||
cutlass_stream = torch.cuda.Stream()
|
cutlass_stream = torch.cuda.Stream()
|
||||||
cutlass_graph = torch.cuda.CUDAGraph()
|
cutlass_graph = torch.cuda.CUDAGraph()
|
||||||
with torch.cuda.graph(cutlass_graph, stream=cutlass_stream):
|
with torch.cuda.graph(cutlass_graph, stream=cutlass_stream):
|
||||||
run_cutlass_from_graph(
|
run_cutlass_from_graph(a, a_scale, w1_q, w2_q, w1_scale, w2_scale,
|
||||||
a,
|
topk_weights, topk_ids, ab_strides1, c_strides1,
|
||||||
a_scale,
|
ab_strides2, c_strides2)
|
||||||
w1_q,
|
|
||||||
w2_q,
|
|
||||||
w1_scale,
|
|
||||||
w2_scale,
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
ab_strides1,
|
|
||||||
c_strides1,
|
|
||||||
ab_strides2,
|
|
||||||
c_strides2,
|
|
||||||
)
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
triton_stream = torch.cuda.Stream()
|
triton_stream = torch.cuda.Stream()
|
||||||
triton_graph = torch.cuda.CUDAGraph()
|
triton_graph = torch.cuda.CUDAGraph()
|
||||||
with torch.cuda.graph(triton_graph, stream=triton_stream):
|
with torch.cuda.graph(triton_graph, stream=triton_stream):
|
||||||
run_triton_from_graph(
|
run_triton_from_graph(a, w1_q_notransp, w2_q_notransp, topk_weights,
|
||||||
a,
|
topk_ids, w1_scale, w2_scale, a_scale)
|
||||||
w1_q_notransp,
|
|
||||||
w2_q_notransp,
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
w1_scale,
|
|
||||||
w2_scale,
|
|
||||||
a_scale,
|
|
||||||
)
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
min_run_time = 5
|
min_run_time = 5
|
||||||
@ -276,27 +225,18 @@ def bench_run(
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Warmup
|
# Warmup
|
||||||
run_triton_moe(
|
run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids,
|
||||||
a,
|
w1_scale, w2_scale, a_scale, num_warmup)
|
||||||
w1_q_notransp,
|
|
||||||
w2_q_notransp,
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
w1_scale,
|
|
||||||
w2_scale,
|
|
||||||
a_scale,
|
|
||||||
num_warmup,
|
|
||||||
)
|
|
||||||
|
|
||||||
results.append(
|
results.append(
|
||||||
benchmark.Timer(
|
benchmark.Timer(
|
||||||
stmt="run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501
|
stmt=
|
||||||
|
"run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501
|
||||||
globals=globals,
|
globals=globals,
|
||||||
label=label,
|
label=label,
|
||||||
sub_label=sub_label,
|
sub_label=sub_label,
|
||||||
description="triton_moe",
|
description="triton_moe",
|
||||||
).blocked_autorange(min_run_time=min_run_time)
|
).blocked_autorange(min_run_time=min_run_time))
|
||||||
)
|
|
||||||
|
|
||||||
# Warmup
|
# Warmup
|
||||||
replay_graph(triton_graph, num_warmup)
|
replay_graph(triton_graph, num_warmup)
|
||||||
@ -308,35 +248,22 @@ def bench_run(
|
|||||||
label=label,
|
label=label,
|
||||||
sub_label=sub_label,
|
sub_label=sub_label,
|
||||||
description="triton_moe_cuda_graphs",
|
description="triton_moe_cuda_graphs",
|
||||||
).blocked_autorange(min_run_time=min_run_time)
|
).blocked_autorange(min_run_time=min_run_time))
|
||||||
)
|
|
||||||
|
|
||||||
# Warmup
|
# Warmup
|
||||||
run_cutlass_moe(
|
run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights,
|
||||||
a,
|
topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2,
|
||||||
a_scale,
|
num_warmup)
|
||||||
w1_q,
|
|
||||||
w2_q,
|
|
||||||
w1_scale,
|
|
||||||
w2_scale,
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
ab_strides1,
|
|
||||||
c_strides1,
|
|
||||||
ab_strides2,
|
|
||||||
c_strides2,
|
|
||||||
num_warmup,
|
|
||||||
)
|
|
||||||
|
|
||||||
results.append(
|
results.append(
|
||||||
benchmark.Timer(
|
benchmark.Timer(
|
||||||
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501
|
stmt=
|
||||||
|
"run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501
|
||||||
globals=globals,
|
globals=globals,
|
||||||
label=label,
|
label=label,
|
||||||
sub_label=sub_label,
|
sub_label=sub_label,
|
||||||
description="grouped_gemm_moe",
|
description="grouped_gemm_moe",
|
||||||
).blocked_autorange(min_run_time=min_run_time)
|
).blocked_autorange(min_run_time=min_run_time))
|
||||||
)
|
|
||||||
|
|
||||||
# Warmup
|
# Warmup
|
||||||
replay_graph(cutlass_graph, num_warmup)
|
replay_graph(cutlass_graph, num_warmup)
|
||||||
@ -348,8 +275,7 @@ def bench_run(
|
|||||||
label=label,
|
label=label,
|
||||||
sub_label=sub_label,
|
sub_label=sub_label,
|
||||||
description="grouped_gemm_moe_cuda_graphs",
|
description="grouped_gemm_moe_cuda_graphs",
|
||||||
).blocked_autorange(min_run_time=min_run_time)
|
).blocked_autorange(min_run_time=min_run_time))
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
@ -377,15 +303,8 @@ def main(args):
|
|||||||
for per_out_ch in PER_OUT_CH_OPTS:
|
for per_out_ch in PER_OUT_CH_OPTS:
|
||||||
for size_m in DEFAULT_BATCH_SIZES:
|
for size_m in DEFAULT_BATCH_SIZES:
|
||||||
mkn = (size_m, size_k, size_n)
|
mkn = (size_m, size_k, size_n)
|
||||||
bench_run(
|
bench_run(results, model, num_experts, topk,
|
||||||
results,
|
per_act_token, per_out_ch, mkn)
|
||||||
model,
|
|
||||||
num_experts,
|
|
||||||
topk,
|
|
||||||
per_act_token,
|
|
||||||
per_out_ch,
|
|
||||||
mkn,
|
|
||||||
)
|
|
||||||
|
|
||||||
compare = benchmark.Compare(results)
|
compare = benchmark.Compare(results)
|
||||||
compare.print()
|
compare.print()
|
||||||
@ -393,8 +312,7 @@ def main(args):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = FlexibleArgumentParser(
|
parser = FlexibleArgumentParser(
|
||||||
description="Benchmark Marlin across specified models/shapes/batches"
|
description="Benchmark Marlin across specified models/shapes/batches")
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--models",
|
"--models",
|
||||||
nargs="+",
|
nargs="+",
|
||||||
@ -402,14 +320,21 @@ if __name__ == "__main__":
|
|||||||
default=DEFAULT_MODELS,
|
default=DEFAULT_MODELS,
|
||||||
choices=WEIGHT_SHAPES_MOE.keys(),
|
choices=WEIGHT_SHAPES_MOE.keys(),
|
||||||
)
|
)
|
||||||
parser.add_argument("--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES)
|
parser.add_argument("--tp-sizes",
|
||||||
parser.add_argument(
|
nargs="+",
|
||||||
"--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
|
type=int,
|
||||||
)
|
default=DEFAULT_TP_SIZES)
|
||||||
|
parser.add_argument("--batch-sizes",
|
||||||
|
nargs="+",
|
||||||
|
type=int,
|
||||||
|
default=DEFAULT_BATCH_SIZES)
|
||||||
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
|
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
|
||||||
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
|
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
|
||||||
parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[])
|
parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[])
|
||||||
parser.add_argument("--limit-per-act-token", nargs="+", type=int, default=[])
|
parser.add_argument("--limit-per-act-token",
|
||||||
|
nargs="+",
|
||||||
|
type=int,
|
||||||
|
default=[])
|
||||||
parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[])
|
parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[])
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|||||||
@ -10,16 +10,14 @@ from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
|
|||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def main(
|
def main(num_tokens: int,
|
||||||
num_tokens: int,
|
hidden_size: int,
|
||||||
hidden_size: int,
|
add_residual: bool,
|
||||||
add_residual: bool,
|
dtype: torch.dtype,
|
||||||
dtype: torch.dtype,
|
seed: int = 0,
|
||||||
seed: int = 0,
|
do_profile: bool = False,
|
||||||
do_profile: bool = False,
|
num_warmup_iters: int = 5,
|
||||||
num_warmup_iters: int = 5,
|
num_iters: int = 100) -> None:
|
||||||
num_iters: int = 100,
|
|
||||||
) -> None:
|
|
||||||
current_platform.seed_everything(seed)
|
current_platform.seed_everything(seed)
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
|
|
||||||
@ -58,35 +56,33 @@ def main(
|
|||||||
print(f"Kernel running time: {latency * 1000000:.3f} us")
|
print(f"Kernel running time: {latency * 1000000:.3f} us")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
parser = FlexibleArgumentParser(description="Benchmark the layernorm kernel.")
|
parser = FlexibleArgumentParser(
|
||||||
|
description="Benchmark the layernorm kernel.")
|
||||||
parser.add_argument("--num-tokens", type=int, default=4096)
|
parser.add_argument("--num-tokens", type=int, default=4096)
|
||||||
parser.add_argument("--hidden-size", type=int, default=8192)
|
parser.add_argument("--hidden-size", type=int, default=8192)
|
||||||
parser.add_argument("--add-residual", action="store_true")
|
parser.add_argument("--add-residual", action="store_true")
|
||||||
parser.add_argument(
|
parser.add_argument("--dtype",
|
||||||
"--dtype", type=str, choices=["half", "bfloat16", "float"], default="half"
|
type=str,
|
||||||
)
|
choices=["half", "bfloat16", "float"],
|
||||||
|
default="half")
|
||||||
parser.add_argument("--seed", type=int, default=0)
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
parser.add_argument("--profile", action="store_true")
|
parser.add_argument("--profile", action="store_true")
|
||||||
parser.add_argument("--num-warmup-iters", type=int, default=5)
|
parser.add_argument("--num-warmup-iters", type=int, default=5)
|
||||||
parser.add_argument(
|
parser.add_argument("--num-iters",
|
||||||
"--num-iters",
|
type=int,
|
||||||
type=int,
|
default=100,
|
||||||
default=100,
|
help="Number of benchmark iterations. "
|
||||||
help="Number of benchmark iterations. "
|
"If --profile is set, this number is ignored")
|
||||||
"If --profile is set, this number is ignored",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
print(args)
|
print(args)
|
||||||
|
|
||||||
main(
|
main(num_tokens=args.num_tokens,
|
||||||
num_tokens=args.num_tokens,
|
hidden_size=args.hidden_size,
|
||||||
hidden_size=args.hidden_size,
|
add_residual=args.add_residual,
|
||||||
add_residual=args.add_residual,
|
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
|
||||||
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
|
seed=args.seed,
|
||||||
seed=args.seed,
|
do_profile=args.profile,
|
||||||
do_profile=args.profile,
|
num_warmup_iters=args.num_warmup_iters,
|
||||||
num_warmup_iters=args.num_warmup_iters,
|
num_iters=args.num_iters)
|
||||||
num_iters=args.num_iters,
|
|
||||||
)
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -20,18 +20,12 @@ from weight_shapes import WEIGHT_SHAPES
|
|||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
GPTQ_MARLIN_MAX_PARALLEL,
|
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, marlin_permute_scales,
|
||||||
GPTQ_MARLIN_MIN_THREAD_N,
|
marlin_zero_points)
|
||||||
marlin_permute_scales,
|
|
||||||
marlin_zero_points,
|
|
||||||
)
|
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||||
MarlinWorkspace,
|
MarlinWorkspace)
|
||||||
)
|
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
pack_rows,
|
pack_rows, quantize_weights)
|
||||||
quantize_weights,
|
|
||||||
)
|
|
||||||
from vllm.scalar_type import ScalarType, scalar_types
|
from vllm.scalar_type import ScalarType, scalar_types
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
@ -88,14 +82,12 @@ def rand_data(shape, dtype=torch.float16, scale=1):
|
|||||||
return torch.randint(-15, 15, shape, dtype=dtype, device="cuda")
|
return torch.randint(-15, 15, shape, dtype=dtype, device="cuda")
|
||||||
|
|
||||||
|
|
||||||
def quantize_and_pack(
|
def quantize_and_pack(atype: torch.dtype,
|
||||||
atype: torch.dtype,
|
w: torch.Tensor,
|
||||||
w: torch.Tensor,
|
wtype: ScalarType,
|
||||||
wtype: ScalarType,
|
stype: Optional[torch.dtype],
|
||||||
stype: Optional[torch.dtype],
|
group_size: Optional[int],
|
||||||
group_size: Optional[int],
|
zero_points: bool = False):
|
||||||
zero_points: bool = False,
|
|
||||||
):
|
|
||||||
assert wtype.is_integer(), "TODO: support floating point weights"
|
assert wtype.is_integer(), "TODO: support floating point weights"
|
||||||
|
|
||||||
w_ref, w_q, w_s, w_zp = quantize_weights(
|
w_ref, w_q, w_s, w_zp = quantize_weights(
|
||||||
@ -104,24 +96,21 @@ def quantize_and_pack(
|
|||||||
group_size=group_size,
|
group_size=group_size,
|
||||||
zero_points=zero_points,
|
zero_points=zero_points,
|
||||||
# to match how the kernel applies zps
|
# to match how the kernel applies zps
|
||||||
ref_zero_points_after_scales=True,
|
ref_zero_points_after_scales=True)
|
||||||
)
|
|
||||||
|
|
||||||
w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape)
|
w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape)
|
||||||
return w_ref, w_q, w_s, w_zp
|
return w_ref, w_q, w_s, w_zp
|
||||||
|
|
||||||
|
|
||||||
def create_bench_tensors(
|
def create_bench_tensors(shape: tuple[int, int, int], types: TypeConfig,
|
||||||
shape: tuple[int, int, int], types: TypeConfig, group_size: Optional[int]
|
group_size: Optional[int]) -> list[BenchmarkTensors]:
|
||||||
) -> list[BenchmarkTensors]:
|
|
||||||
m, n, k = shape
|
m, n, k = shape
|
||||||
|
|
||||||
# we want to make sure that weights don't fit into L2 cache between runs so
|
# we want to make sure that weights don't fit into L2 cache between runs so
|
||||||
# we construct enough weights to exceed L2 cache, which is 50mb on a H100
|
# we construct enough weights to exceed L2 cache, which is 50mb on a H100
|
||||||
# so we target total weight size > 2*50mb
|
# so we target total weight size > 2*50mb
|
||||||
num_weights = math.ceil(
|
num_weights = math.ceil(2 * 50 * 1024**2 * 8 /
|
||||||
2 * 50 * 1024**2 * 8 / (k * n * types.weight_type.size_bits)
|
(k * n * types.weight_type.size_bits))
|
||||||
)
|
|
||||||
|
|
||||||
a = rand_data((m, k), types.act_type, scale=5)
|
a = rand_data((m, k), types.act_type, scale=5)
|
||||||
|
|
||||||
@ -135,13 +124,8 @@ def create_bench_tensors(
|
|||||||
w = w.to(torch.float16)
|
w = w.to(torch.float16)
|
||||||
|
|
||||||
w_ref, w_q_packed, w_s, w_zp = quantize_and_pack(
|
w_ref, w_q_packed, w_s, w_zp = quantize_and_pack(
|
||||||
a.dtype,
|
a.dtype, w, types.weight_type, types.group_scale_type, group_size,
|
||||||
w,
|
types.group_zero_type is not None)
|
||||||
types.weight_type,
|
|
||||||
types.group_scale_type,
|
|
||||||
group_size,
|
|
||||||
types.group_zero_type is not None,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not a.dtype.is_floating_point:
|
if not a.dtype.is_floating_point:
|
||||||
aiinfo = torch.iinfo(a.dtype)
|
aiinfo = torch.iinfo(a.dtype)
|
||||||
@ -149,30 +133,21 @@ def create_bench_tensors(
|
|||||||
|
|
||||||
w_ref = w_ref.to(torch.float32)
|
w_ref = w_ref.to(torch.float32)
|
||||||
|
|
||||||
w_ch_s = (
|
w_ch_s = None if types.channel_scale_type is None else\
|
||||||
None
|
rand_data((n,), types.channel_scale_type)
|
||||||
if types.channel_scale_type is None
|
w_tok_s = None if types.token_scale_type is None else\
|
||||||
else rand_data((n,), types.channel_scale_type)
|
rand_data((m,), types.token_scale_type)
|
||||||
)
|
|
||||||
w_tok_s = (
|
|
||||||
None
|
|
||||||
if types.token_scale_type is None
|
|
||||||
else rand_data((m,), types.token_scale_type)
|
|
||||||
)
|
|
||||||
|
|
||||||
benchmark_tensors.append(
|
benchmark_tensors.append(
|
||||||
BenchmarkTensors(
|
BenchmarkTensors(w_ref=w_ref,
|
||||||
w_ref=w_ref,
|
a=a,
|
||||||
a=a,
|
w_q=w_q_packed,
|
||||||
w_q=w_q_packed,
|
wtype=types.weight_type,
|
||||||
wtype=types.weight_type,
|
w_g_s=w_s,
|
||||||
w_g_s=w_s,
|
w_g_zp=w_zp,
|
||||||
w_g_zp=w_zp,
|
group_size=group_size,
|
||||||
group_size=group_size,
|
w_ch_s=w_ch_s,
|
||||||
w_ch_s=w_ch_s,
|
w_tok_s=w_tok_s))
|
||||||
w_tok_s=w_tok_s,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return benchmark_tensors
|
return benchmark_tensors
|
||||||
|
|
||||||
@ -195,57 +170,50 @@ def cutlass_scaled_mm_create_bench_fn(bt: BenchmarkTensors) -> Callable:
|
|||||||
scale_b = torch.tensor(1.0, dtype=torch.float32, device=bt.a.device)
|
scale_b = torch.tensor(1.0, dtype=torch.float32, device=bt.a.device)
|
||||||
w_col_major = bt.w_ref.to(bt.a.dtype).t().contiguous().t()
|
w_col_major = bt.w_ref.to(bt.a.dtype).t().contiguous().t()
|
||||||
return lambda: ops.cutlass_scaled_mm(
|
return lambda: ops.cutlass_scaled_mm(
|
||||||
bt.a, w_col_major, scale_a, scale_b, out_dtype=torch.float16
|
bt.a, w_col_major, scale_a, scale_b, out_dtype=torch.float16)
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
|
def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
|
||||||
device = bt.a.device
|
device = bt.a.device
|
||||||
|
|
||||||
workspace = MarlinWorkspace(
|
workspace = MarlinWorkspace(bt.w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N,
|
||||||
bt.w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL
|
GPTQ_MARLIN_MAX_PARALLEL)
|
||||||
)
|
|
||||||
|
|
||||||
if bt.w_g_zp is None:
|
if bt.w_g_zp is None:
|
||||||
w_zp = torch.empty(0, dtype=torch.int, device=device)
|
w_zp = torch.empty(0, dtype=torch.int, device=device)
|
||||||
else:
|
else:
|
||||||
w_zp = marlin_zero_points(
|
w_zp = marlin_zero_points(bt.w_g_zp, bt.w_ref.shape[0],
|
||||||
bt.w_g_zp, bt.w_ref.shape[0], bt.w_ref.shape[1], bt.wtype.size_bits
|
bt.w_ref.shape[1], bt.wtype.size_bits)
|
||||||
)
|
|
||||||
|
|
||||||
if bt.group_size is None:
|
if bt.group_size is None:
|
||||||
w_s = torch.tensor([], device="cuda", dtype=torch.half)
|
w_s = torch.tensor([], device="cuda", dtype=torch.half)
|
||||||
else:
|
else:
|
||||||
w_s = marlin_permute_scales(
|
w_s = marlin_permute_scales(bt.w_g_s, bt.w_ref.shape[0],
|
||||||
bt.w_g_s, bt.w_ref.shape[0], bt.w_ref.shape[1], bt.group_size
|
bt.w_ref.shape[1], bt.group_size)
|
||||||
)
|
|
||||||
|
|
||||||
sort_indices = torch.empty(0, dtype=torch.int, device=device)
|
sort_indices = torch.empty(0, dtype=torch.int, device=device)
|
||||||
g_idx = torch.empty(0, dtype=torch.int, device=device)
|
g_idx = torch.empty(0, dtype=torch.int, device=device)
|
||||||
w_q = ops.gptq_marlin_repack(
|
w_q = ops.gptq_marlin_repack(bt.w_q, sort_indices, bt.w_ref.shape[0],
|
||||||
bt.w_q, sort_indices, bt.w_ref.shape[0], bt.w_ref.shape[1], bt.wtype.size_bits
|
bt.w_ref.shape[1], bt.wtype.size_bits)
|
||||||
)
|
|
||||||
|
|
||||||
if bt.a.dtype.is_floating_point:
|
if bt.a.dtype.is_floating_point:
|
||||||
assert bt.w_ch_s is None
|
assert bt.w_ch_s is None
|
||||||
assert bt.w_tok_s is None
|
assert bt.w_tok_s is None
|
||||||
assert bt.group_size is not None
|
assert bt.group_size is not None
|
||||||
|
|
||||||
fn = lambda: ops.gptq_marlin_gemm(
|
fn = lambda: ops.gptq_marlin_gemm(a=bt.a,
|
||||||
a=bt.a,
|
b_q_weight=w_q,
|
||||||
b_q_weight=w_q,
|
b_scales=w_s,
|
||||||
b_scales=w_s,
|
b_zeros=w_zp,
|
||||||
b_zeros=w_zp,
|
g_idx=g_idx,
|
||||||
g_idx=g_idx,
|
perm=sort_indices,
|
||||||
perm=sort_indices,
|
workspace=workspace.scratch,
|
||||||
workspace=workspace.scratch,
|
b_q_type=bt.wtype,
|
||||||
b_q_type=bt.wtype,
|
size_m=bt.a.shape[0],
|
||||||
size_m=bt.a.shape[0],
|
size_n=bt.w_ref.shape[1],
|
||||||
size_n=bt.w_ref.shape[1],
|
size_k=bt.w_ref.shape[0],
|
||||||
size_k=bt.w_ref.shape[0],
|
is_k_full=True,
|
||||||
is_k_full=True,
|
is_zp_float=False)
|
||||||
is_zp_float=False,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
assert bt.a.dtype == torch.int8
|
assert bt.a.dtype == torch.int8
|
||||||
assert bt.wtype == scalar_types.uint4b8
|
assert bt.wtype == scalar_types.uint4b8
|
||||||
@ -253,35 +221,36 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
|
|||||||
if bt.w_ch_s is not None:
|
if bt.w_ch_s is not None:
|
||||||
s_ch = bt.w_ch_s.to(torch.float32)
|
s_ch = bt.w_ch_s.to(torch.float32)
|
||||||
else:
|
else:
|
||||||
s_ch = torch.ones(bt.w_ref.shape[1], dtype=torch.float32, device=device)
|
s_ch = torch.ones(bt.w_ref.shape[1],
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=device)
|
||||||
|
|
||||||
if bt.w_tok_s is not None:
|
if bt.w_tok_s is not None:
|
||||||
s_tok = bt.w_tok_s.to(torch.float32)
|
s_tok = bt.w_tok_s.to(torch.float32)
|
||||||
else:
|
else:
|
||||||
s_tok = torch.ones(bt.a.shape[0], dtype=torch.float32, device=device)
|
s_tok = torch.ones(bt.a.shape[0],
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=device)
|
||||||
|
|
||||||
fn = lambda: ops.marlin_qqq_gemm(
|
fn = lambda: ops.marlin_qqq_gemm(a=bt.a,
|
||||||
a=bt.a,
|
b_q_weight=w_q,
|
||||||
b_q_weight=w_q,
|
s_group=w_s,
|
||||||
s_group=w_s,
|
s_tok=s_tok,
|
||||||
s_tok=s_tok,
|
s_ch=s_ch,
|
||||||
s_ch=s_ch,
|
workspace=workspace.scratch,
|
||||||
workspace=workspace.scratch,
|
size_m=bt.a.shape[0],
|
||||||
size_m=bt.a.shape[0],
|
size_n=bt.w_ref.shape[1],
|
||||||
size_n=bt.w_ref.shape[1],
|
size_k=bt.w_ref.shape[0])
|
||||||
size_k=bt.w_ref.shape[0],
|
|
||||||
)
|
|
||||||
|
|
||||||
return fn
|
return fn
|
||||||
|
|
||||||
|
|
||||||
def machete_create_bench_fn(
|
def machete_create_bench_fn(bt: BenchmarkTensors,
|
||||||
bt: BenchmarkTensors, out_type=torch.dtype, schedule=None
|
out_type=torch.dtype,
|
||||||
) -> Callable:
|
schedule=None) -> Callable:
|
||||||
w_q = bt.w_q.t().contiguous().t() # make col major
|
w_q = bt.w_q.t().contiguous().t() # make col major
|
||||||
w_q = ops.machete_prepack_B(
|
w_q = ops.machete_prepack_B(w_q, bt.a.dtype, bt.wtype,
|
||||||
w_q, bt.a.dtype, bt.wtype, None if bt.w_g_s is None else bt.w_g_s.dtype
|
None if bt.w_g_s is None else bt.w_g_s.dtype)
|
||||||
)
|
|
||||||
|
|
||||||
w_g_zp = bt.w_g_zp
|
w_g_zp = bt.w_g_zp
|
||||||
if w_g_zp is not None:
|
if w_g_zp is not None:
|
||||||
@ -306,24 +275,26 @@ def machete_create_bench_fn(
|
|||||||
# bench
|
# bench
|
||||||
|
|
||||||
|
|
||||||
def bench_fns(label: str, sub_label: str, description: str, fns: list[Callable]):
|
def bench_fns(label: str, sub_label: str, description: str,
|
||||||
|
fns: list[Callable]):
|
||||||
|
|
||||||
min_run_time = 1 if not NVTX_PROFILE else 0.1
|
min_run_time = 1 if not NVTX_PROFILE else 0.1
|
||||||
res = TBenchmark.Timer(
|
res = TBenchmark.Timer(
|
||||||
stmt="""
|
stmt="""
|
||||||
for fn in fns:
|
for fn in fns:
|
||||||
fn()
|
fn()
|
||||||
""",
|
""",
|
||||||
globals={"fns": fns},
|
globals={
|
||||||
|
"fns": fns
|
||||||
|
},
|
||||||
label=label,
|
label=label,
|
||||||
sub_label=sub_label,
|
sub_label=sub_label,
|
||||||
description=description,
|
description=description,
|
||||||
).blocked_autorange(min_run_time=min_run_time)
|
).blocked_autorange(min_run_time=min_run_time)
|
||||||
|
|
||||||
if NVTX_PROFILE:
|
if NVTX_PROFILE:
|
||||||
with (
|
with nvtx.annotate("mm-bench"), nvtx.annotate(
|
||||||
nvtx.annotate("mm-bench"),
|
f"{label}|{sub_label}|{description}"):
|
||||||
nvtx.annotate(f"{label}|{sub_label}|{description}"),
|
|
||||||
):
|
|
||||||
fns[0]()
|
fns[0]()
|
||||||
|
|
||||||
return res
|
return res
|
||||||
@ -333,20 +304,19 @@ _SWEEP_SCHEDULES_RESULTS: Optional[pd.DataFrame] = None
|
|||||||
_SWEEP_SCHEDULES_RESULTS_CSV: Optional[str] = None
|
_SWEEP_SCHEDULES_RESULTS_CSV: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
def bench(
|
def bench(types: TypeConfig,
|
||||||
types: TypeConfig,
|
group_size: int,
|
||||||
group_size: int,
|
m: int,
|
||||||
m: int,
|
k: int,
|
||||||
k: int,
|
n: int,
|
||||||
n: int,
|
label: str,
|
||||||
label: str,
|
sub_label: str,
|
||||||
sub_label: str,
|
sweep_schedules: bool = True) -> list[TMeasurement]:
|
||||||
sweep_schedules: bool = True,
|
|
||||||
) -> list[TMeasurement]:
|
|
||||||
benchmark_tensors = create_bench_tensors((m, n, k), types, group_size)
|
benchmark_tensors = create_bench_tensors((m, n, k), types, group_size)
|
||||||
sub_label += f", L={len(benchmark_tensors)}"
|
sub_label += f", L={len(benchmark_tensors)}"
|
||||||
|
|
||||||
name_type_string = f"W{types.weight_type}" + f"-A{terse_type_name(types.act_type)}"
|
name_type_string = f"W{types.weight_type}"+\
|
||||||
|
f"-A{terse_type_name(types.act_type)}"
|
||||||
if types.group_scale_type is not None:
|
if types.group_scale_type is not None:
|
||||||
name_type_string += f"-GS{terse_type_name(types.group_scale_type)}"
|
name_type_string += f"-GS{terse_type_name(types.group_scale_type)}"
|
||||||
if types.group_zero_type is not None:
|
if types.group_zero_type is not None:
|
||||||
@ -362,45 +332,31 @@ def bench(
|
|||||||
# pytorch impl
|
# pytorch impl
|
||||||
timers.append(
|
timers.append(
|
||||||
bench_fns(
|
bench_fns(
|
||||||
label,
|
label, sub_label, "torch.matmul (fp16)",
|
||||||
sub_label,
|
[torch_matmul_f16_create_bench_fn(bt)
|
||||||
"torch.matmul (fp16)",
|
for bt in benchmark_tensors]))
|
||||||
[torch_matmul_f16_create_bench_fn(bt) for bt in benchmark_tensors],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if types.act_type == torch.int8 or types.act_type == torch.float8_e4m3fn:
|
if types.act_type == torch.int8 or types.act_type == torch.float8_e4m3fn:
|
||||||
timers.append(
|
timers.append(
|
||||||
bench_fns(
|
bench_fns(
|
||||||
label,
|
label, sub_label,
|
||||||
sub_label,
|
f"cutlass_scaled_mm ({terse_type_name(types.act_type)})", [
|
||||||
f"cutlass_scaled_mm ({terse_type_name(types.act_type)})",
|
cutlass_scaled_mm_create_bench_fn(bt)
|
||||||
[cutlass_scaled_mm_create_bench_fn(bt) for bt in benchmark_tensors],
|
for bt in benchmark_tensors
|
||||||
)
|
]))
|
||||||
)
|
|
||||||
|
|
||||||
if types.act_type != torch.float8_e4m3fn:
|
if types.act_type != torch.float8_e4m3fn:
|
||||||
timers.append(
|
timers.append(
|
||||||
bench_fns(
|
bench_fns(label, sub_label, f"marlin ({name_type_string})",
|
||||||
label,
|
[marlin_create_bench_fn(bt)
|
||||||
sub_label,
|
for bt in benchmark_tensors]))
|
||||||
f"marlin ({name_type_string})",
|
|
||||||
[marlin_create_bench_fn(bt) for bt in benchmark_tensors],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# machete
|
# machete
|
||||||
timers.append(
|
timers.append(
|
||||||
bench_fns(
|
bench_fns(label, sub_label, f"machete ({name_type_string})", [
|
||||||
label,
|
machete_create_bench_fn(bt, out_type=types.output_type)
|
||||||
sub_label,
|
for bt in benchmark_tensors
|
||||||
f"machete ({name_type_string})",
|
]))
|
||||||
[
|
|
||||||
machete_create_bench_fn(bt, out_type=types.output_type)
|
|
||||||
for bt in benchmark_tensors
|
|
||||||
],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if sweep_schedules:
|
if sweep_schedules:
|
||||||
global _SWEEP_SCHEDULES_RESULTS
|
global _SWEEP_SCHEDULES_RESULTS
|
||||||
@ -415,8 +371,7 @@ def bench(
|
|||||||
group_zeros_type=types.group_zero_type,
|
group_zeros_type=types.group_zero_type,
|
||||||
token_scales_type=types.token_scale_type,
|
token_scales_type=types.token_scale_type,
|
||||||
channel_scales_type=types.channel_scale_type,
|
channel_scales_type=types.channel_scale_type,
|
||||||
out_type=types.output_type,
|
out_type=types.output_type)
|
||||||
)
|
|
||||||
|
|
||||||
if schedules is None or len(schedules) == 0:
|
if schedules is None or len(schedules) == 0:
|
||||||
raise ValueError("No schedules found to sweep")
|
raise ValueError("No schedules found to sweep")
|
||||||
@ -428,17 +383,11 @@ def bench(
|
|||||||
if schedule_M >= 2 * max(m, 16) or schedule_M < m // 4:
|
if schedule_M >= 2 * max(m, 16) or schedule_M < m // 4:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
res = bench_fns(
|
res = bench_fns(label, sub_label, "machete_best", [
|
||||||
label,
|
machete_create_bench_fn(
|
||||||
sub_label,
|
bt, out_type=types.output_type, schedule=schedule)
|
||||||
"machete_best",
|
for bt in benchmark_tensors
|
||||||
[
|
])
|
||||||
machete_create_bench_fn(
|
|
||||||
bt, out_type=types.output_type, schedule=schedule
|
|
||||||
)
|
|
||||||
for bt in benchmark_tensors
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
results_row = {
|
results_row = {
|
||||||
"M": m,
|
"M": m,
|
||||||
@ -449,8 +398,10 @@ def bench(
|
|||||||
"median": res.median,
|
"median": res.median,
|
||||||
}
|
}
|
||||||
if _SWEEP_SCHEDULES_RESULTS is None:
|
if _SWEEP_SCHEDULES_RESULTS is None:
|
||||||
_SWEEP_SCHEDULES_RESULTS = pd.DataFrame(columns=results_row.keys())
|
_SWEEP_SCHEDULES_RESULTS = pd.DataFrame(
|
||||||
_SWEEP_SCHEDULES_RESULTS.loc[len(_SWEEP_SCHEDULES_RESULTS)] = results_row
|
columns=results_row.keys())
|
||||||
|
_SWEEP_SCHEDULES_RESULTS.\
|
||||||
|
loc[len(_SWEEP_SCHEDULES_RESULTS)] = results_row
|
||||||
|
|
||||||
print(f" {res.median:5.5} ", schedule)
|
print(f" {res.median:5.5} ", schedule)
|
||||||
if not best or res.median < best.median:
|
if not best or res.median < best.median:
|
||||||
@ -471,9 +422,8 @@ def print_timers(timers: list[TMeasurement]):
|
|||||||
def run(args, MKNs: Iterable[tuple[int, int, int]]) -> Iterable[TMeasurement]:
|
def run(args, MKNs: Iterable[tuple[int, int, int]]) -> Iterable[TMeasurement]:
|
||||||
types = TypeConfig(
|
types = TypeConfig(
|
||||||
act_type=args.act_type,
|
act_type=args.act_type,
|
||||||
weight_type=scalar_types.uint4b8
|
weight_type=scalar_types.uint4b8 if args.group_zero_type is None \
|
||||||
if args.group_zero_type is None
|
else scalar_types.uint4,
|
||||||
else scalar_types.uint4,
|
|
||||||
output_type=args.out_type,
|
output_type=args.out_type,
|
||||||
group_scale_type=args.group_scale_type,
|
group_scale_type=args.group_scale_type,
|
||||||
group_zero_type=args.group_zero_type,
|
group_zero_type=args.group_zero_type,
|
||||||
@ -483,16 +433,14 @@ def run(args, MKNs: Iterable[tuple[int, int, int]]) -> Iterable[TMeasurement]:
|
|||||||
|
|
||||||
results: list[TMeasurement] = []
|
results: list[TMeasurement] = []
|
||||||
for m, k, n in MKNs:
|
for m, k, n in MKNs:
|
||||||
timers = bench(
|
timers = bench(types,
|
||||||
types,
|
args.group_size,
|
||||||
args.group_size,
|
m,
|
||||||
m,
|
k,
|
||||||
k,
|
n,
|
||||||
n,
|
f"{args.act_type}-gemm",
|
||||||
f"{args.act_type}-gemm",
|
f"MKN=({m}x{k}x{n})",
|
||||||
f"MKN=({m}x{k}x{n})",
|
sweep_schedules=args.sweep_schedules)
|
||||||
sweep_schedules=args.sweep_schedules,
|
|
||||||
)
|
|
||||||
print_timers(timers)
|
print_timers(timers)
|
||||||
results.extend(timers)
|
results.extend(timers)
|
||||||
|
|
||||||
@ -506,6 +454,7 @@ def make_output(
|
|||||||
base_description: str,
|
base_description: str,
|
||||||
timestamp=None,
|
timestamp=None,
|
||||||
):
|
):
|
||||||
|
|
||||||
print(f"== All Results {base_description} ====")
|
print(f"== All Results {base_description} ====")
|
||||||
print_timers(data)
|
print_timers(data)
|
||||||
|
|
||||||
@ -519,7 +468,8 @@ def make_output(
|
|||||||
|
|
||||||
|
|
||||||
def run_square_bench(args):
|
def run_square_bench(args):
|
||||||
dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment))
|
dim_sizes = list(
|
||||||
|
range(args.dim_start, args.dim_end + 1, args.dim_increment))
|
||||||
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
|
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
|
||||||
data = run(args.dtype, args.sweep_schedules, MKNs)
|
data = run(args.dtype, args.sweep_schedules, MKNs)
|
||||||
|
|
||||||
@ -529,9 +479,8 @@ def run_square_bench(args):
|
|||||||
def run_range_bench(args):
|
def run_range_bench(args):
|
||||||
m_start, k_start, n_start = (int(x) for x in args.dim_start.split(","))
|
m_start, k_start, n_start = (int(x) for x in args.dim_start.split(","))
|
||||||
m_end, k_end, n_end = (int(x) for x in args.dim_end.split(","))
|
m_end, k_end, n_end = (int(x) for x in args.dim_end.split(","))
|
||||||
m_increment, k_increment, n_increment = (
|
m_increment, k_increment, n_increment = \
|
||||||
int(x) for x in args.dim_increment.split(",")
|
(int(x) for x in args.dim_increment.split(","))
|
||||||
)
|
|
||||||
Ms = list(range(m_start, m_end + 1, m_increment))
|
Ms = list(range(m_start, m_end + 1, m_increment))
|
||||||
Ks = list(range(k_start, k_end + 1, k_increment))
|
Ks = list(range(k_start, k_end + 1, k_increment))
|
||||||
Ns = list(range(n_start, n_end + 1, n_increment))
|
Ns = list(range(n_start, n_end + 1, n_increment))
|
||||||
@ -543,6 +492,7 @@ def run_range_bench(args):
|
|||||||
|
|
||||||
|
|
||||||
def run_model_bench(args):
|
def run_model_bench(args):
|
||||||
|
|
||||||
print("Benchmarking models:")
|
print("Benchmarking models:")
|
||||||
for i, model in enumerate(args.models):
|
for i, model in enumerate(args.models):
|
||||||
print(f"[{i}] {model}")
|
print(f"[{i}] {model}")
|
||||||
@ -585,13 +535,10 @@ def run_model_bench(args):
|
|||||||
with open(f"model_bench-{type_string}-{timestr}.pkl", "wb") as f:
|
with open(f"model_bench-{type_string}-{timestr}.pkl", "wb") as f:
|
||||||
args_dict = vars(args)
|
args_dict = vars(args)
|
||||||
args_dict.pop("func")
|
args_dict.pop("func")
|
||||||
pkl.dump(
|
pkl.dump({
|
||||||
{
|
"args": args_dict,
|
||||||
"args": args_dict,
|
"results": all_results,
|
||||||
"results": all_results,
|
}, f)
|
||||||
},
|
|
||||||
f,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -607,6 +554,7 @@ if __name__ == "__main__":
|
|||||||
}[dt]
|
}[dt]
|
||||||
|
|
||||||
class ToTorchDtype(argparse.Action):
|
class ToTorchDtype(argparse.Action):
|
||||||
|
|
||||||
def __call__(self, parser, namespace, values, option_string=None):
|
def __call__(self, parser, namespace, values, option_string=None):
|
||||||
setattr(namespace, self.dest, to_torch_dtype(values))
|
setattr(namespace, self.dest, to_torch_dtype(values))
|
||||||
|
|
||||||
@ -632,32 +580,32 @@ Benchmark Machete GEMM.
|
|||||||
"--act-type",
|
"--act-type",
|
||||||
action=ToTorchDtype,
|
action=ToTorchDtype,
|
||||||
required=True,
|
required=True,
|
||||||
choices=["bfloat16", "float16", "int8", "float8_e4m3fn"],
|
choices=['bfloat16', 'float16', 'int8', 'float8_e4m3fn'],
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--group-scale-type",
|
"--group-scale-type",
|
||||||
action=ToTorchDtype,
|
action=ToTorchDtype,
|
||||||
choices=["bfloat16", "float16"],
|
choices=['bfloat16', 'float16'],
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--group-zero-type",
|
"--group-zero-type",
|
||||||
type=to_torch_dtype,
|
type=to_torch_dtype,
|
||||||
choices=["bfloat16", "float16"],
|
choices=['bfloat16', 'float16'],
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--channel-scale-type",
|
"--channel-scale-type",
|
||||||
action=ToTorchDtype,
|
action=ToTorchDtype,
|
||||||
choices=["float"],
|
choices=['float'],
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--token-scale-type",
|
"--token-scale-type",
|
||||||
action=ToTorchDtype,
|
action=ToTorchDtype,
|
||||||
choices=["float"],
|
choices=['float'],
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--out-type",
|
"--out-type",
|
||||||
action=ToTorchDtype,
|
action=ToTorchDtype,
|
||||||
choices=["bfloat16", "float16"],
|
choices=['bfloat16', 'float16'],
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--group-size",
|
"--group-size",
|
||||||
@ -670,11 +618,9 @@ Benchmark Machete GEMM.
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Run a sweep over all supported schedules",
|
help="Run a sweep over all supported schedules",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument("--sweep-csv-out",
|
||||||
"--sweep-csv-out",
|
help="CSV to store sweep results",
|
||||||
help="CSV to store sweep results",
|
default="sch_sweep_results.csv")
|
||||||
default="sch_sweep_results.csv",
|
|
||||||
)
|
|
||||||
subparsers = parser.add_subparsers(dest="cmd", required=True)
|
subparsers = parser.add_subparsers(dest="cmd", required=True)
|
||||||
|
|
||||||
square_parser = subparsers.add_parser("square_bench")
|
square_parser = subparsers.add_parser("square_bench")
|
||||||
@ -688,20 +634,17 @@ Benchmark Machete GEMM.
|
|||||||
"--dim-start",
|
"--dim-start",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help="Start value for M,K,N as common separated list",
|
help="Start value for M,K,N as common separated list")
|
||||||
)
|
|
||||||
range_parser.add_argument(
|
range_parser.add_argument(
|
||||||
"--dim-end",
|
"--dim-end",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help="End value (inclusive) for M,K,N as common separated list",
|
help="End value (inclusive) for M,K,N as common separated list")
|
||||||
)
|
|
||||||
range_parser.add_argument(
|
range_parser.add_argument(
|
||||||
"--dim-increment",
|
"--dim-increment",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help="Increment value for M,K,N as common separated list",
|
help="Increment value for M,K,N as common separated list")
|
||||||
)
|
|
||||||
range_parser.set_defaults(func=run_range_bench)
|
range_parser.set_defaults(func=run_range_bench)
|
||||||
|
|
||||||
model_parser = subparsers.add_parser("model_bench")
|
model_parser = subparsers.add_parser("model_bench")
|
||||||
@ -712,12 +655,14 @@ Benchmark Machete GEMM.
|
|||||||
default=DEFAULT_MODELS,
|
default=DEFAULT_MODELS,
|
||||||
choices=WEIGHT_SHAPES.keys(),
|
choices=WEIGHT_SHAPES.keys(),
|
||||||
)
|
)
|
||||||
model_parser.add_argument(
|
model_parser.add_argument("--tp-sizes",
|
||||||
"--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES
|
nargs="+",
|
||||||
)
|
type=int,
|
||||||
model_parser.add_argument(
|
default=DEFAULT_TP_SIZES)
|
||||||
"--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
|
model_parser.add_argument("--batch-sizes",
|
||||||
)
|
nargs="+",
|
||||||
|
type=int,
|
||||||
|
default=DEFAULT_BATCH_SIZES)
|
||||||
model_parser.set_defaults(func=run_model_bench)
|
model_parser.set_defaults(func=run_model_bench)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|||||||
@ -6,34 +6,19 @@ from benchmark_shapes import WEIGHT_SHAPES
|
|||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||||
GPTQ_MARLIN_24_MAX_PARALLEL,
|
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||||
GPTQ_MARLIN_24_MIN_THREAD_N,
|
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
|
||||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES,
|
|
||||||
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES,
|
|
||||||
)
|
|
||||||
from vllm.model_executor.layers.quantization.utils.allspark_utils import (
|
from vllm.model_executor.layers.quantization.utils.allspark_utils import (
|
||||||
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
|
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, ALLSPARK_SUPPORTED_QUANT_TYPES)
|
||||||
ALLSPARK_SUPPORTED_QUANT_TYPES,
|
|
||||||
)
|
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
GPTQ_MARLIN_MAX_PARALLEL,
|
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
|
||||||
GPTQ_MARLIN_MIN_THREAD_N,
|
MARLIN_SUPPORTED_GROUP_SIZES, query_marlin_supported_quant_types)
|
||||||
MARLIN_SUPPORTED_GROUP_SIZES,
|
|
||||||
query_marlin_supported_quant_types,
|
|
||||||
)
|
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||||
MarlinWorkspace,
|
MarlinWorkspace, marlin_quantize)
|
||||||
marlin_quantize,
|
|
||||||
)
|
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
|
||||||
marlin_24_quantize,
|
marlin_24_quantize)
|
||||||
)
|
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
gptq_pack,
|
gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights)
|
||||||
gptq_quantize_weights,
|
|
||||||
quantize_weights,
|
|
||||||
sort_weights,
|
|
||||||
)
|
|
||||||
from vllm.scalar_type import ScalarType
|
from vllm.scalar_type import ScalarType
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
@ -44,29 +29,22 @@ ACT_ORDER_OPTS = [False, True]
|
|||||||
K_FULL_OPTS = [False, True]
|
K_FULL_OPTS = [False, True]
|
||||||
|
|
||||||
|
|
||||||
def bench_run(
|
def bench_run(results: list[benchmark.Measurement], model: str,
|
||||||
results: list[benchmark.Measurement],
|
act_order: bool, is_k_full: bool, quant_type: ScalarType,
|
||||||
model: str,
|
group_size: int, size_m: int, size_k: int, size_n: int):
|
||||||
act_order: bool,
|
|
||||||
is_k_full: bool,
|
|
||||||
quant_type: ScalarType,
|
|
||||||
group_size: int,
|
|
||||||
size_m: int,
|
|
||||||
size_k: int,
|
|
||||||
size_n: int,
|
|
||||||
):
|
|
||||||
label = "Quant Matmul"
|
label = "Quant Matmul"
|
||||||
|
|
||||||
sub_label = "{}, act={} k_full={}, q={}, g={}, MKN=({}x{}x{})".format(
|
sub_label = ("{}, act={} k_full={}, q={}, g={}, "
|
||||||
model, act_order, is_k_full, str(quant_type), group_size, size_m, size_k, size_n
|
"MKN=({}x{}x{})".format(model, act_order, is_k_full,
|
||||||
)
|
str(quant_type), group_size, size_m,
|
||||||
|
size_k, size_n))
|
||||||
|
|
||||||
print(f"Testing: {sub_label}")
|
print(f"Testing: {sub_label}")
|
||||||
|
|
||||||
a = torch.randn(size_m, size_k).to(torch.half).cuda()
|
a = torch.randn(size_m, size_k).to(torch.half).cuda()
|
||||||
b = torch.rand(size_k, size_n).to(torch.half).cuda()
|
b = torch.rand(size_k, size_n).to(torch.half).cuda()
|
||||||
|
|
||||||
a_tmp = torch.zeros(size_m, size_k).to(torch.half).cuda()
|
a_tmp = (torch.zeros(size_m, size_k).to(torch.half).cuda())
|
||||||
|
|
||||||
# Marlin quant
|
# Marlin quant
|
||||||
(
|
(
|
||||||
@ -79,16 +57,14 @@ def bench_run(
|
|||||||
) = marlin_quantize(b, quant_type, group_size, act_order)
|
) = marlin_quantize(b, quant_type, group_size, act_order)
|
||||||
|
|
||||||
# Marlin_24 quant
|
# Marlin_24 quant
|
||||||
(marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = (
|
(marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta,
|
||||||
marlin_24_quantize(b, quant_type, group_size)
|
marlin_24_s) = marlin_24_quantize(b, quant_type, group_size)
|
||||||
)
|
|
||||||
|
|
||||||
marlin_zp = torch.empty(0, dtype=torch.int, device=b.device)
|
marlin_zp = torch.empty(0, dtype=torch.int, device=b.device)
|
||||||
|
|
||||||
# GPTQ quant
|
# GPTQ quant
|
||||||
(w_ref, q_w, s, g_idx, rand_perm) = gptq_quantize_weights(
|
(w_ref, q_w, s, g_idx,
|
||||||
b, quant_type, group_size, act_order
|
rand_perm) = gptq_quantize_weights(b, quant_type, group_size, act_order)
|
||||||
)
|
|
||||||
q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
|
q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
|
||||||
|
|
||||||
# For act_order, sort the "weights" and "g_idx"
|
# For act_order, sort the "weights" and "g_idx"
|
||||||
@ -98,37 +74,32 @@ def bench_run(
|
|||||||
(q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx)
|
(q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx)
|
||||||
|
|
||||||
# Prepare
|
# Prepare
|
||||||
marlin_workspace = MarlinWorkspace(
|
marlin_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
||||||
size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL
|
GPTQ_MARLIN_MAX_PARALLEL)
|
||||||
)
|
|
||||||
|
|
||||||
marlin_24_workspace = MarlinWorkspace(
|
marlin_24_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||||
size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL
|
GPTQ_MARLIN_24_MAX_PARALLEL)
|
||||||
)
|
|
||||||
marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int)
|
marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int)
|
||||||
|
|
||||||
# AllSpark W8A16 quant
|
# AllSpark W8A16 quant
|
||||||
as_supported_case = (
|
as_supported_case = (quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES
|
||||||
quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES
|
and group_size == -1 and not act_order and is_k_full)
|
||||||
and group_size == -1
|
|
||||||
and not act_order
|
|
||||||
and is_k_full
|
|
||||||
)
|
|
||||||
if as_supported_case:
|
if as_supported_case:
|
||||||
properties = torch.cuda.get_device_properties(b.device.index)
|
properties = torch.cuda.get_device_properties(b.device.index)
|
||||||
sm_count = properties.multi_processor_count
|
sm_count = properties.multi_processor_count
|
||||||
sm_version = properties.major * 10 + properties.minor
|
sm_version = properties.major * 10 + properties.minor
|
||||||
|
|
||||||
supported_arch = sm_version >= 80 and sm_version < 90
|
supported_arch = (sm_version >= 80 and sm_version < 90)
|
||||||
as_supported_case = as_supported_case and supported_arch
|
as_supported_case = as_supported_case and supported_arch
|
||||||
if supported_arch:
|
if supported_arch:
|
||||||
has_zp = False
|
has_zp = False
|
||||||
w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size, has_zp)
|
w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size,
|
||||||
|
has_zp)
|
||||||
qw = qw.to(torch.uint8)
|
qw = qw.to(torch.uint8)
|
||||||
|
|
||||||
qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight(
|
qw_reorder, s_reorder, zp_reorder = \
|
||||||
qw, s, zp, has_zp
|
ops.allspark_repack_weight(
|
||||||
)
|
qw, s, zp, has_zp)
|
||||||
CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
|
CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
|
||||||
|
|
||||||
globals = {
|
globals = {
|
||||||
@ -165,7 +136,8 @@ def bench_run(
|
|||||||
"zp_reorder": zp_reorder if as_supported_case else None,
|
"zp_reorder": zp_reorder if as_supported_case else None,
|
||||||
"sm_count": sm_count if as_supported_case else None,
|
"sm_count": sm_count if as_supported_case else None,
|
||||||
"sm_version": sm_version if as_supported_case else None,
|
"sm_version": sm_version if as_supported_case else None,
|
||||||
"CUBLAS_M_THRESHOLD": CUBLAS_M_THRESHOLD if as_supported_case else None,
|
"CUBLAS_M_THRESHOLD":
|
||||||
|
CUBLAS_M_THRESHOLD if as_supported_case else None,
|
||||||
# Kernels
|
# Kernels
|
||||||
"gptq_marlin_gemm": ops.gptq_marlin_gemm,
|
"gptq_marlin_gemm": ops.gptq_marlin_gemm,
|
||||||
"gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm,
|
"gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm,
|
||||||
@ -186,63 +158,60 @@ def bench_run(
|
|||||||
label=label,
|
label=label,
|
||||||
sub_label=sub_label,
|
sub_label=sub_label,
|
||||||
description="pytorch_gemm",
|
description="pytorch_gemm",
|
||||||
).blocked_autorange(min_run_time=min_run_time)
|
).blocked_autorange(min_run_time=min_run_time))
|
||||||
)
|
|
||||||
|
|
||||||
results.append(
|
results.append(
|
||||||
benchmark.Timer(
|
benchmark.Timer(
|
||||||
stmt="output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501
|
stmt=
|
||||||
|
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501
|
||||||
globals=globals,
|
globals=globals,
|
||||||
label=label,
|
label=label,
|
||||||
sub_label=sub_label,
|
sub_label=sub_label,
|
||||||
description="gptq_marlin_gemm_fp16",
|
description="gptq_marlin_gemm_fp16",
|
||||||
).blocked_autorange(min_run_time=min_run_time)
|
).blocked_autorange(min_run_time=min_run_time))
|
||||||
)
|
|
||||||
|
|
||||||
results.append(
|
results.append(
|
||||||
benchmark.Timer(
|
benchmark.Timer(
|
||||||
stmt="output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501
|
stmt=
|
||||||
|
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501
|
||||||
globals=globals,
|
globals=globals,
|
||||||
label=label,
|
label=label,
|
||||||
sub_label=sub_label,
|
sub_label=sub_label,
|
||||||
description="gptq_marlin_gemm_fp32",
|
description="gptq_marlin_gemm_fp32",
|
||||||
).blocked_autorange(min_run_time=min_run_time)
|
).blocked_autorange(min_run_time=min_run_time))
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
if (quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
|
||||||
quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
|
and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES):
|
||||||
and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
|
|
||||||
):
|
|
||||||
results.append(
|
results.append(
|
||||||
benchmark.Timer(
|
benchmark.Timer(
|
||||||
stmt="output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)", # noqa: E501
|
stmt=
|
||||||
|
"output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)", # noqa: E501
|
||||||
globals=globals,
|
globals=globals,
|
||||||
label=label,
|
label=label,
|
||||||
sub_label=sub_label,
|
sub_label=sub_label,
|
||||||
description="gptq_marlin_24_gemm",
|
description="gptq_marlin_24_gemm",
|
||||||
).blocked_autorange(min_run_time=min_run_time)
|
).blocked_autorange(min_run_time=min_run_time))
|
||||||
)
|
|
||||||
|
|
||||||
results.append(
|
results.append(
|
||||||
benchmark.Timer(
|
benchmark.Timer(
|
||||||
stmt="q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)", # noqa: E501
|
stmt=
|
||||||
|
"q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)", # noqa: E501
|
||||||
globals=globals,
|
globals=globals,
|
||||||
label=label,
|
label=label,
|
||||||
sub_label=sub_label,
|
sub_label=sub_label,
|
||||||
description="gptq_marlin_repack",
|
description="gptq_marlin_repack",
|
||||||
).blocked_autorange(min_run_time=min_run_time)
|
).blocked_autorange(min_run_time=min_run_time))
|
||||||
)
|
|
||||||
|
|
||||||
if as_supported_case:
|
if as_supported_case:
|
||||||
results.append(
|
results.append(
|
||||||
benchmark.Timer(
|
benchmark.Timer(
|
||||||
stmt="output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)", # noqa: E501
|
stmt=
|
||||||
|
"output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)", # noqa: E501
|
||||||
globals=globals,
|
globals=globals,
|
||||||
label=label,
|
label=label,
|
||||||
sub_label=sub_label,
|
sub_label=sub_label,
|
||||||
description="allspark_w8a16_gemm_fp32",
|
description="allspark_w8a16_gemm_fp32",
|
||||||
).blocked_autorange(min_run_time=min_run_time)
|
).blocked_autorange(min_run_time=min_run_time))
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
@ -264,50 +233,37 @@ def main(args):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
for act_order in ACT_ORDER_OPTS:
|
for act_order in ACT_ORDER_OPTS:
|
||||||
if (
|
if len(args.limit_act_order
|
||||||
len(args.limit_act_order) > 0
|
) > 0 and act_order not in args.limit_act_order:
|
||||||
and act_order not in args.limit_act_order
|
|
||||||
):
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for is_k_full in K_FULL_OPTS:
|
for is_k_full in K_FULL_OPTS:
|
||||||
if (
|
if len(args.limit_k_full
|
||||||
len(args.limit_k_full) > 0
|
) > 0 and is_k_full not in args.limit_k_full:
|
||||||
and is_k_full not in args.limit_k_full
|
|
||||||
):
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for quant_type in query_marlin_supported_quant_types(False):
|
for quant_type in query_marlin_supported_quant_types(
|
||||||
if (
|
False):
|
||||||
len(args.limit_num_bits) > 0
|
if len(args.limit_num_bits) > 0 and \
|
||||||
and quant_type.size_bits not in args.limit_num_bits
|
quant_type.size_bits not in args.limit_num_bits:
|
||||||
):
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for group_size in MARLIN_SUPPORTED_GROUP_SIZES:
|
for group_size in MARLIN_SUPPORTED_GROUP_SIZES:
|
||||||
if (
|
if len(
|
||||||
len(args.limit_group_size) > 0
|
args.limit_group_size
|
||||||
and group_size not in args.limit_group_size
|
) > 0 and group_size not in args.limit_group_size:
|
||||||
):
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# For act_order, the group_size must be less than
|
# For act_order, the group_size must be less than
|
||||||
# size_k
|
# size_k
|
||||||
if act_order and (group_size == size_k or group_size == -1):
|
if act_order and (group_size == size_k
|
||||||
|
or group_size == -1):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for size_m in args.batch_sizes:
|
for size_m in args.batch_sizes:
|
||||||
bench_run(
|
bench_run(results, model, act_order, is_k_full,
|
||||||
results,
|
quant_type, group_size, size_m,
|
||||||
model,
|
size_k, size_n)
|
||||||
act_order,
|
|
||||||
is_k_full,
|
|
||||||
quant_type,
|
|
||||||
group_size,
|
|
||||||
size_m,
|
|
||||||
size_k,
|
|
||||||
size_n,
|
|
||||||
)
|
|
||||||
|
|
||||||
compare = benchmark.Compare(results)
|
compare = benchmark.Compare(results)
|
||||||
compare.print()
|
compare.print()
|
||||||
@ -318,8 +274,7 @@ def main(args):
|
|||||||
#
|
#
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = FlexibleArgumentParser(
|
parser = FlexibleArgumentParser(
|
||||||
description="Benchmark Marlin across specified models/shapes/batches"
|
description="Benchmark Marlin across specified models/shapes/batches")
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--models",
|
"--models",
|
||||||
nargs="+",
|
nargs="+",
|
||||||
@ -327,9 +282,10 @@ if __name__ == "__main__":
|
|||||||
default=DEFAULT_MODELS,
|
default=DEFAULT_MODELS,
|
||||||
choices=WEIGHT_SHAPES.keys(),
|
choices=WEIGHT_SHAPES.keys(),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument("--batch-sizes",
|
||||||
"--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
|
nargs="+",
|
||||||
)
|
type=int,
|
||||||
|
default=DEFAULT_BATCH_SIZES)
|
||||||
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
|
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
|
||||||
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
|
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
|
||||||
parser.add_argument("--limit-group-size", nargs="+", type=int, default=[])
|
parser.add_argument("--limit-group-size", nargs="+", type=int, default=[])
|
||||||
|
|||||||
@ -31,60 +31,56 @@ class BenchmarkConfig(TypedDict):
|
|||||||
num_stages: int
|
num_stages: int
|
||||||
|
|
||||||
|
|
||||||
def benchmark_config(
|
def benchmark_config(config: BenchmarkConfig,
|
||||||
config: BenchmarkConfig,
|
num_tokens: int,
|
||||||
num_tokens: int,
|
num_experts: int,
|
||||||
num_experts: int,
|
shard_intermediate_size: int,
|
||||||
shard_intermediate_size: int,
|
hidden_size: int,
|
||||||
hidden_size: int,
|
topk: int,
|
||||||
topk: int,
|
dtype: torch.dtype,
|
||||||
dtype: torch.dtype,
|
use_fp8_w8a8: bool,
|
||||||
use_fp8_w8a8: bool,
|
use_int8_w8a16: bool,
|
||||||
use_int8_w8a16: bool,
|
num_iters: int = 100,
|
||||||
num_iters: int = 100,
|
block_quant_shape: List[int] = None,
|
||||||
block_quant_shape: List[int] = None,
|
use_deep_gemm: bool = False) -> float:
|
||||||
use_deep_gemm: bool = False,
|
|
||||||
) -> float:
|
|
||||||
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
||||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||||
if use_int8_w8a16:
|
if use_int8_w8a16:
|
||||||
w1 = torch.randint(
|
w1 = torch.randint(-127,
|
||||||
-127,
|
127, (
|
||||||
127,
|
num_experts,
|
||||||
(
|
shard_intermediate_size,
|
||||||
num_experts,
|
hidden_size,
|
||||||
shard_intermediate_size,
|
),
|
||||||
hidden_size,
|
dtype=torch.int8)
|
||||||
),
|
w2 = torch.randint(-127,
|
||||||
dtype=torch.int8,
|
127, (
|
||||||
)
|
num_experts,
|
||||||
w2 = torch.randint(
|
hidden_size,
|
||||||
-127,
|
shard_intermediate_size // 2,
|
||||||
127,
|
),
|
||||||
(
|
dtype=torch.int8)
|
||||||
num_experts,
|
|
||||||
hidden_size,
|
|
||||||
shard_intermediate_size // 2,
|
|
||||||
),
|
|
||||||
dtype=torch.int8,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
w1 = torch.randn(
|
w1 = torch.randn(num_experts,
|
||||||
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
|
shard_intermediate_size,
|
||||||
)
|
hidden_size,
|
||||||
w2 = torch.randn(
|
dtype=init_dtype)
|
||||||
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
|
w2 = torch.randn(num_experts,
|
||||||
)
|
hidden_size,
|
||||||
gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
|
shard_intermediate_size // 2,
|
||||||
|
dtype=init_dtype)
|
||||||
|
gating_output = torch.randn(num_iters,
|
||||||
|
num_tokens,
|
||||||
|
num_experts,
|
||||||
|
dtype=torch.float32)
|
||||||
|
|
||||||
w1_scale = None
|
w1_scale = None
|
||||||
w2_scale = None
|
w2_scale = None
|
||||||
a1_scale = None
|
a1_scale = None
|
||||||
a2_scale = None
|
a2_scale = None
|
||||||
if use_int8_w8a16:
|
if use_int8_w8a16:
|
||||||
w1_scale = torch.randn(
|
w1_scale = torch.randn((num_experts, 2 * shard_intermediate_size),
|
||||||
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32
|
dtype=torch.float32)
|
||||||
)
|
|
||||||
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
|
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
|
||||||
if use_fp8_w8a8:
|
if use_fp8_w8a8:
|
||||||
if block_quant_shape:
|
if block_quant_shape:
|
||||||
@ -97,14 +93,10 @@ def benchmark_config(
|
|||||||
n_tiles_w2 = (K + block_n - 1) // block_n
|
n_tiles_w2 = (K + block_n - 1) // block_n
|
||||||
k_tiles_w1 = (K + block_k - 1) // block_k
|
k_tiles_w1 = (K + block_k - 1) // block_k
|
||||||
k_tiles_w2 = (N + block_k - 1) // block_k
|
k_tiles_w2 = (N + block_k - 1) // block_k
|
||||||
w1_scale = (
|
w1_scale = torch.rand((E, n_tiles_w1, k_tiles_w1),
|
||||||
torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
|
dtype=torch.float32) * factor_for_scale
|
||||||
* factor_for_scale
|
w2_scale = torch.rand((E, n_tiles_w2, k_tiles_w2),
|
||||||
)
|
dtype=torch.float32) * factor_for_scale
|
||||||
w2_scale = (
|
|
||||||
torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
|
|
||||||
* factor_for_scale
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
w1_scale = torch.randn(num_experts, dtype=torch.float32)
|
w1_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||||
w2_scale = torch.randn(num_experts, dtype=torch.float32)
|
w2_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||||
@ -122,12 +114,10 @@ def benchmark_config(
|
|||||||
|
|
||||||
def run():
|
def run():
|
||||||
from vllm.model_executor.layers.fused_moe import override_config
|
from vllm.model_executor.layers.fused_moe import override_config
|
||||||
|
|
||||||
with override_config(config):
|
with override_config(config):
|
||||||
if use_deep_gemm:
|
if use_deep_gemm:
|
||||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||||
x, input_gating, topk, False
|
x, input_gating, topk, False)
|
||||||
)
|
|
||||||
return fused_experts(
|
return fused_experts(
|
||||||
x,
|
x,
|
||||||
w1,
|
w1,
|
||||||
@ -223,7 +213,8 @@ def get_rocm_tuning_space(use_fp16):
|
|||||||
return param_ranges
|
return param_ranges
|
||||||
|
|
||||||
|
|
||||||
def get_configs_compute_bound(use_fp16, block_quant_shape) -> list[dict[str, int]]:
|
def get_configs_compute_bound(use_fp16,
|
||||||
|
block_quant_shape) -> list[dict[str, int]]:
|
||||||
configs: list[BenchmarkConfig] = []
|
configs: list[BenchmarkConfig] = []
|
||||||
|
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_rocm():
|
||||||
@ -259,25 +250,20 @@ def get_configs_compute_bound(use_fp16, block_quant_shape) -> list[dict[str, int
|
|||||||
if block_quant_shape is not None and not use_fp16:
|
if block_quant_shape is not None and not use_fp16:
|
||||||
block_n, block_k = block_quant_shape[0], block_quant_shape[1]
|
block_n, block_k = block_quant_shape[0], block_quant_shape[1]
|
||||||
for config in configs[:]:
|
for config in configs[:]:
|
||||||
if (
|
if config["BLOCK_SIZE_K"] % block_k != 0 or config[
|
||||||
config["BLOCK_SIZE_K"] % block_k != 0
|
"BLOCK_SIZE_N"] % block_n != 0:
|
||||||
or config["BLOCK_SIZE_N"] % block_n != 0
|
|
||||||
):
|
|
||||||
configs.remove(config)
|
configs.remove(config)
|
||||||
return configs
|
return configs
|
||||||
|
|
||||||
|
|
||||||
def prune_rocm_search_space(
|
def prune_rocm_search_space(num_tokens, shard_intermediate_size, hidden_size,
|
||||||
num_tokens, shard_intermediate_size, hidden_size, search_space, is_fp16, topk
|
search_space, is_fp16, topk):
|
||||||
):
|
|
||||||
N1, K1 = shard_intermediate_size, hidden_size
|
N1, K1 = shard_intermediate_size, hidden_size
|
||||||
N2, K2 = hidden_size, shard_intermediate_size // 2
|
N2, K2 = hidden_size, shard_intermediate_size // 2
|
||||||
pruned_space_1 = prune_rocm_configs(
|
pruned_space_1 = prune_rocm_configs(num_tokens * topk, N1, K1,
|
||||||
num_tokens * topk, N1, K1, search_space, is_fp16
|
search_space, is_fp16)
|
||||||
)
|
pruned_space_2 = prune_rocm_configs(num_tokens * topk, N2, K2,
|
||||||
pruned_space_2 = prune_rocm_configs(
|
search_space, is_fp16)
|
||||||
num_tokens * topk, N2, K2, search_space, is_fp16
|
|
||||||
)
|
|
||||||
search_space = merge_unique_dicts(pruned_space_1, pruned_space_2)
|
search_space = merge_unique_dicts(pruned_space_1, pruned_space_2)
|
||||||
return search_space
|
return search_space
|
||||||
|
|
||||||
@ -315,14 +301,14 @@ def prune_rocm_configs(M, N, K, configs, is_fp16=True):
|
|||||||
SPLIT_K = config.get("SPLIT_K", 1)
|
SPLIT_K = config.get("SPLIT_K", 1)
|
||||||
GROUP_M = config.get("GROUP_SIZE_M")
|
GROUP_M = config.get("GROUP_SIZE_M")
|
||||||
if is_fp16:
|
if is_fp16:
|
||||||
if (
|
if (matrix_instr_nonkdim > BLOCK_SIZE_M
|
||||||
matrix_instr_nonkdim > BLOCK_SIZE_M
|
or matrix_instr_nonkdim > BLOCK_SIZE_N):
|
||||||
or matrix_instr_nonkdim > BLOCK_SIZE_N
|
|
||||||
):
|
|
||||||
continue
|
continue
|
||||||
if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M:
|
if (matrix_instr_nonkdim >= M
|
||||||
|
and matrix_instr_nonkdim != BLOCK_SIZE_M):
|
||||||
continue
|
continue
|
||||||
if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N:
|
if (matrix_instr_nonkdim >= N
|
||||||
|
and matrix_instr_nonkdim != BLOCK_SIZE_N):
|
||||||
continue
|
continue
|
||||||
# Skip BLOCK_SIZE that is too large compare to M/N
|
# Skip BLOCK_SIZE that is too large compare to M/N
|
||||||
# unless BLOCK_SIZE is already small enough
|
# unless BLOCK_SIZE is already small enough
|
||||||
@ -343,10 +329,8 @@ def prune_rocm_configs(M, N, K, configs, is_fp16=True):
|
|||||||
continue
|
continue
|
||||||
# out of shared memory resource
|
# out of shared memory resource
|
||||||
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
|
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
|
||||||
LDS = (
|
LDS = (BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a +
|
||||||
BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a
|
BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b)
|
||||||
+ BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b
|
|
||||||
)
|
|
||||||
if LDS > 65536:
|
if LDS > 65536:
|
||||||
continue
|
continue
|
||||||
# Skip small block sizes and num_warps for large gemm
|
# Skip small block sizes and num_warps for large gemm
|
||||||
@ -380,6 +364,7 @@ def merge_unique_dicts(list1, list2):
|
|||||||
|
|
||||||
@ray.remote(num_gpus=1)
|
@ray.remote(num_gpus=1)
|
||||||
class BenchmarkWorker:
|
class BenchmarkWorker:
|
||||||
|
|
||||||
def __init__(self, seed: int) -> None:
|
def __init__(self, seed: int) -> None:
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
current_platform.seed_everything(seed)
|
current_platform.seed_everything(seed)
|
||||||
@ -403,40 +388,36 @@ class BenchmarkWorker:
|
|||||||
use_deep_gemm: bool = False,
|
use_deep_gemm: bool = False,
|
||||||
) -> tuple[dict[str, int], float]:
|
) -> tuple[dict[str, int], float]:
|
||||||
current_platform.seed_everything(self.seed)
|
current_platform.seed_everything(self.seed)
|
||||||
dtype_str = get_config_dtype_str(
|
dtype_str = get_config_dtype_str(dtype,
|
||||||
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
)
|
use_fp8_w8a8=use_fp8_w8a8)
|
||||||
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||||
# is the intermediate size after silu_and_mul.
|
# is the intermediate size after silu_and_mul.
|
||||||
op_config = get_moe_configs(
|
op_config = get_moe_configs(num_experts, shard_intermediate_size // 2,
|
||||||
num_experts, shard_intermediate_size // 2, dtype_str
|
dtype_str)
|
||||||
)
|
|
||||||
if op_config is None:
|
if op_config is None:
|
||||||
config = get_default_config(
|
config = get_default_config(num_tokens,
|
||||||
num_tokens,
|
num_experts,
|
||||||
num_experts,
|
shard_intermediate_size,
|
||||||
shard_intermediate_size,
|
hidden_size,
|
||||||
hidden_size,
|
topk,
|
||||||
topk,
|
dtype_str,
|
||||||
dtype_str,
|
is_marlin=False)
|
||||||
is_marlin=False,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))]
|
config = op_config[min(op_config.keys(),
|
||||||
kernel_time = benchmark_config(
|
key=lambda x: abs(x - num_tokens))]
|
||||||
config,
|
kernel_time = benchmark_config(config,
|
||||||
num_tokens,
|
num_tokens,
|
||||||
num_experts,
|
num_experts,
|
||||||
shard_intermediate_size,
|
shard_intermediate_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
topk,
|
topk,
|
||||||
dtype,
|
dtype,
|
||||||
use_fp8_w8a8,
|
use_fp8_w8a8,
|
||||||
use_int8_w8a16,
|
use_int8_w8a16,
|
||||||
num_iters=100,
|
num_iters=100,
|
||||||
block_quant_shape=block_quant_shape,
|
block_quant_shape=block_quant_shape,
|
||||||
use_deep_gemm=use_deep_gemm,
|
use_deep_gemm=use_deep_gemm)
|
||||||
)
|
|
||||||
return config, kernel_time
|
return config, kernel_time
|
||||||
|
|
||||||
def tune(
|
def tune(
|
||||||
@ -457,14 +438,10 @@ class BenchmarkWorker:
|
|||||||
best_time = float("inf")
|
best_time = float("inf")
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_rocm():
|
||||||
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
|
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
|
||||||
search_space = prune_rocm_search_space(
|
search_space = prune_rocm_search_space(num_tokens,
|
||||||
num_tokens,
|
shard_intermediate_size,
|
||||||
shard_intermediate_size,
|
hidden_size, search_space,
|
||||||
hidden_size,
|
is_fp16, topk)
|
||||||
search_space,
|
|
||||||
is_fp16,
|
|
||||||
topk,
|
|
||||||
)
|
|
||||||
|
|
||||||
need_device_guard = False
|
need_device_guard = False
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_rocm():
|
||||||
@ -472,7 +449,8 @@ class BenchmarkWorker:
|
|||||||
if visible_device != f"{self.device_id}":
|
if visible_device != f"{self.device_id}":
|
||||||
need_device_guard = True
|
need_device_guard = True
|
||||||
|
|
||||||
with torch.cuda.device(self.device_id) if need_device_guard else nullcontext():
|
with torch.cuda.device(
|
||||||
|
self.device_id) if need_device_guard else nullcontext():
|
||||||
for config in tqdm(search_space):
|
for config in tqdm(search_space):
|
||||||
try:
|
try:
|
||||||
kernel_time = benchmark_config(
|
kernel_time = benchmark_config(
|
||||||
@ -487,8 +465,7 @@ class BenchmarkWorker:
|
|||||||
use_int8_w8a16,
|
use_int8_w8a16,
|
||||||
num_iters=20,
|
num_iters=20,
|
||||||
block_quant_shape=block_quant_shape,
|
block_quant_shape=block_quant_shape,
|
||||||
use_deep_gemm=use_deep_gemm,
|
use_deep_gemm=use_deep_gemm)
|
||||||
)
|
|
||||||
except triton.runtime.autotuner.OutOfResources:
|
except triton.runtime.autotuner.OutOfResources:
|
||||||
# Some configurations may be invalid and fail to compile.
|
# Some configurations may be invalid and fail to compile.
|
||||||
continue
|
continue
|
||||||
@ -504,44 +481,42 @@ class BenchmarkWorker:
|
|||||||
|
|
||||||
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
|
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
|
||||||
return {
|
return {
|
||||||
"BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
|
"BLOCK_SIZE_M":
|
||||||
"BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
|
config["BLOCK_SIZE_M"],
|
||||||
"BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
|
"BLOCK_SIZE_N":
|
||||||
"GROUP_SIZE_M": config["GROUP_SIZE_M"],
|
config["BLOCK_SIZE_N"],
|
||||||
"num_warps": config["num_warps"],
|
"BLOCK_SIZE_K":
|
||||||
"num_stages": config["num_stages"],
|
config["BLOCK_SIZE_K"],
|
||||||
**(
|
"GROUP_SIZE_M":
|
||||||
{"waves_per_eu": config["waves_per_eu"]} if "waves_per_eu" in config else {}
|
config["GROUP_SIZE_M"],
|
||||||
),
|
"num_warps":
|
||||||
**(
|
config["num_warps"],
|
||||||
{"matrix_instr_nonkdim": config["matrix_instr_nonkdim"]}
|
"num_stages":
|
||||||
if "matrix_instr_nonkdim" in config
|
config["num_stages"],
|
||||||
else {}
|
**({
|
||||||
),
|
"waves_per_eu": config["waves_per_eu"]
|
||||||
**({"kpack": config["kpack"]} if "kpack" in config else {}),
|
} if "waves_per_eu" in config else {}),
|
||||||
|
**({
|
||||||
|
"matrix_instr_nonkdim": config["matrix_instr_nonkdim"]
|
||||||
|
} if "matrix_instr_nonkdim" in config else {}),
|
||||||
|
**({
|
||||||
|
"kpack": config["kpack"]
|
||||||
|
} if "kpack" in config else {}),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def save_configs(
|
def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int,
|
||||||
configs: dict[int, BenchmarkConfig],
|
shard_intermediate_size: int, hidden_size: int, topk: int,
|
||||||
num_experts: int,
|
dtype: torch.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool,
|
||||||
shard_intermediate_size: int,
|
block_quant_shape: List[int]) -> None:
|
||||||
hidden_size: int,
|
dtype_str = get_config_dtype_str(dtype,
|
||||||
topk: int,
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
dtype: torch.dtype,
|
use_fp8_w8a8=use_fp8_w8a8)
|
||||||
use_fp8_w8a8: bool,
|
|
||||||
use_int8_w8a16: bool,
|
|
||||||
block_quant_shape: List[int],
|
|
||||||
) -> None:
|
|
||||||
dtype_str = get_config_dtype_str(
|
|
||||||
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
|
|
||||||
)
|
|
||||||
|
|
||||||
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||||
# is the intermediate size after silu_and_mul.
|
# is the intermediate size after silu_and_mul.
|
||||||
filename = get_config_file_name(
|
filename = get_config_file_name(num_experts, shard_intermediate_size // 2,
|
||||||
num_experts, shard_intermediate_size // 2, dtype_str, block_quant_shape
|
dtype_str, block_quant_shape)
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Writing best config to {filename}...")
|
print(f"Writing best config to {filename}...")
|
||||||
with open(filename, "w") as f:
|
with open(filename, "w") as f:
|
||||||
@ -550,16 +525,18 @@ def save_configs(
|
|||||||
|
|
||||||
|
|
||||||
def get_weight_block_size_safety(config, default_value=None):
|
def get_weight_block_size_safety(config, default_value=None):
|
||||||
quantization_config = getattr(config, "quantization_config", {})
|
|
||||||
|
quantization_config = getattr(config, 'quantization_config', {})
|
||||||
if isinstance(quantization_config, dict):
|
if isinstance(quantization_config, dict):
|
||||||
return quantization_config.get("weight_block_size", default_value)
|
return quantization_config.get('weight_block_size', default_value)
|
||||||
return default_value
|
return default_value
|
||||||
|
|
||||||
|
|
||||||
def main(args: argparse.Namespace):
|
def main(args: argparse.Namespace):
|
||||||
print(args)
|
print(args)
|
||||||
|
|
||||||
config = get_config(model=args.model, trust_remote_code=args.trust_remote_code)
|
config = get_config(model=args.model,
|
||||||
|
trust_remote_code=args.trust_remote_code)
|
||||||
if args.model_prefix:
|
if args.model_prefix:
|
||||||
config = getattr(config, args.model_prefix)
|
config = getattr(config, args.model_prefix)
|
||||||
config = SimpleNamespace(**config)
|
config = SimpleNamespace(**config)
|
||||||
@ -574,12 +551,14 @@ def main(args: argparse.Namespace):
|
|||||||
topk = config.num_experts_per_tok
|
topk = config.num_experts_per_tok
|
||||||
intermediate_size = config.intermediate_size
|
intermediate_size = config.intermediate_size
|
||||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||||
elif config.architectures[0] in ("DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"):
|
elif (config.architectures[0]
|
||||||
|
in ("DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM")):
|
||||||
E = config.n_routed_experts
|
E = config.n_routed_experts
|
||||||
topk = config.num_experts_per_tok
|
topk = config.num_experts_per_tok
|
||||||
intermediate_size = config.moe_intermediate_size
|
intermediate_size = config.moe_intermediate_size
|
||||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||||
elif config.architectures[0] in ("Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"):
|
elif config.architectures[0] in ("Qwen2MoeForCausalLM",
|
||||||
|
"Qwen3MoeForCausalLM"):
|
||||||
E = config.num_experts
|
E = config.num_experts
|
||||||
topk = config.num_experts_per_tok
|
topk = config.num_experts_per_tok
|
||||||
intermediate_size = config.moe_intermediate_size
|
intermediate_size = config.moe_intermediate_size
|
||||||
@ -594,35 +573,16 @@ def main(args: argparse.Namespace):
|
|||||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||||
|
|
||||||
hidden_size = config.hidden_size
|
hidden_size = config.hidden_size
|
||||||
dtype = (
|
dtype = torch.float16 if current_platform.is_rocm() else getattr(
|
||||||
torch.float16
|
torch, config.torch_dtype)
|
||||||
if current_platform.is_rocm()
|
|
||||||
else getattr(torch, config.torch_dtype)
|
|
||||||
)
|
|
||||||
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
||||||
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
||||||
block_quant_shape = get_weight_block_size_safety(config)
|
block_quant_shape = get_weight_block_size_safety(config)
|
||||||
|
|
||||||
if args.batch_size is None:
|
if args.batch_size is None:
|
||||||
batch_sizes = [
|
batch_sizes = [
|
||||||
1,
|
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
|
||||||
2,
|
2048, 3072, 4096
|
||||||
4,
|
|
||||||
8,
|
|
||||||
16,
|
|
||||||
24,
|
|
||||||
32,
|
|
||||||
48,
|
|
||||||
64,
|
|
||||||
96,
|
|
||||||
128,
|
|
||||||
256,
|
|
||||||
512,
|
|
||||||
1024,
|
|
||||||
1536,
|
|
||||||
2048,
|
|
||||||
3072,
|
|
||||||
4096,
|
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
batch_sizes = [args.batch_size]
|
batch_sizes = [args.batch_size]
|
||||||
@ -633,8 +593,7 @@ def main(args: argparse.Namespace):
|
|||||||
# Ray will set ROCR_VISIBLE_DEVICES for device visibility
|
# Ray will set ROCR_VISIBLE_DEVICES for device visibility
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Ray uses ROCR_VISIBLE_DEVICES to control device accessibility."
|
"Ray uses ROCR_VISIBLE_DEVICES to control device accessibility."
|
||||||
"Replacing HIP_VISIBLE_DEVICES with ROCR_VISIBLE_DEVICES."
|
"Replacing HIP_VISIBLE_DEVICES with ROCR_VISIBLE_DEVICES.")
|
||||||
)
|
|
||||||
val = os.environ["HIP_VISIBLE_DEVICES"]
|
val = os.environ["HIP_VISIBLE_DEVICES"]
|
||||||
os.environ["ROCR_VISIBLE_DEVICES"] = val
|
os.environ["ROCR_VISIBLE_DEVICES"] = val
|
||||||
del os.environ["HIP_VISIBLE_DEVICES"]
|
del os.environ["HIP_VISIBLE_DEVICES"]
|
||||||
@ -661,59 +620,25 @@ def main(args: argparse.Namespace):
|
|||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
configs = _distribute(
|
configs = _distribute(
|
||||||
"tune",
|
"tune", [(batch_size, E, shard_intermediate_size, hidden_size,
|
||||||
[
|
topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space,
|
||||||
(
|
block_quant_shape, use_deep_gemm)
|
||||||
batch_size,
|
for batch_size in batch_sizes])
|
||||||
E,
|
|
||||||
shard_intermediate_size,
|
|
||||||
hidden_size,
|
|
||||||
topk,
|
|
||||||
dtype,
|
|
||||||
use_fp8_w8a8,
|
|
||||||
use_int8_w8a16,
|
|
||||||
search_space,
|
|
||||||
block_quant_shape,
|
|
||||||
use_deep_gemm,
|
|
||||||
)
|
|
||||||
for batch_size in batch_sizes
|
|
||||||
],
|
|
||||||
)
|
|
||||||
best_configs = {
|
best_configs = {
|
||||||
M: sort_config(config) for M, config in zip(batch_sizes, configs)
|
M: sort_config(config)
|
||||||
|
for M, config in zip(batch_sizes, configs)
|
||||||
}
|
}
|
||||||
save_configs(
|
save_configs(best_configs, E, shard_intermediate_size, hidden_size,
|
||||||
best_configs,
|
topk, dtype, use_fp8_w8a8, use_int8_w8a16,
|
||||||
E,
|
block_quant_shape)
|
||||||
shard_intermediate_size,
|
|
||||||
hidden_size,
|
|
||||||
topk,
|
|
||||||
dtype,
|
|
||||||
use_fp8_w8a8,
|
|
||||||
use_int8_w8a16,
|
|
||||||
block_quant_shape,
|
|
||||||
)
|
|
||||||
end = time.time()
|
end = time.time()
|
||||||
print(f"Tuning took {end - start:.2f} seconds")
|
print(f"Tuning took {end - start:.2f} seconds")
|
||||||
else:
|
else:
|
||||||
outputs = _distribute(
|
outputs = _distribute(
|
||||||
"benchmark",
|
"benchmark",
|
||||||
[
|
[(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype,
|
||||||
(
|
use_fp8_w8a8, use_int8_w8a16, block_quant_shape, use_deep_gemm)
|
||||||
batch_size,
|
for batch_size in batch_sizes])
|
||||||
E,
|
|
||||||
shard_intermediate_size,
|
|
||||||
hidden_size,
|
|
||||||
topk,
|
|
||||||
dtype,
|
|
||||||
use_fp8_w8a8,
|
|
||||||
use_int8_w8a16,
|
|
||||||
block_quant_shape,
|
|
||||||
use_deep_gemm,
|
|
||||||
)
|
|
||||||
for batch_size in batch_sizes
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
|
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
|
||||||
print(f"Batch size: {batch_size}, config: {config}")
|
print(f"Batch size: {batch_size}, config: {config}")
|
||||||
@ -722,15 +647,18 @@ def main(args: argparse.Namespace):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = FlexibleArgumentParser()
|
parser = FlexibleArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument("--model",
|
||||||
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
type=str,
|
||||||
)
|
default="mistralai/Mixtral-8x7B-Instruct-v0.1")
|
||||||
parser.add_argument(
|
parser.add_argument("--tp-size",
|
||||||
"--tp-size", "-tp", "--tensor-parallel-size", type=int, default=2
|
"-tp",
|
||||||
)
|
"--tensor-parallel-size",
|
||||||
parser.add_argument(
|
type=int,
|
||||||
"--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto"
|
default=2)
|
||||||
)
|
parser.add_argument("--dtype",
|
||||||
|
type=str,
|
||||||
|
choices=["auto", "fp8_w8a8", "int8_w8a16"],
|
||||||
|
default="auto")
|
||||||
parser.add_argument("--use-deep-gemm", action="store_true")
|
parser.add_argument("--use-deep-gemm", action="store_true")
|
||||||
parser.add_argument("--seed", type=int, default=0)
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
parser.add_argument("--batch-size", type=int, required=False)
|
parser.add_argument("--batch-size", type=int, required=False)
|
||||||
|
|||||||
@ -8,9 +8,7 @@ import torch
|
|||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||||
_moe_permute,
|
_moe_permute, _moe_unpermute_and_reduce)
|
||||||
_moe_unpermute_and_reduce,
|
|
||||||
)
|
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import *
|
from vllm.model_executor.layers.fused_moe.fused_moe import *
|
||||||
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import *
|
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import *
|
||||||
from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize
|
from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize
|
||||||
@ -29,17 +27,15 @@ class BenchmarkConfig(TypedDict):
|
|||||||
num_stages: int
|
num_stages: int
|
||||||
|
|
||||||
|
|
||||||
def benchmark_permute(
|
def benchmark_permute(num_tokens: int,
|
||||||
num_tokens: int,
|
num_experts: int,
|
||||||
num_experts: int,
|
hidden_size: int,
|
||||||
hidden_size: int,
|
topk: int,
|
||||||
topk: int,
|
dtype: torch.dtype,
|
||||||
dtype: torch.dtype,
|
use_fp8_w8a8: bool,
|
||||||
use_fp8_w8a8: bool,
|
use_int8_w8a16: bool,
|
||||||
use_int8_w8a16: bool,
|
num_iters: int = 100,
|
||||||
num_iters: int = 100,
|
use_customized_permute: bool = False) -> float:
|
||||||
use_customized_permute: bool = False,
|
|
||||||
) -> float:
|
|
||||||
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
||||||
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||||
# output_hidden_states = torch.empty_like(hidden_states)
|
# output_hidden_states = torch.empty_like(hidden_states)
|
||||||
@ -50,41 +46,36 @@ def benchmark_permute(
|
|||||||
align_block_size = None
|
align_block_size = None
|
||||||
qhidden_states = hidden_states
|
qhidden_states = hidden_states
|
||||||
|
|
||||||
gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
|
gating_output = torch.randn(num_iters,
|
||||||
|
num_tokens,
|
||||||
|
num_experts,
|
||||||
|
dtype=torch.float32)
|
||||||
|
|
||||||
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
|
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
|
||||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||||
qhidden_states, input_gating, topk, False
|
qhidden_states, input_gating, topk, False)
|
||||||
)
|
|
||||||
|
|
||||||
def prepare(i: int):
|
def prepare(i: int):
|
||||||
input_gating.copy_(gating_output[i])
|
input_gating.copy_(gating_output[i])
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
if use_customized_permute:
|
if use_customized_permute:
|
||||||
(permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = (
|
(permuted_hidden_states, first_token_off, inv_perm_idx,
|
||||||
moe_permute(
|
m_indices) = moe_permute(
|
||||||
qhidden_states,
|
qhidden_states,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
token_expert_indices=token_expert_indices,
|
token_expert_indices=token_expert_indices,
|
||||||
topk=topk,
|
topk=topk,
|
||||||
n_expert=num_experts,
|
n_expert=num_experts,
|
||||||
n_local_expert=num_experts,
|
n_local_expert=num_experts,
|
||||||
expert_map=None,
|
expert_map=None,
|
||||||
align_block_size=align_block_size,
|
align_block_size=align_block_size,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
(
|
(permuted_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
|
||||||
permuted_hidden_states,
|
inv_perm) = _moe_permute(qhidden_states, None, topk_ids,
|
||||||
a1q_scale,
|
num_experts, None, align_block_size)
|
||||||
sorted_token_ids,
|
|
||||||
expert_ids,
|
|
||||||
inv_perm,
|
|
||||||
) = _moe_permute(
|
|
||||||
qhidden_states, None, topk_ids, num_experts, None, align_block_size
|
|
||||||
)
|
|
||||||
|
|
||||||
# JIT compilation & warmup
|
# JIT compilation & warmup
|
||||||
run()
|
run()
|
||||||
@ -120,17 +111,15 @@ def benchmark_permute(
|
|||||||
return avg
|
return avg
|
||||||
|
|
||||||
|
|
||||||
def benchmark_unpermute(
|
def benchmark_unpermute(num_tokens: int,
|
||||||
num_tokens: int,
|
num_experts: int,
|
||||||
num_experts: int,
|
hidden_size: int,
|
||||||
hidden_size: int,
|
topk: int,
|
||||||
topk: int,
|
dtype: torch.dtype,
|
||||||
dtype: torch.dtype,
|
use_fp8_w8a8: bool,
|
||||||
use_fp8_w8a8: bool,
|
use_int8_w8a16: bool,
|
||||||
use_int8_w8a16: bool,
|
num_iters: int = 100,
|
||||||
num_iters: int = 100,
|
use_customized_permute: bool = False) -> float:
|
||||||
use_customized_permute: bool = False,
|
|
||||||
) -> float:
|
|
||||||
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
||||||
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||||
output_hidden_states = torch.empty_like(hidden_states)
|
output_hidden_states = torch.empty_like(hidden_states)
|
||||||
@ -144,74 +133,46 @@ def benchmark_unpermute(
|
|||||||
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
|
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
|
||||||
|
|
||||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||||
qhidden_states, input_gating, topk, False
|
qhidden_states, input_gating, topk, False)
|
||||||
)
|
|
||||||
|
|
||||||
def prepare():
|
def prepare():
|
||||||
if use_customized_permute:
|
if use_customized_permute:
|
||||||
(permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = (
|
(permuted_hidden_states, first_token_off, inv_perm_idx,
|
||||||
moe_permute(
|
m_indices) = moe_permute(
|
||||||
qhidden_states,
|
qhidden_states,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
token_expert_indices=token_expert_indices,
|
token_expert_indices=token_expert_indices,
|
||||||
topk=topk,
|
topk=topk,
|
||||||
n_expert=num_experts,
|
n_expert=num_experts,
|
||||||
n_local_expert=num_experts,
|
n_local_expert=num_experts,
|
||||||
expert_map=None,
|
expert_map=None,
|
||||||
align_block_size=align_block_size,
|
align_block_size=align_block_size,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
# convert to fp16/bf16 as gemm output
|
# convert to fp16/bf16 as gemm output
|
||||||
return (
|
return (permuted_hidden_states.to(dtype), first_token_off,
|
||||||
permuted_hidden_states.to(dtype),
|
inv_perm_idx, m_indices)
|
||||||
first_token_off,
|
|
||||||
inv_perm_idx,
|
|
||||||
m_indices,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
(
|
(permuted_qhidden_states, a1q_scale, sorted_token_ids, expert_ids,
|
||||||
permuted_qhidden_states,
|
inv_perm) = _moe_permute(qhidden_states, None, topk_ids,
|
||||||
a1q_scale,
|
num_experts, None, align_block_size)
|
||||||
sorted_token_ids,
|
|
||||||
expert_ids,
|
|
||||||
inv_perm,
|
|
||||||
) = _moe_permute(
|
|
||||||
qhidden_states, None, topk_ids, num_experts, None, align_block_size
|
|
||||||
)
|
|
||||||
# convert to fp16/bf16 as gemm output
|
# convert to fp16/bf16 as gemm output
|
||||||
return (
|
return (permuted_qhidden_states.to(dtype), a1q_scale,
|
||||||
permuted_qhidden_states.to(dtype),
|
sorted_token_ids, expert_ids, inv_perm)
|
||||||
a1q_scale,
|
|
||||||
sorted_token_ids,
|
|
||||||
expert_ids,
|
|
||||||
inv_perm,
|
|
||||||
)
|
|
||||||
|
|
||||||
def run(input: tuple):
|
def run(input: tuple):
|
||||||
if use_customized_permute:
|
if use_customized_permute:
|
||||||
(permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = input
|
(permuted_hidden_states, first_token_off, inv_perm_idx,
|
||||||
moe_unpermute(
|
m_indices) = input
|
||||||
permuted_hidden_states,
|
moe_unpermute(permuted_hidden_states, topk_weights, topk_ids,
|
||||||
topk_weights,
|
inv_perm_idx, first_token_off, topk, num_experts,
|
||||||
topk_ids,
|
num_experts)
|
||||||
inv_perm_idx,
|
|
||||||
first_token_off,
|
|
||||||
topk,
|
|
||||||
num_experts,
|
|
||||||
num_experts,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
(
|
(permuted_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
|
||||||
permuted_hidden_states,
|
inv_perm) = input
|
||||||
a1q_scale,
|
_moe_unpermute_and_reduce(output_hidden_states,
|
||||||
sorted_token_ids,
|
permuted_hidden_states, inv_perm,
|
||||||
expert_ids,
|
topk_weights)
|
||||||
inv_perm,
|
|
||||||
) = input
|
|
||||||
_moe_unpermute_and_reduce(
|
|
||||||
output_hidden_states, permuted_hidden_states, inv_perm, topk_weights
|
|
||||||
)
|
|
||||||
|
|
||||||
# JIT compilation & warmup
|
# JIT compilation & warmup
|
||||||
input = prepare()
|
input = prepare()
|
||||||
@ -248,6 +209,7 @@ def benchmark_unpermute(
|
|||||||
|
|
||||||
@ray.remote(num_gpus=1)
|
@ray.remote(num_gpus=1)
|
||||||
class BenchmarkWorker:
|
class BenchmarkWorker:
|
||||||
|
|
||||||
def __init__(self, seed: int) -> None:
|
def __init__(self, seed: int) -> None:
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
current_platform.seed_everything(seed)
|
current_platform.seed_everything(seed)
|
||||||
@ -279,8 +241,7 @@ class BenchmarkWorker:
|
|||||||
use_fp8_w8a8,
|
use_fp8_w8a8,
|
||||||
use_int8_w8a16,
|
use_int8_w8a16,
|
||||||
num_iters=100,
|
num_iters=100,
|
||||||
use_customized_permute=use_customized_permute,
|
use_customized_permute=use_customized_permute)
|
||||||
)
|
|
||||||
unpermute_time = benchmark_unpermute(
|
unpermute_time = benchmark_unpermute(
|
||||||
num_tokens,
|
num_tokens,
|
||||||
num_experts,
|
num_experts,
|
||||||
@ -290,15 +251,15 @@ class BenchmarkWorker:
|
|||||||
use_fp8_w8a8,
|
use_fp8_w8a8,
|
||||||
use_int8_w8a16,
|
use_int8_w8a16,
|
||||||
num_iters=100,
|
num_iters=100,
|
||||||
use_customized_permute=use_customized_permute,
|
use_customized_permute=use_customized_permute)
|
||||||
)
|
|
||||||
return permute_time, unpermute_time
|
return permute_time, unpermute_time
|
||||||
|
|
||||||
|
|
||||||
def get_weight_block_size_safety(config, default_value=None):
|
def get_weight_block_size_safety(config, default_value=None):
|
||||||
quantization_config = getattr(config, "quantization_config", {})
|
|
||||||
|
quantization_config = getattr(config, 'quantization_config', {})
|
||||||
if isinstance(quantization_config, dict):
|
if isinstance(quantization_config, dict):
|
||||||
return quantization_config.get("weight_block_size", default_value)
|
return quantization_config.get('weight_block_size', default_value)
|
||||||
return default_value
|
return default_value
|
||||||
|
|
||||||
|
|
||||||
@ -306,21 +267,20 @@ def main(args: argparse.Namespace):
|
|||||||
print(args)
|
print(args)
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
args.model, trust_remote_code=args.trust_remote_code
|
args.model, trust_remote_code=args.trust_remote_code)
|
||||||
)
|
|
||||||
if config.architectures[0] == "DbrxForCausalLM":
|
if config.architectures[0] == "DbrxForCausalLM":
|
||||||
E = config.ffn_config.moe_num_experts
|
E = config.ffn_config.moe_num_experts
|
||||||
topk = config.ffn_config.moe_top_k
|
topk = config.ffn_config.moe_top_k
|
||||||
elif config.architectures[0] == "JambaForCausalLM":
|
elif config.architectures[0] == "JambaForCausalLM":
|
||||||
E = config.num_experts
|
E = config.num_experts
|
||||||
topk = config.num_experts_per_tok
|
topk = config.num_experts_per_tok
|
||||||
elif (
|
elif (config.architectures[0] == "DeepseekV3ForCausalLM"
|
||||||
config.architectures[0] == "DeepseekV3ForCausalLM"
|
or config.architectures[0] == "DeepseekV2ForCausalLM"):
|
||||||
or config.architectures[0] == "DeepseekV2ForCausalLM"
|
|
||||||
):
|
|
||||||
E = config.n_routed_experts
|
E = config.n_routed_experts
|
||||||
topk = config.num_experts_per_tok
|
topk = config.num_experts_per_tok
|
||||||
elif config.architectures[0] in ["Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"]:
|
elif config.architectures[0] in [
|
||||||
|
"Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"
|
||||||
|
]:
|
||||||
E = config.num_experts
|
E = config.num_experts
|
||||||
topk = config.num_experts_per_tok
|
topk = config.num_experts_per_tok
|
||||||
|
|
||||||
@ -339,24 +299,8 @@ def main(args: argparse.Namespace):
|
|||||||
|
|
||||||
if args.batch_size is None:
|
if args.batch_size is None:
|
||||||
batch_sizes = [
|
batch_sizes = [
|
||||||
1,
|
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
|
||||||
2,
|
2048, 3072, 4096
|
||||||
4,
|
|
||||||
8,
|
|
||||||
16,
|
|
||||||
24,
|
|
||||||
32,
|
|
||||||
48,
|
|
||||||
64,
|
|
||||||
96,
|
|
||||||
128,
|
|
||||||
256,
|
|
||||||
512,
|
|
||||||
1024,
|
|
||||||
1536,
|
|
||||||
2048,
|
|
||||||
3072,
|
|
||||||
4096,
|
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
batch_sizes = [args.batch_size]
|
batch_sizes = [args.batch_size]
|
||||||
@ -377,21 +321,9 @@ def main(args: argparse.Namespace):
|
|||||||
return ray.get(outputs)
|
return ray.get(outputs)
|
||||||
|
|
||||||
outputs = _distribute(
|
outputs = _distribute(
|
||||||
"benchmark",
|
"benchmark", [(batch_size, E, hidden_size, topk, dtype, use_fp8_w8a8,
|
||||||
[
|
use_int8_w8a16, use_customized_permute)
|
||||||
(
|
for batch_size in batch_sizes])
|
||||||
batch_size,
|
|
||||||
E,
|
|
||||||
hidden_size,
|
|
||||||
topk,
|
|
||||||
dtype,
|
|
||||||
use_fp8_w8a8,
|
|
||||||
use_int8_w8a16,
|
|
||||||
use_customized_permute,
|
|
||||||
)
|
|
||||||
for batch_size in batch_sizes
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
for batch_size, (permute, unpermute) in zip(batch_sizes, outputs):
|
for batch_size, (permute, unpermute) in zip(batch_sizes, outputs):
|
||||||
print(f"Batch size: {batch_size}")
|
print(f"Batch size: {batch_size}")
|
||||||
@ -401,12 +333,13 @@ def main(args: argparse.Namespace):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = FlexibleArgumentParser()
|
parser = FlexibleArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument("--model",
|
||||||
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
type=str,
|
||||||
)
|
default="mistralai/Mixtral-8x7B-Instruct-v0.1")
|
||||||
parser.add_argument(
|
parser.add_argument("--dtype",
|
||||||
"--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto"
|
type=str,
|
||||||
)
|
choices=["auto", "fp8_w8a8", "int8_w8a16"],
|
||||||
|
default="auto")
|
||||||
parser.add_argument("--use-customized-permute", action="store_true")
|
parser.add_argument("--use-customized-permute", action="store_true")
|
||||||
parser.add_argument("--seed", type=int, default=0)
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
parser.add_argument("--batch-size", type=int, required=False)
|
parser.add_argument("--batch-size", type=int, required=False)
|
||||||
|
|||||||
@ -9,11 +9,8 @@ import torch
|
|||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import (
|
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
|
||||||
STR_DTYPE_TO_TORCH_DTYPE,
|
create_kv_caches_with_random)
|
||||||
FlexibleArgumentParser,
|
|
||||||
create_kv_caches_with_random,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -41,15 +38,19 @@ def main(
|
|||||||
current_platform.seed_everything(seed)
|
current_platform.seed_everything(seed)
|
||||||
|
|
||||||
scale = float(1.0 / (head_size**0.5))
|
scale = float(1.0 / (head_size**0.5))
|
||||||
query = torch.empty(
|
query = torch.empty(num_seqs,
|
||||||
num_seqs, num_query_heads, head_size, dtype=dtype, device=device
|
num_query_heads,
|
||||||
)
|
head_size,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device)
|
||||||
query.uniform_(-scale, scale)
|
query.uniform_(-scale, scale)
|
||||||
|
|
||||||
assert num_query_heads % num_kv_heads == 0
|
assert num_query_heads % num_kv_heads == 0
|
||||||
alibi_slopes = None
|
alibi_slopes = None
|
||||||
if use_alibi:
|
if use_alibi:
|
||||||
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float, device=device)
|
alibi_slopes = torch.randn(num_query_heads,
|
||||||
|
dtype=torch.float,
|
||||||
|
device=device)
|
||||||
|
|
||||||
seq_lens = [seq_len for _ in range(num_seqs)]
|
seq_lens = [seq_len for _ in range(num_seqs)]
|
||||||
max_seq_len = max(seq_lens)
|
max_seq_len = max(seq_lens)
|
||||||
@ -60,23 +61,24 @@ def main(
|
|||||||
block_tables_lst: list[list[int]] = []
|
block_tables_lst: list[list[int]] = []
|
||||||
for _ in range(num_seqs):
|
for _ in range(num_seqs):
|
||||||
block_table = [
|
block_table = [
|
||||||
random.randint(0, NUM_BLOCKS - 1) for _ in range(max_num_blocks_per_seq)
|
random.randint(0, NUM_BLOCKS - 1)
|
||||||
|
for _ in range(max_num_blocks_per_seq)
|
||||||
]
|
]
|
||||||
block_tables_lst.append(block_table)
|
block_tables_lst.append(block_table)
|
||||||
|
|
||||||
block_tables = torch.tensor(block_tables_lst, dtype=torch.int, device=device)
|
block_tables = torch.tensor(block_tables_lst,
|
||||||
|
dtype=torch.int,
|
||||||
|
device=device)
|
||||||
|
|
||||||
# Create the KV cache.
|
# Create the KV cache.
|
||||||
key_caches, value_caches = create_kv_caches_with_random(
|
key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS,
|
||||||
NUM_BLOCKS,
|
block_size,
|
||||||
block_size,
|
1,
|
||||||
1,
|
num_kv_heads,
|
||||||
num_kv_heads,
|
head_size,
|
||||||
head_size,
|
kv_cache_dtype,
|
||||||
kv_cache_dtype,
|
dtype,
|
||||||
dtype,
|
device=device)
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||||
|
|
||||||
# Prepare for the paged attention kernel.
|
# Prepare for the paged attention kernel.
|
||||||
@ -84,11 +86,11 @@ def main(
|
|||||||
if version == "v2":
|
if version == "v2":
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_rocm():
|
||||||
global PARTITION_SIZE
|
global PARTITION_SIZE
|
||||||
if not args.custom_paged_attn and not current_platform.is_navi():
|
if not args.custom_paged_attn:
|
||||||
PARTITION_SIZE = 1024
|
PARTITION_SIZE = 1024
|
||||||
else:
|
else:
|
||||||
PARTITION_SIZE = PARTITION_SIZE_ROCM
|
PARTITION_SIZE = PARTITION_SIZE_ROCM
|
||||||
num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE
|
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
|
||||||
tmp_output = torch.empty(
|
tmp_output = torch.empty(
|
||||||
size=(num_seqs, num_query_heads, num_partitions, head_size),
|
size=(num_seqs, num_query_heads, num_partitions, head_size),
|
||||||
dtype=output.dtype,
|
dtype=output.dtype,
|
||||||
@ -108,7 +110,9 @@ def main(
|
|||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
# Using default kv_scale
|
# Using default kv_scale
|
||||||
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
k_scale = v_scale = torch.tensor(1.0,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=device)
|
||||||
|
|
||||||
for _ in range(num_iters):
|
for _ in range(num_iters):
|
||||||
if version == "v1":
|
if version == "v1":
|
||||||
@ -162,7 +166,6 @@ def main(
|
|||||||
scale,
|
scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
None,
|
|
||||||
block_size,
|
block_size,
|
||||||
max_seq_len,
|
max_seq_len,
|
||||||
alibi_slopes,
|
alibi_slopes,
|
||||||
@ -192,29 +195,30 @@ def main(
|
|||||||
print(f"Kernel running time: {latency * 1000000:.3f} us")
|
print(f"Kernel running time: {latency * 1000000:.3f} us")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
logger.warning(
|
logger.warning("This script benchmarks the paged attention kernel. "
|
||||||
"This script benchmarks the paged attention kernel. "
|
"By default this is no longer used in vLLM inference.")
|
||||||
"By default this is no longer used in vLLM inference."
|
|
||||||
)
|
|
||||||
|
|
||||||
parser = FlexibleArgumentParser(description="Benchmark the paged attention kernel.")
|
parser = FlexibleArgumentParser(
|
||||||
parser.add_argument("--version", type=str, choices=["v1", "v2"], default="v2")
|
description="Benchmark the paged attention kernel.")
|
||||||
|
parser.add_argument("--version",
|
||||||
|
type=str,
|
||||||
|
choices=["v1", "v2"],
|
||||||
|
default="v2")
|
||||||
parser.add_argument("--batch-size", type=int, default=8)
|
parser.add_argument("--batch-size", type=int, default=8)
|
||||||
parser.add_argument("--seq-len", type=int, default=4096)
|
parser.add_argument("--seq-len", type=int, default=4096)
|
||||||
parser.add_argument("--num-query-heads", type=int, default=64)
|
parser.add_argument("--num-query-heads", type=int, default=64)
|
||||||
parser.add_argument("--num-kv-heads", type=int, default=8)
|
parser.add_argument("--num-kv-heads", type=int, default=8)
|
||||||
parser.add_argument(
|
parser.add_argument("--head-size",
|
||||||
"--head-size",
|
type=int,
|
||||||
type=int,
|
choices=[64, 80, 96, 112, 120, 128, 192, 256],
|
||||||
choices=[64, 80, 96, 112, 120, 128, 192, 256],
|
default=128)
|
||||||
default=128,
|
|
||||||
)
|
|
||||||
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
|
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
|
||||||
parser.add_argument("--use-alibi", action="store_true")
|
parser.add_argument("--use-alibi", action="store_true")
|
||||||
parser.add_argument(
|
parser.add_argument("--dtype",
|
||||||
"--dtype", type=str, choices=["half", "bfloat16", "float"], default="half"
|
type=str,
|
||||||
)
|
choices=["half", "bfloat16", "float"],
|
||||||
|
default="half")
|
||||||
parser.add_argument("--seed", type=int, default=0)
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
parser.add_argument("--profile", action="store_true")
|
parser.add_argument("--profile", action="store_true")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -224,11 +228,10 @@ if __name__ == "__main__":
|
|||||||
default="auto",
|
default="auto",
|
||||||
help="Data type for kv cache storage. If 'auto', will use model "
|
help="Data type for kv cache storage. If 'auto', will use model "
|
||||||
"data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
|
"data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
|
||||||
"ROCm (AMD GPU) supports fp8 (=fp8_e4m3)",
|
"ROCm (AMD GPU) supports fp8 (=fp8_e4m3)")
|
||||||
)
|
parser.add_argument("--custom-paged-attn",
|
||||||
parser.add_argument(
|
action="store_true",
|
||||||
"--custom-paged-attn", action="store_true", help="Use custom paged attention"
|
help="Use custom paged attention")
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
print(args)
|
print(args)
|
||||||
|
|
||||||
|
|||||||
@ -10,17 +10,15 @@ from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
|
|||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def main(
|
def main(num_tokens: int,
|
||||||
num_tokens: int,
|
hidden_size: int,
|
||||||
hidden_size: int,
|
static_scale: bool,
|
||||||
static_scale: bool,
|
quant_dtype: torch.dtype,
|
||||||
quant_dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
dtype: torch.dtype,
|
seed: int = 0,
|
||||||
seed: int = 0,
|
do_profile: bool = False,
|
||||||
do_profile: bool = False,
|
num_warmup_iters: int = 5,
|
||||||
num_warmup_iters: int = 5,
|
num_iters: int = 100) -> None:
|
||||||
num_iters: int = 100,
|
|
||||||
) -> None:
|
|
||||||
current_platform.seed_everything(seed)
|
current_platform.seed_everything(seed)
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
|
|
||||||
@ -58,7 +56,7 @@ def main(
|
|||||||
print(f"Kernel running time: {latency * 1000000:.3f} us")
|
print(f"Kernel running time: {latency * 1000000:.3f} us")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
|
|
||||||
def to_torch_dtype(dt):
|
def to_torch_dtype(dt):
|
||||||
if dt == "int8":
|
if dt == "int8":
|
||||||
@ -68,40 +66,37 @@ if __name__ == "__main__":
|
|||||||
raise ValueError(f"Unsupported dtype: {dt}")
|
raise ValueError(f"Unsupported dtype: {dt}")
|
||||||
|
|
||||||
parser = FlexibleArgumentParser(
|
parser = FlexibleArgumentParser(
|
||||||
description="Benchmark the quantization (fp8 or int8) kernel."
|
description="Benchmark the quantization (fp8 or int8) kernel.")
|
||||||
)
|
|
||||||
parser.add_argument("--num-tokens", type=int, default=4096)
|
parser.add_argument("--num-tokens", type=int, default=4096)
|
||||||
parser.add_argument("--hidden-size", type=int, default=8192)
|
parser.add_argument("--hidden-size", type=int, default=8192)
|
||||||
parser.add_argument("--static-scale", action="store_true")
|
parser.add_argument("--static-scale", action="store_true")
|
||||||
parser.add_argument(
|
parser.add_argument("--quant-dtype",
|
||||||
"--quant-dtype", type=str, choices=["fp8", "int8"], default="int8"
|
type=str,
|
||||||
)
|
choices=["fp8", "int8"],
|
||||||
parser.add_argument(
|
default="int8")
|
||||||
"--dtype", type=str, choices=["half", "bfloat16", "float"], default="half"
|
parser.add_argument("--dtype",
|
||||||
)
|
type=str,
|
||||||
|
choices=["half", "bfloat16", "float"],
|
||||||
|
default="half")
|
||||||
|
|
||||||
parser.add_argument("--seed", type=int, default=0)
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
parser.add_argument("--profile", action="store_true")
|
parser.add_argument("--profile", action="store_true")
|
||||||
parser.add_argument("--num-warmup-iters", type=int, default=5)
|
parser.add_argument("--num-warmup-iters", type=int, default=5)
|
||||||
parser.add_argument(
|
parser.add_argument("--num-iters",
|
||||||
"--num-iters",
|
type=int,
|
||||||
type=int,
|
default=100,
|
||||||
default=100,
|
help="Number of benchmark iterations. "
|
||||||
help="Number of benchmark iterations. "
|
"If --profile is set, this number is ignored")
|
||||||
"If --profile is set, this number is ignored",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
print(args)
|
print(args)
|
||||||
|
|
||||||
main(
|
main(num_tokens=args.num_tokens,
|
||||||
num_tokens=args.num_tokens,
|
hidden_size=args.hidden_size,
|
||||||
hidden_size=args.hidden_size,
|
static_scale=args.static_scale,
|
||||||
static_scale=args.static_scale,
|
quant_dtype=to_torch_dtype(args.quant_dtype),
|
||||||
quant_dtype=to_torch_dtype(args.quant_dtype),
|
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
|
||||||
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
|
seed=args.seed,
|
||||||
seed=args.seed,
|
do_profile=args.profile,
|
||||||
do_profile=args.profile,
|
num_warmup_iters=args.num_warmup_iters,
|
||||||
num_warmup_iters=args.num_warmup_iters,
|
num_iters=args.num_iters)
|
||||||
num_iters=args.num_iters,
|
|
||||||
)
|
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from vllm.triton_utils import triton
|
|||||||
|
|
||||||
|
|
||||||
class HuggingFaceRMSNorm(nn.Module):
|
class HuggingFaceRMSNorm(nn.Module):
|
||||||
|
|
||||||
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
|
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||||
@ -113,19 +114,23 @@ def rmsnorm_vllm(
|
|||||||
|
|
||||||
def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
|
def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
|
x = torch.randn(batch_size,
|
||||||
|
seq_len,
|
||||||
|
hidden_size,
|
||||||
|
dtype=dtype,
|
||||||
|
device="cuda")
|
||||||
weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
|
weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
|
||||||
residual = torch.randn_like(x) if use_residual else None
|
residual = torch.randn_like(x) if use_residual else None
|
||||||
|
|
||||||
output_naive = rmsnorm_naive(
|
output_naive = rmsnorm_naive(
|
||||||
x.clone(), weight, residual.clone() if residual is not None else None
|
x.clone(), weight,
|
||||||
)
|
residual.clone() if residual is not None else None)
|
||||||
output_flashinfer = rmsnorm_flashinfer(
|
output_flashinfer = rmsnorm_flashinfer(
|
||||||
x.clone(), weight, residual.clone() if residual is not None else None
|
x.clone(), weight,
|
||||||
)
|
residual.clone() if residual is not None else None)
|
||||||
output_vllm = rmsnorm_vllm(
|
output_vllm = rmsnorm_vllm(
|
||||||
x.clone(), weight, residual.clone() if residual is not None else None
|
x.clone(), weight,
|
||||||
)
|
residual.clone() if residual is not None else None)
|
||||||
|
|
||||||
if use_residual:
|
if use_residual:
|
||||||
output_naive = output_naive[0]
|
output_naive = output_naive[0]
|
||||||
@ -136,9 +141,9 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
|
|||||||
print(f"FlashInfer output={output_flashinfer}")
|
print(f"FlashInfer output={output_flashinfer}")
|
||||||
print(f"vLLM output={output_vllm}")
|
print(f"vLLM output={output_vllm}")
|
||||||
|
|
||||||
if torch.allclose(
|
if torch.allclose(output_naive, output_flashinfer, atol=1e-2,
|
||||||
output_naive, output_flashinfer, atol=1e-2, rtol=1e-2
|
rtol=1e-2) and torch.allclose(
|
||||||
) and torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2):
|
output_naive, output_vllm, atol=1e-2, rtol=1e-2):
|
||||||
print("✅ All implementations match")
|
print("✅ All implementations match")
|
||||||
else:
|
else:
|
||||||
print("❌ Implementations differ")
|
print("❌ Implementations differ")
|
||||||
@ -147,10 +152,12 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
|
|||||||
batch_size_range = [2**i for i in range(0, 7, 2)]
|
batch_size_range = [2**i for i in range(0, 7, 2)]
|
||||||
seq_length_range = [2**i for i in range(6, 11, 1)]
|
seq_length_range = [2**i for i in range(6, 11, 1)]
|
||||||
head_num_range = [32, 48]
|
head_num_range = [32, 48]
|
||||||
configs = list(itertools.product(head_num_range, batch_size_range, seq_length_range))
|
configs = list(
|
||||||
|
itertools.product(head_num_range, batch_size_range, seq_length_range))
|
||||||
|
|
||||||
|
|
||||||
def get_benchmark(use_residual):
|
def get_benchmark(use_residual):
|
||||||
|
|
||||||
@triton.testing.perf_report(
|
@triton.testing.perf_report(
|
||||||
triton.testing.Benchmark(
|
triton.testing.Benchmark(
|
||||||
x_names=["head_num", "batch_size", "seq_len"],
|
x_names=["head_num", "batch_size", "seq_len"],
|
||||||
@ -160,15 +167,19 @@ def get_benchmark(use_residual):
|
|||||||
line_names=["HuggingFace", "FlashInfer", "vLLM"],
|
line_names=["HuggingFace", "FlashInfer", "vLLM"],
|
||||||
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
|
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
|
||||||
ylabel="us",
|
ylabel="us",
|
||||||
plot_name=f"rmsnorm-perf-{'with' if use_residual else 'without'}-residual",
|
plot_name=
|
||||||
|
f"rmsnorm-perf-{'with' if use_residual else 'without'}-residual",
|
||||||
args={},
|
args={},
|
||||||
)
|
))
|
||||||
)
|
|
||||||
def benchmark(head_num, batch_size, seq_len, provider):
|
def benchmark(head_num, batch_size, seq_len, provider):
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
hidden_size = head_num * 128 # assuming head_dim = 128
|
hidden_size = head_num * 128 # assuming head_dim = 128
|
||||||
|
|
||||||
x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
|
x = torch.randn(batch_size,
|
||||||
|
seq_len,
|
||||||
|
hidden_size,
|
||||||
|
dtype=dtype,
|
||||||
|
device="cuda")
|
||||||
weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
|
weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
|
||||||
residual = torch.randn_like(x) if use_residual else None
|
residual = torch.randn_like(x) if use_residual else None
|
||||||
|
|
||||||
@ -229,9 +240,9 @@ if __name__ == "__main__":
|
|||||||
default=4096,
|
default=4096,
|
||||||
help="Hidden size (2nd dimension) of the sequence",
|
help="Hidden size (2nd dimension) of the sequence",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument("--use-residual",
|
||||||
"--use-residual", action="store_true", help="Whether to use residual connection"
|
action="store_true",
|
||||||
)
|
help="Whether to use residual connection")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save-path",
|
"--save-path",
|
||||||
type=str,
|
type=str,
|
||||||
@ -242,12 +253,10 @@ if __name__ == "__main__":
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Run correctness test
|
# Run correctness test
|
||||||
calculate_diff(
|
calculate_diff(batch_size=args.batch_size,
|
||||||
batch_size=args.batch_size,
|
seq_len=args.seq_len,
|
||||||
seq_len=args.seq_len,
|
hidden_size=args.hidden_size,
|
||||||
hidden_size=args.hidden_size,
|
use_residual=args.use_residual)
|
||||||
use_residual=args.use_residual,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get the benchmark function with proper use_residual setting
|
# Get the benchmark function with proper use_residual setting
|
||||||
benchmark = get_benchmark(args.use_residual)
|
benchmark = get_benchmark(args.use_residual)
|
||||||
|
|||||||
@ -6,7 +6,8 @@ from typing import Optional
|
|||||||
import nvtx
|
import nvtx
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding, get_rope
|
from vllm.model_executor.layers.rotary_embedding import (RotaryEmbedding,
|
||||||
|
get_rope)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
@ -31,49 +32,40 @@ def benchmark_rope_kernels_multi_lora(
|
|||||||
# silulating serving 4 LoRAs
|
# silulating serving 4 LoRAs
|
||||||
scaling_factors = [1, 2, 4, 8]
|
scaling_factors = [1, 2, 4, 8]
|
||||||
# batched RoPE can take multiple scaling factors
|
# batched RoPE can take multiple scaling factors
|
||||||
batched_rope = get_rope(
|
batched_rope = get_rope(head_size, rotary_dim, max_position, base,
|
||||||
head_size,
|
is_neox_style, {
|
||||||
rotary_dim,
|
"rope_type": "linear",
|
||||||
max_position,
|
"factor": tuple(scaling_factors)
|
||||||
base,
|
})
|
||||||
is_neox_style,
|
|
||||||
{"rope_type": "linear", "factor": tuple(scaling_factors)},
|
|
||||||
)
|
|
||||||
# non-batched RoPE takes only one scaling factor, we create multiple
|
# non-batched RoPE takes only one scaling factor, we create multiple
|
||||||
# instances to simulate the same behavior
|
# instances to simulate the same behavior
|
||||||
non_batched_ropes: list[RotaryEmbedding] = []
|
non_batched_ropes: list[RotaryEmbedding] = []
|
||||||
for scaling_factor in scaling_factors:
|
for scaling_factor in scaling_factors:
|
||||||
non_batched_ropes.append(
|
non_batched_ropes.append(
|
||||||
get_rope(
|
get_rope(head_size, rotary_dim, max_position, base, is_neox_style,
|
||||||
head_size,
|
{
|
||||||
rotary_dim,
|
"rope_type": "linear",
|
||||||
max_position,
|
"factor": (scaling_factor, )
|
||||||
base,
|
}))
|
||||||
is_neox_style,
|
|
||||||
{"rope_type": "linear", "factor": (scaling_factor,)},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
||||||
query = torch.randn(batch_size, seq_len, num_heads * head_size, dtype=dtype)
|
query = torch.randn(batch_size,
|
||||||
|
seq_len,
|
||||||
|
num_heads * head_size,
|
||||||
|
dtype=dtype)
|
||||||
key = torch.randn_like(query)
|
key = torch.randn_like(query)
|
||||||
|
|
||||||
# create query offsets for batched RoPE, we concat multiple kv cache
|
# create query offsets for batched RoPE, we concat multiple kv cache
|
||||||
# together and each query needs to find the right kv cache of its type
|
# together and each query needs to find the right kv cache of its type
|
||||||
offset_map = torch.tensor(
|
offset_map = torch.tensor(
|
||||||
list(
|
list(
|
||||||
accumulate(
|
accumulate([0] + [
|
||||||
[0]
|
max_position * scaling_factor * 2
|
||||||
+ [
|
for scaling_factor in scaling_factors[:-1]
|
||||||
max_position * scaling_factor * 2
|
])))
|
||||||
for scaling_factor in scaling_factors[:-1]
|
query_types = torch.randint(0,
|
||||||
]
|
len(scaling_factors), (batch_size, seq_len),
|
||||||
)
|
device=device)
|
||||||
)
|
|
||||||
)
|
|
||||||
query_types = torch.randint(
|
|
||||||
0, len(scaling_factors), (batch_size, seq_len), device=device
|
|
||||||
)
|
|
||||||
# map query types to offsets
|
# map query types to offsets
|
||||||
query_offsets = offset_map[query_types]
|
query_offsets = offset_map[query_types]
|
||||||
# the kernel takes flattened offsets
|
# the kernel takes flattened offsets
|
||||||
@ -94,28 +86,27 @@ def benchmark_rope_kernels_multi_lora(
|
|||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
parser = FlexibleArgumentParser(
|
parser = FlexibleArgumentParser(
|
||||||
description="Benchmark the rotary embedding kernels."
|
description="Benchmark the rotary embedding kernels.")
|
||||||
)
|
|
||||||
parser.add_argument("--is-neox-style", type=bool, default=True)
|
parser.add_argument("--is-neox-style", type=bool, default=True)
|
||||||
parser.add_argument("--batch-size", type=int, default=16)
|
parser.add_argument("--batch-size", type=int, default=16)
|
||||||
parser.add_argument("--seq-len", type=int, default=512)
|
parser.add_argument("--seq-len", type=int, default=512)
|
||||||
parser.add_argument("--num-heads", type=int, default=8)
|
parser.add_argument("--num-heads", type=int, default=8)
|
||||||
parser.add_argument(
|
parser.add_argument("--head-size",
|
||||||
"--head-size",
|
type=int,
|
||||||
type=int,
|
choices=[64, 80, 96, 112, 120, 128, 192, 256],
|
||||||
choices=[64, 80, 96, 112, 120, 128, 192, 256],
|
default=128)
|
||||||
default=128,
|
|
||||||
)
|
|
||||||
parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
|
parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
|
||||||
parser.add_argument(
|
parser.add_argument("--dtype",
|
||||||
"--dtype", type=str, choices=["bfloat16", "float"], default="float"
|
type=str,
|
||||||
)
|
choices=["bfloat16", "float"],
|
||||||
|
default="float")
|
||||||
parser.add_argument("--seed", type=int, default=0)
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
parser.add_argument(
|
parser.add_argument("--device",
|
||||||
"--device", type=str, choices=["cuda:0", "cuda:1"], default="cuda:0"
|
type=str,
|
||||||
)
|
choices=["cuda:0", "cuda:1"],
|
||||||
|
default="cuda:0")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
print(args)
|
print(args)
|
||||||
|
|
||||||
|
|||||||
@ -14,16 +14,14 @@ import tqdm
|
|||||||
import triton
|
import triton
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
_w8a8_block_fp8_matmul,
|
_w8a8_block_fp8_matmul)
|
||||||
)
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
mp.set_start_method("spawn", force=True)
|
mp.set_start_method("spawn", force=True)
|
||||||
|
|
||||||
assert current_platform.is_cuda(), (
|
assert current_platform.is_cuda(
|
||||||
"Only support tune w8a8 block fp8 kernel on CUDA device."
|
), "Only support tune w8a8 block fp8 kernel on CUDA device."
|
||||||
)
|
|
||||||
|
|
||||||
DTYPE_MAP = {
|
DTYPE_MAP = {
|
||||||
"float32": torch.float32,
|
"float32": torch.float32,
|
||||||
@ -42,7 +40,7 @@ def w8a8_block_matmul(
|
|||||||
config: dict[str, Any],
|
config: dict[str, Any],
|
||||||
output_dtype: torch.dtype = torch.float16,
|
output_dtype: torch.dtype = torch.float16,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""This function performs matrix multiplication with
|
"""This function performs matrix multiplication with
|
||||||
block-wise quantization.
|
block-wise quantization.
|
||||||
|
|
||||||
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
|
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
|
||||||
@ -53,7 +51,7 @@ def w8a8_block_matmul(
|
|||||||
B: The input tensor, e.g., weight.
|
B: The input tensor, e.g., weight.
|
||||||
As: The per-token-group quantization scale for `A`.
|
As: The per-token-group quantization scale for `A`.
|
||||||
Bs: The per-block quantization scale for `B`.
|
Bs: The per-block quantization scale for `B`.
|
||||||
block_size: The block size for per-block quantization.
|
block_size: The block size for per-block quantization.
|
||||||
It should be 2-dim, e.g., [128, 128].
|
It should be 2-dim, e.g., [128, 128].
|
||||||
output_dytpe: The dtype of the returned tensor.
|
output_dytpe: The dtype of the returned tensor.
|
||||||
|
|
||||||
@ -73,18 +71,18 @@ def w8a8_block_matmul(
|
|||||||
assert triton.cdiv(N, block_n) == Bs.shape[0]
|
assert triton.cdiv(N, block_n) == Bs.shape[0]
|
||||||
assert triton.cdiv(K, block_k) == Bs.shape[1]
|
assert triton.cdiv(K, block_k) == Bs.shape[1]
|
||||||
|
|
||||||
C_shape = A.shape[:-1] + (N,)
|
C_shape = A.shape[:-1] + (N, )
|
||||||
C = A.new_empty(C_shape, dtype=output_dtype)
|
C = A.new_empty(C_shape, dtype=output_dtype)
|
||||||
|
|
||||||
def grid(META):
|
def grid(META):
|
||||||
return (
|
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) *
|
||||||
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
triton.cdiv(N, META["BLOCK_SIZE_N"]), )
|
||||||
)
|
|
||||||
|
|
||||||
if A.dtype == torch.float8_e4m3fn:
|
if A.dtype == torch.float8_e4m3fn:
|
||||||
kernel = _w8a8_block_fp8_matmul
|
kernel = _w8a8_block_fp8_matmul
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Currently, only support tune w8a8 block fp8 kernel.")
|
raise RuntimeError(
|
||||||
|
"Currently, only support tune w8a8 block fp8 kernel.")
|
||||||
|
|
||||||
kernel[grid](
|
kernel[grid](
|
||||||
A,
|
A,
|
||||||
@ -121,16 +119,14 @@ def get_configs_compute_bound():
|
|||||||
for block_n in [32, 64, 128, 256]:
|
for block_n in [32, 64, 128, 256]:
|
||||||
for num_warps in [4, 8]:
|
for num_warps in [4, 8]:
|
||||||
for group_size in [1, 16, 32, 64]:
|
for group_size in [1, 16, 32, 64]:
|
||||||
configs.append(
|
configs.append({
|
||||||
{
|
"BLOCK_SIZE_M": block_m,
|
||||||
"BLOCK_SIZE_M": block_m,
|
"BLOCK_SIZE_N": block_n,
|
||||||
"BLOCK_SIZE_N": block_n,
|
"BLOCK_SIZE_K": block_k,
|
||||||
"BLOCK_SIZE_K": block_k,
|
"GROUP_SIZE_M": group_size,
|
||||||
"GROUP_SIZE_M": group_size,
|
"num_warps": num_warps,
|
||||||
"num_warps": num_warps,
|
"num_stages": num_stages,
|
||||||
"num_stages": num_stages,
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
return configs
|
return configs
|
||||||
|
|
||||||
|
|
||||||
@ -169,9 +165,15 @@ def get_weight_shapes(tp_size):
|
|||||||
return weight_shapes
|
return weight_shapes
|
||||||
|
|
||||||
|
|
||||||
def benchmark_config(
|
def benchmark_config(A,
|
||||||
A, B, As, Bs, block_size, config, out_dtype=torch.float16, num_iters=10
|
B,
|
||||||
):
|
As,
|
||||||
|
Bs,
|
||||||
|
block_size,
|
||||||
|
config,
|
||||||
|
out_dtype=torch.float16,
|
||||||
|
num_iters=10):
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
w8a8_block_matmul(A, B, As, Bs, block_size, config, out_dtype)
|
w8a8_block_matmul(A, B, As, Bs, block_size, config, out_dtype)
|
||||||
|
|
||||||
@ -204,26 +206,26 @@ def tune(M, N, K, block_size, out_dtype, search_space, input_type):
|
|||||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||||
|
|
||||||
A_fp32 = (
|
A_fp32 = (
|
||||||
(torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
|
(torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 *
|
||||||
)
|
fp8_max)
|
||||||
A = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
A = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||||
|
|
||||||
B_fp32 = (
|
B_fp32 = (
|
||||||
(torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
|
(torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 *
|
||||||
)
|
fp8_max)
|
||||||
B = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
B = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Currently, only support tune w8a8 block fp8 kernel.")
|
raise RuntimeError(
|
||||||
|
"Currently, only support tune w8a8 block fp8 kernel.")
|
||||||
|
|
||||||
block_n, block_k = block_size[0], block_size[1]
|
block_n, block_k = block_size[0], block_size[1]
|
||||||
n_tiles = (N + block_n - 1) // block_n
|
n_tiles = (N + block_n - 1) // block_n
|
||||||
k_tiles = (K + block_k - 1) // block_k
|
k_tiles = (K + block_k - 1) // block_k
|
||||||
|
|
||||||
As = torch.rand(M, k_tiles, dtype=torch.float32, device="cuda") * factor_for_scale
|
As = torch.rand(M, k_tiles, dtype=torch.float32,
|
||||||
Bs = (
|
device="cuda") * factor_for_scale
|
||||||
torch.rand(n_tiles, k_tiles, dtype=torch.float32, device="cuda")
|
Bs = (torch.rand(n_tiles, k_tiles, dtype=torch.float32, device="cuda") *
|
||||||
* factor_for_scale
|
factor_for_scale)
|
||||||
)
|
|
||||||
|
|
||||||
best_config = None
|
best_config = None
|
||||||
best_time = float("inf")
|
best_time = float("inf")
|
||||||
@ -265,8 +267,7 @@ def save_configs(
|
|||||||
device_name = current_platform.get_device_name().replace(" ", "_")
|
device_name = current_platform.get_device_name().replace(" ", "_")
|
||||||
json_file_name = (
|
json_file_name = (
|
||||||
f"N={N},K={K},device_name={device_name},dtype={input_type}_w8a8,"
|
f"N={N},K={K},device_name={device_name},dtype={input_type}_w8a8,"
|
||||||
f"block_shape=[{block_n},{block_k}].json"
|
f"block_shape=[{block_n},{block_k}].json")
|
||||||
)
|
|
||||||
|
|
||||||
config_file_path = os.path.join(save_path, json_file_name)
|
config_file_path = os.path.join(save_path, json_file_name)
|
||||||
print(f"Writing best config to {config_file_path}...")
|
print(f"Writing best config to {config_file_path}...")
|
||||||
@ -294,7 +295,8 @@ def tune_on_gpu(args_dict):
|
|||||||
|
|
||||||
search_space = get_configs_compute_bound()
|
search_space = get_configs_compute_bound()
|
||||||
search_space = [
|
search_space = [
|
||||||
config for config in search_space if block_k % config["BLOCK_SIZE_K"] == 0
|
config for config in search_space
|
||||||
|
if block_k % config["BLOCK_SIZE_K"] == 0
|
||||||
]
|
]
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
@ -310,11 +312,15 @@ def tune_on_gpu(args_dict):
|
|||||||
out_dtype,
|
out_dtype,
|
||||||
search_space,
|
search_space,
|
||||||
input_type,
|
input_type,
|
||||||
)
|
) for batch_size in tqdm(batch_sizes,
|
||||||
for batch_size in tqdm(batch_sizes, desc=f"GPU {gpu_id} - Batch sizes")
|
desc=f"GPU {gpu_id} - Batch sizes")
|
||||||
]
|
]
|
||||||
best_configs = {M: config for M, config in zip(batch_sizes, benchmark_results)}
|
best_configs = {
|
||||||
save_configs(N, K, block_n, block_k, best_configs, save_path, input_type)
|
M: config
|
||||||
|
for M, config in zip(batch_sizes, benchmark_results)
|
||||||
|
}
|
||||||
|
save_configs(N, K, block_n, block_k, best_configs, save_path,
|
||||||
|
input_type)
|
||||||
|
|
||||||
end = time.time()
|
end = time.time()
|
||||||
print(f"Tuning on GPU {gpu_id} took {end - start:.2f} seconds")
|
print(f"Tuning on GPU {gpu_id} took {end - start:.2f} seconds")
|
||||||
@ -370,14 +376,13 @@ def main(args):
|
|||||||
|
|
||||||
process_args = []
|
process_args = []
|
||||||
for gpu_id in range(num_gpus):
|
for gpu_id in range(num_gpus):
|
||||||
process_args.append(
|
process_args.append({
|
||||||
{
|
"gpu_id": gpu_id,
|
||||||
"gpu_id": gpu_id,
|
"batch_sizes": batches_per_gpu[gpu_id],
|
||||||
"batch_sizes": batches_per_gpu[gpu_id],
|
"weight_shapes":
|
||||||
"weight_shapes": weight_shapes, # Each GPU processes all weight shapes
|
weight_shapes, # Each GPU processes all weight shapes
|
||||||
"args": args,
|
"args": args,
|
||||||
}
|
})
|
||||||
)
|
|
||||||
|
|
||||||
ctx = mp.get_context("spawn")
|
ctx = mp.get_context("spawn")
|
||||||
with ctx.Pool(num_gpus) as pool:
|
with ctx.Pool(num_gpus) as pool:
|
||||||
@ -393,11 +398,13 @@ Tune triton w8a8 block fp8 for DeepSeek-V3/DeepSeek-R1:
|
|||||||
python3 benchmark_w8a8_block_fp8.py --tp-size 8 --input-type fp8
|
python3 benchmark_w8a8_block_fp8.py --tp-size 8 --input-type fp8
|
||||||
Then copy to model_executor/layers/quantization/utils/configs
|
Then copy to model_executor/layers/quantization/utils/configs
|
||||||
""",
|
""",
|
||||||
formatter_class=argparse.RawTextHelpFormatter,
|
formatter_class=argparse.RawTextHelpFormatter)
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument("--tp-size", "-tp", type=int, default=8)
|
parser.add_argument("--tp-size", "-tp", type=int, default=8)
|
||||||
parser.add_argument("--input-type", type=str, choices=["fp8"], default="fp8")
|
parser.add_argument("--input-type",
|
||||||
|
type=str,
|
||||||
|
choices=["fp8"],
|
||||||
|
default="fp8")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--out-dtype",
|
"--out-dtype",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
@ -11,9 +11,7 @@ from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor
|
|||||||
# Import vLLM functions
|
# Import vLLM functions
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
per_token_group_quant_fp8,
|
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
|
||||||
w8a8_block_fp8_matmul,
|
|
||||||
)
|
|
||||||
from vllm.triton_utils import triton
|
from vllm.triton_utils import triton
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -2,11 +2,11 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
import pickle
|
import pickle
|
||||||
|
import re
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import regex as re
|
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
from torch.utils.benchmark import Measurement as TMeasurement
|
from torch.utils.benchmark import Measurement as TMeasurement
|
||||||
|
|
||||||
@ -14,14 +14,13 @@ from vllm.utils import FlexibleArgumentParser
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = FlexibleArgumentParser(
|
parser = FlexibleArgumentParser(
|
||||||
description="Benchmark the latency of processing a single batch of "
|
description='Benchmark the latency of processing a single batch of '
|
||||||
"requests till completion."
|
'requests till completion.')
|
||||||
)
|
parser.add_argument('filename', type=str)
|
||||||
parser.add_argument("filename", type=str)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
with open(args.filename, "rb") as f:
|
with open(args.filename, 'rb') as f:
|
||||||
data = pickle.load(f)
|
data = pickle.load(f)
|
||||||
raw_results: list[TMeasurement] = data["results"]
|
raw_results: list[TMeasurement] = data["results"]
|
||||||
|
|
||||||
@ -39,7 +38,11 @@ if __name__ == "__main__":
|
|||||||
raise Exception("MKN not found")
|
raise Exception("MKN not found")
|
||||||
|
|
||||||
kernel = v.task_spec.description
|
kernel = v.task_spec.description
|
||||||
results[KN].append({"kernel": kernel, "batch_size": M, "median": v.median})
|
results[KN].append({
|
||||||
|
"kernel": kernel,
|
||||||
|
"batch_size": M,
|
||||||
|
"median": v.median
|
||||||
|
})
|
||||||
|
|
||||||
rows = int(math.ceil(len(results) / 2))
|
rows = int(math.ceil(len(results) / 2))
|
||||||
fig, axs = plt.subplots(rows, 2, figsize=(12, 5 * rows))
|
fig, axs = plt.subplots(rows, 2, figsize=(12, 5 * rows))
|
||||||
@ -47,16 +50,14 @@ if __name__ == "__main__":
|
|||||||
for axs_idx, (shape, data) in enumerate(results.items()):
|
for axs_idx, (shape, data) in enumerate(results.items()):
|
||||||
plt.sca(axs[axs_idx])
|
plt.sca(axs[axs_idx])
|
||||||
df = pd.DataFrame(data)
|
df = pd.DataFrame(data)
|
||||||
sns.lineplot(
|
sns.lineplot(data=df,
|
||||||
data=df,
|
x="batch_size",
|
||||||
x="batch_size",
|
y="median",
|
||||||
y="median",
|
hue="kernel",
|
||||||
hue="kernel",
|
style="kernel",
|
||||||
style="kernel",
|
markers=True,
|
||||||
markers=True,
|
dashes=False,
|
||||||
dashes=False,
|
palette="Dark2")
|
||||||
palette="Dark2",
|
|
||||||
)
|
|
||||||
plt.title(f"Shape: {shape}")
|
plt.title(f"Shape: {shape}")
|
||||||
plt.ylabel("time (median, s)")
|
plt.ylabel("time (median, s)")
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
|
|||||||
@ -23,7 +23,6 @@ class ArgPool:
|
|||||||
For every invocation during a benchmarking run, it will choose a
|
For every invocation during a benchmarking run, it will choose a
|
||||||
different value from the list.
|
different value from the list.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
values: Iterable[Any]
|
values: Iterable[Any]
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
@ -31,7 +30,9 @@ class ArgPool:
|
|||||||
|
|
||||||
|
|
||||||
class Bench:
|
class Bench:
|
||||||
|
|
||||||
class ArgsIterator:
|
class ArgsIterator:
|
||||||
|
|
||||||
def __init__(self, args_list, kwargs_list):
|
def __init__(self, args_list, kwargs_list):
|
||||||
assert len(args_list) == len(kwargs_list)
|
assert len(args_list) == len(kwargs_list)
|
||||||
self.args_list = args_list
|
self.args_list = args_list
|
||||||
@ -52,16 +53,10 @@ class Bench:
|
|||||||
def n_args(self):
|
def n_args(self):
|
||||||
return self.n
|
return self.n
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, cuda_graph_params: Optional[CudaGraphBenchParams],
|
||||||
self,
|
label: str, sub_label: str, description: str, fn: Callable,
|
||||||
cuda_graph_params: Optional[CudaGraphBenchParams],
|
*args, **kwargs):
|
||||||
label: str,
|
|
||||||
sub_label: str,
|
|
||||||
description: str,
|
|
||||||
fn: Callable,
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
self.cuda_graph_params = cuda_graph_params
|
self.cuda_graph_params = cuda_graph_params
|
||||||
self.use_cuda_graph = self.cuda_graph_params is not None
|
self.use_cuda_graph = self.cuda_graph_params is not None
|
||||||
self.label = label
|
self.label = label
|
||||||
@ -72,8 +67,10 @@ class Bench:
|
|||||||
# Process args
|
# Process args
|
||||||
self._args = args
|
self._args = args
|
||||||
self._kwargs = kwargs
|
self._kwargs = kwargs
|
||||||
self.args_list, self.kwargs_list = self.collapse_argpool(*args, **kwargs)
|
self.args_list, self.kwargs_list = self.collapse_argpool(
|
||||||
self.args_iterator = self.ArgsIterator(self.args_list, self.kwargs_list)
|
*args, **kwargs)
|
||||||
|
self.args_iterator = self.ArgsIterator(self.args_list,
|
||||||
|
self.kwargs_list)
|
||||||
|
|
||||||
# Cudagraph runner
|
# Cudagraph runner
|
||||||
self.g = None
|
self.g = None
|
||||||
@ -103,13 +100,16 @@ class Bench:
|
|||||||
|
|
||||||
for i in range(argpool_size):
|
for i in range(argpool_size):
|
||||||
# collapse args; Just pick the ith value
|
# collapse args; Just pick the ith value
|
||||||
args_list[i] = tuple(
|
args_list[i] = tuple([
|
||||||
[arg[i] if isinstance(arg, ArgPool) else arg for arg in args_list[i]]
|
arg[i] if isinstance(arg, ArgPool) else arg
|
||||||
)
|
for arg in args_list[i]
|
||||||
|
])
|
||||||
|
|
||||||
# collapse kwargs
|
# collapse kwargs
|
||||||
kwargs_i = kwargs_list[i]
|
kwargs_i = kwargs_list[i]
|
||||||
arg_pool_keys = [k for k, v in kwargs_i.items() if isinstance(v, ArgPool)]
|
arg_pool_keys = [
|
||||||
|
k for k, v in kwargs_i.items() if isinstance(v, ArgPool)
|
||||||
|
]
|
||||||
for k in arg_pool_keys:
|
for k in arg_pool_keys:
|
||||||
# again just pick the ith value
|
# again just pick the ith value
|
||||||
kwargs_i[k] = kwargs_i[k][i]
|
kwargs_i[k] = kwargs_i[k][i]
|
||||||
@ -142,7 +142,7 @@ class Bench:
|
|||||||
|
|
||||||
def run_cudagrah(self) -> TMeasurement:
|
def run_cudagrah(self) -> TMeasurement:
|
||||||
assert self.use_cuda_graph
|
assert self.use_cuda_graph
|
||||||
globals = {"g": self.g}
|
globals = {'g': self.g}
|
||||||
|
|
||||||
return TBenchmark.Timer(
|
return TBenchmark.Timer(
|
||||||
stmt="g.replay()",
|
stmt="g.replay()",
|
||||||
@ -162,15 +162,15 @@ class Bench:
|
|||||||
|
|
||||||
has_arg_pool = self.args_iterator.n_args > 1
|
has_arg_pool = self.args_iterator.n_args > 1
|
||||||
if has_arg_pool:
|
if has_arg_pool:
|
||||||
setup = """
|
setup = '''
|
||||||
args_iterator.reset()
|
args_iterator.reset()
|
||||||
args_it = args_iterator.__next__()
|
args_it = args_iterator.__next__()
|
||||||
"""
|
'''
|
||||||
stmt = """
|
stmt = '''
|
||||||
args, kwargs = next(args_it)
|
args, kwargs = next(args_it)
|
||||||
fn(*args, **kwargs)
|
fn(*args, **kwargs)
|
||||||
"""
|
'''
|
||||||
globals = {"fn": self.fn, "args_iterator": self.args_iterator}
|
globals = {'fn': self.fn, 'args_iterator': self.args_iterator}
|
||||||
else:
|
else:
|
||||||
# no arg pool. Just use the args and kwargs directly
|
# no arg pool. Just use the args and kwargs directly
|
||||||
self.args_iterator.reset()
|
self.args_iterator.reset()
|
||||||
@ -178,10 +178,10 @@ class Bench:
|
|||||||
args, kwargs = next(args_it)
|
args, kwargs = next(args_it)
|
||||||
|
|
||||||
setup = ""
|
setup = ""
|
||||||
stmt = """
|
stmt = '''
|
||||||
fn(*args, **kwargs)
|
fn(*args, **kwargs)
|
||||||
"""
|
'''
|
||||||
globals = {"fn": self.fn, "args": args, "kwargs": kwargs}
|
globals = {'fn': self.fn, 'args': args, 'kwargs': kwargs}
|
||||||
|
|
||||||
return TBenchmark.Timer(
|
return TBenchmark.Timer(
|
||||||
stmt=stmt,
|
stmt=stmt,
|
||||||
|
|||||||
@ -7,8 +7,9 @@ from vllm import LLM, SamplingParams
|
|||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
# A very long prompt, total number of tokens is about 15k.
|
# A very long prompt, total number of tokens is about 15k.
|
||||||
LONG_PROMPT = ["You are an expert in large language models, aren't you?"] * 1000
|
LONG_PROMPT = ["You are an expert in large language models, aren't you?"
|
||||||
LONG_PROMPT = " ".join(LONG_PROMPT)
|
] * 1000
|
||||||
|
LONG_PROMPT = ' '.join(LONG_PROMPT)
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
@ -29,35 +30,32 @@ def main(args):
|
|||||||
|
|
||||||
print("------start generating------")
|
print("------start generating------")
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
profiler.runctx(
|
profiler.runctx('llm.generate(LONG_PROMPT, sampling_params)',
|
||||||
"llm.generate(LONG_PROMPT, sampling_params)", globals(), locals()
|
globals(), locals())
|
||||||
)
|
|
||||||
|
|
||||||
# analyze the runtime of hashing function
|
# analyze the runtime of hashing function
|
||||||
stats = pstats.Stats(profiler)
|
stats = pstats.Stats(profiler)
|
||||||
stats.sort_stats("cumulative")
|
stats.sort_stats('cumulative')
|
||||||
total_time = 0
|
total_time = 0
|
||||||
total_calls = 0
|
total_calls = 0
|
||||||
for func in stats.stats:
|
for func in stats.stats:
|
||||||
if "hash_of_block" in func[2]:
|
if 'hash_of_block' in func[2]:
|
||||||
total_time = stats.stats[func][3]
|
total_time = stats.stats[func][3]
|
||||||
total_calls = stats.stats[func][0]
|
total_calls = stats.stats[func][0]
|
||||||
percentage = (total_time / stats.total_tt) * 100
|
percentage = (total_time / stats.total_tt) * 100
|
||||||
print(
|
print(f"Hashing took {total_time:.2f} seconds,"
|
||||||
f"Hashing took {total_time:.2f} seconds,{percentage:.2f}% of the total runtime."
|
f"{percentage:.2f}% of the total runtime.")
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = FlexibleArgumentParser(
|
parser = FlexibleArgumentParser(
|
||||||
description="Benchmark the performance of hashing function in"
|
description='Benchmark the performance of hashing function in'
|
||||||
"automatic prefix caching."
|
'automatic prefix caching.')
|
||||||
)
|
parser.add_argument('--model', type=str, default='lmsys/longchat-7b-16k')
|
||||||
parser.add_argument("--model", type=str, default="lmsys/longchat-7b-16k")
|
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
|
||||||
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
|
parser.add_argument('--output-len', type=int, default=10)
|
||||||
parser.add_argument("--output-len", type=int, default=10)
|
parser.add_argument('--enable-prefix-caching',
|
||||||
parser.add_argument(
|
action='store_true',
|
||||||
"--enable-prefix-caching", action="store_true", help="enable prefix caching"
|
help='enable prefix caching')
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
|||||||
@ -1,49 +0,0 @@
|
|||||||
# This local pyproject file is part of the migration from yapf to ruff format.
|
|
||||||
# It uses the same core rules as the main pyproject.toml file, but with the
|
|
||||||
# following differences:
|
|
||||||
# - ruff line length is overridden to 88
|
|
||||||
# - deprecated typing ignores (UP006, UP035) have been removed
|
|
||||||
|
|
||||||
[tool.ruff]
|
|
||||||
line-length = 88
|
|
||||||
|
|
||||||
[tool.ruff.lint.per-file-ignores]
|
|
||||||
"vllm/third_party/**" = ["ALL"]
|
|
||||||
"vllm/version.py" = ["F401"]
|
|
||||||
"vllm/_version.py" = ["ALL"]
|
|
||||||
|
|
||||||
[tool.ruff.lint]
|
|
||||||
select = [
|
|
||||||
# pycodestyle
|
|
||||||
"E",
|
|
||||||
# Pyflakes
|
|
||||||
"F",
|
|
||||||
# pyupgrade
|
|
||||||
"UP",
|
|
||||||
# flake8-bugbear
|
|
||||||
"B",
|
|
||||||
# flake8-simplify
|
|
||||||
"SIM",
|
|
||||||
# isort
|
|
||||||
"I",
|
|
||||||
# flake8-logging-format
|
|
||||||
"G",
|
|
||||||
]
|
|
||||||
ignore = [
|
|
||||||
# star imports
|
|
||||||
"F405", "F403",
|
|
||||||
# lambda expression assignment
|
|
||||||
"E731",
|
|
||||||
# Loop control variable not used within loop body
|
|
||||||
"B007",
|
|
||||||
# f-string format
|
|
||||||
"UP032",
|
|
||||||
# Can remove once 3.10+ is the minimum Python version
|
|
||||||
"UP007",
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.ruff.lint.isort]
|
|
||||||
known-first-party = ["vllm"]
|
|
||||||
|
|
||||||
[tool.ruff.format]
|
|
||||||
docstring-code-format = true
|
|
||||||
@ -1,98 +1,32 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
# default values
|
# Define the model to use
|
||||||
MODEL=${MODEL:-"Qwen/Qwen2.5-7B-Instruct"}
|
MODEL=${1:-"Qwen/Qwen2.5-7B-Instruct"}
|
||||||
BACKEND=${BACKEND:-"vllm"}
|
|
||||||
DATASET=${DATASET:-"xgrammar_bench"}
|
# Define the backend to use
|
||||||
|
BACKEND=${2:-"vllm"}
|
||||||
|
|
||||||
|
# Define the dataset to use
|
||||||
|
DATASET=${3:-"xgrammar_bench"}
|
||||||
|
|
||||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
OUTPUT_DIR=${OUTPUT_DIR:-"$SCRIPT_DIR/structured_output_benchmark_results"}
|
OUTPUT_DIR=${4:-"$SCRIPT_DIR/structured_output_benchmark_results"}
|
||||||
PORT=${PORT:-8000}
|
|
||||||
STRUCTURED_OUTPUT_RATIO=${STRUCTURED_OUTPUT_RATIO:-1}
|
|
||||||
TOTAL_SECONDS=${TOTAL_SECONDS:-90}
|
|
||||||
MAX_NEW_TOKENS=${MAX_NEW_TOKENS:-300}
|
|
||||||
TOKENIZER_MODE=${TOKENIZER_MODE:-"auto"}
|
|
||||||
|
|
||||||
usage() {
|
GUIDED_RATIO=${5:-0.5}
|
||||||
echo "Usage: $0 [options]"
|
|
||||||
echo "Options:"
|
|
||||||
echo " --model MODEL Model to benchmark (default: $MODEL)"
|
|
||||||
echo " --backend BACKEND Backend to use (default: $BACKEND)"
|
|
||||||
echo " --dataset DATASET Dataset to use (default: $DATASET)"
|
|
||||||
echo " --max-new-tokens N Maximum number of tokens to generate (default: $MAX_NEW_TOKENS)"
|
|
||||||
echo " --output-dir DIR Output directory for results (default: $OUTPUT_DIR)"
|
|
||||||
echo " --port PORT Port to use (default: $PORT)"
|
|
||||||
echo " --structured-output-ratio N Ratio of structured outputs (default: $STRUCTURED_OUTPUT_RATIO)"
|
|
||||||
echo " --tokenizer-mode MODE Tokenizer mode to use (default: $TOKENIZER_MODE)"
|
|
||||||
echo " --total-seconds N Total seconds to run the benchmark (default: $TOTAL_SECONDS)"
|
|
||||||
echo " -h, --help Show this help message and exit"
|
|
||||||
exit 0
|
|
||||||
}
|
|
||||||
|
|
||||||
# parse command line arguments
|
|
||||||
while [[ $# -gt 0 ]]; do
|
|
||||||
case $1 in
|
|
||||||
--model)
|
|
||||||
MODEL="$2"
|
|
||||||
shift 2
|
|
||||||
;;
|
|
||||||
--backend)
|
|
||||||
BACKEND="$2"
|
|
||||||
shift 2
|
|
||||||
;;
|
|
||||||
--dataset)
|
|
||||||
DATASET="$2"
|
|
||||||
shift 2
|
|
||||||
;;
|
|
||||||
--max-new-tokens)
|
|
||||||
MAX_NEW_TOKENS="$2"
|
|
||||||
shift 2
|
|
||||||
;;
|
|
||||||
--output-dir)
|
|
||||||
OUTPUT_DIR="$2"
|
|
||||||
shift 2
|
|
||||||
;;
|
|
||||||
--port)
|
|
||||||
PORT="$2"
|
|
||||||
shift 2
|
|
||||||
;;
|
|
||||||
--structured-output-ratio)
|
|
||||||
STRUCTURED_OUTPUT_RATIO="$2"
|
|
||||||
shift 2
|
|
||||||
;;
|
|
||||||
--tokenizer-mode)
|
|
||||||
TOKENIZER_MODE="$2"
|
|
||||||
shift 2
|
|
||||||
;;
|
|
||||||
--total-seconds)
|
|
||||||
TOTAL_SECONDS="$2"
|
|
||||||
shift 2
|
|
||||||
;;
|
|
||||||
-h|--help)
|
|
||||||
usage
|
|
||||||
;;
|
|
||||||
*)
|
|
||||||
echo "Unknown argument: $1\n"
|
|
||||||
usage
|
|
||||||
;;
|
|
||||||
esac
|
|
||||||
done
|
|
||||||
|
|
||||||
# Create output directory if it doesn't exist
|
# Create output directory if it doesn't exist
|
||||||
mkdir -p "$OUTPUT_DIR"
|
mkdir -p "$OUTPUT_DIR"
|
||||||
|
|
||||||
# Define QPS values to test
|
# Define QPS values to test
|
||||||
QPS_VALUES=(25 20 15 10 5 1)
|
QPS_VALUES=(70 60 50 25 20 15 10)
|
||||||
|
|
||||||
# Common parameters
|
# Common parameters
|
||||||
COMMON_PARAMS="--backend $BACKEND \
|
COMMON_PARAMS="--backend $BACKEND \
|
||||||
--model $MODEL \
|
--model $MODEL \
|
||||||
--dataset $DATASET \
|
--dataset $DATASET \
|
||||||
--structured-output-ratio $STRUCTURED_OUTPUT_RATIO \
|
--structured-output-ratio $GUIDED_RATIO \
|
||||||
--save-results \
|
--save-results \
|
||||||
--result-dir $OUTPUT_DIR \
|
--result-dir $OUTPUT_DIR"
|
||||||
--output-len $MAX_NEW_TOKENS \
|
|
||||||
--port $PORT \
|
|
||||||
--tokenizer-mode $TOKENIZER_MODE"
|
|
||||||
|
|
||||||
echo "Starting structured output benchmark with model: $MODEL"
|
echo "Starting structured output benchmark with model: $MODEL"
|
||||||
echo "Backend: $BACKEND"
|
echo "Backend: $BACKEND"
|
||||||
@ -111,15 +45,12 @@ for qps in "${QPS_VALUES[@]}"; do
|
|||||||
# Construct filename for this run
|
# Construct filename for this run
|
||||||
FILENAME="${BACKEND}_${qps}qps_$(basename $MODEL)_${DATASET}_${GIT_HASH}.json"
|
FILENAME="${BACKEND}_${qps}qps_$(basename $MODEL)_${DATASET}_${GIT_HASH}.json"
|
||||||
|
|
||||||
NUM_PROMPTS=$(echo "$TOTAL_SECONDS * $qps" | bc)
|
|
||||||
NUM_PROMPTS=${NUM_PROMPTS%.*} # Remove fractional part
|
|
||||||
echo "Running benchmark with $NUM_PROMPTS prompts"
|
|
||||||
|
|
||||||
# Run the benchmark
|
# Run the benchmark
|
||||||
python "$SCRIPT_DIR/benchmark_serving_structured_output.py" $COMMON_PARAMS \
|
python "$SCRIPT_DIR/benchmark_serving_structured_output.py" $COMMON_PARAMS \
|
||||||
--request-rate $qps \
|
--request-rate $qps \
|
||||||
--result-filename "$FILENAME" \
|
--result-filename "$FILENAME" \
|
||||||
--num-prompts $NUM_PROMPTS
|
--tokenizer-mode ${TOKENIZER_MODE:-"auto"} \
|
||||||
|
--port ${PORT:-8000}
|
||||||
|
|
||||||
echo "Completed benchmark with QPS: $qps"
|
echo "Completed benchmark with QPS: $qps"
|
||||||
echo "----------------------------------------"
|
echo "----------------------------------------"
|
||||||
|
|||||||
@ -228,26 +228,11 @@ macro(set_gencode_flags_for_srcs)
|
|||||||
"${multiValueArgs}" ${ARGN} )
|
"${multiValueArgs}" ${ARGN} )
|
||||||
|
|
||||||
foreach(_ARCH ${arg_CUDA_ARCHS})
|
foreach(_ARCH ${arg_CUDA_ARCHS})
|
||||||
# handle +PTX suffix: generate both sm and ptx codes if requested
|
string(REPLACE "." "" _ARCH "${_ARCH}")
|
||||||
string(FIND "${_ARCH}" "+PTX" _HAS_PTX)
|
set_gencode_flag_for_srcs(
|
||||||
if(NOT _HAS_PTX EQUAL -1)
|
SRCS ${arg_SRCS}
|
||||||
string(REPLACE "+PTX" "" _BASE_ARCH "${_ARCH}")
|
ARCH "compute_${_ARCH}"
|
||||||
string(REPLACE "." "" _STRIPPED_ARCH "${_BASE_ARCH}")
|
CODE "sm_${_ARCH}")
|
||||||
set_gencode_flag_for_srcs(
|
|
||||||
SRCS ${arg_SRCS}
|
|
||||||
ARCH "compute_${_STRIPPED_ARCH}"
|
|
||||||
CODE "sm_${_STRIPPED_ARCH}")
|
|
||||||
set_gencode_flag_for_srcs(
|
|
||||||
SRCS ${arg_SRCS}
|
|
||||||
ARCH "compute_${_STRIPPED_ARCH}"
|
|
||||||
CODE "compute_${_STRIPPED_ARCH}")
|
|
||||||
else()
|
|
||||||
string(REPLACE "." "" _STRIPPED_ARCH "${_ARCH}")
|
|
||||||
set_gencode_flag_for_srcs(
|
|
||||||
SRCS ${arg_SRCS}
|
|
||||||
ARCH "compute_${_STRIPPED_ARCH}"
|
|
||||||
CODE "sm_${_STRIPPED_ARCH}")
|
|
||||||
endif()
|
|
||||||
endforeach()
|
endforeach()
|
||||||
|
|
||||||
if (${arg_BUILD_PTX_FOR_ARCH})
|
if (${arg_BUILD_PTX_FOR_ARCH})
|
||||||
@ -266,10 +251,7 @@ endmacro()
|
|||||||
#
|
#
|
||||||
# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
|
# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
|
||||||
# `<major>.<minor>[letter]` compute the "loose intersection" with the
|
# `<major>.<minor>[letter]` compute the "loose intersection" with the
|
||||||
# `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in
|
# `TGT_CUDA_ARCHS` list of gencodes.
|
||||||
# `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there
|
|
||||||
# is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the
|
|
||||||
# architecture in `SRC_CUDA_ARCHS`.
|
|
||||||
# The loose intersection is defined as:
|
# The loose intersection is defined as:
|
||||||
# { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
|
# { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
|
||||||
# where `<=` is the version comparison operator.
|
# where `<=` is the version comparison operator.
|
||||||
@ -286,63 +268,44 @@ endmacro()
|
|||||||
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
|
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
|
||||||
# OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a"
|
# OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a"
|
||||||
#
|
#
|
||||||
# Example With PTX:
|
|
||||||
# SRC_CUDA_ARCHS="8.0+PTX"
|
|
||||||
# TGT_CUDA_ARCHS="9.0"
|
|
||||||
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
|
|
||||||
# OUT_CUDA_ARCHS="8.0+PTX"
|
|
||||||
#
|
|
||||||
function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
|
function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
|
||||||
set(_SRC_CUDA_ARCHS "${SRC_CUDA_ARCHS}")
|
list(REMOVE_DUPLICATES SRC_CUDA_ARCHS)
|
||||||
set(_TGT_CUDA_ARCHS ${TGT_CUDA_ARCHS})
|
set(TGT_CUDA_ARCHS_ ${TGT_CUDA_ARCHS})
|
||||||
|
|
||||||
# handle +PTX suffix: separate base arch for matching, record PTX requests
|
|
||||||
set(_PTX_ARCHS)
|
|
||||||
foreach(_arch ${_SRC_CUDA_ARCHS})
|
|
||||||
if(_arch MATCHES "\\+PTX$")
|
|
||||||
string(REPLACE "+PTX" "" _base "${_arch}")
|
|
||||||
list(APPEND _PTX_ARCHS "${_base}")
|
|
||||||
list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}")
|
|
||||||
list(APPEND _SRC_CUDA_ARCHS "${_base}")
|
|
||||||
endif()
|
|
||||||
endforeach()
|
|
||||||
list(REMOVE_DUPLICATES _PTX_ARCHS)
|
|
||||||
list(REMOVE_DUPLICATES _SRC_CUDA_ARCHS)
|
|
||||||
|
|
||||||
# if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
|
# if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
|
||||||
# remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS
|
# remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS
|
||||||
set(_CUDA_ARCHS)
|
set(_CUDA_ARCHS)
|
||||||
if ("9.0a" IN_LIST _SRC_CUDA_ARCHS)
|
if ("9.0a" IN_LIST SRC_CUDA_ARCHS)
|
||||||
list(REMOVE_ITEM _SRC_CUDA_ARCHS "9.0a")
|
list(REMOVE_ITEM SRC_CUDA_ARCHS "9.0a")
|
||||||
if ("9.0" IN_LIST TGT_CUDA_ARCHS)
|
if ("9.0" IN_LIST TGT_CUDA_ARCHS_)
|
||||||
list(REMOVE_ITEM _TGT_CUDA_ARCHS "9.0")
|
list(REMOVE_ITEM TGT_CUDA_ARCHS_ "9.0")
|
||||||
set(_CUDA_ARCHS "9.0a")
|
set(_CUDA_ARCHS "9.0a")
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if ("10.0a" IN_LIST _SRC_CUDA_ARCHS)
|
if ("10.0a" IN_LIST SRC_CUDA_ARCHS)
|
||||||
list(REMOVE_ITEM _SRC_CUDA_ARCHS "10.0a")
|
list(REMOVE_ITEM SRC_CUDA_ARCHS "10.0a")
|
||||||
if ("10.0" IN_LIST TGT_CUDA_ARCHS)
|
if ("10.0" IN_LIST TGT_CUDA_ARCHS)
|
||||||
list(REMOVE_ITEM _TGT_CUDA_ARCHS "10.0")
|
list(REMOVE_ITEM TGT_CUDA_ARCHS_ "10.0")
|
||||||
set(_CUDA_ARCHS "10.0a")
|
set(_CUDA_ARCHS "10.0a")
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
list(SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
|
list(SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
|
||||||
|
|
||||||
# for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that
|
# for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that
|
||||||
# is less or equal to ARCH (but has the same major version since SASS binary
|
# is less or equal to ARCH (but has the same major version since SASS binary
|
||||||
# compatibility is only forward compatible within the same major version).
|
# compatibility is only forward compatible within the same major version).
|
||||||
foreach(_ARCH ${_TGT_CUDA_ARCHS})
|
foreach(_ARCH ${TGT_CUDA_ARCHS_})
|
||||||
set(_TMP_ARCH)
|
set(_TMP_ARCH)
|
||||||
# Extract the major version of the target arch
|
# Extract the major version of the target arch
|
||||||
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" TGT_ARCH_MAJOR "${_ARCH}")
|
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" TGT_ARCH_MAJOR "${_ARCH}")
|
||||||
foreach(_SRC_ARCH ${_SRC_CUDA_ARCHS})
|
foreach(_SRC_ARCH ${SRC_CUDA_ARCHS})
|
||||||
# Extract the major version of the source arch
|
# Extract the major version of the source arch
|
||||||
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" SRC_ARCH_MAJOR "${_SRC_ARCH}")
|
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" SRC_ARCH_MAJOR "${_SRC_ARCH}")
|
||||||
# Check version-less-or-equal, and allow PTX arches to match across majors
|
# Check major-version match AND version-less-or-equal
|
||||||
if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
|
if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
|
||||||
if (_SRC_ARCH IN_LIST _PTX_ARCHS OR SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
|
if (SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
|
||||||
set(_TMP_ARCH "${_SRC_ARCH}")
|
set(_TMP_ARCH "${_SRC_ARCH}")
|
||||||
endif()
|
endif()
|
||||||
else()
|
else()
|
||||||
@ -358,18 +321,6 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
|
|||||||
endforeach()
|
endforeach()
|
||||||
|
|
||||||
list(REMOVE_DUPLICATES _CUDA_ARCHS)
|
list(REMOVE_DUPLICATES _CUDA_ARCHS)
|
||||||
|
|
||||||
# reapply +PTX suffix to architectures that requested PTX
|
|
||||||
set(_FINAL_ARCHS)
|
|
||||||
foreach(_arch ${_CUDA_ARCHS})
|
|
||||||
if(_arch IN_LIST _PTX_ARCHS)
|
|
||||||
list(APPEND _FINAL_ARCHS "${_arch}+PTX")
|
|
||||||
else()
|
|
||||||
list(APPEND _FINAL_ARCHS "${_arch}")
|
|
||||||
endif()
|
|
||||||
endforeach()
|
|
||||||
set(_CUDA_ARCHS ${_FINAL_ARCHS})
|
|
||||||
|
|
||||||
set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
|
set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
|
||||||
endfunction()
|
endfunction()
|
||||||
|
|
||||||
|
|||||||
@ -70,9 +70,6 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
|
|||||||
int64_t num_tokens = input.numel() / input.size(-1); \
|
int64_t num_tokens = input.numel() / input.size(-1); \
|
||||||
dim3 grid(num_tokens); \
|
dim3 grid(num_tokens); \
|
||||||
dim3 block(std::min(d, 1024)); \
|
dim3 block(std::min(d, 1024)); \
|
||||||
if (num_tokens == 0) { \
|
|
||||||
return; \
|
|
||||||
} \
|
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||||
|
|||||||
@ -172,7 +172,7 @@ __device__ void paged_attention_kernel(
|
|||||||
|
|
||||||
// Load the query to registers.
|
// Load the query to registers.
|
||||||
// Each thread in a thread group has a different part of the query.
|
// Each thread in a thread group has a different part of the query.
|
||||||
// For example, if the thread group size is 4, then the first thread in
|
// For example, if the the thread group size is 4, then the first thread in
|
||||||
// the group has 0, 4, 8, ... th vectors of the query, and the second thread
|
// the group has 0, 4, 8, ... th vectors of the query, and the second thread
|
||||||
// has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
|
// has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
|
||||||
// q is split from a qkv tensor, it may not be contiguous.
|
// q is split from a qkv tensor, it may not be contiguous.
|
||||||
@ -259,7 +259,7 @@ __device__ void paged_attention_kernel(
|
|||||||
|
|
||||||
// Load a key to registers.
|
// Load a key to registers.
|
||||||
// Each thread in a thread group has a different part of the key.
|
// Each thread in a thread group has a different part of the key.
|
||||||
// For example, if the thread group size is 4, then the first thread in
|
// For example, if the the thread group size is 4, then the first thread in
|
||||||
// the group has 0, 4, 8, ... th vectors of the key, and the second thread
|
// the group has 0, 4, 8, ... th vectors of the key, and the second thread
|
||||||
// has 1, 5, 9, ... th vectors of the key, and so on.
|
// has 1, 5, 9, ... th vectors of the key, and so on.
|
||||||
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
|
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
|
||||||
|
|||||||
@ -143,14 +143,6 @@ void merge_attn_states_launcher(torch::Tensor& output,
|
|||||||
const uint pack_size = 16 / sizeof(scalar_t);
|
const uint pack_size = 16 / sizeof(scalar_t);
|
||||||
TORCH_CHECK(head_size % pack_size == 0,
|
TORCH_CHECK(head_size % pack_size == 0,
|
||||||
"headsize must be multiple of pack_size:", pack_size);
|
"headsize must be multiple of pack_size:", pack_size);
|
||||||
TORCH_CHECK(output.stride(-2) == head_size && output.stride(-1) == 1,
|
|
||||||
"output heads must be contiguous in memory");
|
|
||||||
TORCH_CHECK(
|
|
||||||
prefix_output.stride(-2) == head_size && prefix_output.stride(-1) == 1,
|
|
||||||
"prefix_output heads must be contiguous in memory");
|
|
||||||
TORCH_CHECK(
|
|
||||||
suffix_output.stride(-2) == head_size && suffix_output.stride(-1) == 1,
|
|
||||||
"suffix_output heads must be contiguous in memory");
|
|
||||||
float* output_lse_ptr = nullptr;
|
float* output_lse_ptr = nullptr;
|
||||||
if (output_lse.has_value()) {
|
if (output_lse.has_value()) {
|
||||||
output_lse_ptr = output_lse.value().data_ptr<float>();
|
output_lse_ptr = output_lse.value().data_ptr<float>();
|
||||||
|
|||||||
@ -1,401 +0,0 @@
|
|||||||
// Copyright (c) Microsoft Corporation.
|
|
||||||
// Licensed under the MIT license.
|
|
||||||
|
|
||||||
#include <assert.h>
|
|
||||||
|
|
||||||
#include <cuda.h>
|
|
||||||
|
|
||||||
#include <torch/all.h>
|
|
||||||
|
|
||||||
__device__ int64_t save_blocks(int* block_offset, int64_t range_start,
|
|
||||||
int64_t range_end, int64_t block_size,
|
|
||||||
int64_t input_block_count, int64_t kv_seqlen) {
|
|
||||||
if (range_start >= kv_seqlen) {
|
|
||||||
return input_block_count;
|
|
||||||
}
|
|
||||||
if (range_end > kv_seqlen) {
|
|
||||||
range_end = kv_seqlen;
|
|
||||||
}
|
|
||||||
int64_t current_block_count = input_block_count;
|
|
||||||
for (int idx = range_start; idx < range_end; idx += block_size) {
|
|
||||||
block_offset[current_block_count++] = idx;
|
|
||||||
}
|
|
||||||
return current_block_count;
|
|
||||||
}
|
|
||||||
|
|
||||||
__global__ void convert_vertical_slash_indexes_kernel(
|
|
||||||
const int* q_seqlens, // [BATCH, ]
|
|
||||||
const int* kv_seqlens, // [BATCH, ]
|
|
||||||
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
|
|
||||||
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
|
|
||||||
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
|
|
||||||
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
|
|
||||||
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
|
|
||||||
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
|
|
||||||
int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
|
|
||||||
int64_t NNZ_V, int64_t NNZ_S,
|
|
||||||
bool causal // True for intra, False for succ
|
|
||||||
) {
|
|
||||||
const int batch_idx = blockIdx.y;
|
|
||||||
const int head_idx = blockIdx.x;
|
|
||||||
const int group_idx = blockIdx.z;
|
|
||||||
|
|
||||||
int64_t q_seqlen = q_seqlens[batch_idx];
|
|
||||||
int64_t kv_seqlen = kv_seqlens[batch_idx];
|
|
||||||
int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x;
|
|
||||||
int64_t start_m = block_idx_m * BLOCK_SIZE_M;
|
|
||||||
if (start_m >= q_seqlen) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
int64_t end_m = start_m + BLOCK_SIZE_M;
|
|
||||||
vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
|
|
||||||
slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
|
|
||||||
int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
|
|
||||||
block_count += row_offset;
|
|
||||||
block_offset += row_offset * NNZ_S;
|
|
||||||
column_count += row_offset;
|
|
||||||
column_index += row_offset * NNZ_V;
|
|
||||||
|
|
||||||
bool has_slash = true;
|
|
||||||
int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0;
|
|
||||||
int64_t s = 0, v = 0;
|
|
||||||
int64_t v_idx = vertical_indexes[v++];
|
|
||||||
int64_t s_idx = slash_indexes[s++];
|
|
||||||
if (causal) {
|
|
||||||
while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) {
|
|
||||||
s_idx = slash_indexes[s++];
|
|
||||||
}
|
|
||||||
if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false;
|
|
||||||
s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M);
|
|
||||||
} else {
|
|
||||||
while (s_idx >= end_m + kv_seqlen && s < NNZ_S) {
|
|
||||||
s_idx = slash_indexes[s++];
|
|
||||||
}
|
|
||||||
if (s_idx > end_m + kv_seqlen) has_slash = false;
|
|
||||||
s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M);
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx;
|
|
||||||
if (!has_slash) {
|
|
||||||
if (causal) {
|
|
||||||
range_start = (kv_seqlen - q_seqlen) + end_m;
|
|
||||||
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
|
|
||||||
} else {
|
|
||||||
range_start = kv_seqlen;
|
|
||||||
range_end = kv_seqlen + BLOCK_SIZE_N;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool slash_finished = false;
|
|
||||||
while (1) {
|
|
||||||
if (v_idx < range_end) {
|
|
||||||
if (v_idx < range_start) {
|
|
||||||
column_index[tmp_col_cnt++] = v_idx;
|
|
||||||
}
|
|
||||||
if (v < NNZ_V) {
|
|
||||||
v_idx = vertical_indexes[v++];
|
|
||||||
} else {
|
|
||||||
if (causal)
|
|
||||||
v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen);
|
|
||||||
else
|
|
||||||
v_idx = end_m + BLOCK_SIZE_N + kv_seqlen;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if ((s < NNZ_S && causal) ||
|
|
||||||
(s < NNZ_S && !causal && slash_indexes[s] >= start_m)) {
|
|
||||||
if (causal)
|
|
||||||
s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++],
|
|
||||||
BLOCK_SIZE_M);
|
|
||||||
else
|
|
||||||
s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M);
|
|
||||||
} else {
|
|
||||||
if (v == NNZ_V || (v_idx > range_start && causal)) {
|
|
||||||
// add the last vertical if no more slash
|
|
||||||
if (v == NNZ_V && !causal && v_idx < kv_seqlen) {
|
|
||||||
column_index[tmp_col_cnt++] = v_idx;
|
|
||||||
}
|
|
||||||
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
|
|
||||||
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
|
|
||||||
break;
|
|
||||||
} else {
|
|
||||||
if (causal) {
|
|
||||||
range_start = (kv_seqlen - q_seqlen) + end_m;
|
|
||||||
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
|
|
||||||
} else {
|
|
||||||
// if slash_finished but there are vertical left, save current
|
|
||||||
// blocks
|
|
||||||
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
|
|
||||||
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
|
|
||||||
range_start = kv_seqlen;
|
|
||||||
range_end = kv_seqlen + BLOCK_SIZE_N;
|
|
||||||
}
|
|
||||||
slash_finished = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!slash_finished) {
|
|
||||||
if (s_idx > range_end + BLOCK_SIZE_M) {
|
|
||||||
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
|
|
||||||
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
|
|
||||||
range_start = s_idx - BLOCK_SIZE_M;
|
|
||||||
range_end = s_idx;
|
|
||||||
} else if (s_idx > range_end) {
|
|
||||||
range_end += BLOCK_SIZE_M;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
block_count[0] = tmp_blk_cnt;
|
|
||||||
column_count[0] = tmp_col_cnt;
|
|
||||||
}
|
|
||||||
|
|
||||||
void convert_vertical_slash_indexes_64x64(
|
|
||||||
const int* q_seqlens, // [BATCH, ]
|
|
||||||
const int* kv_seqlens, // [BATCH, ]
|
|
||||||
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
|
|
||||||
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
|
|
||||||
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
|
|
||||||
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
|
|
||||||
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
|
|
||||||
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
|
|
||||||
int64_t BATCH_SIZE, int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M,
|
|
||||||
int64_t BLOCK_SIZE_N, int64_t NNZ_V, int64_t NNZ_S, bool causal) {
|
|
||||||
const int N_THREADS = 64;
|
|
||||||
const dim3 dimBlock(N_THREADS);
|
|
||||||
const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS);
|
|
||||||
convert_vertical_slash_indexes_kernel<<<dimGrid, dimBlock>>>(
|
|
||||||
q_seqlens, kv_seqlens, vertical_indexes, slash_indexes, block_count,
|
|
||||||
block_offset, column_count, column_index, N_HEADS, N_ROWS, BLOCK_SIZE_M,
|
|
||||||
BLOCK_SIZE_N, NNZ_V, NNZ_S, causal);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Implements the Algorithm 4 in paper https://arxiv.org/abs/2407.02490.
|
|
||||||
*
|
|
||||||
* This function builds the index of each row of blocks from vertical indices
|
|
||||||
* and slash indices. The vertical indices are treated as points, while the
|
|
||||||
* slash indices are converted as ranges. The output consists of the merged
|
|
||||||
* ranges and separate column indices, where the ranges are represented by
|
|
||||||
* block indices.
|
|
||||||
*
|
|
||||||
* The implementation is referenced from the original MInference repo:
|
|
||||||
* https://github.com/microsoft/MInference/blob/main/csrc/vertical_slash_index.cu.
|
|
||||||
*/
|
|
||||||
void convert_vertical_slash_indexes(
|
|
||||||
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
|
|
||||||
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
|
|
||||||
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
|
|
||||||
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
|
|
||||||
torch::Tensor q_seqlens, // [BATCH, ]
|
|
||||||
torch::Tensor kv_seqlens, // [BATCH, ]
|
|
||||||
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
|
|
||||||
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
|
|
||||||
int64_t context_size, int64_t block_size_M, int64_t block_size_N,
|
|
||||||
bool causal) {
|
|
||||||
cudaSetDevice(q_seqlens.get_device());
|
|
||||||
|
|
||||||
int batch_size = slash_indexes.size(0);
|
|
||||||
int num_heads = slash_indexes.size(1);
|
|
||||||
int nnz_slash = slash_indexes.size(2);
|
|
||||||
int nnz_vertical = vertical_indexes.size(2);
|
|
||||||
int num_rows = (context_size + block_size_M - 1) / block_size_M;
|
|
||||||
|
|
||||||
convert_vertical_slash_indexes_64x64(
|
|
||||||
q_seqlens.data_ptr<int>(), kv_seqlens.data_ptr<int>(),
|
|
||||||
vertical_indexes.data_ptr<int>(), slash_indexes.data_ptr<int>(),
|
|
||||||
block_count.data_ptr<int>(), block_offset.data_ptr<int>(),
|
|
||||||
column_count.data_ptr<int>(), column_index.data_ptr<int>(), batch_size,
|
|
||||||
num_heads, num_rows, block_size_M, block_size_N, nnz_vertical, nnz_slash,
|
|
||||||
causal);
|
|
||||||
}
|
|
||||||
|
|
||||||
__global__ void convert_vertical_slash_indexes_kernel_mergehead(
|
|
||||||
const int* q_seqlens, // [BATCH, ]
|
|
||||||
const int* kv_seqlens, // [BATCH, ]
|
|
||||||
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
|
|
||||||
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
|
|
||||||
const int* per_head_vertical_topkv, const int* per_head_slash_topkv,
|
|
||||||
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
|
|
||||||
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
|
|
||||||
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
|
|
||||||
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
|
|
||||||
int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
|
|
||||||
int64_t NNZ_V, int64_t NNZ_S,
|
|
||||||
bool causal // True for intra, False for succ
|
|
||||||
) {
|
|
||||||
const int batch_idx = blockIdx.y;
|
|
||||||
const int head_idx = blockIdx.x;
|
|
||||||
const int group_idx = blockIdx.z;
|
|
||||||
|
|
||||||
int64_t q_seqlen = q_seqlens[batch_idx];
|
|
||||||
int64_t kv_seqlen = kv_seqlens[batch_idx];
|
|
||||||
int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x;
|
|
||||||
int64_t start_m = block_idx_m * BLOCK_SIZE_M;
|
|
||||||
if (start_m >= q_seqlen) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
int64_t end_m = start_m + BLOCK_SIZE_M;
|
|
||||||
vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
|
|
||||||
slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
|
|
||||||
int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
|
|
||||||
block_count += row_offset;
|
|
||||||
block_offset += row_offset * NNZ_S;
|
|
||||||
column_count += row_offset;
|
|
||||||
column_index += row_offset * NNZ_V;
|
|
||||||
|
|
||||||
// MergeHead: each head has it's unique max topk NNZ_V,NNZ_S. (NNZ_V,NNZ_S
|
|
||||||
// above is buffer size, use to compute offset)
|
|
||||||
NNZ_S = per_head_slash_topkv[head_idx];
|
|
||||||
NNZ_V = per_head_vertical_topkv[head_idx];
|
|
||||||
|
|
||||||
bool has_slash = true;
|
|
||||||
int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0;
|
|
||||||
int64_t s = 0, v = 0;
|
|
||||||
int64_t v_idx = vertical_indexes[v++];
|
|
||||||
int64_t s_idx = slash_indexes[s++];
|
|
||||||
if (causal) {
|
|
||||||
while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) {
|
|
||||||
s_idx = slash_indexes[s++];
|
|
||||||
}
|
|
||||||
if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false;
|
|
||||||
s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M);
|
|
||||||
} else {
|
|
||||||
while (s_idx >= end_m + kv_seqlen && s < NNZ_S) {
|
|
||||||
s_idx = slash_indexes[s++];
|
|
||||||
}
|
|
||||||
if (s_idx > end_m + kv_seqlen) has_slash = false;
|
|
||||||
s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M);
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx;
|
|
||||||
if (!has_slash) {
|
|
||||||
if (causal) {
|
|
||||||
range_start = (kv_seqlen - q_seqlen) + end_m;
|
|
||||||
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
|
|
||||||
} else {
|
|
||||||
range_start = kv_seqlen;
|
|
||||||
range_end = kv_seqlen + BLOCK_SIZE_N;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool slash_finished = false;
|
|
||||||
while (1) {
|
|
||||||
if (v_idx < range_end) {
|
|
||||||
if (v_idx < range_start) {
|
|
||||||
column_index[tmp_col_cnt++] = v_idx;
|
|
||||||
}
|
|
||||||
if (v < NNZ_V) {
|
|
||||||
v_idx = vertical_indexes[v++];
|
|
||||||
} else {
|
|
||||||
if (causal)
|
|
||||||
v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen);
|
|
||||||
else
|
|
||||||
v_idx = end_m + BLOCK_SIZE_N + kv_seqlen;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if ((s < NNZ_S && causal) ||
|
|
||||||
(s < NNZ_S && !causal && slash_indexes[s] >= start_m)) {
|
|
||||||
if (causal)
|
|
||||||
s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++],
|
|
||||||
BLOCK_SIZE_M);
|
|
||||||
else
|
|
||||||
s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M);
|
|
||||||
} else {
|
|
||||||
if (v == NNZ_V || (v_idx > range_start && causal)) {
|
|
||||||
// add the last vertical if no more slash
|
|
||||||
if (v == NNZ_V && !causal && v_idx < kv_seqlen) {
|
|
||||||
column_index[tmp_col_cnt++] = v_idx;
|
|
||||||
}
|
|
||||||
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
|
|
||||||
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
|
|
||||||
break;
|
|
||||||
} else {
|
|
||||||
if (causal) {
|
|
||||||
range_start = (kv_seqlen - q_seqlen) + end_m;
|
|
||||||
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
|
|
||||||
} else {
|
|
||||||
// if slash_finished but there are vertical left, save current
|
|
||||||
// blocks
|
|
||||||
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
|
|
||||||
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
|
|
||||||
range_start = kv_seqlen;
|
|
||||||
range_end = kv_seqlen + BLOCK_SIZE_N;
|
|
||||||
}
|
|
||||||
slash_finished = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!slash_finished) {
|
|
||||||
if (s_idx > range_end + BLOCK_SIZE_M) {
|
|
||||||
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
|
|
||||||
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
|
|
||||||
range_start = s_idx - BLOCK_SIZE_M;
|
|
||||||
range_end = s_idx;
|
|
||||||
} else if (s_idx > range_end) {
|
|
||||||
range_end += BLOCK_SIZE_M;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
block_count[0] = tmp_blk_cnt;
|
|
||||||
column_count[0] = tmp_col_cnt;
|
|
||||||
}
|
|
||||||
|
|
||||||
void convert_vertical_slash_indexes_64x64_mergehead(
|
|
||||||
const int* q_seqlens, // [BATCH, ]
|
|
||||||
const int* kv_seqlens, // [BATCH, ]
|
|
||||||
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
|
|
||||||
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
|
|
||||||
int* per_head_vertical_topkv, int* per_head_slash_topkv,
|
|
||||||
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
|
|
||||||
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
|
|
||||||
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
|
|
||||||
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
|
|
||||||
int64_t BATCH_SIZE, int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M,
|
|
||||||
int64_t BLOCK_SIZE_N, int64_t NNZ_V, int64_t NNZ_S, bool causal) {
|
|
||||||
const int N_THREADS = 64;
|
|
||||||
const dim3 dimBlock(N_THREADS);
|
|
||||||
const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS);
|
|
||||||
convert_vertical_slash_indexes_kernel_mergehead<<<dimGrid, dimBlock>>>(
|
|
||||||
q_seqlens, kv_seqlens, vertical_indexes, slash_indexes,
|
|
||||||
per_head_vertical_topkv, per_head_slash_topkv, block_count, block_offset,
|
|
||||||
column_count, column_index, N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N,
|
|
||||||
NNZ_V, NNZ_S, causal);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Implements the Algorithm 4 in paper https://arxiv.org/abs/2407.02490.
|
|
||||||
*
|
|
||||||
* Like the above convert_vertical_slash_indexes, but with
|
|
||||||
* pre-computed vertical and slash counts.
|
|
||||||
*/
|
|
||||||
void convert_vertical_slash_indexes_mergehead(
|
|
||||||
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
|
|
||||||
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
|
|
||||||
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
|
|
||||||
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
|
|
||||||
torch::Tensor q_seqlens, // [BATCH, ]
|
|
||||||
torch::Tensor kv_seqlens, // [BATCH, ]
|
|
||||||
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
|
|
||||||
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
|
|
||||||
torch::Tensor vertical_indices_count, // [N_HEADS, ]
|
|
||||||
torch::Tensor slash_indices_count, // [N_HEADS, ]
|
|
||||||
int64_t context_size, int64_t block_size_M, int64_t block_size_N,
|
|
||||||
bool causal) {
|
|
||||||
cudaSetDevice(q_seqlens.get_device());
|
|
||||||
|
|
||||||
int batch_size = slash_indexes.size(0);
|
|
||||||
int num_heads = slash_indexes.size(1);
|
|
||||||
int nnz_slash = slash_indexes.size(2);
|
|
||||||
int nnz_vertical = vertical_indexes.size(2);
|
|
||||||
int num_rows = (context_size + block_size_M - 1) / block_size_M;
|
|
||||||
|
|
||||||
convert_vertical_slash_indexes_64x64_mergehead(
|
|
||||||
q_seqlens.data_ptr<int>(), kv_seqlens.data_ptr<int>(),
|
|
||||||
vertical_indexes.data_ptr<int>(), slash_indexes.data_ptr<int>(),
|
|
||||||
vertical_indices_count.data_ptr<int>(),
|
|
||||||
slash_indices_count.data_ptr<int>(), block_count.data_ptr<int>(),
|
|
||||||
block_offset.data_ptr<int>(), column_count.data_ptr<int>(),
|
|
||||||
column_index.data_ptr<int>(), batch_size, num_heads, num_rows,
|
|
||||||
block_size_M, block_size_N, nnz_vertical, nnz_slash, causal);
|
|
||||||
}
|
|
||||||
@ -315,8 +315,6 @@ static inline constexpr auto kS8 = ScalarType::int_(8);
|
|||||||
static inline constexpr auto kU8 = ScalarType::uint(8);
|
static inline constexpr auto kU8 = ScalarType::uint(8);
|
||||||
static inline constexpr auto kU8B128 = ScalarType::uint(8, 128);
|
static inline constexpr auto kU8B128 = ScalarType::uint(8, 128);
|
||||||
|
|
||||||
static inline constexpr auto kFE2M1f =
|
|
||||||
ScalarType::float_(2, 1, true, ScalarType::NAN_NONE);
|
|
||||||
static inline constexpr auto kFE3M2f =
|
static inline constexpr auto kFE3M2f =
|
||||||
ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);
|
ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);
|
||||||
static inline constexpr auto kFE4M3fn =
|
static inline constexpr auto kFE4M3fn =
|
||||||
@ -334,7 +332,6 @@ static inline constexpr auto kInt8 = kS8;
|
|||||||
static inline constexpr auto kUint8 = kU8;
|
static inline constexpr auto kUint8 = kU8;
|
||||||
static inline constexpr auto kUint8b128 = kU8B128;
|
static inline constexpr auto kUint8b128 = kU8B128;
|
||||||
|
|
||||||
static inline constexpr auto kFloat4_e2m1f = kFE2M1f;
|
|
||||||
static inline constexpr auto kFloat6_e3m2f = kFE3M2f;
|
static inline constexpr auto kFloat6_e3m2f = kFE3M2f;
|
||||||
static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn;
|
static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn;
|
||||||
static inline constexpr auto kFloat8_e5m2 = kFE5M2;
|
static inline constexpr auto kFloat8_e5m2 = kFE5M2;
|
||||||
|
|||||||
@ -19,7 +19,6 @@ namespace vec_op {
|
|||||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES_FP8(...) \
|
#define VLLM_DISPATCH_CASE_FLOATING_TYPES_FP8(...) \
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__)
|
AT_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__)
|
||||||
|
|
||||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||||
|
|||||||
@ -15,6 +15,15 @@
|
|||||||
cutlassGetStatusString(error)); \
|
cutlassGetStatusString(error)); \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Panic wrapper for unwinding CUDA runtime errors
|
||||||
|
*/
|
||||||
|
#define CUDA_CHECK(status) \
|
||||||
|
{ \
|
||||||
|
cudaError_t error = status; \
|
||||||
|
TORCH_CHECK(error == cudaSuccess, cudaGetErrorString(error)); \
|
||||||
|
}
|
||||||
|
|
||||||
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
|
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
|
||||||
int max_shared_mem_per_block_opt_in = 0;
|
int max_shared_mem_per_block_opt_in = 0;
|
||||||
cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
|
cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
|
||||||
|
|||||||
@ -65,19 +65,5 @@
|
|||||||
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
|
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
|
||||||
|
|
||||||
#define VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(...) \
|
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
|
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
|
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
|
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
|
|
||||||
AT_DISPATCH_CASE(at::ScalarType::UInt16, __VA_ARGS__) \
|
|
||||||
AT_DISPATCH_CASE(at::ScalarType::UInt32, __VA_ARGS__) \
|
|
||||||
AT_DISPATCH_CASE(at::ScalarType::UInt64, __VA_ARGS__)
|
|
||||||
|
|
||||||
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
|
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
|
||||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
||||||
|
|
||||||
#define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \
|
|
||||||
AT_DISPATCH_SWITCH( \
|
|
||||||
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__))
|
|
||||||
|
|||||||
@ -13,10 +13,6 @@
|
|||||||
#include <cub/block/block_load.cuh>
|
#include <cub/block/block_load.cuh>
|
||||||
#include <cub/block/block_store.cuh>
|
#include <cub/block/block_store.cuh>
|
||||||
|
|
||||||
#ifdef USE_ROCM
|
|
||||||
namespace cub = hipcub;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#include "static_switch.h"
|
#include "static_switch.h"
|
||||||
|
|
||||||
|
|
||||||
@ -505,9 +501,15 @@ void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
|||||||
auto kernel = &causal_conv1d_fwd_kernel<Ktraits>;
|
auto kernel = &causal_conv1d_fwd_kernel<Ktraits>;
|
||||||
|
|
||||||
if (kSmemSize >= 48 * 1024) {
|
if (kSmemSize >= 48 * 1024) {
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||||
|
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||||
|
#else
|
||||||
|
// There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function.
|
||||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||||
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||||
std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
|
std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
||||||
|
|
||||||
|
|||||||
@ -321,7 +321,7 @@ void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) {
|
|||||||
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
|
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
|
||||||
if (kSmemSize >= 48 * 1024) {
|
if (kSmemSize >= 48 * 1024) {
|
||||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||||
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||||
}
|
}
|
||||||
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
||||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
|
|||||||
@ -31,10 +31,7 @@ TEMPLATE = ("template __global__ void Marlin<"
|
|||||||
|
|
||||||
# int8 with zero point case (vllm::kU8) is also supported,
|
# int8 with zero point case (vllm::kU8) is also supported,
|
||||||
# we don't add it to reduce wheel size.
|
# we don't add it to reduce wheel size.
|
||||||
SCALAR_TYPES = [
|
SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn"]
|
||||||
"vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn",
|
|
||||||
"vllm::kFE2M1f"
|
|
||||||
]
|
|
||||||
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
|
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
|
||||||
|
|
||||||
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
|
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
|
||||||
@ -42,7 +39,7 @@ THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
|
|||||||
# = 0 : act order case
|
# = 0 : act order case
|
||||||
# = -1 : channelwise quantization
|
# = -1 : channelwise quantization
|
||||||
# > 0 : group_size=16*group_blocks
|
# > 0 : group_size=16*group_blocks
|
||||||
GROUP_BLOCKS = [0, -1, 1, 2, 4, 8]
|
GROUP_BLOCKS = [0, -1, 2, 4, 8]
|
||||||
DTYPES = ["fp16", "bf16"]
|
DTYPES = ["fp16", "bf16"]
|
||||||
|
|
||||||
|
|
||||||
@ -75,12 +72,6 @@ def generate_new_kernels():
|
|||||||
# for fp8
|
# for fp8
|
||||||
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
|
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
|
||||||
continue
|
continue
|
||||||
# nvfp4 only supports group_size == 16
|
|
||||||
if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]:
|
|
||||||
continue
|
|
||||||
# other quantization methods don't support group_size = 16
|
|
||||||
if scalar_type != "vllm::kFE2M1f" and group_blocks == 1:
|
|
||||||
continue
|
|
||||||
|
|
||||||
k_blocks = thread_configs[0] // 16
|
k_blocks = thread_configs[0] // 16
|
||||||
n_blocks = thread_configs[1] // 16
|
n_blocks = thread_configs[1] // 16
|
||||||
|
|||||||
@ -7,18 +7,17 @@
|
|||||||
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
|
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
|
||||||
#include "core/scalar_type.hpp"
|
#include "core/scalar_type.hpp"
|
||||||
|
|
||||||
#define MARLIN_KERNEL_PARAMS \
|
#define MARLIN_KERNEL_PARAMS \
|
||||||
const int4 *__restrict__ A, const int4 *__restrict__ B, \
|
const int4 *__restrict__ A, const int4 *__restrict__ B, \
|
||||||
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
|
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
|
||||||
const int4 *__restrict__ scales_ptr, \
|
const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \
|
||||||
const uint16_t *__restrict__ scale2_ptr, \
|
const int *__restrict__ g_idx, \
|
||||||
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
|
const int32_t *__restrict__ sorted_token_ids_ptr, \
|
||||||
const int32_t *__restrict__ sorted_token_ids_ptr, \
|
const int32_t *__restrict__ expert_ids_ptr, \
|
||||||
const int32_t *__restrict__ expert_ids_ptr, \
|
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
|
||||||
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
|
const float *__restrict__ topk_weights_ptr, int top_k, \
|
||||||
const float *__restrict__ topk_weights_ptr, int top_k, \
|
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
|
||||||
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
|
int prob_n, int prob_k, int *locks, bool use_atomic_add, \
|
||||||
int prob_n, int prob_k, int *locks, bool use_atomic_add, \
|
|
||||||
bool use_fp32_reduce, int max_shared_mem
|
bool use_fp32_reduce, int max_shared_mem
|
||||||
|
|
||||||
namespace MARLIN_NAMESPACE_NAME {
|
namespace MARLIN_NAMESPACE_NAME {
|
||||||
|
|||||||
@ -301,11 +301,9 @@ __global__ void Marlin(
|
|||||||
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
|
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
|
||||||
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
||||||
// (k/groupsize)xn
|
// (k/groupsize)xn
|
||||||
const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4
|
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
|
||||||
// only)
|
// (k/groupsize)x(n/pack_factor)
|
||||||
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
|
const int* __restrict__ g_idx, // int32 group indices of shape k
|
||||||
// (k/groupsize)x(n/pack_factor)
|
|
||||||
const int* __restrict__ g_idx, // int32 group indices of shape k
|
|
||||||
const int32_t* __restrict__ sorted_token_ids_ptr, // moe sorted_ids
|
const int32_t* __restrict__ sorted_token_ids_ptr, // moe sorted_ids
|
||||||
const int32_t* __restrict__ expert_ids_ptr, // moe expert ids
|
const int32_t* __restrict__ expert_ids_ptr, // moe expert ids
|
||||||
const int32_t* __restrict__ num_tokens_past_padded_ptr, // moe num tokens
|
const int32_t* __restrict__ num_tokens_past_padded_ptr, // moe num tokens
|
||||||
@ -343,16 +341,6 @@ __global__ void Marlin(
|
|||||||
extern __shared__ int4 sh[];
|
extern __shared__ int4 sh[];
|
||||||
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
|
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
|
||||||
constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8;
|
constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8;
|
||||||
constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 ||
|
|
||||||
w_type == vllm::kU4B8 || w_type == vllm::kU8B128;
|
|
||||||
// see comments of dequant.h for more details
|
|
||||||
constexpr bool dequant_skip_flop =
|
|
||||||
!is_int_type ||
|
|
||||||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
|
|
||||||
has_zp && !is_zp_float && !(w_type == vllm::kU8);
|
|
||||||
|
|
||||||
scalar_t2 global_scale;
|
|
||||||
|
|
||||||
constexpr bool has_act_order = group_blocks == 0;
|
constexpr bool has_act_order = group_blocks == 0;
|
||||||
|
|
||||||
constexpr int pack_factor = 32 / w_type.size_bits();
|
constexpr int pack_factor = 32 / w_type.size_bits();
|
||||||
@ -360,8 +348,7 @@ __global__ void Marlin(
|
|||||||
constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks);
|
constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks);
|
||||||
const int group_size =
|
const int group_size =
|
||||||
(!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups;
|
(!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups;
|
||||||
const int scales_expert_stride =
|
const int scales_expert_stride = prob_n * prob_k / group_size / 8;
|
||||||
prob_n * prob_k / group_size / (w_type == vllm::kFE2M1f ? 16 : 8);
|
|
||||||
const int zp_expert_stride =
|
const int zp_expert_stride =
|
||||||
is_zp_float ? prob_n * prob_k / group_size / 8
|
is_zp_float ? prob_n * prob_k / group_size / 8
|
||||||
: prob_n * prob_k / group_size / (pack_factor * 4);
|
: prob_n * prob_k / group_size / (pack_factor * 4);
|
||||||
@ -473,16 +460,9 @@ __global__ void Marlin(
|
|||||||
if (mul_topk_weights) {
|
if (mul_topk_weights) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 4; i++) {
|
for (int i = 0; i < 4; i++) {
|
||||||
int idx = tid4 * 4 + i;
|
sh_block_topk_weights[tid4 * 4 + i] =
|
||||||
idx = idx < block_num_valid_tokens ? idx : 0;
|
Dtype::num2num2(Dtype::float2num(
|
||||||
if constexpr (w_type == vllm::kFE2M1f) {
|
topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]]));
|
||||||
sh_block_topk_weights[idx] = __hmul2(
|
|
||||||
global_scale, Dtype::num2num2(Dtype::float2num(
|
|
||||||
topk_weights_ptr[sh_block_sorted_ids[idx]])));
|
|
||||||
} else {
|
|
||||||
sh_block_topk_weights[idx] = Dtype::num2num2(
|
|
||||||
Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[idx]]));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -513,11 +493,6 @@ __global__ void Marlin(
|
|||||||
expert_id = expert_ids_ptr[block_id];
|
expert_id = expert_ids_ptr[block_id];
|
||||||
}
|
}
|
||||||
|
|
||||||
if constexpr (w_type == vllm::kFE2M1f) {
|
|
||||||
uint16_t val = scale2_ptr[expert_id];
|
|
||||||
global_scale = Dtype::num2num2(*reinterpret_cast<scalar_t*>(&val));
|
|
||||||
}
|
|
||||||
|
|
||||||
B_expert_off = expert_id * prob_n * prob_k / (pack_factor * 4);
|
B_expert_off = expert_id * prob_n * prob_k / (pack_factor * 4);
|
||||||
scales_ptr += (expert_id - old_expert_id) * scales_expert_stride;
|
scales_ptr += (expert_id - old_expert_id) * scales_expert_stride;
|
||||||
if constexpr (has_zp) {
|
if constexpr (has_zp) {
|
||||||
@ -631,7 +606,7 @@ __global__ void Marlin(
|
|||||||
constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
|
constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
|
||||||
constexpr int s_tb_groups =
|
constexpr int s_tb_groups =
|
||||||
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
|
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
|
||||||
? thread_k_blocks / group_blocks / (w_type == vllm::kFE2M1f ? 2 : 1)
|
? thread_k_blocks / group_blocks
|
||||||
: 1;
|
: 1;
|
||||||
constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
|
constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
|
||||||
int s_gl_rd_delta = s_gl_stride;
|
int s_gl_rd_delta = s_gl_stride;
|
||||||
@ -689,8 +664,7 @@ __global__ void Marlin(
|
|||||||
if constexpr (group_blocks == -1) {
|
if constexpr (group_blocks == -1) {
|
||||||
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
||||||
} else {
|
} else {
|
||||||
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) /
|
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
|
||||||
(w_type == vllm::kFE2M1f ? 2 : 1) +
|
|
||||||
s_sh_stride * slice_col + threadIdx.x;
|
s_sh_stride * slice_col + threadIdx.x;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -714,20 +688,10 @@ __global__ void Marlin(
|
|||||||
// we scale a `half2` tile in column-major layout in the former and in
|
// we scale a `half2` tile in column-major layout in the former and in
|
||||||
// row-major in the latter case.
|
// row-major in the latter case.
|
||||||
int s_sh_rd;
|
int s_sh_rd;
|
||||||
if constexpr (group_blocks != -1 && w_type == vllm::kFE2M1f) {
|
if constexpr (group_blocks != -1)
|
||||||
auto warp_id = threadIdx.x / 32;
|
|
||||||
int n_warps = thread_n_blocks / 4;
|
|
||||||
int warp_row = warp_id / n_warps;
|
|
||||||
|
|
||||||
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||||
(threadIdx.x % 32) / 4;
|
(threadIdx.x % 32) / 4;
|
||||||
s_sh_rd = s_sh_rd * 2 + warp_row % 2;
|
else if constexpr (group_blocks == -1 && (m_block_size_8 || has_zp))
|
||||||
|
|
||||||
} else if constexpr (group_blocks != -1)
|
|
||||||
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
|
||||||
(threadIdx.x % 32) / 4;
|
|
||||||
else if constexpr (group_blocks == -1 &&
|
|
||||||
(m_block_size_8 || (has_zp && !dequant_skip_flop)))
|
|
||||||
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||||
(threadIdx.x % 32) / 8;
|
(threadIdx.x % 32) / 8;
|
||||||
else
|
else
|
||||||
@ -837,7 +801,7 @@ __global__ void Marlin(
|
|||||||
sh_first_group_id = first_group_id;
|
sh_first_group_id = first_group_id;
|
||||||
sh_num_groups = last_group_id - first_group_id + 1;
|
sh_num_groups = last_group_id - first_group_id + 1;
|
||||||
|
|
||||||
if (sh_num_groups > act_s_max_num_groups) {
|
if (sh_num_groups < act_s_max_num_groups) {
|
||||||
sh_num_groups = act_s_max_num_groups;
|
sh_num_groups = act_s_max_num_groups;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1057,19 +1021,12 @@ __global__ void Marlin(
|
|||||||
cur_k += k_iter_size * (k % b_sh_wr_iters);
|
cur_k += k_iter_size * (k % b_sh_wr_iters);
|
||||||
|
|
||||||
int k_blocks = cur_k / 16;
|
int k_blocks = cur_k / 16;
|
||||||
int cur_group_id =
|
int cur_group_id = k_blocks / group_blocks;
|
||||||
k_blocks / (group_blocks * (w_type == vllm::kFE2M1f ? 2 : 1));
|
|
||||||
|
|
||||||
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
||||||
|
|
||||||
if constexpr (w_type_id != vllm::kFE2M1f.id()) {
|
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
|
||||||
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
|
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
|
||||||
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
|
|
||||||
} else {
|
|
||||||
reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
|
|
||||||
reinterpret_cast<int2*>(
|
|
||||||
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1242,7 +1199,22 @@ __global__ void Marlin(
|
|||||||
};
|
};
|
||||||
|
|
||||||
auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) {
|
auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) {
|
||||||
dequant<scalar_t2, w_type_id, dequant_skip_flop>(q, frag_b_ptr);
|
if constexpr (has_zp && is_zp_float || !has_zp) {
|
||||||
|
dequant<scalar_t2, w_type_id>(q, frag_b_ptr);
|
||||||
|
} else {
|
||||||
|
static_assert(has_zp && !is_zp_float);
|
||||||
|
static_assert(w_type_id == vllm::kU4.id() || w_type_id == vllm::kU8.id());
|
||||||
|
// If (has_zp && !is_zp_float),
|
||||||
|
// we use not-zp version `dequant` function
|
||||||
|
// to improve numerical accuracy.
|
||||||
|
// Since both weight and zero point are dequanted using this logic,
|
||||||
|
// the final dequanted weight would be correct.
|
||||||
|
if constexpr (w_type_id == vllm::kU4.id()) {
|
||||||
|
dequant<scalar_t2, vllm::kU4B8.id()>(q, frag_b_ptr);
|
||||||
|
} else if constexpr (w_type_id == vllm::kU8.id()) {
|
||||||
|
dequant<scalar_t2, vllm::kU8B128.id()>(q, frag_b_ptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Execute the actual tensor core matmul of a sub-tile.
|
// Execute the actual tensor core matmul of a sub-tile.
|
||||||
@ -1272,23 +1244,13 @@ __global__ void Marlin(
|
|||||||
dequant_data(zp_quant_1, reinterpret_cast<scalar_t2*>(&frag_zp) + 2);
|
dequant_data(zp_quant_1, reinterpret_cast<scalar_t2*>(&frag_zp) + 2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if constexpr (!dequant_skip_flop && has_zp && is_zp_float) {
|
if constexpr (has_zp && is_zp_float) {
|
||||||
if (is_new_zp) {
|
if (is_new_zp) {
|
||||||
reinterpret_cast<int4*>(&frag_zp)[0] =
|
reinterpret_cast<int4*>(&frag_zp)[0] =
|
||||||
reinterpret_cast<int4*>(&frag_zpf[k2])[0];
|
reinterpret_cast<int4*>(&frag_zpf[k2])[0];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if constexpr (w_type == vllm::kFE2M1f) {
|
|
||||||
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
|
|
||||||
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];
|
|
||||||
|
|
||||||
dequant_fp8_scales<scalar_t2>(s_quant_0,
|
|
||||||
reinterpret_cast<scalar_t2*>(&frag_s[k2]));
|
|
||||||
dequant_fp8_scales<scalar_t2>(
|
|
||||||
s_quant_1, reinterpret_cast<scalar_t2*>(&frag_s[k2]) + 2);
|
|
||||||
}
|
|
||||||
|
|
||||||
// We have the m dimension as the inner loop in order to encourage overlapping
|
// We have the m dimension as the inner loop in order to encourage overlapping
|
||||||
// dequantization and matmul operations.
|
// dequantization and matmul operations.
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@ -1297,10 +1259,7 @@ __global__ void Marlin(
|
|||||||
FragB frag_b1;
|
FragB frag_b1;
|
||||||
int b_quant_0, b_quant_1;
|
int b_quant_0, b_quant_1;
|
||||||
|
|
||||||
if constexpr (w_type_id == vllm::kFE2M1f.id()) {
|
if constexpr (w_type.size_bits() == 4) {
|
||||||
b_quant_1 = frag_b_quant[k2][0][j];
|
|
||||||
b_quant_0 = b_quant_1 << 8;
|
|
||||||
} else if constexpr (w_type.size_bits() == 4) {
|
|
||||||
b_quant_0 = frag_b_quant[k2][0][j];
|
b_quant_0 = frag_b_quant[k2][0][j];
|
||||||
b_quant_1 = b_quant_0 >> 8;
|
b_quant_1 = b_quant_0 >> 8;
|
||||||
} else {
|
} else {
|
||||||
@ -1313,11 +1272,6 @@ __global__ void Marlin(
|
|||||||
dequant_data(b_quant_0, reinterpret_cast<scalar_t2*>(&frag_b0));
|
dequant_data(b_quant_0, reinterpret_cast<scalar_t2*>(&frag_b0));
|
||||||
dequant_data(b_quant_1, reinterpret_cast<scalar_t2*>(&frag_b1));
|
dequant_data(b_quant_1, reinterpret_cast<scalar_t2*>(&frag_b1));
|
||||||
|
|
||||||
if constexpr (dequant_skip_flop && has_zp && !is_zp_float) {
|
|
||||||
sub_zp<scalar_t>(frag_b0, frag_zp[j], 0);
|
|
||||||
sub_zp<scalar_t>(frag_b1, frag_zp[j], 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply scale to frag_b0
|
// Apply scale to frag_b0
|
||||||
if constexpr (has_act_order) {
|
if constexpr (has_act_order) {
|
||||||
static_assert(group_blocks != -1);
|
static_assert(group_blocks != -1);
|
||||||
@ -1325,8 +1279,7 @@ __global__ void Marlin(
|
|||||||
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0);
|
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0);
|
||||||
scale4<scalar_t>(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
|
scale4<scalar_t>(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
|
||||||
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1);
|
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1);
|
||||||
} else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float &&
|
} else if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
|
||||||
group_blocks == -1) {
|
|
||||||
int idx = (threadIdx.x / 4) % 2;
|
int idx = (threadIdx.x / 4) % 2;
|
||||||
scalar_t2 s2 = Dtype::nums2num2(
|
scalar_t2 s2 = Dtype::nums2num2(
|
||||||
reinterpret_cast<scalar_t*>(&frag_s[j / 2][j % 2 * 2 + 0])[idx],
|
reinterpret_cast<scalar_t*>(&frag_s[j / 2][j % 2 * 2 + 0])[idx],
|
||||||
@ -1334,7 +1287,7 @@ __global__ void Marlin(
|
|||||||
if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2);
|
if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2);
|
||||||
scale_and_sub<scalar_t>(frag_b0, s2.x, frag_zp[j].x);
|
scale_and_sub<scalar_t>(frag_b0, s2.x, frag_zp[j].x);
|
||||||
scale_and_sub<scalar_t>(frag_b1, s2.y, frag_zp[j].y);
|
scale_and_sub<scalar_t>(frag_b1, s2.y, frag_zp[j].y);
|
||||||
} else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) {
|
} else if constexpr (has_zp && group_blocks != -1) {
|
||||||
if (is_new_zp)
|
if (is_new_zp)
|
||||||
frag_zp[j] = __hmul2(frag_zp[j],
|
frag_zp[j] = __hmul2(frag_zp[j],
|
||||||
*reinterpret_cast<scalar_t2*>(&frag_s[k2][j]));
|
*reinterpret_cast<scalar_t2*>(&frag_s[k2][j]));
|
||||||
@ -1601,17 +1554,10 @@ __global__ void Marlin(
|
|||||||
// For per-column quantization we finally apply the scale here (only for
|
// For per-column quantization we finally apply the scale here (only for
|
||||||
// 4-bit)
|
// 4-bit)
|
||||||
if constexpr (!has_act_order && group_blocks == -1 &&
|
if constexpr (!has_act_order && group_blocks == -1 &&
|
||||||
w_type.size_bits() == 4 &&
|
w_type.size_bits() == 4 && !has_zp) {
|
||||||
(has_zp && dequant_skip_flop || !has_zp)) {
|
|
||||||
res = __hmul2(res, s[0]);
|
res = __hmul2(res, s[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
if constexpr (w_type == vllm::kFE2M1f) {
|
|
||||||
if (!mul_topk_weights) {
|
|
||||||
res = __hmul2(res, global_scale);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if constexpr (m_block_size_8) {
|
if constexpr (m_block_size_8) {
|
||||||
((scalar_t*)sh_red)[idx] = res.x;
|
((scalar_t*)sh_red)[idx] = res.x;
|
||||||
((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y;
|
((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y;
|
||||||
@ -1702,9 +1648,7 @@ __global__ void Marlin(
|
|||||||
if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
|
if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
|
||||||
if (i == 0) {
|
if (i == 0) {
|
||||||
fetch_col_zp_to_shared();
|
fetch_col_zp_to_shared();
|
||||||
if constexpr (!dequant_skip_flop) {
|
fetch_col_scale_to_shared();
|
||||||
fetch_col_scale_to_shared();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
fetch_to_shared(i, i, i < slice_iters, i);
|
fetch_to_shared(i, i, i < slice_iters, i);
|
||||||
@ -1767,20 +1711,17 @@ __global__ void Marlin(
|
|||||||
|
|
||||||
if constexpr (has_act_order) {
|
if constexpr (has_act_order) {
|
||||||
slice_k_start += tb_k * stages;
|
slice_k_start += tb_k * stages;
|
||||||
|
slice_k_start_shared_fetch += tb_k * stages;
|
||||||
if (slice_k_start < prob_k) {
|
int first_group_id = g_idx[slice_k_start];
|
||||||
slice_k_start_shared_fetch += tb_k * stages;
|
int last_g_idx = slice_k_start + stages * tb_k * 2;
|
||||||
int first_group_id = g_idx[slice_k_start];
|
if (last_g_idx >= prob_k) {
|
||||||
int last_g_idx = slice_k_start + stages * tb_k * 2;
|
last_g_idx = prob_k - 1;
|
||||||
if (last_g_idx >= prob_k) {
|
}
|
||||||
last_g_idx = prob_k - 1;
|
int last_group_id = g_idx[last_g_idx];
|
||||||
}
|
if (last_group_id >= sh_first_group_id + sh_num_groups) {
|
||||||
int last_group_id = g_idx[last_g_idx];
|
fetch_act_order_scales_to_shared(false, first_group_id,
|
||||||
if (last_group_id >= sh_first_group_id + sh_num_groups) {
|
last_group_id);
|
||||||
fetch_act_order_scales_to_shared(false, first_group_id,
|
__syncthreads();
|
||||||
last_group_id);
|
|
||||||
__syncthreads();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (slice_iters == 0) {
|
if (slice_iters == 0) {
|
||||||
@ -1796,8 +1737,7 @@ __global__ void Marlin(
|
|||||||
bool last = slice_idx == slice_count - 1;
|
bool last = slice_idx == slice_count - 1;
|
||||||
// For per-column scales, we only fetch them here in the final step before
|
// For per-column scales, we only fetch them here in the final step before
|
||||||
// write-out
|
// write-out
|
||||||
if constexpr (!has_act_order && group_blocks == -1 &&
|
if constexpr (!has_act_order && group_blocks == -1 && !has_zp) {
|
||||||
(has_zp && dequant_skip_flop || !has_zp)) {
|
|
||||||
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
|
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
|
||||||
if (s_sh_wr_pred) {
|
if (s_sh_wr_pred) {
|
||||||
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
|
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
|
||||||
@ -1807,8 +1747,7 @@ __global__ void Marlin(
|
|||||||
}
|
}
|
||||||
|
|
||||||
thread_block_reduce();
|
thread_block_reduce();
|
||||||
if constexpr (!has_act_order && group_blocks == -1 &&
|
if constexpr (!has_act_order && group_blocks == -1 && !has_zp) {
|
||||||
(has_zp && dequant_skip_flop || !has_zp)) {
|
|
||||||
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
|
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
|
||||||
cp_async_wait<0>();
|
cp_async_wait<0>();
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
@ -1832,8 +1771,7 @@ __global__ void Marlin(
|
|||||||
// that converts the fp32 results to fp16 (so that we avoid possible
|
// that converts the fp32 results to fp16 (so that we avoid possible
|
||||||
// overflow in fp16)
|
// overflow in fp16)
|
||||||
if constexpr (!has_act_order && group_blocks == -1 &&
|
if constexpr (!has_act_order && group_blocks == -1 &&
|
||||||
w_type.size_bits() == 8 &&
|
w_type.size_bits() == 8 && !has_zp) {
|
||||||
(has_zp && dequant_skip_flop || !has_zp)) {
|
|
||||||
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < thread_m_blocks; i++) {
|
for (int i = 0; i < thread_m_blocks; i++) {
|
||||||
|
|||||||
@ -291,7 +291,6 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
|
|||||||
// BIGGROUP: cases for big group size (group_blocks in [-1, 8])
|
// BIGGROUP: cases for big group size (group_blocks in [-1, 8])
|
||||||
// FZP: cases for float-zero-point (is_zp_float = true)
|
// FZP: cases for float-zero-point (is_zp_float = true)
|
||||||
// ACT: cases for act order case (group_blocks == 0)
|
// ACT: cases for act order case (group_blocks == 0)
|
||||||
// FP4: cases for nvfp4(e2m1) (group_blocks == 1)
|
|
||||||
#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
|
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
|
||||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
|
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
|
||||||
@ -339,21 +338,6 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
|
|||||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
|
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
|
||||||
|
|
||||||
#define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
|
||||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \
|
|
||||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
|
|
||||||
|
|
||||||
#define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
|
||||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
|
|
||||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
|
|
||||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
|
|
||||||
|
|
||||||
#define FP4_GET_IF(W_TYPE) \
|
|
||||||
FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
|
||||||
FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
|
||||||
FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
|
||||||
FP4_GET_IF_M234(W_TYPE, 8, 4, 128)
|
|
||||||
|
|
||||||
#define BIGGROUP_GET_IF(W_TYPE) \
|
#define BIGGROUP_GET_IF(W_TYPE) \
|
||||||
BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||||
BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||||
@ -410,8 +394,6 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
|
|||||||
|
|
||||||
BIGGROUP_GET_IF(vllm::kFE4M3fn)
|
BIGGROUP_GET_IF(vllm::kFE4M3fn)
|
||||||
|
|
||||||
FP4_GET_IF(vllm::kFE2M1f)
|
|
||||||
|
|
||||||
ACT_GET_IF(vllm::kU4B8)
|
ACT_GET_IF(vllm::kU4B8)
|
||||||
ACT_GET_IF(vllm::kU8B128)
|
ACT_GET_IF(vllm::kU8B128)
|
||||||
|
|
||||||
@ -483,7 +465,7 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
|
|||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||||
void* s2, void* zp, void* g_idx, void* perm, void* a_tmp,
|
void* zp, void* g_idx, void* perm, void* a_tmp,
|
||||||
void* sorted_token_ids, void* expert_ids,
|
void* sorted_token_ids, void* expert_ids,
|
||||||
void* num_tokens_past_padded, void* topk_weights,
|
void* num_tokens_past_padded, void* topk_weights,
|
||||||
int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep,
|
int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep,
|
||||||
@ -497,16 +479,14 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
|||||||
bool m_block_size_8 = moe_block_size == 8;
|
bool m_block_size_8 = moe_block_size == 8;
|
||||||
|
|
||||||
if (has_zp) {
|
if (has_zp) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(q_type == vllm::kU4,
|
||||||
q_type == vllm::kU4 || q_type == vllm::kU8,
|
"q_type must be u4 when has_zp = True. Got = ", q_type.str());
|
||||||
"q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
|
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(q_type == vllm::kU4B8 || q_type == vllm::kU8B128 ||
|
||||||
q_type == vllm::kU4B8 || q_type == vllm::kU8B128 ||
|
q_type == vllm::kFE4M3fn,
|
||||||
q_type == vllm::kFE4M3fn || q_type == vllm::kFE2M1f,
|
"q_type must be uint4b8, uint8b128 or fp8e4m3 when has_zp = "
|
||||||
"q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when "
|
"False. Got = ",
|
||||||
"has_zp = False. Got = ",
|
q_type.str());
|
||||||
q_type.str());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
||||||
@ -539,7 +519,6 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
|||||||
int4* C_ptr = (int4*)C;
|
int4* C_ptr = (int4*)C;
|
||||||
int4* C_tmp_ptr = (int4*)C_tmp;
|
int4* C_tmp_ptr = (int4*)C_tmp;
|
||||||
const int4* s_ptr = (const int4*)s;
|
const int4* s_ptr = (const int4*)s;
|
||||||
const uint16_t* s2_ptr = (const uint16_t*)s2;
|
|
||||||
const int4* zp_ptr = (const int4*)zp;
|
const int4* zp_ptr = (const int4*)zp;
|
||||||
const int* g_idx_ptr = (const int*)g_idx;
|
const int* g_idx_ptr = (const int*)g_idx;
|
||||||
const int* perm_ptr = (const int*)perm;
|
const int* perm_ptr = (const int*)perm;
|
||||||
@ -648,7 +627,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
|||||||
// avoid ">>>" being formatted to "> > >"
|
// avoid ">>>" being formatted to "> > >"
|
||||||
// clang-format off
|
// clang-format off
|
||||||
kernel<<<blocks, num_threads, max_shared_mem, stream>>>(
|
kernel<<<blocks, num_threads, max_shared_mem, stream>>>(
|
||||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr,
|
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr,
|
||||||
sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,
|
sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,
|
||||||
topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m,
|
topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m,
|
||||||
prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce, max_shared_mem);
|
prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce, max_shared_mem);
|
||||||
@ -660,7 +639,6 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
|||||||
torch::Tensor moe_wna16_marlin_gemm(
|
torch::Tensor moe_wna16_marlin_gemm(
|
||||||
torch::Tensor& a, std::optional<torch::Tensor> const& c_or_none,
|
torch::Tensor& a, std::optional<torch::Tensor> const& c_or_none,
|
||||||
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
|
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
|
||||||
std::optional<torch::Tensor> const& global_scale_or_none,
|
|
||||||
std::optional<torch::Tensor> const& b_zeros_or_none,
|
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||||
std::optional<torch::Tensor> const& g_idx_or_none,
|
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||||
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
|
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
|
||||||
@ -812,17 +790,6 @@ torch::Tensor moe_wna16_marlin_gemm(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor global_scale;
|
|
||||||
if (global_scale_or_none.has_value()) {
|
|
||||||
global_scale = global_scale_or_none.value();
|
|
||||||
TORCH_CHECK(b_q_type == vllm::kFE2M1f,
|
|
||||||
"global_scale can only be used for float4_e2m1f.");
|
|
||||||
} else {
|
|
||||||
global_scale = torch::empty({0}, options);
|
|
||||||
TORCH_CHECK(!(b_q_type == vllm::kFE2M1f),
|
|
||||||
"the global_scale parameter must be passed for float4_e2m1f.");
|
|
||||||
}
|
|
||||||
|
|
||||||
torch::Tensor b_zeros;
|
torch::Tensor b_zeros;
|
||||||
if (b_zeros_or_none.has_value()) {
|
if (b_zeros_or_none.has_value()) {
|
||||||
b_zeros = b_zeros_or_none.value();
|
b_zeros = b_zeros_or_none.value();
|
||||||
@ -835,14 +802,13 @@ torch::Tensor moe_wna16_marlin_gemm(
|
|||||||
|
|
||||||
if (has_zp) {
|
if (has_zp) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
b_q_type == vllm::kU4 || b_q_type == vllm::kU8,
|
b_q_type == vllm::kU4,
|
||||||
"b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str());
|
"b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str());
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 ||
|
TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 ||
|
||||||
b_q_type == vllm::kFE4M3fn || b_q_type == vllm::kFE2M1f,
|
b_q_type == vllm::kFE4M3fn,
|
||||||
"b_q_type must be uint4b8, uint8b128, float8_e4m3fn or "
|
"b_q_type must be uint4b8, uint8b128 or fp8e4m3 when has_zp = "
|
||||||
"float4_e2m1f when "
|
"False. Got = ",
|
||||||
"has_zp = False. Got = ",
|
|
||||||
b_q_type.str());
|
b_q_type.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -888,16 +854,9 @@ torch::Tensor moe_wna16_marlin_gemm(
|
|||||||
|
|
||||||
int dev = a.get_device();
|
int dev = a.get_device();
|
||||||
if (a.scalar_type() == at::ScalarType::Half) {
|
if (a.scalar_type() == at::ScalarType::Half) {
|
||||||
void* scales_ptr;
|
|
||||||
if (b_q_type == vllm::kFE2M1f) {
|
|
||||||
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
|
|
||||||
} else {
|
|
||||||
scales_ptr = b_scales.data_ptr<at::Half>();
|
|
||||||
}
|
|
||||||
|
|
||||||
MARLIN_NAMESPACE_NAME::marlin_mm<half>(
|
MARLIN_NAMESPACE_NAME::marlin_mm<half>(
|
||||||
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
|
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
|
||||||
c_tmp.data_ptr<float>(), scales_ptr, global_scale.data_ptr<at::Half>(),
|
c_tmp.data_ptr<float>(), b_scales.data_ptr<at::Half>(),
|
||||||
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
|
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
|
||||||
a_tmp.data_ptr<at::Half>(), sorted_token_ids.data_ptr(),
|
a_tmp.data_ptr<at::Half>(), sorted_token_ids.data_ptr(),
|
||||||
expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(),
|
expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(),
|
||||||
@ -907,18 +866,11 @@ torch::Tensor moe_wna16_marlin_gemm(
|
|||||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||||
use_atomic_add, use_fp32_reduce, is_zp_float);
|
use_atomic_add, use_fp32_reduce, is_zp_float);
|
||||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||||
void* scales_ptr;
|
|
||||||
if (b_q_type == vllm::kFE2M1f) {
|
|
||||||
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
|
|
||||||
} else {
|
|
||||||
scales_ptr = b_scales.data_ptr<at::BFloat16>();
|
|
||||||
}
|
|
||||||
|
|
||||||
MARLIN_NAMESPACE_NAME::marlin_mm<nv_bfloat16>(
|
MARLIN_NAMESPACE_NAME::marlin_mm<nv_bfloat16>(
|
||||||
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
||||||
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(), scales_ptr,
|
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
|
||||||
global_scale.data_ptr<at::BFloat16>(), b_zeros.data_ptr(),
|
b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(),
|
||||||
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
|
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
|
||||||
sorted_token_ids.data_ptr(), expert_ids.data_ptr(),
|
sorted_token_ids.data_ptr(), expert_ids.data_ptr(),
|
||||||
num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(),
|
num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(),
|
||||||
moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k,
|
moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k,
|
||||||
|
|||||||
@ -326,7 +326,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (use_global_memory) {
|
if (use_global_memory) {
|
||||||
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
|
VLLM_DISPATCH_INTEGRAL_TYPES(
|
||||||
topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
|
topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
|
||||||
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
|
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
|
||||||
// tensors
|
// tensors
|
||||||
@ -351,7 +351,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
|||||||
cumsum_buffer.data_ptr<int32_t>());
|
cumsum_buffer.data_ptr<int32_t>());
|
||||||
});
|
});
|
||||||
} else if (use_i16) {
|
} else if (use_i16) {
|
||||||
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
|
VLLM_DISPATCH_INTEGRAL_TYPES(
|
||||||
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||||
// set dynamic shared mem
|
// set dynamic shared mem
|
||||||
auto kernel =
|
auto kernel =
|
||||||
@ -366,7 +366,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
|||||||
topk_ids.numel());
|
topk_ids.numel());
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
|
VLLM_DISPATCH_INTEGRAL_TYPES(
|
||||||
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||||
auto kernel =
|
auto kernel =
|
||||||
vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>;
|
vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>;
|
||||||
@ -391,7 +391,7 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
|||||||
TORCH_CHECK(num_experts == 256,
|
TORCH_CHECK(num_experts == 256,
|
||||||
"sgl_moe_align_block_size kernel only supports deepseek v3.");
|
"sgl_moe_align_block_size kernel only supports deepseek v3.");
|
||||||
|
|
||||||
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
|
VLLM_DISPATCH_INTEGRAL_TYPES(
|
||||||
topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] {
|
topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] {
|
||||||
// calc needed amount of shared mem for `cumsum` tensors
|
// calc needed amount of shared mem for `cumsum` tensors
|
||||||
auto options_int =
|
auto options_int =
|
||||||
|
|||||||
@ -28,6 +28,4 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
|
|||||||
torch::Tensor num_tokens_post_pad, int64_t top_k,
|
torch::Tensor num_tokens_post_pad, int64_t top_k,
|
||||||
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
|
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
|
||||||
int64_t BLOCK_SIZE_K, int64_t bit);
|
int64_t BLOCK_SIZE_K, int64_t bit);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
bool moe_permute_unpermute_supported();
|
|
||||||
@ -5,9 +5,6 @@
|
|||||||
#include "permute_unpermute_kernels/dispatch.h"
|
#include "permute_unpermute_kernels/dispatch.h"
|
||||||
#include "core/registration.h"
|
#include "core/registration.h"
|
||||||
|
|
||||||
// moe_permute kernels require at least CUDA 12.0
|
|
||||||
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)
|
|
||||||
|
|
||||||
void moe_permute(
|
void moe_permute(
|
||||||
const torch::Tensor& input, // [n_token, hidden]
|
const torch::Tensor& input, // [n_token, hidden]
|
||||||
const torch::Tensor& topk_weights, //[n_token, topk]
|
const torch::Tensor& topk_weights, //[n_token, topk]
|
||||||
@ -130,45 +127,7 @@ void moe_unpermute(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
#else
|
|
||||||
|
|
||||||
void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights,
|
|
||||||
torch::Tensor& topk_ids,
|
|
||||||
const torch::Tensor& token_expert_indicies,
|
|
||||||
const std::optional<torch::Tensor>& expert_map,
|
|
||||||
int64_t n_expert, int64_t n_local_expert, int64_t topk,
|
|
||||||
const std::optional<int64_t>& align_block_size,
|
|
||||||
torch::Tensor& permuted_input,
|
|
||||||
torch::Tensor& expert_first_token_offset,
|
|
||||||
torch::Tensor& src_row_id2dst_row_id_map,
|
|
||||||
torch::Tensor& m_indices) {
|
|
||||||
TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0");
|
|
||||||
}
|
|
||||||
|
|
||||||
void moe_unpermute(const torch::Tensor& input,
|
|
||||||
const torch::Tensor& topk_weights, torch::Tensor& topk_ids,
|
|
||||||
const torch::Tensor& token_expert_indicies,
|
|
||||||
const std::optional<torch::Tensor>& expert_map,
|
|
||||||
int64_t n_expert, int64_t n_local_expert, int64_t topk,
|
|
||||||
const std::optional<int64_t>& align_block_size,
|
|
||||||
torch::Tensor& permuted_input,
|
|
||||||
torch::Tensor& expert_first_token_offset,
|
|
||||||
torch::Tensor& src_row_id2dst_row_id_map,
|
|
||||||
torch::Tensor& m_indices) {
|
|
||||||
TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0");
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
bool moe_permute_unpermute_supported() {
|
|
||||||
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)
|
|
||||||
return true;
|
|
||||||
#else
|
|
||||||
return false;
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||||
m.impl("moe_permute", &moe_permute);
|
m.impl("moe_permute", &moe_permute);
|
||||||
m.impl("moe_unpermute", &moe_unpermute);
|
m.impl("moe_unpermute", &moe_unpermute);
|
||||||
}
|
}
|
||||||
@ -1,9 +1,6 @@
|
|||||||
|
|
||||||
#include "moe_permute_unpermute_kernel.h"
|
#include "moe_permute_unpermute_kernel.h"
|
||||||
|
|
||||||
// moe_permute kernels require at least CUDA 12.0
|
|
||||||
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)
|
|
||||||
|
|
||||||
// CubKeyValueSorter definition begin
|
// CubKeyValueSorter definition begin
|
||||||
CubKeyValueSorter::CubKeyValueSorter()
|
CubKeyValueSorter::CubKeyValueSorter()
|
||||||
: num_experts_(0), num_bits_(sizeof(int) * 8) {}
|
: num_experts_(0), num_bits_(sizeof(int) * 8) {}
|
||||||
@ -134,6 +131,9 @@ __global__ void preprocessTopkIdKernel(int* topk_id_ptr, int size,
|
|||||||
int num_experts) {
|
int num_experts) {
|
||||||
auto tidx = threadIdx.x;
|
auto tidx = threadIdx.x;
|
||||||
auto bidx = blockIdx.x;
|
auto bidx = blockIdx.x;
|
||||||
|
auto lidx = tidx & 31;
|
||||||
|
auto widx = tidx >> 5;
|
||||||
|
auto warp_count = (blockDim.x + 31) >> 5;
|
||||||
auto offset = bidx * blockDim.x;
|
auto offset = bidx * blockDim.x;
|
||||||
auto bound = min(offset + blockDim.x, size);
|
auto bound = min(offset + blockDim.x, size);
|
||||||
extern __shared__ int smem_expert_map[];
|
extern __shared__ int smem_expert_map[];
|
||||||
@ -226,6 +226,4 @@ void getMIndices(int64_t* expert_first_token_offset,
|
|||||||
expert_first_token_offset, align_expert_first_token_offset, m_indices,
|
expert_first_token_offset, align_expert_first_token_offset, m_indices,
|
||||||
num_local_expert, align_block_size);
|
num_local_expert, align_block_size);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
|
||||||
@ -108,17 +108,9 @@ __launch_bounds__(TPB) __global__
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int TPB, typename IndType>
|
template <int TPB>
|
||||||
__launch_bounds__(TPB) __global__ void moeTopK(
|
__launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output,
|
||||||
const float* inputs_after_softmax,
|
int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert)
|
||||||
const bool* finished,
|
|
||||||
float* output,
|
|
||||||
IndType* indices,
|
|
||||||
int* source_rows,
|
|
||||||
const int num_experts,
|
|
||||||
const int k,
|
|
||||||
const int start_expert,
|
|
||||||
const int end_expert)
|
|
||||||
{
|
{
|
||||||
|
|
||||||
using cub_kvp = cub::KeyValuePair<int, float>;
|
using cub_kvp = cub::KeyValuePair<int, float>;
|
||||||
@ -190,9 +182,9 @@ __launch_bounds__(TPB) __global__ void moeTopK(
|
|||||||
2) This implementation assumes k is small, but will work for any k.
|
2) This implementation assumes k is small, but will work for any k.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG, typename IndType>
|
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
|
||||||
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
|
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
|
||||||
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, IndType* indices,
|
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices,
|
||||||
int* source_rows, const int k, const int start_expert, const int end_expert)
|
int* source_rows, const int k, const int start_expert, const int end_expert)
|
||||||
{
|
{
|
||||||
// We begin by enforcing compile time assertions and setting up compile time constants.
|
// We begin by enforcing compile time assertions and setting up compile time constants.
|
||||||
@ -405,8 +397,8 @@ struct TopkConstants
|
|||||||
};
|
};
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
template <int EXPERTS, int WARPS_PER_TB, typename IndType>
|
template <int EXPERTS, int WARPS_PER_TB>
|
||||||
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, IndType* indices,
|
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices,
|
||||||
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
|
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
|
||||||
{
|
{
|
||||||
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
|
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
|
||||||
@ -429,11 +421,10 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
|
|||||||
token_expert_indices, num_tokens, topk, 0, num_experts, \
|
token_expert_indices, num_tokens, topk, 0, num_experts, \
|
||||||
stream);
|
stream);
|
||||||
|
|
||||||
template <typename IndType>
|
|
||||||
void topkGatingSoftmaxKernelLauncher(
|
void topkGatingSoftmaxKernelLauncher(
|
||||||
const float* gating_output,
|
const float* gating_output,
|
||||||
float* topk_weights,
|
float* topk_weights,
|
||||||
IndType* topk_indicies,
|
int* topk_indicies,
|
||||||
int* token_expert_indices,
|
int* token_expert_indices,
|
||||||
float* softmax_workspace,
|
float* softmax_workspace,
|
||||||
const int num_tokens,
|
const int num_tokens,
|
||||||
@ -502,32 +493,14 @@ void topk_softmax(
|
|||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options());
|
torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options());
|
||||||
|
vllm::moe::topkGatingSoftmaxKernelLauncher(
|
||||||
if(topk_indices.scalar_type() == at::ScalarType::Int)
|
gating_output.data_ptr<float>(),
|
||||||
{
|
topk_weights.data_ptr<float>(),
|
||||||
vllm::moe::topkGatingSoftmaxKernelLauncher(
|
topk_indices.data_ptr<int>(),
|
||||||
gating_output.data_ptr<float>(),
|
token_expert_indices.data_ptr<int>(),
|
||||||
topk_weights.data_ptr<float>(),
|
softmax_workspace.data_ptr<float>(),
|
||||||
topk_indices.data_ptr<int>(),
|
num_tokens,
|
||||||
token_expert_indices.data_ptr<int>(),
|
num_experts,
|
||||||
softmax_workspace.data_ptr<float>(),
|
topk,
|
||||||
num_tokens,
|
stream);
|
||||||
num_experts,
|
|
||||||
topk,
|
|
||||||
stream);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
assert(topk_indices.scalar_type() == at::ScalarType::UInt32);
|
|
||||||
vllm::moe::topkGatingSoftmaxKernelLauncher(
|
|
||||||
gating_output.data_ptr<float>(),
|
|
||||||
topk_weights.data_ptr<float>(),
|
|
||||||
topk_indices.data_ptr<uint32_t>(),
|
|
||||||
token_expert_indices.data_ptr<int>(),
|
|
||||||
softmax_workspace.data_ptr<float>(),
|
|
||||||
num_tokens,
|
|
||||||
num_experts,
|
|
||||||
topk,
|
|
||||||
stream);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -10,7 +10,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
|||||||
|
|
||||||
// Calculate the result of moe by summing up the partial results
|
// Calculate the result of moe by summing up the partial results
|
||||||
// from all selected experts.
|
// from all selected experts.
|
||||||
m.def("moe_sum(Tensor input, Tensor! output) -> ()");
|
m.def("moe_sum(Tensor! input, Tensor output) -> ()");
|
||||||
m.impl("moe_sum", torch::kCUDA, &moe_sum);
|
m.impl("moe_sum", torch::kCUDA, &moe_sum);
|
||||||
|
|
||||||
// Aligning the number of tokens to be processed by each expert such
|
// Aligning the number of tokens to be processed by each expert such
|
||||||
@ -44,8 +44,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
|||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
|
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
|
||||||
"Tensor! b_q_weight, Tensor! b_scales, Tensor? global_scale, Tensor? "
|
"Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none,"
|
||||||
"b_zeros_or_none,"
|
|
||||||
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
|
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
|
||||||
"Tensor sorted_token_ids,"
|
"Tensor sorted_token_ids,"
|
||||||
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
|
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
|
||||||
@ -77,9 +76,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
|||||||
"Tensor topk_ids,Tensor src_row_id2dst_row_id_map, Tensor "
|
"Tensor topk_ids,Tensor src_row_id2dst_row_id_map, Tensor "
|
||||||
"expert_first_token_offset, int n_expert, int n_local_expert,int "
|
"expert_first_token_offset, int n_expert, int n_local_expert,int "
|
||||||
"topk, Tensor! hidden_states)->()");
|
"topk, Tensor! hidden_states)->()");
|
||||||
|
// conditionally compiled so impl registration is in source file
|
||||||
m.def("moe_permute_unpermute_supported() -> bool");
|
|
||||||
m.impl("moe_permute_unpermute_supported", &moe_permute_unpermute_supported);
|
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|||||||
37
csrc/ops.h
37
csrc/ops.h
@ -59,31 +59,6 @@ void merge_attn_states(torch::Tensor& output,
|
|||||||
const torch::Tensor& prefix_lse,
|
const torch::Tensor& prefix_lse,
|
||||||
const torch::Tensor& suffix_output,
|
const torch::Tensor& suffix_output,
|
||||||
const torch::Tensor& suffix_lse);
|
const torch::Tensor& suffix_lse);
|
||||||
|
|
||||||
void convert_vertical_slash_indexes(
|
|
||||||
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
|
|
||||||
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
|
|
||||||
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
|
|
||||||
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
|
|
||||||
torch::Tensor q_seqlens, // [BATCH, ]
|
|
||||||
torch::Tensor kv_seqlens, // [BATCH, ]
|
|
||||||
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
|
|
||||||
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
|
|
||||||
int64_t context_size, int64_t block_size_M, int64_t block_size_N,
|
|
||||||
bool causal);
|
|
||||||
|
|
||||||
void convert_vertical_slash_indexes_mergehead(
|
|
||||||
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
|
|
||||||
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
|
|
||||||
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
|
|
||||||
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
|
|
||||||
torch::Tensor q_seqlens, // [BATCH, ]
|
|
||||||
torch::Tensor kv_seqlens, // [BATCH, ]
|
|
||||||
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
|
|
||||||
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
|
|
||||||
torch::Tensor vertical_indices_count, // [N_HEADS, ]
|
|
||||||
torch::Tensor slash_indices_count, int64_t context_size,
|
|
||||||
int64_t block_size_M, int64_t block_size_N, bool causal);
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
||||||
@ -233,12 +208,6 @@ void cutlass_moe_mm(
|
|||||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides);
|
torch::Tensor const& b_strides, torch::Tensor const& c_strides);
|
||||||
|
|
||||||
void cutlass_fp4_group_mm(
|
|
||||||
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
|
|
||||||
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
|
|
||||||
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
|
|
||||||
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets);
|
|
||||||
|
|
||||||
void get_cutlass_moe_mm_data(
|
void get_cutlass_moe_mm_data(
|
||||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||||
@ -266,12 +235,6 @@ std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a);
|
|||||||
void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
|
void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
|
||||||
torch::Tensor& output_scale,
|
torch::Tensor& output_scale,
|
||||||
torch::Tensor const& input_scale);
|
torch::Tensor const& input_scale);
|
||||||
|
|
||||||
void scaled_fp4_experts_quant(
|
|
||||||
torch::Tensor& output, torch::Tensor& output_scale,
|
|
||||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
|
||||||
torch::Tensor const& input_offset_by_experts,
|
|
||||||
torch::Tensor const& output_scale_offset_by_experts);
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||||
|
|||||||
@ -44,8 +44,7 @@ inline __device__ void apply_rotary_embedding(
|
|||||||
// head_size]
|
// head_size]
|
||||||
const scalar_t* cache_ptr, const int head_size, const int num_heads,
|
const scalar_t* cache_ptr, const int head_size, const int num_heads,
|
||||||
const int num_kv_heads, const int rot_dim, const int token_idx,
|
const int num_kv_heads, const int rot_dim, const int token_idx,
|
||||||
const int64_t query_stride, const int64_t key_stride,
|
const int64_t query_stride, const int64_t key_stride) {
|
||||||
const int64_t head_stride) {
|
|
||||||
const int embed_dim = rot_dim / 2;
|
const int embed_dim = rot_dim / 2;
|
||||||
const scalar_t* cos_ptr = cache_ptr;
|
const scalar_t* cos_ptr = cache_ptr;
|
||||||
const scalar_t* sin_ptr = cache_ptr + embed_dim;
|
const scalar_t* sin_ptr = cache_ptr + embed_dim;
|
||||||
@ -53,8 +52,7 @@ inline __device__ void apply_rotary_embedding(
|
|||||||
const int nq = num_heads * embed_dim;
|
const int nq = num_heads * embed_dim;
|
||||||
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
|
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
|
||||||
const int head_idx = i / embed_dim;
|
const int head_idx = i / embed_dim;
|
||||||
const int64_t token_head =
|
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
|
||||||
token_idx * query_stride + head_idx * head_stride;
|
|
||||||
const int rot_offset = i % embed_dim;
|
const int rot_offset = i % embed_dim;
|
||||||
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
|
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
|
||||||
query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
|
query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
|
||||||
@ -64,8 +62,7 @@ inline __device__ void apply_rotary_embedding(
|
|||||||
const int nk = num_kv_heads * embed_dim;
|
const int nk = num_kv_heads * embed_dim;
|
||||||
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
|
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
|
||||||
const int head_idx = i / embed_dim;
|
const int head_idx = i / embed_dim;
|
||||||
const int64_t token_head =
|
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
|
||||||
token_idx * key_stride + head_idx * head_stride;
|
|
||||||
const int rot_offset = i % embed_dim;
|
const int rot_offset = i % embed_dim;
|
||||||
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
|
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
|
||||||
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
|
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
|
||||||
@ -87,8 +84,7 @@ __global__ void rotary_embedding_kernel(
|
|||||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
|
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
|
||||||
// 2]
|
// 2]
|
||||||
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
||||||
const int64_t head_stride, const int num_heads, const int num_kv_heads,
|
const int num_heads, const int num_kv_heads, const int head_size) {
|
||||||
const int head_size) {
|
|
||||||
// Each thread block is responsible for one token.
|
// Each thread block is responsible for one token.
|
||||||
const int token_idx = blockIdx.x;
|
const int token_idx = blockIdx.x;
|
||||||
int64_t pos = positions[token_idx];
|
int64_t pos = positions[token_idx];
|
||||||
@ -96,7 +92,7 @@ __global__ void rotary_embedding_kernel(
|
|||||||
|
|
||||||
apply_rotary_embedding<scalar_t, IS_NEOX>(
|
apply_rotary_embedding<scalar_t, IS_NEOX>(
|
||||||
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
|
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
|
||||||
token_idx, query_stride, key_stride, head_stride);
|
token_idx, query_stride, key_stride);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t, bool IS_NEOX>
|
template <typename scalar_t, bool IS_NEOX>
|
||||||
@ -113,9 +109,9 @@ __global__ void batched_rotary_embedding_kernel(
|
|||||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
|
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
|
||||||
// 2]
|
// 2]
|
||||||
const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len]
|
const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len]
|
||||||
|
// or [num_tokens]
|
||||||
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
||||||
const int64_t head_stride, const int num_heads, const int num_kv_heads,
|
const int num_heads, const int num_kv_heads, const int head_size) {
|
||||||
const int head_size) {
|
|
||||||
// Each thread block is responsible for one token.
|
// Each thread block is responsible for one token.
|
||||||
const int token_idx = blockIdx.x;
|
const int token_idx = blockIdx.x;
|
||||||
int64_t pos = positions[token_idx];
|
int64_t pos = positions[token_idx];
|
||||||
@ -125,7 +121,7 @@ __global__ void batched_rotary_embedding_kernel(
|
|||||||
|
|
||||||
apply_rotary_embedding<scalar_t, IS_NEOX>(
|
apply_rotary_embedding<scalar_t, IS_NEOX>(
|
||||||
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
|
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
|
||||||
token_idx, query_stride, key_stride, head_stride);
|
token_idx, query_stride, key_stride);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
@ -183,12 +179,6 @@ void rotary_embedding(
|
|||||||
int seq_dim_idx = positions_ndim - 1;
|
int seq_dim_idx = positions_ndim - 1;
|
||||||
int64_t query_stride = query.stride(seq_dim_idx);
|
int64_t query_stride = query.stride(seq_dim_idx);
|
||||||
int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
|
int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
|
||||||
// Determine head stride: for [*, heads, head_size] use stride of last dim;
|
|
||||||
// for flat [*, heads*head_size], heads blocks are contiguous of size
|
|
||||||
// head_size
|
|
||||||
int query_ndim = query.dim();
|
|
||||||
int64_t head_stride =
|
|
||||||
(query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size;
|
|
||||||
|
|
||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
|
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
|
||||||
@ -200,14 +190,14 @@ void rotary_embedding(
|
|||||||
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||||
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
|
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
|
||||||
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride, key_stride,
|
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride, key_stride,
|
||||||
head_stride, num_heads, num_kv_heads, head_size);
|
num_heads, num_kv_heads, head_size);
|
||||||
} else {
|
} else {
|
||||||
vllm::rotary_embedding_kernel<scalar_t, false>
|
vllm::rotary_embedding_kernel<scalar_t, false>
|
||||||
<<<grid, block, 0, stream>>>(
|
<<<grid, block, 0, stream>>>(
|
||||||
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||||
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
|
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
|
||||||
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride,
|
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride,
|
||||||
key_stride, head_stride, num_heads, num_kv_heads, head_size);
|
key_stride, num_heads, num_kv_heads, head_size);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -273,12 +263,6 @@ void batched_rotary_embedding(
|
|||||||
int seq_dim_idx = positions_ndim - 1;
|
int seq_dim_idx = positions_ndim - 1;
|
||||||
int64_t query_stride = query.stride(seq_dim_idx);
|
int64_t query_stride = query.stride(seq_dim_idx);
|
||||||
int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
|
int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
|
||||||
// Determine head stride: for [*, heads, head_size] use stride of last dim;
|
|
||||||
// for flat [*, heads*head_size], heads blocks are contiguous of size
|
|
||||||
// head_size
|
|
||||||
int query_ndim = query.dim();
|
|
||||||
int64_t head_stride =
|
|
||||||
(query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size;
|
|
||||||
|
|
||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
|
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
|
||||||
@ -292,7 +276,7 @@ void batched_rotary_embedding(
|
|||||||
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
|
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
|
||||||
cos_sin_cache.data_ptr<scalar_t>(),
|
cos_sin_cache.data_ptr<scalar_t>(),
|
||||||
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
|
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
|
||||||
key_stride, head_stride, num_heads, num_kv_heads, head_size);
|
key_stride, num_heads, num_kv_heads, head_size);
|
||||||
} else {
|
} else {
|
||||||
vllm::batched_rotary_embedding_kernel<scalar_t, false>
|
vllm::batched_rotary_embedding_kernel<scalar_t, false>
|
||||||
<<<grid, block, 0, stream>>>(
|
<<<grid, block, 0, stream>>>(
|
||||||
@ -300,7 +284,7 @@ void batched_rotary_embedding(
|
|||||||
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
|
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
|
||||||
cos_sin_cache.data_ptr<scalar_t>(),
|
cos_sin_cache.data_ptr<scalar_t>(),
|
||||||
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
|
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
|
||||||
key_stride, head_stride, num_heads, num_kv_heads, head_size);
|
key_stride, num_heads, num_kv_heads, head_size);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@ -112,8 +112,7 @@ __global__ void act_and_mul_quant_kernel(
|
|||||||
void silu_and_mul_quant(torch::Tensor& out, // [..., d]
|
void silu_and_mul_quant(torch::Tensor& out, // [..., d]
|
||||||
torch::Tensor& input, // [..., 2 * d]
|
torch::Tensor& input, // [..., 2 * d]
|
||||||
torch::Tensor& scale) {
|
torch::Tensor& scale) {
|
||||||
TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn ||
|
TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn);
|
||||||
out.dtype() == torch::kFloat8_e4m3fnuz);
|
|
||||||
TORCH_CHECK(input.dtype() == torch::kFloat16 ||
|
TORCH_CHECK(input.dtype() == torch::kFloat16 ||
|
||||||
input.dtype() == torch::kBFloat16);
|
input.dtype() == torch::kBFloat16);
|
||||||
TORCH_CHECK(input.size(-1) % 2 == 0);
|
TORCH_CHECK(input.size(-1) % 2 == 0);
|
||||||
|
|||||||
@ -26,13 +26,7 @@ static inline __device__ int8_t float_to_int8_rn(float x) {
|
|||||||
float dst = std::nearbyint(x);
|
float dst = std::nearbyint(x);
|
||||||
|
|
||||||
// saturate
|
// saturate
|
||||||
|
dst = std::clamp(dst, i8_min, i8_max);
|
||||||
// See https://github.com/pytorch/pytorch/issues/127666
|
|
||||||
// See https://github.com/llvm/llvm-project/issues/95183
|
|
||||||
// hip-clang std::clamp __glibcxx_assert_fail host function when building on
|
|
||||||
// Arch/gcc14. The following replaces std::clamp usage with similar logic
|
|
||||||
// dst = std::clamp(dst, i8_min, i8_max);
|
|
||||||
dst = (dst < i8_min) ? i8_min : (dst > i8_max) ? i8_max : dst;
|
|
||||||
return static_cast<int8_t>(dst);
|
return static_cast<int8_t>(dst);
|
||||||
#else
|
#else
|
||||||
// CUDA path
|
// CUDA path
|
||||||
@ -85,13 +79,7 @@ static inline __device__ int8_t int32_to_int8(int32_t x) {
|
|||||||
static_cast<int32_t>(std::numeric_limits<int8_t>::max());
|
static_cast<int32_t>(std::numeric_limits<int8_t>::max());
|
||||||
|
|
||||||
// saturate
|
// saturate
|
||||||
|
int32_t dst = std::clamp(x, i8_min, i8_max);
|
||||||
// See https://github.com/pytorch/pytorch/issues/127666
|
|
||||||
// See https://github.com/llvm/llvm-project/issues/95183
|
|
||||||
// hip-clang std::clamp __glibcxx_assert_fail host function when building on
|
|
||||||
// Arch/gcc14. The following replaces std::clamp usage with similar logic
|
|
||||||
// int32_t dst = std::clamp(x, i8_min, i8_max);
|
|
||||||
int32_t dst = (x < i8_min) ? i8_min : (x > i8_max) ? i8_max : x;
|
|
||||||
return static_cast<int8_t>(dst);
|
return static_cast<int8_t>(dst);
|
||||||
#else
|
#else
|
||||||
// CUDA path
|
// CUDA path
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
#include <torch/all.h>
|
#include <torch/all.h>
|
||||||
#include "cuda_utils.h"
|
#include "cuda_utils.h"
|
||||||
#include "cutlass_extensions/common.hpp"
|
|
||||||
|
|
||||||
template <typename Fp8Func, typename Int8Func, typename BlockwiseFunc>
|
template <typename Fp8Func, typename Int8Func, typename BlockwiseFunc>
|
||||||
void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||||
@ -29,46 +28,29 @@ void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor.");
|
using GroupShape = std::array<int64_t, 2>;
|
||||||
TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor.");
|
auto make_group_shape = [](torch::Tensor const& x,
|
||||||
int32_t version_num = get_sm_version_num();
|
torch::Tensor const& s) -> GroupShape {
|
||||||
if (version_num >= 100) {
|
TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D");
|
||||||
TORCH_CHECK(
|
return {cuda_utils::ceil_div(x.size(0), s.size(0)),
|
||||||
a.size(0) == a_scales.size(0) &&
|
cuda_utils::ceil_div(x.size(1), s.size(1))};
|
||||||
cuda_utils::ceil_div(a.size(1), int64_t(128)) == a_scales.size(1),
|
};
|
||||||
"a_scale_group_shape must be [1, 128].");
|
|
||||||
TORCH_CHECK(
|
|
||||||
cuda_utils::ceil_div(b.size(0), int64_t(128)) == b_scales.size(0) &&
|
|
||||||
cuda_utils::ceil_div(b.size(1), int64_t(128)) == b_scales.size(1),
|
|
||||||
"b_scale_group_shape must be [128, 128].");
|
|
||||||
} else {
|
|
||||||
// TODO: Remove this after using cutlass sm90 blockwise scaling gemm
|
|
||||||
// kernel, or introducing ceil_div to the load_init() of mainloop.
|
|
||||||
using GroupShape = std::array<int64_t, 2>;
|
|
||||||
auto make_group_shape = [](torch::Tensor const& x,
|
|
||||||
torch::Tensor const& s) -> GroupShape {
|
|
||||||
TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D");
|
|
||||||
return {cuda_utils::ceil_div(x.size(0), s.size(0)),
|
|
||||||
cuda_utils::ceil_div(x.size(1), s.size(1))};
|
|
||||||
};
|
|
||||||
|
|
||||||
GroupShape a_scale_group_shape = make_group_shape(a, a_scales);
|
GroupShape a_scale_group_shape = make_group_shape(a, a_scales);
|
||||||
GroupShape b_scale_group_shape = make_group_shape(b, b_scales);
|
GroupShape b_scale_group_shape = make_group_shape(b, b_scales);
|
||||||
|
|
||||||
// 1x128 per-token group scales for activations
|
|
||||||
// 128x128 blockwise scales for weights
|
|
||||||
TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} &&
|
|
||||||
b_scale_group_shape == GroupShape{128, 128} &&
|
|
||||||
a.dtype() == torch::kFloat8_e4m3fn &&
|
|
||||||
b.dtype() == torch::kFloat8_e4m3fn),
|
|
||||||
"cutlass_scaled_mm only supports datatype float8_e4m3fn.\n"
|
|
||||||
"a_scale_group_shape must be [1, 128]. Got: [",
|
|
||||||
a_scale_group_shape[0], ", ", a_scale_group_shape[1],
|
|
||||||
"]\n"
|
|
||||||
"b_scale_group_shape must be [128, 128]. Got: [",
|
|
||||||
b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
// 1x128 per-token group scales for activations
|
||||||
|
// 128x128 blockwise scales for weights
|
||||||
|
TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} &&
|
||||||
|
b_scale_group_shape == GroupShape{128, 128} &&
|
||||||
|
a.dtype() == torch::kFloat8_e4m3fn &&
|
||||||
|
b.dtype() == torch::kFloat8_e4m3fn),
|
||||||
|
"cutlass_scaled_mm only supports datatype float8_e4m3fn.\n"
|
||||||
|
"a_scale_group_shape must be [1, 128]. Got: [",
|
||||||
|
a_scale_group_shape[0], ", ", a_scale_group_shape[1],
|
||||||
|
"]\n"
|
||||||
|
"b_scale_group_shape must be [128, 128]. Got: [",
|
||||||
|
b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]");
|
||||||
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
|
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
|
||||||
blockwise_func(c, a, b, a_scales, b_scales);
|
blockwise_func(c, a, b, a_scales, b_scales);
|
||||||
}
|
}
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user