[Benchmark] Add plot utility for parameter sweep (#27168)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-10-22 11:30:03 +08:00
committed by GitHub
parent bfa59be8f1
commit ceacedc1f9
10 changed files with 1788 additions and 1168 deletions

View File

@ -7,7 +7,7 @@ toc_depth: 4
vLLM provides comprehensive benchmarking tools for performance testing and evaluation:
- **[Benchmark CLI](#benchmark-cli)**: `vllm bench` CLI tools and specialized benchmark scripts for interactive performance testing
- **[Batch Scripts](#batch-scripts)**: Run `vllm bench` against multiple configurations conveniently
- **[Parameter sweeps](#parameter-sweeps)**: Automate `vllm bench` runs for multiple configurations
- **[Performance benchmarks](#performance-benchmarks)**: Automated CI benchmarks for development
- **[Nightly benchmarks](#nightly-benchmarks)**: Comparative benchmarks against alternatives
@ -925,15 +925,13 @@ throughput numbers correctly is also adjusted.
</details>
## Batch Scripts
## Parameter Sweeps
### Batch Serving Script
### Online Benchmark
[`vllm/benchmarks/serve_multi.py`](../../vllm/benchmarks/serve_multi.py) automatically starts `vllm serve` and runs `vllm bench serve` over multiple configurations.
[`vllm/benchmarks/sweep/serve.py`](../../vllm/benchmarks/sweep/serve.py) automatically starts `vllm serve` and runs `vllm bench serve` to evaluate vLLM over multiple configurations.
#### Batch Mode
The basic purpose of this script is to evaluate vLLM under different settings. Follows these steps to run the script:
Follow these steps to run the script:
1. Construct the base command to `vllm serve`, and pass it to the `--serve-cmd` option.
2. Construct the base command to `vllm bench serve`, and pass it to the `--bench-cmd` option.
@ -996,7 +994,7 @@ The basic purpose of this script is to evaluate vLLM under different settings. F
Example command:
```bash
python vllm/benchmarks/serve_multi.py \
python -m vllm.benchmarks.sweep.serve \
--serve-cmd 'vllm serve meta-llama/Llama-2-7b-chat-hf' \
--bench-cmd 'vllm bench serve --model meta-llama/Llama-2-7b-chat-hf --backend vllm --endpoint /v1/completions --dataset-name sharegpt --dataset-path benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json' \
--serve-params benchmarks/serve_hparams.json \
@ -1018,9 +1016,9 @@ python vllm/benchmarks/serve_multi.py \
!!! tip
You can use the `--resume` option to continue the parameter sweep if one of the runs failed.
#### SLA Mode
### SLA Auto-Tuner
By passing SLA constraints via `--sla-params`, you can run this script in SLA mode, causing it to adjust either the request rate or concurrency (choose using `--sla-variable`) in order to satisfy the SLA constraints.
[`vllm/benchmarks/sweep/serve_sla.py`](../../vllm/benchmarks/sweep/serve_sla.py) is a wrapper over [`vllm/benchmarks/sweep/serve.py`](../../vllm/benchmarks/sweep/serve.py) that tunes either the request rate or concurrency (choose using `--sla-variable`) in order to satisfy the SLA constraints given by `--sla-params`.
For example, to ensure E2E latency within different target values for 99% of requests:
@ -1044,7 +1042,7 @@ For example, to ensure E2E latency within different target values for 99% of req
Example command:
```bash
python vllm/benchmarks/serve_multi.py \
python -m vllm.benchmarks.sweep.serve_sla \
--serve-cmd 'vllm serve meta-llama/Llama-2-7b-chat-hf' \
--bench-cmd 'vllm bench serve --model meta-llama/Llama-2-7b-chat-hf --backend vllm --endpoint /v1/completions --dataset-name sharegpt --dataset-path benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json' \
--serve-params benchmarks/serve_hparams.json \
@ -1066,6 +1064,24 @@ The algorithm for adjusting the SLA variable is as follows:
For a given combination of `--serve-params` and `--bench-params`, we share the benchmark results across `--sla-params` to avoid rerunning benchmarks with the same SLA variable value.
### Visualizer
[`vllm/benchmarks/sweep/plot.py`](../../vllm/benchmarks/sweep/plot.py) can be used to plot performance curves from parameter sweep results.
Example command:
```bash
python -m vllm.benchmarks.sweep.plot benchmarks/results/<timestamp> \
--var-x max_concurrency \
--row-by random_input_len \
--col-by random_output_len \
--curve-by api_server_count,max_num_batched_tokens \
--filter-by 'max_concurrency<=1024'
```
!!! tip
You can use `--dry-run` to preview the figures to be plotted.
## Performance Benchmarks
The performance benchmarks are used for development to confirm whether new changes improve performance under various workloads. They are triggered on every commit with both the `perf-benchmarks` and `ready` labels, and when a PR is merged into vLLM.

File diff suppressed because it is too large Load Diff

View File

View File

@ -0,0 +1,91 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import os
from typing import Any
class ParameterSweep(list["ParameterSweepItem"]):
@classmethod
def read_json(cls, filepath: os.PathLike):
with open(filepath, "rb") as f:
records = json.load(f)
return cls.from_records(records)
@classmethod
def from_records(cls, records: list[dict[str, object]]):
if not isinstance(records, list):
raise TypeError(
f"The parameter sweep should be a list of dictionaries, "
f"but found type: {type(records)}"
)
return cls(ParameterSweepItem.from_record(record) for record in records)
class ParameterSweepItem(dict[str, object]):
@classmethod
def from_record(cls, record: dict[str, object]):
if not isinstance(record, dict):
raise TypeError(
f"Each item in the parameter sweep should be a dictionary, "
f"but found type: {type(record)}"
)
return cls(record)
def __or__(self, other: dict[str, Any]):
return type(self)(super().__or__(other))
# In JSON, we prefer "_"
def _iter_param_key_candidates(self, param_key: str):
# Inner config arguments are not converted by the CLI
if "." in param_key:
prefix, rest = param_key.split(".", 1)
for prefix_candidate in self._iter_param_key_candidates(prefix):
yield prefix_candidate + "." + rest
return
yield param_key
yield param_key.replace("-", "_")
yield param_key.replace("_", "-")
# In CLI, we prefer "-"
def _iter_cmd_key_candidates(self, param_key: str):
for k in reversed(tuple(self._iter_param_key_candidates(param_key))):
yield "--" + k
def _normalize_cmd_key(self, param_key: str):
return next(self._iter_cmd_key_candidates(param_key))
def has_param(self, param_key: str) -> bool:
return any(k in self for k in self._iter_param_key_candidates(param_key))
def apply_to_cmd(self, cmd: list[str]) -> list[str]:
cmd = list(cmd)
for k, v in self.items():
for k_candidate in self._iter_cmd_key_candidates(k):
try:
k_idx = cmd.index(k_candidate)
if isinstance(v, bool):
cmd[k_idx] = self._normalize_cmd_key(k if v else "no-" + k)
else:
cmd[k_idx + 1] = str(v)
break
except ValueError:
continue
else:
if isinstance(v, bool):
cmd.append(self._normalize_cmd_key(k if v else "no-" + k))
else:
cmd.extend([self._normalize_cmd_key(k), str(v)])
return cmd
def as_text(self, sep: str = ", ") -> str:
return sep.join(f"{k}={v}" for k, v in self.items())

View File

@ -0,0 +1,530 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import json
from abc import ABC, abstractmethod
from concurrent.futures import ProcessPoolExecutor
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from types import TracebackType
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from typing_extensions import Self, override
from vllm.utils.collection_utils import full_groupby
from .utils import sanitize_filename
@dataclass
class PlotFilterBase(ABC):
var: str
target: str
@classmethod
def parse_str(cls, s: str):
for op_key in PLOT_FILTERS:
if op_key in s:
key, value = s.split(op_key)
return PLOT_FILTERS[op_key](
key,
value.removeprefix(op_key).strip("'").strip('"'),
)
else:
raise ValueError(
f"Invalid operator for plot filter '{s}'. "
f"Valid operators are: {sorted(PLOT_FILTERS)}",
)
@abstractmethod
def apply(self, df: pd.DataFrame) -> pd.DataFrame:
"""Applies this filter to a DataFrame."""
raise NotImplementedError
@dataclass
class PlotEqualTo(PlotFilterBase):
@override
def apply(self, df: pd.DataFrame) -> pd.DataFrame:
try:
target = float(self.target)
except ValueError:
target = self.target
return df[df[self.var] == target]
@dataclass
class PlotLessThan(PlotFilterBase):
@override
def apply(self, df: pd.DataFrame) -> pd.DataFrame:
return df[df[self.var] < float(self.target)]
@dataclass
class PlotLessThanOrEqualTo(PlotFilterBase):
@override
def apply(self, df: pd.DataFrame) -> pd.DataFrame:
return df[df[self.var] <= float(self.target)]
@dataclass
class PlotGreaterThan(PlotFilterBase):
@override
def apply(self, df: pd.DataFrame) -> pd.DataFrame:
return df[df[self.var] > float(self.target)]
@dataclass
class PlotGreaterThanOrEqualTo(PlotFilterBase):
@override
def apply(self, df: pd.DataFrame) -> pd.DataFrame:
return df[df[self.var] >= float(self.target)]
# NOTE: The ordering is important! Match longer op_keys first
PLOT_FILTERS: dict[str, type[PlotFilterBase]] = {
"==": PlotEqualTo,
"<=": PlotLessThanOrEqualTo,
">=": PlotGreaterThanOrEqualTo,
"<": PlotLessThan,
">": PlotGreaterThan,
}
class PlotFilters(list[PlotFilterBase]):
@classmethod
def parse_str(cls, s: str):
if not s:
return cls()
return cls(PlotFilterBase.parse_str(e) for e in s.split(","))
def apply(self, df: pd.DataFrame) -> pd.DataFrame:
for item in self:
df = item.apply(df)
return df
@dataclass
class PlotBinner:
var: str
bin_size: float
@classmethod
def parse_str(cls, s: str):
for op_key in PLOT_BINNERS:
if op_key in s:
key, value = s.split(op_key)
return PLOT_BINNERS[op_key](key, float(value.removeprefix(op_key)))
else:
raise ValueError(
f"Invalid operator for plot binner '{s}'. "
f"Valid operators are: {sorted(PLOT_BINNERS)}",
)
def apply(self, df: pd.DataFrame) -> pd.DataFrame:
"""Applies this binner to a DataFrame."""
df = df.copy()
df[self.var] = df[self.var] // self.bin_size * self.bin_size
return df
PLOT_BINNERS: dict[str, type[PlotBinner]] = {
"%": PlotBinner,
}
class PlotBinners(list[PlotBinner]):
@classmethod
def parse_str(cls, s: str):
if not s:
return cls()
return cls(PlotBinner.parse_str(e) for e in s.split(","))
def apply(self, df: pd.DataFrame) -> pd.DataFrame:
for item in self:
df = item.apply(df)
return df
def _json_load_bytes(path: Path) -> list[dict[str, object]]:
with path.open("rb") as f:
return json.load(f)
def _get_metric(run_data: dict[str, object], metric_key: str):
try:
return run_data[metric_key]
except KeyError as exc:
raise ValueError(f"Cannot find metric {metric_key!r} in {run_data=}") from exc
def _get_group(run_data: dict[str, object], group_keys: list[str]):
return tuple((k, str(_get_metric(run_data, k))) for k in group_keys)
def _get_fig_path(fig_dir: Path, group: tuple[tuple[str, str], ...]):
parts = list[str]()
if group:
parts.extend(("FIGURE-", *(f"{k}={v}" for k, v in group)))
else:
parts.append("figure")
return fig_dir / sanitize_filename("-".join(parts) + ".png")
class DummyExecutor:
map = map
def __enter__(self) -> Self:
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
exc_traceback: TracebackType | None,
) -> None:
return None
def _plot_fig(
fig_dir: Path,
fig_group_data: tuple[tuple[tuple[str, str], ...], list[dict[str, object]]],
row_by: list[str],
col_by: list[str],
curve_by: list[str],
*,
var_x: str,
var_y: str,
filter_by: PlotFilters,
bin_by: PlotBinners,
scale_x: str | None,
scale_y: str | None,
dry_run: bool,
):
fig_group, fig_data = fig_group_data
row_groups = full_groupby(
fig_data,
key=lambda item: _get_group(item, row_by),
)
num_rows = len(row_groups)
num_cols = max(
len(full_groupby(row_data, key=lambda item: _get_group(item, col_by)))
for _, row_data in row_groups
)
fig_path = _get_fig_path(fig_dir, fig_group)
print("[BEGIN FIGURE]")
print(f"Group: {dict(fig_group)}")
print(f"Grid: {num_rows} rows x {num_cols} cols")
print(f"Output file: {fig_path}")
if dry_run:
print("[END FIGURE]")
return
df = pd.DataFrame.from_records(fig_data)
if var_x not in df.columns:
raise ValueError(
f"Cannot find {var_x=!r} in parameter sweep results. "
f"Available variables: {df.columns.tolist()}"
)
if var_y not in df.columns:
raise ValueError(
f"Cannot find {var_y=!r} in parameter sweep results. "
f"Available variables: {df.columns.tolist()}"
)
for k in row_by:
if k not in df.columns:
raise ValueError(
f"Cannot find row_by={k!r} in parameter sweep results. "
f"Available variables: {df.columns.tolist()}"
)
for k in col_by:
if k not in df.columns:
raise ValueError(
f"Cannot find col_by={k!r} in parameter sweep results. "
f"Available variables: {df.columns.tolist()}"
)
for k in curve_by:
if k not in df.columns:
raise ValueError(
f"Cannot find curve_by={k!r} in parameter sweep results. "
f"Available variables: {df.columns.tolist()}"
)
df = filter_by.apply(df)
df = bin_by.apply(df)
df["row_group"] = (
pd.concat(
[k + "=" + df[k].astype(str) for k in row_by],
axis=1,
).agg("\n".join, axis=1)
if row_by
else "(All)"
)
df["col_group"] = (
pd.concat(
[k + "=" + df[k].astype(str) for k in col_by],
axis=1,
).agg("\n".join, axis=1)
if col_by
else "(All)"
)
g = sns.FacetGrid(df, row="row_group", col="col_group")
if row_by and col_by:
g.set_titles("{row_name}\n{col_name}")
elif row_by:
g.set_titles("{row_name}")
elif col_by:
g.set_titles("{col_name}")
else:
g.set_titles("")
if scale_x:
g.set(xscale=scale_x)
if scale_y:
g.set(yscale=scale_y)
if len(curve_by) <= 3:
hue, style, size, *_ = (*curve_by, None, None, None)
g.map_dataframe(
sns.lineplot,
x=var_x,
y=var_y,
hue=hue,
style=style,
size=size,
markers=True,
)
g.add_legend(title=hue)
else:
df["curve_group"] = (
pd.concat(
[k + "=" + df[k].astype(str) for k in curve_by],
axis=1,
).agg("\n".join, axis=1)
if curve_by
else "(All)"
)
g.map_dataframe(
sns.lineplot,
x=var_x,
y=var_y,
hue="curve_group",
markers=True,
)
g.add_legend()
g.savefig(fig_path)
plt.close(g.figure)
print("[END FIGURE]")
def plot(
output_dir: Path,
fig_dir: Path,
fig_by: list[str],
row_by: list[str],
col_by: list[str],
curve_by: list[str],
*,
var_x: str,
var_y: str,
filter_by: PlotFilters,
bin_by: PlotBinners,
scale_x: str | None,
scale_y: str | None,
dry_run: bool,
):
all_data = [
run_data
for path in output_dir.rglob("**/summary.json")
for run_data in _json_load_bytes(path)
]
if not all_data:
raise ValueError(f"Did not find any parameter sweep results under {output_dir}")
fig_dir.mkdir(parents=True, exist_ok=True)
fig_groups = full_groupby(
all_data,
key=lambda item: _get_group(item, fig_by),
)
with DummyExecutor() if len(fig_groups) <= 1 else ProcessPoolExecutor() as executor:
# Resolve the iterable to ensure that the workers are run
all(
executor.map(
partial(
_plot_fig,
fig_dir,
row_by=row_by,
col_by=col_by,
curve_by=curve_by,
var_x=var_x,
var_y=var_y,
filter_by=filter_by,
bin_by=bin_by,
scale_x=scale_x,
scale_y=scale_y,
dry_run=dry_run,
),
fig_groups,
)
)
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
"OUTPUT_DIR",
type=str,
default="results",
help="The directory containing the results to plot, "
"i.e., the `--output-dir` argument to the parameter sweep script.",
)
parser.add_argument(
"--fig-dir",
type=str,
default="",
help="The directory to save the figures, relative to `OUTPUT_DIR`. "
"By default, the same directory is used.",
)
parser.add_argument(
"--fig-by",
type=str,
default="",
help="A comma-separated list of variables, such that a separate figure "
"is created for each combination of these variables.",
)
parser.add_argument(
"--row-by",
type=str,
default="",
help="A comma-separated list of variables, such that a separate row "
"is created for each combination of these variables.",
)
parser.add_argument(
"--col-by",
type=str,
default="",
help="A comma-separated list of variables, such that a separate column "
"is created for each combination of these variables.",
)
parser.add_argument(
"--curve-by",
type=str,
default=None,
help="A comma-separated list of variables, such that a separate curve "
"is created for each combination of these variables.",
)
parser.add_argument(
"--var-x",
type=str,
default="request_throughput",
help="The variable for the x-axis.",
)
parser.add_argument(
"--var-y",
type=str,
default="p99_e2el_ms",
help="The variable for the y-axis",
)
parser.add_argument(
"--filter-by",
type=str,
default="",
help="A comma-separated list of statements indicating values to filter by. "
"This is useful to remove outliers. "
"Example: `max_concurrency<1000,max_num_batched_tokens<=4096` means "
"plot only the points where `max_concurrency` is less than 1000 and "
"`max_num_batched_tokens` is no greater than 4096.",
)
parser.add_argument(
"--bin-by",
type=str,
default="",
help="A comma-separated list of statements indicating values to bin by. "
"This is useful to avoid plotting points that are too close together. "
"Example: `request_throughput%1` means "
"use a bin size of 1 for the `request_throughput` variable.",
)
parser.add_argument(
"--scale-x",
type=str,
default=None,
help="The scale to use for the x-axis. "
"Currently only accepts string values such as 'log' and 'sqrt'. "
"See also: https://seaborn.pydata.org/generated/seaborn.objects.Plot.scale.html",
)
parser.add_argument(
"--scale-y",
type=str,
default=None,
help="The scale to use for the y-axis. "
"Currently only accepts string values such as 'log' and 'sqrt'. "
"See also: https://seaborn.pydata.org/generated/seaborn.objects.Plot.scale.html",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="If set, prints the information about each figure to plot, "
"then exits without drawing them.",
)
def main(args: argparse.Namespace):
output_dir = Path(args.OUTPUT_DIR)
if not output_dir.exists():
raise ValueError(f"No parameter sweep results under {output_dir}")
curve_by = [] if not args.curve_by else args.curve_by.split(",")
row_by = [] if not args.row_by else args.row_by.split(",")
col_by = [] if not args.col_by else args.col_by.split(",")
fig_by = [] if not args.fig_by else args.fig_by.split(",")
plot(
output_dir=output_dir,
fig_dir=output_dir / args.fig_dir,
fig_by=fig_by,
row_by=row_by,
col_by=col_by,
curve_by=curve_by,
var_x=args.var_x,
var_y=args.var_y,
filter_by=PlotFilters.parse_str(args.filter_by),
bin_by=PlotBinners.parse_str(args.bin_by),
scale_x=args.scale_x,
scale_y=args.scale_y,
dry_run=args.dry_run,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Plot performance curves from parameter sweep results."
)
add_cli_args(parser)
main(parser.parse_args())

View File

@ -0,0 +1,407 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import contextlib
import json
import shlex
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
import pandas as pd
from .param_sweep import ParameterSweep, ParameterSweepItem
from .server import ServerProcess
from .utils import sanitize_filename
@contextlib.contextmanager
def run_server(
serve_cmd: list[str],
after_bench_cmd: list[str],
*,
show_stdout: bool,
serve_overrides: ParameterSweepItem,
dry_run: bool,
):
server_cmd = serve_overrides.apply_to_cmd(serve_cmd)
print("[BEGIN SERVER]")
print(f"Server overrides: {serve_overrides}")
print(f"Server command: {server_cmd}")
if dry_run:
yield None
print("[END SERVER]")
return
with ServerProcess(server_cmd, after_bench_cmd, show_stdout=show_stdout) as server:
yield server
print("[END SERVER]")
def _update_run_data(
run_data: dict[str, object],
serve_overrides: ParameterSweepItem,
bench_overrides: ParameterSweepItem,
run_number: int,
):
run_data["run_number"] = run_number
run_data.update(serve_overrides)
run_data.update(bench_overrides)
return run_data
def run_benchmark(
server: ServerProcess | None,
bench_cmd: list[str],
*,
serve_overrides: ParameterSweepItem,
bench_overrides: ParameterSweepItem,
run_number: int,
output_path: Path,
dry_run: bool,
):
benchmark_cmd = [
*bench_overrides.apply_to_cmd(bench_cmd),
"--save-result",
"--result-dir",
str(output_path.parent),
"--result-filename",
output_path.name,
]
print("[BEGIN BENCHMARK]")
print(f"Benchmark overrides: {bench_overrides}")
print(f"Run Number: {run_number}")
print(f"Benchmark command: {benchmark_cmd}")
print(f"Output file: {output_path}")
run_data: dict[str, object]
if output_path.exists():
print("Found existing results. Skipping.")
with output_path.open("rb") as f:
run_data = json.load(f)
return _update_run_data(
run_data,
serve_overrides,
bench_overrides,
run_number,
)
if server is None:
if not dry_run:
raise ValueError(f"Cannot find results at {output_path}")
print("[END BENCHMARK]")
return None
output_path.parent.mkdir(parents=True, exist_ok=True)
server.run_subcommand(benchmark_cmd)
server.after_bench()
with output_path.open("rb") as f:
run_data = json.load(f)
run_data = _update_run_data(
run_data,
serve_overrides,
bench_overrides,
run_number,
)
with output_path.open("w") as f:
json.dump(run_data, f, indent=4)
print("[END BENCHMARK]")
return run_data
def _get_comb_base_path(
output_dir: Path,
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
):
parts = list[str]()
if serve_comb:
parts.extend(("SERVE-", serve_comb.as_text(sep="-")))
if bench_comb:
parts.extend(("BENCH-", bench_comb.as_text(sep="-")))
return output_dir / sanitize_filename("-".join(parts))
def _get_comb_run_path(base_path: Path, run_number: int | None):
if run_number is None:
return base_path / "summary.json"
return base_path / f"run={run_number}.json"
def _comb_needs_server(
serve_comb: ParameterSweepItem,
bench_combs: ParameterSweep,
output_dir: Path,
):
for bench_comb in bench_combs:
base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb)
if not _get_comb_run_path(base_path, run_number=None).exists():
return True
return False
def run_comb(
server: ServerProcess | None,
bench_cmd: list[str],
*,
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
base_path: Path,
num_runs: int,
dry_run: bool,
):
comb_data = list[dict[str, object]]()
for run_number in range(num_runs):
run_data = run_benchmark(
server,
bench_cmd,
serve_overrides=serve_comb,
bench_overrides=bench_comb,
run_number=run_number,
output_path=_get_comb_run_path(base_path, run_number),
dry_run=dry_run,
)
if run_data is not None:
comb_data.append(run_data)
if dry_run:
return None
with _get_comb_run_path(base_path, run_number=None).open("w") as f:
json.dump(comb_data, f, indent=4)
return comb_data
def run_combs(
serve_cmd: list[str],
bench_cmd: list[str],
after_bench_cmd: list[str],
*,
show_stdout: bool,
serve_params: ParameterSweep,
bench_params: ParameterSweep,
output_dir: Path,
num_runs: int,
dry_run: bool,
):
all_data = list[dict[str, object]]()
for serve_comb in serve_params:
with (
run_server(
serve_cmd,
after_bench_cmd,
show_stdout=show_stdout,
serve_overrides=serve_comb,
dry_run=dry_run,
)
if _comb_needs_server(serve_comb, bench_params, output_dir)
else contextlib.nullcontext()
) as server:
for bench_comb in bench_params:
base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb)
comb_data = run_comb(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb,
base_path=base_path,
num_runs=num_runs,
dry_run=dry_run,
)
if comb_data is not None:
all_data.extend(comb_data)
if dry_run:
return None
combined_df = pd.DataFrame.from_records(all_data)
combined_df.to_csv(output_dir / "summary.csv")
return combined_df
@dataclass
class SweepServeArgs:
serve_cmd: list[str]
bench_cmd: list[str]
after_bench_cmd: list[str]
show_stdout: bool
serve_params: ParameterSweep
bench_params: ParameterSweep
output_dir: Path
num_runs: int
dry_run: bool
resume: str | None
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
serve_cmd = shlex.split(args.serve_cmd)
bench_cmd = shlex.split(args.bench_cmd)
after_bench_cmd = (
[] if args.after_bench_cmd is None else shlex.split(args.after_bench_cmd)
)
if args.serve_params:
serve_params = ParameterSweep.read_json(args.serve_params)
else:
# i.e.: run serve_cmd without any modification
serve_params = ParameterSweep.from_records([{}])
if args.bench_params:
bench_params = ParameterSweep.read_json(args.bench_params)
else:
# i.e.: run bench_cmd without any modification
bench_params = ParameterSweep.from_records([{}])
num_runs = args.num_runs
if num_runs < 1:
raise ValueError("`num_runs` should be at least 1.")
return cls(
serve_cmd=serve_cmd,
bench_cmd=bench_cmd,
after_bench_cmd=after_bench_cmd,
show_stdout=args.show_stdout,
serve_params=serve_params,
bench_params=bench_params,
output_dir=Path(args.output_dir),
num_runs=num_runs,
dry_run=args.dry_run,
resume=args.resume,
)
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument(
"--serve-cmd",
type=str,
required=True,
help="The command used to run the server: `vllm serve ...`",
)
parser.add_argument(
"--bench-cmd",
type=str,
required=True,
help="The command used to run the benchmark: `vllm bench serve ...`",
)
parser.add_argument(
"--after-bench-cmd",
type=str,
default=None,
help="After a benchmark run is complete, invoke this command instead of "
"the default `ServerWrapper.clear_cache()`.",
)
parser.add_argument(
"--show-stdout",
action="store_true",
help="If set, logs the standard output of subcommands. "
"Useful for debugging but can be quite spammy.",
)
parser.add_argument(
"--serve-params",
type=str,
default=None,
help="Path to JSON file containing a list of parameter combinations "
"for the `vllm serve` command. "
"If both `serve_params` and `bench_params` are given, "
"this script will iterate over their Cartesian product.",
)
parser.add_argument(
"--bench-params",
type=str,
default=None,
help="Path to JSON file containing a list of parameter combinations "
"for the `vllm bench serve` command. "
"If both `serve_params` and `bench_params` are given, "
"this script will iterate over their Cartesian product.",
)
parser.add_argument(
"-o",
"--output-dir",
type=str,
default="results",
help="The directory to which results are written.",
)
parser.add_argument(
"--num-runs",
type=int,
default=3,
help="Number of runs per parameter combination.",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="If set, prints the commands to run, "
"then exits without executing them.",
)
parser.add_argument(
"--resume",
type=str,
default=None,
help="Set this to the name of a directory under `output_dir` (which is a "
"timestamp) to resume a previous execution of this script, i.e., only run "
"parameter combinations for which there are still no output files.",
)
return parser
def run_main(args: SweepServeArgs):
timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = args.output_dir / timestamp
if args.resume and not output_dir.exists():
raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")
try:
return run_combs(
serve_cmd=args.serve_cmd,
bench_cmd=args.bench_cmd,
after_bench_cmd=args.after_bench_cmd,
show_stdout=args.show_stdout,
serve_params=args.serve_params,
bench_params=args.bench_params,
output_dir=output_dir,
num_runs=args.num_runs,
dry_run=args.dry_run,
)
except BaseException as exc:
raise RuntimeError(
f"The script was terminated early. Use `--resume {timestamp}` "
f"to continue the script from its last checkpoint."
) from exc
def main(args: argparse.Namespace):
run_main(SweepServeArgs.from_cli_args(args))
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run vLLM server benchmark under multiple settings."
)
SweepServeArgs.add_cli_args(parser)
main(parser.parse_args())

View File

@ -0,0 +1,483 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import contextlib
import json
import math
from dataclasses import asdict, dataclass
from datetime import datetime
from pathlib import Path
from typing import Literal, get_args
import pandas as pd
from typing_extensions import assert_never
from .param_sweep import ParameterSweep, ParameterSweepItem
from .serve import SweepServeArgs, run_benchmark, run_server
from .server import ServerProcess
from .sla_sweep import SLASweep, SLASweepItem
from .utils import sanitize_filename
def _get_sla_base_path(
output_dir: Path,
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
):
parts = list[str]()
if serve_comb:
parts.extend(("SERVE-", serve_comb.as_text(sep="-")))
if bench_comb:
parts.extend(("BENCH-", bench_comb.as_text(sep="-")))
return output_dir / sanitize_filename("-".join(parts))
def _get_sla_iter_path(
base_path: Path,
sla_comb: SLASweepItem,
sla_variable: str,
sla_value: int | None,
):
if sla_value is None:
prefix = sla_comb.as_text(sep="-")
return base_path / f"SLA--{prefix}.json"
return base_path / f"{sla_variable}={sla_value}"
def _get_sla_run_path(iter_path: Path, run_number: int | None):
if run_number is None:
return iter_path / "summary.json"
return iter_path / f"run={run_number}.json"
def _sla_needs_server(
serve_comb: ParameterSweepItem,
bench_combs: ParameterSweep,
sla_combs: SLASweep,
sla_variable: str,
output_dir: Path,
):
for bench_comb in bench_combs:
base_path = _get_sla_base_path(output_dir, serve_comb, bench_comb)
for sla_comb in sla_combs:
if not _get_sla_iter_path(
base_path,
sla_comb,
sla_variable,
sla_value=None,
).exists():
return True
return False
def run_sla(
server: ServerProcess | None,
bench_cmd: list[str],
*,
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
iter_path: Path,
num_runs: int,
dry_run: bool,
):
iter_data = list[dict[str, object]]()
for run_number in range(num_runs):
run_data = run_benchmark(
server,
bench_cmd,
serve_overrides=serve_comb,
bench_overrides=bench_comb,
run_number=run_number,
output_path=_get_sla_run_path(iter_path, run_number),
dry_run=dry_run,
)
if run_data is not None:
iter_data.append(run_data)
if dry_run:
return None
with _get_sla_run_path(iter_path, run_number=None).open("w") as f:
json.dump(iter_data, f, indent=4)
return iter_data
SLAVariable = Literal["request_rate", "max_concurrency"]
def _estimate_sla_value(run_data: dict[str, object], sla_variable: SLAVariable):
request_throughput = float(run_data["request_throughput"]) # type: ignore
if sla_variable == "request_rate":
return request_throughput
if sla_variable == "max_concurrency":
mean_latency_ms = float(run_data["mean_e2el_ms"]) # type: ignore
return request_throughput * mean_latency_ms / 1000
assert_never(sla_variable)
def _estimate_sla_bounds(
server: ServerProcess | None,
bench_cmd: list[str],
*,
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
sla_comb: SLASweepItem,
base_path: Path,
num_runs: int,
dry_run: bool,
sla_variable: SLAVariable,
init_value: int,
max_value: int,
):
sla_data = list[dict[str, object]]()
max_passing: int = 0
min_failing: int = 0
val: int = init_value
assert val > 0
while True:
print(f"Testing {sla_variable}: {val} req/s")
iter_data = run_sla(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb | {sla_variable: val},
iter_path=_get_sla_iter_path(base_path, sla_comb, sla_variable, val),
num_runs=num_runs,
dry_run=dry_run,
)
assert iter_data is not None
sla_data.extend(iter_data)
iter_data_mean = {
k: sum(float(run_data[k]) for run_data in iter_data) / len(iter_data) # type: ignore
for k in sla_comb
}
sla_results = [
criterion.print_and_validate(iter_data_mean, k)
for k, criterion in sla_comb.items()
]
if all(sla_results):
print("SLA criteria are met.")
max_passing = val
val *= 2
else:
print("SLA criteria are not met.")
min_failing = val
break
if val >= max_value:
break
return sla_data, (max_passing, min_failing)
def _find_sla_value(
server: ServerProcess | None,
bench_cmd: list[str],
*,
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
sla_comb: SLASweepItem,
base_path: Path,
num_runs: int,
dry_run: bool,
sla_variable: SLAVariable,
min_value: int,
max_value: int,
):
sla_data = list[dict[str, object]]()
left: int = min_value
right: int = max_value
while True:
val = (left + right) // 2
print(f"Testing {sla_variable}: {val} req/s")
iter_data = run_sla(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb | {sla_variable: val},
iter_path=_get_sla_iter_path(base_path, sla_comb, sla_variable, val),
num_runs=num_runs,
dry_run=dry_run,
)
assert iter_data is not None
sla_data.extend(iter_data)
iter_data_mean = {
k: sum(float(run_data[k]) for run_data in iter_data) / len(iter_data) # type: ignore
for k in sla_comb
}
sla_results = [
criterion.print_and_validate(iter_data_mean, k)
for k, criterion in sla_comb.items()
]
if all(sla_results):
print("SLA criteria are met.")
left = val
else:
print("SLA criteria are not met.")
right = val
if right - left <= 1:
break
return sla_data, left
def search_sla(
server: ServerProcess | None,
bench_cmd: list[str],
*,
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
sla_comb: SLASweepItem,
sla_variable: SLAVariable,
sla_inf_value: int = 65536, # The value that represents infinite QPS
base_path: Path,
num_runs: int,
dry_run: bool,
):
print("[SLA START]")
print(f"SLA criteria: {sla_comb.as_text()}")
sla_data_0 = run_sla(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb | {sla_variable: sla_inf_value},
iter_path=_get_sla_iter_path(base_path, sla_comb, sla_variable, sla_inf_value),
num_runs=num_runs,
dry_run=dry_run,
)
if sla_data_0 is None:
assert dry_run
print("Omitting SLA search.")
print("[SLA END]")
return None
sla_init_value = math.ceil(
sum(_estimate_sla_value(item, sla_variable) for item in sla_data_0)
/ len(sla_data_0)
)
print(f"Initial {sla_variable} to search: {sla_init_value} req/s.")
sla_data_1, (sla_min, sla_max) = _estimate_sla_bounds(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb,
sla_comb=sla_comb,
base_path=base_path,
num_runs=num_runs,
dry_run=dry_run,
sla_variable=sla_variable,
init_value=sla_init_value,
max_value=sla_inf_value,
)
print(f"Range of {sla_variable} to search: [{sla_min}, {sla_max}] req/s.")
sla_data_2, sla_value = _find_sla_value(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb,
sla_comb=sla_comb,
base_path=base_path,
num_runs=num_runs,
dry_run=dry_run,
sla_variable=sla_variable,
min_value=sla_min,
max_value=sla_max,
)
sla_data = sla_data_0 + sla_data_1 + sla_data_2
print(f"Maximum {sla_variable} for SLA: {sla_value} req/s.")
with _get_sla_iter_path(
base_path,
sla_comb,
sla_variable,
sla_value=None,
).open("w") as f:
json.dump(sla_data, f, indent=4)
print("[SLA END]")
return sla_data
def run_slas(
serve_cmd: list[str],
bench_cmd: list[str],
after_bench_cmd: list[str],
*,
show_stdout: bool,
serve_params: ParameterSweep,
bench_params: ParameterSweep,
sla_params: SLASweep,
sla_variable: SLAVariable,
output_dir: Path,
num_runs: int,
dry_run: bool,
):
if any(bench_comb.has_param(sla_variable) for bench_comb in bench_params):
raise ValueError(
f"You should not override `{sla_variable}` in `bench_params` in SLA mode, "
"since it is supposed to be determined automatically."
)
all_data = list[dict[str, object]]()
for serve_comb in serve_params:
with (
run_server(
serve_cmd,
after_bench_cmd,
show_stdout=show_stdout,
serve_overrides=serve_comb,
dry_run=dry_run,
)
if _sla_needs_server(
serve_comb,
bench_params,
sla_params,
sla_variable,
output_dir,
)
else contextlib.nullcontext()
) as server:
for bench_comb in bench_params:
for sla_comb in sla_params:
base_path = _get_sla_base_path(output_dir, serve_comb, bench_comb)
comb_data = search_sla(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb,
sla_comb=sla_comb,
sla_variable=sla_variable,
base_path=base_path,
num_runs=num_runs,
dry_run=dry_run,
)
if comb_data is not None:
all_data.extend(comb_data)
if dry_run:
return None
combined_df = pd.DataFrame.from_records(all_data)
combined_df.to_csv(output_dir / "summary.csv")
return combined_df
@dataclass
class SweepServeSLAArgs(SweepServeArgs):
sla_params: SLASweep
sla_variable: SLAVariable
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
base_args = super().from_cli_args(args)
if args.sla_params:
sla_params = SLASweep.read_json(args.sla_params)
else:
sla_params = SLASweep.from_records([])
return cls(
**asdict(base_args),
sla_params=sla_params,
sla_variable=args.sla_variable,
)
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser = super().add_cli_args(parser)
parser.add_argument(
"--sla-params",
type=str,
required=True,
help="Path to JSON file containing a list of SLA constraints to satisfy. "
'Each constraint is expressed in `{"<KEY>": "<OP><VALUE>"}` format, '
'e.g.: `{"p99_e2el_ms": "<=500"}` means that '
"the E2E latency should be less than 500ms 99%% of the time. "
"Setting this option runs this script in SLA mode, which searches for "
"the maximum `sla_variable` that satisfies the constraints for "
"each combination of `serve_params`, `bench_params`, and `sla_params`.",
)
parser.add_argument(
"--sla-variable",
type=str,
choices=get_args(SLAVariable),
default="request_rate",
help="Whether to tune request rate or maximum concurrency to satisfy "
"the SLA constraints.",
)
return parser
def run_main(args: SweepServeSLAArgs):
timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = args.output_dir / timestamp
if args.resume and not output_dir.exists():
raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")
try:
return run_slas(
serve_cmd=args.serve_cmd,
bench_cmd=args.bench_cmd,
after_bench_cmd=args.after_bench_cmd,
show_stdout=args.show_stdout,
serve_params=args.serve_params,
bench_params=args.bench_params,
sla_params=args.sla_params,
sla_variable=args.sla_variable,
output_dir=output_dir,
num_runs=args.num_runs,
dry_run=args.dry_run,
)
except BaseException as exc:
raise RuntimeError(
f"The script was terminated early. Use `--resume {timestamp}` "
f"to continue the script from its last checkpoint."
) from exc
def main(args: argparse.Namespace):
run_main(SweepServeSLAArgs.from_cli_args(args))
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Tune a variable to meet SLAs under multiple settings."
)
SweepServeSLAArgs.add_cli_args(parser)
main(parser.parse_args())

View File

@ -0,0 +1,114 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import os
import signal
import subprocess
from types import TracebackType
import requests
from typing_extensions import Self
class ServerProcess:
def __init__(
self,
server_cmd: list[str],
after_bench_cmd: list[str],
*,
show_stdout: bool,
) -> None:
super().__init__()
self.server_cmd = server_cmd
self.after_bench_cmd = after_bench_cmd
self.show_stdout = show_stdout
def __enter__(self) -> Self:
self.start()
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
exc_traceback: TracebackType | None,
) -> None:
self.stop()
def start(self):
# Create new process for clean termination
self._server_process = subprocess.Popen(
self.server_cmd,
start_new_session=True,
stdout=None if self.show_stdout else subprocess.DEVNULL,
# Need `VLLM_SERVER_DEV_MODE=1` for `_reset_caches`
env=os.environ | {"VLLM_SERVER_DEV_MODE": "1"},
)
def stop(self):
server_process = self._server_process
if server_process.poll() is None:
# In case only some processes have been terminated
with contextlib.suppress(ProcessLookupError):
# We need to kill both API Server and Engine processes
os.killpg(os.getpgid(server_process.pid), signal.SIGKILL)
def run_subcommand(self, cmd: list[str]):
return subprocess.run(
cmd,
stdout=None if self.show_stdout else subprocess.DEVNULL,
check=True,
)
def after_bench(self) -> None:
if not self.after_bench_cmd:
self.reset_caches()
return
self.run_subcommand(self.after_bench_cmd)
def _get_vllm_server_address(self) -> str:
server_cmd = self.server_cmd
for host_key in ("--host",):
if host_key in server_cmd:
host = server_cmd[server_cmd.index(host_key) + 1]
break
else:
host = "localhost"
for port_key in ("-p", "--port"):
if port_key in server_cmd:
port = int(server_cmd[server_cmd.index(port_key) + 1])
break
else:
port = 8000 # The default value in vllm serve
return f"http://{host}:{port}"
def reset_caches(self) -> None:
server_cmd = self.server_cmd
# Use `.endswith()` to match `/bin/...`
if server_cmd[0].endswith("vllm"):
server_address = self._get_vllm_server_address()
print(f"Resetting caches at {server_address}")
res = requests.post(f"{server_address}/reset_prefix_cache")
res.raise_for_status()
res = requests.post(f"{server_address}/reset_mm_cache")
res.raise_for_status()
elif server_cmd[0].endswith("infinity_emb"):
if "--vector-disk-cache" in server_cmd:
raise NotImplementedError(
"Infinity server uses caching but does not expose a method "
"to reset the cache"
)
else:
raise NotImplementedError(
f"No implementation of `reset_caches` for `{server_cmd[0]}` server. "
"Please specify a custom command via `--after-bench-cmd`."
)

View File

@ -0,0 +1,132 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing_extensions import override
@dataclass
class SLACriterionBase(ABC):
target: float
@abstractmethod
def validate(self, actual: float) -> bool:
"""Return `True` if this criterion is met; otherwise `False`."""
raise NotImplementedError
@abstractmethod
def format_cond(self, lhs: str) -> str:
raise NotImplementedError
def print_and_validate(
self,
metrics: dict[str, float],
metrics_key: str,
) -> bool:
metric = metrics[metrics_key]
result = self.validate(metric)
cond = self.format_cond(f"{metrics_key} = {metric:.2f}")
print(f"Validating SLA: {cond} | " + ("PASSED" if result else "FAILED"))
return result
@dataclass
class SLALessThan(SLACriterionBase):
@override
def validate(self, actual: float) -> bool:
return actual < self.target
@override
def format_cond(self, lhs: str) -> str:
return f"{lhs}<{self.target:.2f}"
@dataclass
class SLALessThanOrEqualTo(SLACriterionBase):
@override
def validate(self, actual: float) -> bool:
return actual <= self.target
@override
def format_cond(self, lhs: str) -> str:
return f"{lhs}<={self.target:.2f}"
@dataclass
class SLAGreaterThan(SLACriterionBase):
@override
def validate(self, actual: float) -> bool:
return actual > self.target
@override
def format_cond(self, lhs: str) -> str:
return f"{lhs}>{self.target:.2f}"
@dataclass
class SLAGreaterThanOrEqualTo(SLACriterionBase):
@override
def validate(self, actual: float) -> bool:
return actual >= self.target
@override
def format_cond(self, lhs: str) -> str:
return f"{lhs}>={self.target:.2f}"
# NOTE: The ordering is important! Match longer op_keys first
SLA_CRITERIA: dict[str, type[SLACriterionBase]] = {
"<=": SLALessThanOrEqualTo,
">=": SLAGreaterThanOrEqualTo,
"<": SLALessThan,
">": SLAGreaterThan,
}
class SLASweep(list["SLASweepItem"]):
@classmethod
def read_json(cls, filepath: os.PathLike):
with open(filepath, "rb") as f:
records = json.load(f)
return cls.from_records(records)
@classmethod
def from_records(cls, records: list[dict[str, str]]):
if not isinstance(records, list):
raise TypeError(
f"The SLA sweep should be a list of dictionaries, "
f"but found type: {type(records)}"
)
return cls(SLASweepItem.from_record(record) for record in records)
class SLASweepItem(dict[str, SLACriterionBase]):
@classmethod
def from_record(cls, record: dict[str, str]):
sla_criteria: dict[str, SLACriterionBase] = {}
for metric_key, metric_value in record.items():
for op_key in SLA_CRITERIA:
if metric_value.startswith(op_key):
sla_criteria[metric_key] = SLA_CRITERIA[op_key](
float(metric_value.removeprefix(op_key))
)
break
else:
raise ValueError(
f"Invalid operator for "
f"SLA constraint '{metric_key}={metric_value}'. "
f"Valid operators are: {sorted(SLA_CRITERIA)}",
)
return cls(sla_criteria)
def as_text(self, sep: str = ", ") -> str:
return sep.join(v.format_cond(k) for k, v in self.items())

View File

@ -0,0 +1,4 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
def sanitize_filename(filename: str) -> str:
return filename.replace("/", "_").replace("..", "__").strip("'").strip('"')