[Benchmark] Add plot utility for parameter sweep (#27168)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@ -7,7 +7,7 @@ toc_depth: 4
|
||||
vLLM provides comprehensive benchmarking tools for performance testing and evaluation:
|
||||
|
||||
- **[Benchmark CLI](#benchmark-cli)**: `vllm bench` CLI tools and specialized benchmark scripts for interactive performance testing
|
||||
- **[Batch Scripts](#batch-scripts)**: Run `vllm bench` against multiple configurations conveniently
|
||||
- **[Parameter sweeps](#parameter-sweeps)**: Automate `vllm bench` runs for multiple configurations
|
||||
- **[Performance benchmarks](#performance-benchmarks)**: Automated CI benchmarks for development
|
||||
- **[Nightly benchmarks](#nightly-benchmarks)**: Comparative benchmarks against alternatives
|
||||
|
||||
@ -925,15 +925,13 @@ throughput numbers correctly is also adjusted.
|
||||
|
||||
</details>
|
||||
|
||||
## Batch Scripts
|
||||
## Parameter Sweeps
|
||||
|
||||
### Batch Serving Script
|
||||
### Online Benchmark
|
||||
|
||||
[`vllm/benchmarks/serve_multi.py`](../../vllm/benchmarks/serve_multi.py) automatically starts `vllm serve` and runs `vllm bench serve` over multiple configurations.
|
||||
[`vllm/benchmarks/sweep/serve.py`](../../vllm/benchmarks/sweep/serve.py) automatically starts `vllm serve` and runs `vllm bench serve` to evaluate vLLM over multiple configurations.
|
||||
|
||||
#### Batch Mode
|
||||
|
||||
The basic purpose of this script is to evaluate vLLM under different settings. Follows these steps to run the script:
|
||||
Follow these steps to run the script:
|
||||
|
||||
1. Construct the base command to `vllm serve`, and pass it to the `--serve-cmd` option.
|
||||
2. Construct the base command to `vllm bench serve`, and pass it to the `--bench-cmd` option.
|
||||
@ -996,7 +994,7 @@ The basic purpose of this script is to evaluate vLLM under different settings. F
|
||||
Example command:
|
||||
|
||||
```bash
|
||||
python vllm/benchmarks/serve_multi.py \
|
||||
python -m vllm.benchmarks.sweep.serve \
|
||||
--serve-cmd 'vllm serve meta-llama/Llama-2-7b-chat-hf' \
|
||||
--bench-cmd 'vllm bench serve --model meta-llama/Llama-2-7b-chat-hf --backend vllm --endpoint /v1/completions --dataset-name sharegpt --dataset-path benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json' \
|
||||
--serve-params benchmarks/serve_hparams.json \
|
||||
@ -1018,9 +1016,9 @@ python vllm/benchmarks/serve_multi.py \
|
||||
!!! tip
|
||||
You can use the `--resume` option to continue the parameter sweep if one of the runs failed.
|
||||
|
||||
#### SLA Mode
|
||||
### SLA Auto-Tuner
|
||||
|
||||
By passing SLA constraints via `--sla-params`, you can run this script in SLA mode, causing it to adjust either the request rate or concurrency (choose using `--sla-variable`) in order to satisfy the SLA constraints.
|
||||
[`vllm/benchmarks/sweep/serve_sla.py`](../../vllm/benchmarks/sweep/serve_sla.py) is a wrapper over [`vllm/benchmarks/sweep/serve.py`](../../vllm/benchmarks/sweep/serve.py) that tunes either the request rate or concurrency (choose using `--sla-variable`) in order to satisfy the SLA constraints given by `--sla-params`.
|
||||
|
||||
For example, to ensure E2E latency within different target values for 99% of requests:
|
||||
|
||||
@ -1044,7 +1042,7 @@ For example, to ensure E2E latency within different target values for 99% of req
|
||||
Example command:
|
||||
|
||||
```bash
|
||||
python vllm/benchmarks/serve_multi.py \
|
||||
python -m vllm.benchmarks.sweep.serve_sla \
|
||||
--serve-cmd 'vllm serve meta-llama/Llama-2-7b-chat-hf' \
|
||||
--bench-cmd 'vllm bench serve --model meta-llama/Llama-2-7b-chat-hf --backend vllm --endpoint /v1/completions --dataset-name sharegpt --dataset-path benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json' \
|
||||
--serve-params benchmarks/serve_hparams.json \
|
||||
@ -1066,6 +1064,24 @@ The algorithm for adjusting the SLA variable is as follows:
|
||||
|
||||
For a given combination of `--serve-params` and `--bench-params`, we share the benchmark results across `--sla-params` to avoid rerunning benchmarks with the same SLA variable value.
|
||||
|
||||
### Visualizer
|
||||
|
||||
[`vllm/benchmarks/sweep/plot.py`](../../vllm/benchmarks/sweep/plot.py) can be used to plot performance curves from parameter sweep results.
|
||||
|
||||
Example command:
|
||||
|
||||
```bash
|
||||
python -m vllm.benchmarks.sweep.plot benchmarks/results/<timestamp> \
|
||||
--var-x max_concurrency \
|
||||
--row-by random_input_len \
|
||||
--col-by random_output_len \
|
||||
--curve-by api_server_count,max_num_batched_tokens \
|
||||
--filter-by 'max_concurrency<=1024'
|
||||
```
|
||||
|
||||
!!! tip
|
||||
You can use `--dry-run` to preview the figures to be plotted.
|
||||
|
||||
## Performance Benchmarks
|
||||
|
||||
The performance benchmarks are used for development to confirm whether new changes improve performance under various workloads. They are triggered on every commit with both the `perf-benchmarks` and `ready` labels, and when a PR is merged into vLLM.
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
0
vllm/benchmarks/sweep/__init__.py
Normal file
0
vllm/benchmarks/sweep/__init__.py
Normal file
91
vllm/benchmarks/sweep/param_sweep.py
Normal file
91
vllm/benchmarks/sweep/param_sweep.py
Normal file
@ -0,0 +1,91 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ParameterSweep(list["ParameterSweepItem"]):
|
||||
@classmethod
|
||||
def read_json(cls, filepath: os.PathLike):
|
||||
with open(filepath, "rb") as f:
|
||||
records = json.load(f)
|
||||
|
||||
return cls.from_records(records)
|
||||
|
||||
@classmethod
|
||||
def from_records(cls, records: list[dict[str, object]]):
|
||||
if not isinstance(records, list):
|
||||
raise TypeError(
|
||||
f"The parameter sweep should be a list of dictionaries, "
|
||||
f"but found type: {type(records)}"
|
||||
)
|
||||
|
||||
return cls(ParameterSweepItem.from_record(record) for record in records)
|
||||
|
||||
|
||||
class ParameterSweepItem(dict[str, object]):
|
||||
@classmethod
|
||||
def from_record(cls, record: dict[str, object]):
|
||||
if not isinstance(record, dict):
|
||||
raise TypeError(
|
||||
f"Each item in the parameter sweep should be a dictionary, "
|
||||
f"but found type: {type(record)}"
|
||||
)
|
||||
|
||||
return cls(record)
|
||||
|
||||
def __or__(self, other: dict[str, Any]):
|
||||
return type(self)(super().__or__(other))
|
||||
|
||||
# In JSON, we prefer "_"
|
||||
def _iter_param_key_candidates(self, param_key: str):
|
||||
# Inner config arguments are not converted by the CLI
|
||||
if "." in param_key:
|
||||
prefix, rest = param_key.split(".", 1)
|
||||
for prefix_candidate in self._iter_param_key_candidates(prefix):
|
||||
yield prefix_candidate + "." + rest
|
||||
|
||||
return
|
||||
|
||||
yield param_key
|
||||
yield param_key.replace("-", "_")
|
||||
yield param_key.replace("_", "-")
|
||||
|
||||
# In CLI, we prefer "-"
|
||||
def _iter_cmd_key_candidates(self, param_key: str):
|
||||
for k in reversed(tuple(self._iter_param_key_candidates(param_key))):
|
||||
yield "--" + k
|
||||
|
||||
def _normalize_cmd_key(self, param_key: str):
|
||||
return next(self._iter_cmd_key_candidates(param_key))
|
||||
|
||||
def has_param(self, param_key: str) -> bool:
|
||||
return any(k in self for k in self._iter_param_key_candidates(param_key))
|
||||
|
||||
def apply_to_cmd(self, cmd: list[str]) -> list[str]:
|
||||
cmd = list(cmd)
|
||||
|
||||
for k, v in self.items():
|
||||
for k_candidate in self._iter_cmd_key_candidates(k):
|
||||
try:
|
||||
k_idx = cmd.index(k_candidate)
|
||||
|
||||
if isinstance(v, bool):
|
||||
cmd[k_idx] = self._normalize_cmd_key(k if v else "no-" + k)
|
||||
else:
|
||||
cmd[k_idx + 1] = str(v)
|
||||
|
||||
break
|
||||
except ValueError:
|
||||
continue
|
||||
else:
|
||||
if isinstance(v, bool):
|
||||
cmd.append(self._normalize_cmd_key(k if v else "no-" + k))
|
||||
else:
|
||||
cmd.extend([self._normalize_cmd_key(k), str(v)])
|
||||
|
||||
return cmd
|
||||
|
||||
def as_text(self, sep: str = ", ") -> str:
|
||||
return sep.join(f"{k}={v}" for k, v in self.items())
|
||||
530
vllm/benchmarks/sweep/plot.py
Normal file
530
vllm/benchmarks/sweep/plot.py
Normal file
@ -0,0 +1,530 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from types import TracebackType
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
from typing_extensions import Self, override
|
||||
|
||||
from vllm.utils.collection_utils import full_groupby
|
||||
|
||||
from .utils import sanitize_filename
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotFilterBase(ABC):
|
||||
var: str
|
||||
target: str
|
||||
|
||||
@classmethod
|
||||
def parse_str(cls, s: str):
|
||||
for op_key in PLOT_FILTERS:
|
||||
if op_key in s:
|
||||
key, value = s.split(op_key)
|
||||
return PLOT_FILTERS[op_key](
|
||||
key,
|
||||
value.removeprefix(op_key).strip("'").strip('"'),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid operator for plot filter '{s}'. "
|
||||
f"Valid operators are: {sorted(PLOT_FILTERS)}",
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Applies this filter to a DataFrame."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotEqualTo(PlotFilterBase):
|
||||
@override
|
||||
def apply(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
try:
|
||||
target = float(self.target)
|
||||
except ValueError:
|
||||
target = self.target
|
||||
|
||||
return df[df[self.var] == target]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotLessThan(PlotFilterBase):
|
||||
@override
|
||||
def apply(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
return df[df[self.var] < float(self.target)]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotLessThanOrEqualTo(PlotFilterBase):
|
||||
@override
|
||||
def apply(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
return df[df[self.var] <= float(self.target)]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotGreaterThan(PlotFilterBase):
|
||||
@override
|
||||
def apply(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
return df[df[self.var] > float(self.target)]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotGreaterThanOrEqualTo(PlotFilterBase):
|
||||
@override
|
||||
def apply(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
return df[df[self.var] >= float(self.target)]
|
||||
|
||||
|
||||
# NOTE: The ordering is important! Match longer op_keys first
|
||||
PLOT_FILTERS: dict[str, type[PlotFilterBase]] = {
|
||||
"==": PlotEqualTo,
|
||||
"<=": PlotLessThanOrEqualTo,
|
||||
">=": PlotGreaterThanOrEqualTo,
|
||||
"<": PlotLessThan,
|
||||
">": PlotGreaterThan,
|
||||
}
|
||||
|
||||
|
||||
class PlotFilters(list[PlotFilterBase]):
|
||||
@classmethod
|
||||
def parse_str(cls, s: str):
|
||||
if not s:
|
||||
return cls()
|
||||
|
||||
return cls(PlotFilterBase.parse_str(e) for e in s.split(","))
|
||||
|
||||
def apply(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
for item in self:
|
||||
df = item.apply(df)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotBinner:
|
||||
var: str
|
||||
bin_size: float
|
||||
|
||||
@classmethod
|
||||
def parse_str(cls, s: str):
|
||||
for op_key in PLOT_BINNERS:
|
||||
if op_key in s:
|
||||
key, value = s.split(op_key)
|
||||
return PLOT_BINNERS[op_key](key, float(value.removeprefix(op_key)))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid operator for plot binner '{s}'. "
|
||||
f"Valid operators are: {sorted(PLOT_BINNERS)}",
|
||||
)
|
||||
|
||||
def apply(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Applies this binner to a DataFrame."""
|
||||
df = df.copy()
|
||||
df[self.var] = df[self.var] // self.bin_size * self.bin_size
|
||||
return df
|
||||
|
||||
|
||||
PLOT_BINNERS: dict[str, type[PlotBinner]] = {
|
||||
"%": PlotBinner,
|
||||
}
|
||||
|
||||
|
||||
class PlotBinners(list[PlotBinner]):
|
||||
@classmethod
|
||||
def parse_str(cls, s: str):
|
||||
if not s:
|
||||
return cls()
|
||||
|
||||
return cls(PlotBinner.parse_str(e) for e in s.split(","))
|
||||
|
||||
def apply(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
for item in self:
|
||||
df = item.apply(df)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def _json_load_bytes(path: Path) -> list[dict[str, object]]:
|
||||
with path.open("rb") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def _get_metric(run_data: dict[str, object], metric_key: str):
|
||||
try:
|
||||
return run_data[metric_key]
|
||||
except KeyError as exc:
|
||||
raise ValueError(f"Cannot find metric {metric_key!r} in {run_data=}") from exc
|
||||
|
||||
|
||||
def _get_group(run_data: dict[str, object], group_keys: list[str]):
|
||||
return tuple((k, str(_get_metric(run_data, k))) for k in group_keys)
|
||||
|
||||
|
||||
def _get_fig_path(fig_dir: Path, group: tuple[tuple[str, str], ...]):
|
||||
parts = list[str]()
|
||||
if group:
|
||||
parts.extend(("FIGURE-", *(f"{k}={v}" for k, v in group)))
|
||||
else:
|
||||
parts.append("figure")
|
||||
|
||||
return fig_dir / sanitize_filename("-".join(parts) + ".png")
|
||||
|
||||
|
||||
class DummyExecutor:
|
||||
map = map
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
exc_traceback: TracebackType | None,
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def _plot_fig(
|
||||
fig_dir: Path,
|
||||
fig_group_data: tuple[tuple[tuple[str, str], ...], list[dict[str, object]]],
|
||||
row_by: list[str],
|
||||
col_by: list[str],
|
||||
curve_by: list[str],
|
||||
*,
|
||||
var_x: str,
|
||||
var_y: str,
|
||||
filter_by: PlotFilters,
|
||||
bin_by: PlotBinners,
|
||||
scale_x: str | None,
|
||||
scale_y: str | None,
|
||||
dry_run: bool,
|
||||
):
|
||||
fig_group, fig_data = fig_group_data
|
||||
|
||||
row_groups = full_groupby(
|
||||
fig_data,
|
||||
key=lambda item: _get_group(item, row_by),
|
||||
)
|
||||
num_rows = len(row_groups)
|
||||
num_cols = max(
|
||||
len(full_groupby(row_data, key=lambda item: _get_group(item, col_by)))
|
||||
for _, row_data in row_groups
|
||||
)
|
||||
|
||||
fig_path = _get_fig_path(fig_dir, fig_group)
|
||||
|
||||
print("[BEGIN FIGURE]")
|
||||
print(f"Group: {dict(fig_group)}")
|
||||
print(f"Grid: {num_rows} rows x {num_cols} cols")
|
||||
print(f"Output file: {fig_path}")
|
||||
|
||||
if dry_run:
|
||||
print("[END FIGURE]")
|
||||
return
|
||||
|
||||
df = pd.DataFrame.from_records(fig_data)
|
||||
|
||||
if var_x not in df.columns:
|
||||
raise ValueError(
|
||||
f"Cannot find {var_x=!r} in parameter sweep results. "
|
||||
f"Available variables: {df.columns.tolist()}"
|
||||
)
|
||||
if var_y not in df.columns:
|
||||
raise ValueError(
|
||||
f"Cannot find {var_y=!r} in parameter sweep results. "
|
||||
f"Available variables: {df.columns.tolist()}"
|
||||
)
|
||||
for k in row_by:
|
||||
if k not in df.columns:
|
||||
raise ValueError(
|
||||
f"Cannot find row_by={k!r} in parameter sweep results. "
|
||||
f"Available variables: {df.columns.tolist()}"
|
||||
)
|
||||
for k in col_by:
|
||||
if k not in df.columns:
|
||||
raise ValueError(
|
||||
f"Cannot find col_by={k!r} in parameter sweep results. "
|
||||
f"Available variables: {df.columns.tolist()}"
|
||||
)
|
||||
for k in curve_by:
|
||||
if k not in df.columns:
|
||||
raise ValueError(
|
||||
f"Cannot find curve_by={k!r} in parameter sweep results. "
|
||||
f"Available variables: {df.columns.tolist()}"
|
||||
)
|
||||
|
||||
df = filter_by.apply(df)
|
||||
df = bin_by.apply(df)
|
||||
|
||||
df["row_group"] = (
|
||||
pd.concat(
|
||||
[k + "=" + df[k].astype(str) for k in row_by],
|
||||
axis=1,
|
||||
).agg("\n".join, axis=1)
|
||||
if row_by
|
||||
else "(All)"
|
||||
)
|
||||
|
||||
df["col_group"] = (
|
||||
pd.concat(
|
||||
[k + "=" + df[k].astype(str) for k in col_by],
|
||||
axis=1,
|
||||
).agg("\n".join, axis=1)
|
||||
if col_by
|
||||
else "(All)"
|
||||
)
|
||||
|
||||
g = sns.FacetGrid(df, row="row_group", col="col_group")
|
||||
|
||||
if row_by and col_by:
|
||||
g.set_titles("{row_name}\n{col_name}")
|
||||
elif row_by:
|
||||
g.set_titles("{row_name}")
|
||||
elif col_by:
|
||||
g.set_titles("{col_name}")
|
||||
else:
|
||||
g.set_titles("")
|
||||
|
||||
if scale_x:
|
||||
g.set(xscale=scale_x)
|
||||
if scale_y:
|
||||
g.set(yscale=scale_y)
|
||||
|
||||
if len(curve_by) <= 3:
|
||||
hue, style, size, *_ = (*curve_by, None, None, None)
|
||||
|
||||
g.map_dataframe(
|
||||
sns.lineplot,
|
||||
x=var_x,
|
||||
y=var_y,
|
||||
hue=hue,
|
||||
style=style,
|
||||
size=size,
|
||||
markers=True,
|
||||
)
|
||||
|
||||
g.add_legend(title=hue)
|
||||
else:
|
||||
df["curve_group"] = (
|
||||
pd.concat(
|
||||
[k + "=" + df[k].astype(str) for k in curve_by],
|
||||
axis=1,
|
||||
).agg("\n".join, axis=1)
|
||||
if curve_by
|
||||
else "(All)"
|
||||
)
|
||||
|
||||
g.map_dataframe(
|
||||
sns.lineplot,
|
||||
x=var_x,
|
||||
y=var_y,
|
||||
hue="curve_group",
|
||||
markers=True,
|
||||
)
|
||||
|
||||
g.add_legend()
|
||||
|
||||
g.savefig(fig_path)
|
||||
plt.close(g.figure)
|
||||
|
||||
print("[END FIGURE]")
|
||||
|
||||
|
||||
def plot(
|
||||
output_dir: Path,
|
||||
fig_dir: Path,
|
||||
fig_by: list[str],
|
||||
row_by: list[str],
|
||||
col_by: list[str],
|
||||
curve_by: list[str],
|
||||
*,
|
||||
var_x: str,
|
||||
var_y: str,
|
||||
filter_by: PlotFilters,
|
||||
bin_by: PlotBinners,
|
||||
scale_x: str | None,
|
||||
scale_y: str | None,
|
||||
dry_run: bool,
|
||||
):
|
||||
all_data = [
|
||||
run_data
|
||||
for path in output_dir.rglob("**/summary.json")
|
||||
for run_data in _json_load_bytes(path)
|
||||
]
|
||||
|
||||
if not all_data:
|
||||
raise ValueError(f"Did not find any parameter sweep results under {output_dir}")
|
||||
|
||||
fig_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
fig_groups = full_groupby(
|
||||
all_data,
|
||||
key=lambda item: _get_group(item, fig_by),
|
||||
)
|
||||
|
||||
with DummyExecutor() if len(fig_groups) <= 1 else ProcessPoolExecutor() as executor:
|
||||
# Resolve the iterable to ensure that the workers are run
|
||||
all(
|
||||
executor.map(
|
||||
partial(
|
||||
_plot_fig,
|
||||
fig_dir,
|
||||
row_by=row_by,
|
||||
col_by=col_by,
|
||||
curve_by=curve_by,
|
||||
var_x=var_x,
|
||||
var_y=var_y,
|
||||
filter_by=filter_by,
|
||||
bin_by=bin_by,
|
||||
scale_x=scale_x,
|
||||
scale_y=scale_y,
|
||||
dry_run=dry_run,
|
||||
),
|
||||
fig_groups,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"OUTPUT_DIR",
|
||||
type=str,
|
||||
default="results",
|
||||
help="The directory containing the results to plot, "
|
||||
"i.e., the `--output-dir` argument to the parameter sweep script.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fig-dir",
|
||||
type=str,
|
||||
default="",
|
||||
help="The directory to save the figures, relative to `OUTPUT_DIR`. "
|
||||
"By default, the same directory is used.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fig-by",
|
||||
type=str,
|
||||
default="",
|
||||
help="A comma-separated list of variables, such that a separate figure "
|
||||
"is created for each combination of these variables.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--row-by",
|
||||
type=str,
|
||||
default="",
|
||||
help="A comma-separated list of variables, such that a separate row "
|
||||
"is created for each combination of these variables.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--col-by",
|
||||
type=str,
|
||||
default="",
|
||||
help="A comma-separated list of variables, such that a separate column "
|
||||
"is created for each combination of these variables.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--curve-by",
|
||||
type=str,
|
||||
default=None,
|
||||
help="A comma-separated list of variables, such that a separate curve "
|
||||
"is created for each combination of these variables.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--var-x",
|
||||
type=str,
|
||||
default="request_throughput",
|
||||
help="The variable for the x-axis.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--var-y",
|
||||
type=str,
|
||||
default="p99_e2el_ms",
|
||||
help="The variable for the y-axis",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--filter-by",
|
||||
type=str,
|
||||
default="",
|
||||
help="A comma-separated list of statements indicating values to filter by. "
|
||||
"This is useful to remove outliers. "
|
||||
"Example: `max_concurrency<1000,max_num_batched_tokens<=4096` means "
|
||||
"plot only the points where `max_concurrency` is less than 1000 and "
|
||||
"`max_num_batched_tokens` is no greater than 4096.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bin-by",
|
||||
type=str,
|
||||
default="",
|
||||
help="A comma-separated list of statements indicating values to bin by. "
|
||||
"This is useful to avoid plotting points that are too close together. "
|
||||
"Example: `request_throughput%1` means "
|
||||
"use a bin size of 1 for the `request_throughput` variable.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale-x",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The scale to use for the x-axis. "
|
||||
"Currently only accepts string values such as 'log' and 'sqrt'. "
|
||||
"See also: https://seaborn.pydata.org/generated/seaborn.objects.Plot.scale.html",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale-y",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The scale to use for the y-axis. "
|
||||
"Currently only accepts string values such as 'log' and 'sqrt'. "
|
||||
"See also: https://seaborn.pydata.org/generated/seaborn.objects.Plot.scale.html",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="If set, prints the information about each figure to plot, "
|
||||
"then exits without drawing them.",
|
||||
)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
output_dir = Path(args.OUTPUT_DIR)
|
||||
if not output_dir.exists():
|
||||
raise ValueError(f"No parameter sweep results under {output_dir}")
|
||||
|
||||
curve_by = [] if not args.curve_by else args.curve_by.split(",")
|
||||
row_by = [] if not args.row_by else args.row_by.split(",")
|
||||
col_by = [] if not args.col_by else args.col_by.split(",")
|
||||
fig_by = [] if not args.fig_by else args.fig_by.split(",")
|
||||
|
||||
plot(
|
||||
output_dir=output_dir,
|
||||
fig_dir=output_dir / args.fig_dir,
|
||||
fig_by=fig_by,
|
||||
row_by=row_by,
|
||||
col_by=col_by,
|
||||
curve_by=curve_by,
|
||||
var_x=args.var_x,
|
||||
var_y=args.var_y,
|
||||
filter_by=PlotFilters.parse_str(args.filter_by),
|
||||
bin_by=PlotBinners.parse_str(args.bin_by),
|
||||
scale_x=args.scale_x,
|
||||
scale_y=args.scale_y,
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Plot performance curves from parameter sweep results."
|
||||
)
|
||||
add_cli_args(parser)
|
||||
|
||||
main(parser.parse_args())
|
||||
407
vllm/benchmarks/sweep/serve.py
Normal file
407
vllm/benchmarks/sweep/serve.py
Normal file
@ -0,0 +1,407 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import contextlib
|
||||
import json
|
||||
import shlex
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from .param_sweep import ParameterSweep, ParameterSweepItem
|
||||
from .server import ServerProcess
|
||||
from .utils import sanitize_filename
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def run_server(
|
||||
serve_cmd: list[str],
|
||||
after_bench_cmd: list[str],
|
||||
*,
|
||||
show_stdout: bool,
|
||||
serve_overrides: ParameterSweepItem,
|
||||
dry_run: bool,
|
||||
):
|
||||
server_cmd = serve_overrides.apply_to_cmd(serve_cmd)
|
||||
|
||||
print("[BEGIN SERVER]")
|
||||
print(f"Server overrides: {serve_overrides}")
|
||||
print(f"Server command: {server_cmd}")
|
||||
|
||||
if dry_run:
|
||||
yield None
|
||||
print("[END SERVER]")
|
||||
return
|
||||
|
||||
with ServerProcess(server_cmd, after_bench_cmd, show_stdout=show_stdout) as server:
|
||||
yield server
|
||||
|
||||
print("[END SERVER]")
|
||||
|
||||
|
||||
def _update_run_data(
|
||||
run_data: dict[str, object],
|
||||
serve_overrides: ParameterSweepItem,
|
||||
bench_overrides: ParameterSweepItem,
|
||||
run_number: int,
|
||||
):
|
||||
run_data["run_number"] = run_number
|
||||
run_data.update(serve_overrides)
|
||||
run_data.update(bench_overrides)
|
||||
|
||||
return run_data
|
||||
|
||||
|
||||
def run_benchmark(
|
||||
server: ServerProcess | None,
|
||||
bench_cmd: list[str],
|
||||
*,
|
||||
serve_overrides: ParameterSweepItem,
|
||||
bench_overrides: ParameterSweepItem,
|
||||
run_number: int,
|
||||
output_path: Path,
|
||||
dry_run: bool,
|
||||
):
|
||||
benchmark_cmd = [
|
||||
*bench_overrides.apply_to_cmd(bench_cmd),
|
||||
"--save-result",
|
||||
"--result-dir",
|
||||
str(output_path.parent),
|
||||
"--result-filename",
|
||||
output_path.name,
|
||||
]
|
||||
|
||||
print("[BEGIN BENCHMARK]")
|
||||
print(f"Benchmark overrides: {bench_overrides}")
|
||||
print(f"Run Number: {run_number}")
|
||||
print(f"Benchmark command: {benchmark_cmd}")
|
||||
print(f"Output file: {output_path}")
|
||||
|
||||
run_data: dict[str, object]
|
||||
|
||||
if output_path.exists():
|
||||
print("Found existing results. Skipping.")
|
||||
|
||||
with output_path.open("rb") as f:
|
||||
run_data = json.load(f)
|
||||
return _update_run_data(
|
||||
run_data,
|
||||
serve_overrides,
|
||||
bench_overrides,
|
||||
run_number,
|
||||
)
|
||||
|
||||
if server is None:
|
||||
if not dry_run:
|
||||
raise ValueError(f"Cannot find results at {output_path}")
|
||||
|
||||
print("[END BENCHMARK]")
|
||||
return None
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
server.run_subcommand(benchmark_cmd)
|
||||
server.after_bench()
|
||||
|
||||
with output_path.open("rb") as f:
|
||||
run_data = json.load(f)
|
||||
|
||||
run_data = _update_run_data(
|
||||
run_data,
|
||||
serve_overrides,
|
||||
bench_overrides,
|
||||
run_number,
|
||||
)
|
||||
|
||||
with output_path.open("w") as f:
|
||||
json.dump(run_data, f, indent=4)
|
||||
|
||||
print("[END BENCHMARK]")
|
||||
|
||||
return run_data
|
||||
|
||||
|
||||
def _get_comb_base_path(
|
||||
output_dir: Path,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
):
|
||||
parts = list[str]()
|
||||
if serve_comb:
|
||||
parts.extend(("SERVE-", serve_comb.as_text(sep="-")))
|
||||
if bench_comb:
|
||||
parts.extend(("BENCH-", bench_comb.as_text(sep="-")))
|
||||
|
||||
return output_dir / sanitize_filename("-".join(parts))
|
||||
|
||||
|
||||
def _get_comb_run_path(base_path: Path, run_number: int | None):
|
||||
if run_number is None:
|
||||
return base_path / "summary.json"
|
||||
|
||||
return base_path / f"run={run_number}.json"
|
||||
|
||||
|
||||
def _comb_needs_server(
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_combs: ParameterSweep,
|
||||
output_dir: Path,
|
||||
):
|
||||
for bench_comb in bench_combs:
|
||||
base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb)
|
||||
if not _get_comb_run_path(base_path, run_number=None).exists():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def run_comb(
|
||||
server: ServerProcess | None,
|
||||
bench_cmd: list[str],
|
||||
*,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
base_path: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
):
|
||||
comb_data = list[dict[str, object]]()
|
||||
|
||||
for run_number in range(num_runs):
|
||||
run_data = run_benchmark(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_overrides=serve_comb,
|
||||
bench_overrides=bench_comb,
|
||||
run_number=run_number,
|
||||
output_path=_get_comb_run_path(base_path, run_number),
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
if run_data is not None:
|
||||
comb_data.append(run_data)
|
||||
|
||||
if dry_run:
|
||||
return None
|
||||
|
||||
with _get_comb_run_path(base_path, run_number=None).open("w") as f:
|
||||
json.dump(comb_data, f, indent=4)
|
||||
|
||||
return comb_data
|
||||
|
||||
|
||||
def run_combs(
|
||||
serve_cmd: list[str],
|
||||
bench_cmd: list[str],
|
||||
after_bench_cmd: list[str],
|
||||
*,
|
||||
show_stdout: bool,
|
||||
serve_params: ParameterSweep,
|
||||
bench_params: ParameterSweep,
|
||||
output_dir: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
):
|
||||
all_data = list[dict[str, object]]()
|
||||
for serve_comb in serve_params:
|
||||
with (
|
||||
run_server(
|
||||
serve_cmd,
|
||||
after_bench_cmd,
|
||||
show_stdout=show_stdout,
|
||||
serve_overrides=serve_comb,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
if _comb_needs_server(serve_comb, bench_params, output_dir)
|
||||
else contextlib.nullcontext()
|
||||
) as server:
|
||||
for bench_comb in bench_params:
|
||||
base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb)
|
||||
|
||||
comb_data = run_comb(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
base_path=base_path,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
if comb_data is not None:
|
||||
all_data.extend(comb_data)
|
||||
|
||||
if dry_run:
|
||||
return None
|
||||
|
||||
combined_df = pd.DataFrame.from_records(all_data)
|
||||
combined_df.to_csv(output_dir / "summary.csv")
|
||||
|
||||
return combined_df
|
||||
|
||||
|
||||
@dataclass
|
||||
class SweepServeArgs:
|
||||
serve_cmd: list[str]
|
||||
bench_cmd: list[str]
|
||||
after_bench_cmd: list[str]
|
||||
show_stdout: bool
|
||||
serve_params: ParameterSweep
|
||||
bench_params: ParameterSweep
|
||||
output_dir: Path
|
||||
num_runs: int
|
||||
dry_run: bool
|
||||
resume: str | None
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
serve_cmd = shlex.split(args.serve_cmd)
|
||||
bench_cmd = shlex.split(args.bench_cmd)
|
||||
after_bench_cmd = (
|
||||
[] if args.after_bench_cmd is None else shlex.split(args.after_bench_cmd)
|
||||
)
|
||||
|
||||
if args.serve_params:
|
||||
serve_params = ParameterSweep.read_json(args.serve_params)
|
||||
else:
|
||||
# i.e.: run serve_cmd without any modification
|
||||
serve_params = ParameterSweep.from_records([{}])
|
||||
|
||||
if args.bench_params:
|
||||
bench_params = ParameterSweep.read_json(args.bench_params)
|
||||
else:
|
||||
# i.e.: run bench_cmd without any modification
|
||||
bench_params = ParameterSweep.from_records([{}])
|
||||
|
||||
num_runs = args.num_runs
|
||||
if num_runs < 1:
|
||||
raise ValueError("`num_runs` should be at least 1.")
|
||||
|
||||
return cls(
|
||||
serve_cmd=serve_cmd,
|
||||
bench_cmd=bench_cmd,
|
||||
after_bench_cmd=after_bench_cmd,
|
||||
show_stdout=args.show_stdout,
|
||||
serve_params=serve_params,
|
||||
bench_params=bench_params,
|
||||
output_dir=Path(args.output_dir),
|
||||
num_runs=num_runs,
|
||||
dry_run=args.dry_run,
|
||||
resume=args.resume,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
"--serve-cmd",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The command used to run the server: `vllm serve ...`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bench-cmd",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The command used to run the benchmark: `vllm bench serve ...`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--after-bench-cmd",
|
||||
type=str,
|
||||
default=None,
|
||||
help="After a benchmark run is complete, invoke this command instead of "
|
||||
"the default `ServerWrapper.clear_cache()`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--show-stdout",
|
||||
action="store_true",
|
||||
help="If set, logs the standard output of subcommands. "
|
||||
"Useful for debugging but can be quite spammy.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--serve-params",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to JSON file containing a list of parameter combinations "
|
||||
"for the `vllm serve` command. "
|
||||
"If both `serve_params` and `bench_params` are given, "
|
||||
"this script will iterate over their Cartesian product.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bench-params",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to JSON file containing a list of parameter combinations "
|
||||
"for the `vllm bench serve` command. "
|
||||
"If both `serve_params` and `bench_params` are given, "
|
||||
"this script will iterate over their Cartesian product.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="results",
|
||||
help="The directory to which results are written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-runs",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Number of runs per parameter combination.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="If set, prints the commands to run, "
|
||||
"then exits without executing them.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Set this to the name of a directory under `output_dir` (which is a "
|
||||
"timestamp) to resume a previous execution of this script, i.e., only run "
|
||||
"parameter combinations for which there are still no output files.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def run_main(args: SweepServeArgs):
|
||||
timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_dir = args.output_dir / timestamp
|
||||
|
||||
if args.resume and not output_dir.exists():
|
||||
raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")
|
||||
|
||||
try:
|
||||
return run_combs(
|
||||
serve_cmd=args.serve_cmd,
|
||||
bench_cmd=args.bench_cmd,
|
||||
after_bench_cmd=args.after_bench_cmd,
|
||||
show_stdout=args.show_stdout,
|
||||
serve_params=args.serve_params,
|
||||
bench_params=args.bench_params,
|
||||
output_dir=output_dir,
|
||||
num_runs=args.num_runs,
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
except BaseException as exc:
|
||||
raise RuntimeError(
|
||||
f"The script was terminated early. Use `--resume {timestamp}` "
|
||||
f"to continue the script from its last checkpoint."
|
||||
) from exc
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
run_main(SweepServeArgs.from_cli_args(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run vLLM server benchmark under multiple settings."
|
||||
)
|
||||
SweepServeArgs.add_cli_args(parser)
|
||||
|
||||
main(parser.parse_args())
|
||||
483
vllm/benchmarks/sweep/serve_sla.py
Normal file
483
vllm/benchmarks/sweep/serve_sla.py
Normal file
@ -0,0 +1,483 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import contextlib
|
||||
import json
|
||||
import math
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Literal, get_args
|
||||
|
||||
import pandas as pd
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from .param_sweep import ParameterSweep, ParameterSweepItem
|
||||
from .serve import SweepServeArgs, run_benchmark, run_server
|
||||
from .server import ServerProcess
|
||||
from .sla_sweep import SLASweep, SLASweepItem
|
||||
from .utils import sanitize_filename
|
||||
|
||||
|
||||
def _get_sla_base_path(
|
||||
output_dir: Path,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
):
|
||||
parts = list[str]()
|
||||
if serve_comb:
|
||||
parts.extend(("SERVE-", serve_comb.as_text(sep="-")))
|
||||
if bench_comb:
|
||||
parts.extend(("BENCH-", bench_comb.as_text(sep="-")))
|
||||
|
||||
return output_dir / sanitize_filename("-".join(parts))
|
||||
|
||||
|
||||
def _get_sla_iter_path(
|
||||
base_path: Path,
|
||||
sla_comb: SLASweepItem,
|
||||
sla_variable: str,
|
||||
sla_value: int | None,
|
||||
):
|
||||
if sla_value is None:
|
||||
prefix = sla_comb.as_text(sep="-")
|
||||
return base_path / f"SLA--{prefix}.json"
|
||||
|
||||
return base_path / f"{sla_variable}={sla_value}"
|
||||
|
||||
|
||||
def _get_sla_run_path(iter_path: Path, run_number: int | None):
|
||||
if run_number is None:
|
||||
return iter_path / "summary.json"
|
||||
|
||||
return iter_path / f"run={run_number}.json"
|
||||
|
||||
|
||||
def _sla_needs_server(
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_combs: ParameterSweep,
|
||||
sla_combs: SLASweep,
|
||||
sla_variable: str,
|
||||
output_dir: Path,
|
||||
):
|
||||
for bench_comb in bench_combs:
|
||||
base_path = _get_sla_base_path(output_dir, serve_comb, bench_comb)
|
||||
for sla_comb in sla_combs:
|
||||
if not _get_sla_iter_path(
|
||||
base_path,
|
||||
sla_comb,
|
||||
sla_variable,
|
||||
sla_value=None,
|
||||
).exists():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def run_sla(
|
||||
server: ServerProcess | None,
|
||||
bench_cmd: list[str],
|
||||
*,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
iter_path: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
):
|
||||
iter_data = list[dict[str, object]]()
|
||||
|
||||
for run_number in range(num_runs):
|
||||
run_data = run_benchmark(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_overrides=serve_comb,
|
||||
bench_overrides=bench_comb,
|
||||
run_number=run_number,
|
||||
output_path=_get_sla_run_path(iter_path, run_number),
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
if run_data is not None:
|
||||
iter_data.append(run_data)
|
||||
|
||||
if dry_run:
|
||||
return None
|
||||
|
||||
with _get_sla_run_path(iter_path, run_number=None).open("w") as f:
|
||||
json.dump(iter_data, f, indent=4)
|
||||
|
||||
return iter_data
|
||||
|
||||
|
||||
SLAVariable = Literal["request_rate", "max_concurrency"]
|
||||
|
||||
|
||||
def _estimate_sla_value(run_data: dict[str, object], sla_variable: SLAVariable):
|
||||
request_throughput = float(run_data["request_throughput"]) # type: ignore
|
||||
if sla_variable == "request_rate":
|
||||
return request_throughput
|
||||
if sla_variable == "max_concurrency":
|
||||
mean_latency_ms = float(run_data["mean_e2el_ms"]) # type: ignore
|
||||
return request_throughput * mean_latency_ms / 1000
|
||||
|
||||
assert_never(sla_variable)
|
||||
|
||||
|
||||
def _estimate_sla_bounds(
|
||||
server: ServerProcess | None,
|
||||
bench_cmd: list[str],
|
||||
*,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
sla_comb: SLASweepItem,
|
||||
base_path: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
sla_variable: SLAVariable,
|
||||
init_value: int,
|
||||
max_value: int,
|
||||
):
|
||||
sla_data = list[dict[str, object]]()
|
||||
|
||||
max_passing: int = 0
|
||||
min_failing: int = 0
|
||||
|
||||
val: int = init_value
|
||||
assert val > 0
|
||||
|
||||
while True:
|
||||
print(f"Testing {sla_variable}: {val} req/s")
|
||||
|
||||
iter_data = run_sla(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb | {sla_variable: val},
|
||||
iter_path=_get_sla_iter_path(base_path, sla_comb, sla_variable, val),
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
assert iter_data is not None
|
||||
sla_data.extend(iter_data)
|
||||
|
||||
iter_data_mean = {
|
||||
k: sum(float(run_data[k]) for run_data in iter_data) / len(iter_data) # type: ignore
|
||||
for k in sla_comb
|
||||
}
|
||||
|
||||
sla_results = [
|
||||
criterion.print_and_validate(iter_data_mean, k)
|
||||
for k, criterion in sla_comb.items()
|
||||
]
|
||||
|
||||
if all(sla_results):
|
||||
print("SLA criteria are met.")
|
||||
max_passing = val
|
||||
val *= 2
|
||||
else:
|
||||
print("SLA criteria are not met.")
|
||||
min_failing = val
|
||||
break
|
||||
|
||||
if val >= max_value:
|
||||
break
|
||||
|
||||
return sla_data, (max_passing, min_failing)
|
||||
|
||||
|
||||
def _find_sla_value(
|
||||
server: ServerProcess | None,
|
||||
bench_cmd: list[str],
|
||||
*,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
sla_comb: SLASweepItem,
|
||||
base_path: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
sla_variable: SLAVariable,
|
||||
min_value: int,
|
||||
max_value: int,
|
||||
):
|
||||
sla_data = list[dict[str, object]]()
|
||||
|
||||
left: int = min_value
|
||||
right: int = max_value
|
||||
|
||||
while True:
|
||||
val = (left + right) // 2
|
||||
print(f"Testing {sla_variable}: {val} req/s")
|
||||
|
||||
iter_data = run_sla(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb | {sla_variable: val},
|
||||
iter_path=_get_sla_iter_path(base_path, sla_comb, sla_variable, val),
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
assert iter_data is not None
|
||||
sla_data.extend(iter_data)
|
||||
|
||||
iter_data_mean = {
|
||||
k: sum(float(run_data[k]) for run_data in iter_data) / len(iter_data) # type: ignore
|
||||
for k in sla_comb
|
||||
}
|
||||
|
||||
sla_results = [
|
||||
criterion.print_and_validate(iter_data_mean, k)
|
||||
for k, criterion in sla_comb.items()
|
||||
]
|
||||
|
||||
if all(sla_results):
|
||||
print("SLA criteria are met.")
|
||||
left = val
|
||||
else:
|
||||
print("SLA criteria are not met.")
|
||||
right = val
|
||||
|
||||
if right - left <= 1:
|
||||
break
|
||||
|
||||
return sla_data, left
|
||||
|
||||
|
||||
def search_sla(
|
||||
server: ServerProcess | None,
|
||||
bench_cmd: list[str],
|
||||
*,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
sla_comb: SLASweepItem,
|
||||
sla_variable: SLAVariable,
|
||||
sla_inf_value: int = 65536, # The value that represents infinite QPS
|
||||
base_path: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
):
|
||||
print("[SLA START]")
|
||||
print(f"SLA criteria: {sla_comb.as_text()}")
|
||||
|
||||
sla_data_0 = run_sla(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb | {sla_variable: sla_inf_value},
|
||||
iter_path=_get_sla_iter_path(base_path, sla_comb, sla_variable, sla_inf_value),
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
if sla_data_0 is None:
|
||||
assert dry_run
|
||||
print("Omitting SLA search.")
|
||||
print("[SLA END]")
|
||||
return None
|
||||
|
||||
sla_init_value = math.ceil(
|
||||
sum(_estimate_sla_value(item, sla_variable) for item in sla_data_0)
|
||||
/ len(sla_data_0)
|
||||
)
|
||||
print(f"Initial {sla_variable} to search: {sla_init_value} req/s.")
|
||||
|
||||
sla_data_1, (sla_min, sla_max) = _estimate_sla_bounds(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
sla_comb=sla_comb,
|
||||
base_path=base_path,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
sla_variable=sla_variable,
|
||||
init_value=sla_init_value,
|
||||
max_value=sla_inf_value,
|
||||
)
|
||||
print(f"Range of {sla_variable} to search: [{sla_min}, {sla_max}] req/s.")
|
||||
|
||||
sla_data_2, sla_value = _find_sla_value(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
sla_comb=sla_comb,
|
||||
base_path=base_path,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
sla_variable=sla_variable,
|
||||
min_value=sla_min,
|
||||
max_value=sla_max,
|
||||
)
|
||||
|
||||
sla_data = sla_data_0 + sla_data_1 + sla_data_2
|
||||
print(f"Maximum {sla_variable} for SLA: {sla_value} req/s.")
|
||||
|
||||
with _get_sla_iter_path(
|
||||
base_path,
|
||||
sla_comb,
|
||||
sla_variable,
|
||||
sla_value=None,
|
||||
).open("w") as f:
|
||||
json.dump(sla_data, f, indent=4)
|
||||
|
||||
print("[SLA END]")
|
||||
|
||||
return sla_data
|
||||
|
||||
|
||||
def run_slas(
|
||||
serve_cmd: list[str],
|
||||
bench_cmd: list[str],
|
||||
after_bench_cmd: list[str],
|
||||
*,
|
||||
show_stdout: bool,
|
||||
serve_params: ParameterSweep,
|
||||
bench_params: ParameterSweep,
|
||||
sla_params: SLASweep,
|
||||
sla_variable: SLAVariable,
|
||||
output_dir: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
):
|
||||
if any(bench_comb.has_param(sla_variable) for bench_comb in bench_params):
|
||||
raise ValueError(
|
||||
f"You should not override `{sla_variable}` in `bench_params` in SLA mode, "
|
||||
"since it is supposed to be determined automatically."
|
||||
)
|
||||
|
||||
all_data = list[dict[str, object]]()
|
||||
for serve_comb in serve_params:
|
||||
with (
|
||||
run_server(
|
||||
serve_cmd,
|
||||
after_bench_cmd,
|
||||
show_stdout=show_stdout,
|
||||
serve_overrides=serve_comb,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
if _sla_needs_server(
|
||||
serve_comb,
|
||||
bench_params,
|
||||
sla_params,
|
||||
sla_variable,
|
||||
output_dir,
|
||||
)
|
||||
else contextlib.nullcontext()
|
||||
) as server:
|
||||
for bench_comb in bench_params:
|
||||
for sla_comb in sla_params:
|
||||
base_path = _get_sla_base_path(output_dir, serve_comb, bench_comb)
|
||||
|
||||
comb_data = search_sla(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
sla_comb=sla_comb,
|
||||
sla_variable=sla_variable,
|
||||
base_path=base_path,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
if comb_data is not None:
|
||||
all_data.extend(comb_data)
|
||||
|
||||
if dry_run:
|
||||
return None
|
||||
|
||||
combined_df = pd.DataFrame.from_records(all_data)
|
||||
combined_df.to_csv(output_dir / "summary.csv")
|
||||
|
||||
return combined_df
|
||||
|
||||
|
||||
@dataclass
|
||||
class SweepServeSLAArgs(SweepServeArgs):
|
||||
sla_params: SLASweep
|
||||
sla_variable: SLAVariable
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
base_args = super().from_cli_args(args)
|
||||
|
||||
if args.sla_params:
|
||||
sla_params = SLASweep.read_json(args.sla_params)
|
||||
else:
|
||||
sla_params = SLASweep.from_records([])
|
||||
|
||||
return cls(
|
||||
**asdict(base_args),
|
||||
sla_params=sla_params,
|
||||
sla_variable=args.sla_variable,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
parser = super().add_cli_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--sla-params",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to JSON file containing a list of SLA constraints to satisfy. "
|
||||
'Each constraint is expressed in `{"<KEY>": "<OP><VALUE>"}` format, '
|
||||
'e.g.: `{"p99_e2el_ms": "<=500"}` means that '
|
||||
"the E2E latency should be less than 500ms 99%% of the time. "
|
||||
"Setting this option runs this script in SLA mode, which searches for "
|
||||
"the maximum `sla_variable` that satisfies the constraints for "
|
||||
"each combination of `serve_params`, `bench_params`, and `sla_params`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sla-variable",
|
||||
type=str,
|
||||
choices=get_args(SLAVariable),
|
||||
default="request_rate",
|
||||
help="Whether to tune request rate or maximum concurrency to satisfy "
|
||||
"the SLA constraints.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def run_main(args: SweepServeSLAArgs):
|
||||
timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_dir = args.output_dir / timestamp
|
||||
|
||||
if args.resume and not output_dir.exists():
|
||||
raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")
|
||||
|
||||
try:
|
||||
return run_slas(
|
||||
serve_cmd=args.serve_cmd,
|
||||
bench_cmd=args.bench_cmd,
|
||||
after_bench_cmd=args.after_bench_cmd,
|
||||
show_stdout=args.show_stdout,
|
||||
serve_params=args.serve_params,
|
||||
bench_params=args.bench_params,
|
||||
sla_params=args.sla_params,
|
||||
sla_variable=args.sla_variable,
|
||||
output_dir=output_dir,
|
||||
num_runs=args.num_runs,
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
except BaseException as exc:
|
||||
raise RuntimeError(
|
||||
f"The script was terminated early. Use `--resume {timestamp}` "
|
||||
f"to continue the script from its last checkpoint."
|
||||
) from exc
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
run_main(SweepServeSLAArgs.from_cli_args(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Tune a variable to meet SLAs under multiple settings."
|
||||
)
|
||||
SweepServeSLAArgs.add_cli_args(parser)
|
||||
|
||||
main(parser.parse_args())
|
||||
114
vllm/benchmarks/sweep/server.py
Normal file
114
vllm/benchmarks/sweep/server.py
Normal file
@ -0,0 +1,114 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
from types import TracebackType
|
||||
|
||||
import requests
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
class ServerProcess:
|
||||
def __init__(
|
||||
self,
|
||||
server_cmd: list[str],
|
||||
after_bench_cmd: list[str],
|
||||
*,
|
||||
show_stdout: bool,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.server_cmd = server_cmd
|
||||
self.after_bench_cmd = after_bench_cmd
|
||||
self.show_stdout = show_stdout
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
exc_traceback: TracebackType | None,
|
||||
) -> None:
|
||||
self.stop()
|
||||
|
||||
def start(self):
|
||||
# Create new process for clean termination
|
||||
self._server_process = subprocess.Popen(
|
||||
self.server_cmd,
|
||||
start_new_session=True,
|
||||
stdout=None if self.show_stdout else subprocess.DEVNULL,
|
||||
# Need `VLLM_SERVER_DEV_MODE=1` for `_reset_caches`
|
||||
env=os.environ | {"VLLM_SERVER_DEV_MODE": "1"},
|
||||
)
|
||||
|
||||
def stop(self):
|
||||
server_process = self._server_process
|
||||
|
||||
if server_process.poll() is None:
|
||||
# In case only some processes have been terminated
|
||||
with contextlib.suppress(ProcessLookupError):
|
||||
# We need to kill both API Server and Engine processes
|
||||
os.killpg(os.getpgid(server_process.pid), signal.SIGKILL)
|
||||
|
||||
def run_subcommand(self, cmd: list[str]):
|
||||
return subprocess.run(
|
||||
cmd,
|
||||
stdout=None if self.show_stdout else subprocess.DEVNULL,
|
||||
check=True,
|
||||
)
|
||||
|
||||
def after_bench(self) -> None:
|
||||
if not self.after_bench_cmd:
|
||||
self.reset_caches()
|
||||
return
|
||||
|
||||
self.run_subcommand(self.after_bench_cmd)
|
||||
|
||||
def _get_vllm_server_address(self) -> str:
|
||||
server_cmd = self.server_cmd
|
||||
|
||||
for host_key in ("--host",):
|
||||
if host_key in server_cmd:
|
||||
host = server_cmd[server_cmd.index(host_key) + 1]
|
||||
break
|
||||
else:
|
||||
host = "localhost"
|
||||
|
||||
for port_key in ("-p", "--port"):
|
||||
if port_key in server_cmd:
|
||||
port = int(server_cmd[server_cmd.index(port_key) + 1])
|
||||
break
|
||||
else:
|
||||
port = 8000 # The default value in vllm serve
|
||||
|
||||
return f"http://{host}:{port}"
|
||||
|
||||
def reset_caches(self) -> None:
|
||||
server_cmd = self.server_cmd
|
||||
|
||||
# Use `.endswith()` to match `/bin/...`
|
||||
if server_cmd[0].endswith("vllm"):
|
||||
server_address = self._get_vllm_server_address()
|
||||
print(f"Resetting caches at {server_address}")
|
||||
|
||||
res = requests.post(f"{server_address}/reset_prefix_cache")
|
||||
res.raise_for_status()
|
||||
|
||||
res = requests.post(f"{server_address}/reset_mm_cache")
|
||||
res.raise_for_status()
|
||||
elif server_cmd[0].endswith("infinity_emb"):
|
||||
if "--vector-disk-cache" in server_cmd:
|
||||
raise NotImplementedError(
|
||||
"Infinity server uses caching but does not expose a method "
|
||||
"to reset the cache"
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"No implementation of `reset_caches` for `{server_cmd[0]}` server. "
|
||||
"Please specify a custom command via `--after-bench-cmd`."
|
||||
)
|
||||
132
vllm/benchmarks/sweep/sla_sweep.py
Normal file
132
vllm/benchmarks/sweep/sla_sweep.py
Normal file
@ -0,0 +1,132 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import json
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
@dataclass
|
||||
class SLACriterionBase(ABC):
|
||||
target: float
|
||||
|
||||
@abstractmethod
|
||||
def validate(self, actual: float) -> bool:
|
||||
"""Return `True` if this criterion is met; otherwise `False`."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def format_cond(self, lhs: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def print_and_validate(
|
||||
self,
|
||||
metrics: dict[str, float],
|
||||
metrics_key: str,
|
||||
) -> bool:
|
||||
metric = metrics[metrics_key]
|
||||
result = self.validate(metric)
|
||||
|
||||
cond = self.format_cond(f"{metrics_key} = {metric:.2f}")
|
||||
print(f"Validating SLA: {cond} | " + ("PASSED" if result else "FAILED"))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class SLALessThan(SLACriterionBase):
|
||||
@override
|
||||
def validate(self, actual: float) -> bool:
|
||||
return actual < self.target
|
||||
|
||||
@override
|
||||
def format_cond(self, lhs: str) -> str:
|
||||
return f"{lhs}<{self.target:.2f}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SLALessThanOrEqualTo(SLACriterionBase):
|
||||
@override
|
||||
def validate(self, actual: float) -> bool:
|
||||
return actual <= self.target
|
||||
|
||||
@override
|
||||
def format_cond(self, lhs: str) -> str:
|
||||
return f"{lhs}<={self.target:.2f}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SLAGreaterThan(SLACriterionBase):
|
||||
@override
|
||||
def validate(self, actual: float) -> bool:
|
||||
return actual > self.target
|
||||
|
||||
@override
|
||||
def format_cond(self, lhs: str) -> str:
|
||||
return f"{lhs}>{self.target:.2f}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SLAGreaterThanOrEqualTo(SLACriterionBase):
|
||||
@override
|
||||
def validate(self, actual: float) -> bool:
|
||||
return actual >= self.target
|
||||
|
||||
@override
|
||||
def format_cond(self, lhs: str) -> str:
|
||||
return f"{lhs}>={self.target:.2f}"
|
||||
|
||||
|
||||
# NOTE: The ordering is important! Match longer op_keys first
|
||||
SLA_CRITERIA: dict[str, type[SLACriterionBase]] = {
|
||||
"<=": SLALessThanOrEqualTo,
|
||||
">=": SLAGreaterThanOrEqualTo,
|
||||
"<": SLALessThan,
|
||||
">": SLAGreaterThan,
|
||||
}
|
||||
|
||||
|
||||
class SLASweep(list["SLASweepItem"]):
|
||||
@classmethod
|
||||
def read_json(cls, filepath: os.PathLike):
|
||||
with open(filepath, "rb") as f:
|
||||
records = json.load(f)
|
||||
|
||||
return cls.from_records(records)
|
||||
|
||||
@classmethod
|
||||
def from_records(cls, records: list[dict[str, str]]):
|
||||
if not isinstance(records, list):
|
||||
raise TypeError(
|
||||
f"The SLA sweep should be a list of dictionaries, "
|
||||
f"but found type: {type(records)}"
|
||||
)
|
||||
|
||||
return cls(SLASweepItem.from_record(record) for record in records)
|
||||
|
||||
|
||||
class SLASweepItem(dict[str, SLACriterionBase]):
|
||||
@classmethod
|
||||
def from_record(cls, record: dict[str, str]):
|
||||
sla_criteria: dict[str, SLACriterionBase] = {}
|
||||
|
||||
for metric_key, metric_value in record.items():
|
||||
for op_key in SLA_CRITERIA:
|
||||
if metric_value.startswith(op_key):
|
||||
sla_criteria[metric_key] = SLA_CRITERIA[op_key](
|
||||
float(metric_value.removeprefix(op_key))
|
||||
)
|
||||
break
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid operator for "
|
||||
f"SLA constraint '{metric_key}={metric_value}'. "
|
||||
f"Valid operators are: {sorted(SLA_CRITERIA)}",
|
||||
)
|
||||
|
||||
return cls(sla_criteria)
|
||||
|
||||
def as_text(self, sep: str = ", ") -> str:
|
||||
return sep.join(v.format_cond(k) for k, v in self.items())
|
||||
4
vllm/benchmarks/sweep/utils.py
Normal file
4
vllm/benchmarks/sweep/utils.py
Normal file
@ -0,0 +1,4 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
def sanitize_filename(filename: str) -> str:
|
||||
return filename.replace("/", "_").replace("..", "__").strip("'").strip('"')
|
||||
Reference in New Issue
Block a user