This commit is contained in:
root
2026-04-09 11:23:47 +08:00
parent 8082d5f4b2
commit 72387e4fa8
1885 changed files with 611521 additions and 1 deletions

View File

View File

@@ -0,0 +1,44 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG
from .plot import SweepPlotArgs
from .plot import main as plot_main
from .plot_pareto import SweepPlotParetoArgs
from .plot_pareto import main as plot_pareto_main
from .serve import SweepServeArgs
from .serve import main as serve_main
from .serve_sla import SweepServeSLAArgs
from .serve_sla import main as serve_sla_main
from .startup import SweepStartupArgs
from .startup import main as startup_main
SUBCOMMANDS = (
(SweepServeArgs, serve_main),
(SweepServeSLAArgs, serve_sla_main),
(SweepStartupArgs, startup_main),
(SweepPlotArgs, plot_main),
(SweepPlotParetoArgs, plot_pareto_main),
)
def add_cli_args(parser: argparse.ArgumentParser):
subparsers = parser.add_subparsers(required=True, dest="sweep_type")
for cmd, entrypoint in SUBCOMMANDS:
cmd_subparser = subparsers.add_parser(
cmd.parser_name,
description=cmd.parser_help,
usage=f"vllm bench sweep {cmd.parser_name} [options]",
)
cmd_subparser.set_defaults(dispatch_function=entrypoint)
cmd.add_cli_args(cmd_subparser)
cmd_subparser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format(
subcmd=f"sweep {cmd.parser_name}"
)
def main(args: argparse.Namespace):
args.dispatch_function(args)

View File

@@ -0,0 +1,159 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import os
from typing import Any
class ParameterSweep(list["ParameterSweepItem"]):
@classmethod
def read_json(cls, filepath: os.PathLike):
with open(filepath, "rb") as f:
data = json.load(f)
# Support both list and dict formats
if isinstance(data, dict):
return cls.read_from_dict(data)
return cls.from_records(data)
@classmethod
def read_from_dict(cls, data: dict[str, dict[str, object]]):
"""
Read parameter sweep from a dict format where keys are names.
Example:
{
"experiment1": {"max_tokens": 100, "temperature": 0.7},
"experiment2": {"max_tokens": 200, "temperature": 0.9}
}
"""
records = [{"_benchmark_name": name, **params} for name, params in data.items()]
return cls.from_records(records)
@classmethod
def from_records(cls, records: list[dict[str, object]]):
if not isinstance(records, list):
raise TypeError(
f"The parameter sweep should be a list of dictionaries, "
f"but found type: {type(records)}"
)
# Validate that all _benchmark_name values are unique if provided
names = [r["_benchmark_name"] for r in records if "_benchmark_name" in r]
if names and len(names) != len(set(names)):
duplicates = [name for name in names if names.count(name) > 1]
raise ValueError(
f"Duplicate _benchmark_name values found: {set(duplicates)}. "
f"All _benchmark_name values must be unique."
)
return cls(ParameterSweepItem.from_record(record) for record in records)
class ParameterSweepItem(dict[str, object]):
@classmethod
def from_record(cls, record: dict[str, object]):
if not isinstance(record, dict):
raise TypeError(
f"Each item in the parameter sweep should be a dictionary, "
f"but found type: {type(record)}"
)
return cls(record)
def __or__(self, other: dict[str, Any]):
return type(self)(super().__or__(other))
@property
def name(self) -> str:
"""
Get the name for this parameter sweep item.
Returns the '_benchmark_name' field if present, otherwise returns a text
representation of all parameters.
"""
if "_benchmark_name" in self:
return str(self["_benchmark_name"])
return self.as_text(sep="-")
# In JSON, we prefer "_"
def _iter_param_key_candidates(self, param_key: str):
# Inner config arguments are not converted by the CLI
if "." in param_key:
prefix, rest = param_key.split(".", 1)
for prefix_candidate in self._iter_param_key_candidates(prefix):
yield prefix_candidate + "." + rest
return
yield param_key
yield param_key.replace("-", "_")
yield param_key.replace("_", "-")
# In CLI, we prefer "-"
def _iter_cmd_key_candidates(self, param_key: str):
for k in reversed(tuple(self._iter_param_key_candidates(param_key))):
yield "--" + k
def _normalize_cmd_key(self, param_key: str):
return next(self._iter_cmd_key_candidates(param_key))
def has_param(self, param_key: str) -> bool:
return any(k in self for k in self._iter_param_key_candidates(param_key))
def _normalize_cmd_kv_pair(self, k: str, v: object) -> list[str]:
"""
Normalize a key-value pair into command-line arguments.
Returns a list containing either:
- A single element for boolean flags (e.g., ['--flag'] or ['--flag=true'])
- Two elements for key-value pairs (e.g., ['--key', 'value'])
"""
if isinstance(v, bool):
# For nested params (containing "."), use =true/false syntax
if "." in k:
return [f"{self._normalize_cmd_key(k)}={'true' if v else 'false'}"]
else:
return [self._normalize_cmd_key(k if v else "no-" + k)]
else:
return [self._normalize_cmd_key(k), str(v)]
def apply_to_cmd(self, cmd: list[str]) -> list[str]:
cmd = list(cmd)
for k, v in self.items():
# Skip the '_benchmark_name' field, not a parameter
if k == "_benchmark_name":
continue
# Serialize dict values as JSON
if isinstance(v, dict):
v = json.dumps(v)
for k_candidate in self._iter_cmd_key_candidates(k):
try:
k_idx = cmd.index(k_candidate)
# Replace existing parameter
normalized = self._normalize_cmd_kv_pair(k, v)
if len(normalized) == 1:
# Boolean flag
cmd[k_idx] = normalized[0]
else:
# Key-value pair
cmd[k_idx] = normalized[0]
cmd[k_idx + 1] = normalized[1]
break
except ValueError:
continue
else:
# Add new parameter
cmd.extend(self._normalize_cmd_kv_pair(k, v))
return cmd
def as_text(self, sep: str = ", ") -> str:
return sep.join(f"{k}={v}" for k, v in self.items() if k != "_benchmark_name")

View File

@@ -0,0 +1,683 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import json
from abc import ABC, abstractmethod
from concurrent.futures import ProcessPoolExecutor
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from types import TracebackType
from typing import ClassVar
from typing_extensions import Self, override
from vllm.utils.collection_utils import full_groupby
from vllm.utils.import_utils import PlaceholderModule
from .utils import sanitize_filename
try:
import matplotlib.pyplot as plt
except ImportError:
plt = PlaceholderModule("matplotlib").placeholder_attr("pyplot")
try:
import pandas as pd
except ImportError:
pd = PlaceholderModule("pandas")
try:
import seaborn as sns
except ImportError:
seaborn = PlaceholderModule("seaborn")
@dataclass
class PlotFilterBase(ABC):
var: str
target: str
@classmethod
def parse_str(cls, s: str):
for op_key in PLOT_FILTERS:
if op_key in s:
key, value = s.split(op_key)
return PLOT_FILTERS[op_key](
key,
value.removeprefix(op_key).strip("'").strip('"'),
)
else:
raise ValueError(
f"Invalid operator for plot filter '{s}'. "
f"Valid operators are: {sorted(PLOT_FILTERS)}",
)
@abstractmethod
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
"""Applies this filter to a DataFrame."""
raise NotImplementedError
@dataclass
class PlotEqualTo(PlotFilterBase):
@override
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
try:
target = float(self.target)
except ValueError:
target = self.target
return df[df[self.var] == target]
@dataclass
class PlotNotEqualTo(PlotFilterBase):
@override
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
try:
target = float(self.target)
except ValueError:
target = self.target
return df[df[self.var] != target]
@dataclass
class PlotLessThan(PlotFilterBase):
@override
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
return df[df[self.var] < float(self.target)]
@dataclass
class PlotLessThanOrEqualTo(PlotFilterBase):
@override
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
return df[df[self.var] <= float(self.target)]
@dataclass
class PlotGreaterThan(PlotFilterBase):
@override
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
return df[df[self.var] > float(self.target)]
@dataclass
class PlotGreaterThanOrEqualTo(PlotFilterBase):
@override
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
return df[df[self.var] >= float(self.target)]
# NOTE: The ordering is important! Match longer op_keys first
PLOT_FILTERS: dict[str, type[PlotFilterBase]] = {
"==": PlotEqualTo,
"!=": PlotNotEqualTo,
"<=": PlotLessThanOrEqualTo,
">=": PlotGreaterThanOrEqualTo,
"<": PlotLessThan,
">": PlotGreaterThan,
}
class PlotFilters(list[PlotFilterBase]):
@classmethod
def parse_str(cls, s: str):
if not s:
return cls()
return cls(PlotFilterBase.parse_str(e) for e in s.split(","))
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
for item in self:
df = item.apply(df)
return df
@dataclass
class PlotBinner:
var: str
bin_size: float
@classmethod
def parse_str(cls, s: str):
for op_key in PLOT_BINNERS:
if op_key in s:
key, value = s.split(op_key)
return PLOT_BINNERS[op_key](key, float(value.removeprefix(op_key)))
else:
raise ValueError(
f"Invalid operator for plot binner '{s}'. "
f"Valid operators are: {sorted(PLOT_BINNERS)}",
)
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
"""Applies this binner to a DataFrame."""
df = df.copy()
df[self.var] = df[self.var] // self.bin_size * self.bin_size
return df
PLOT_BINNERS: dict[str, type[PlotBinner]] = {
"%": PlotBinner,
}
class PlotBinners(list[PlotBinner]):
@classmethod
def parse_str(cls, s: str):
if not s:
return cls()
return cls(PlotBinner.parse_str(e) for e in s.split(","))
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
for item in self:
df = item.apply(df)
return df
def _json_load_bytes(path: Path) -> list[dict[str, object]]:
with path.open("rb") as f:
return json.load(f)
def _convert_inf_nan_strings(data: list[dict[str, object]]) -> list[dict[str, object]]:
"""
Convert string values "inf", "-inf", and "nan" to their float equivalents.
This handles the case where JSON serialization represents inf/nan as strings.
"""
converted_data = []
for record in data:
converted_record = {}
for key, value in record.items():
if isinstance(value, str):
if value in ["inf", "-inf", "nan"]:
converted_record[key] = float(value)
else:
converted_record[key] = value
else:
converted_record[key] = value
converted_data.append(converted_record)
return converted_data
def _get_metric(run_data: dict[str, object], metric_key: str):
try:
return run_data[metric_key]
except KeyError as exc:
raise ValueError(f"Cannot find metric {metric_key!r} in {run_data=}") from exc
def _get_group(run_data: dict[str, object], group_keys: list[str]):
return tuple((k, str(_get_metric(run_data, k))) for k in group_keys)
def _get_fig_path(fig_dir: Path, group: tuple[tuple[str, str], ...], fig_name: str):
parts = list[str]()
# Start with figure name (always provided, defaults to "FIGURE")
parts.append(fig_name)
# Always append group data if present
if group:
parts.extend(f"{k}={v}" for k, v in group)
return fig_dir / sanitize_filename("-".join(parts) + ".png")
class DummyExecutor:
map = map
def __enter__(self) -> Self:
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
exc_traceback: TracebackType | None,
) -> None:
return None
def _plot_fig(
fig_dir: Path,
fig_group_data: tuple[tuple[tuple[str, str], ...], list[dict[str, object]]],
row_by: list[str],
col_by: list[str],
curve_by: list[str],
*,
var_x: str,
var_y: str,
filter_by: PlotFilters,
bin_by: PlotBinners,
scale_x: str | None,
scale_y: str | None,
dry_run: bool,
fig_name: str,
error_bars: bool,
fig_height: float,
fig_dpi: int,
):
fig_group, fig_data = fig_group_data
row_groups = full_groupby(
fig_data,
key=lambda item: _get_group(item, row_by),
)
num_rows = len(row_groups)
num_cols = max(
len(full_groupby(row_data, key=lambda item: _get_group(item, col_by)))
for _, row_data in row_groups
)
fig_path = _get_fig_path(fig_dir, fig_group, fig_name)
print("[BEGIN FIGURE]")
print(f"Group: {dict(fig_group)}")
print(f"Grid: {num_rows} rows x {num_cols} cols")
print(f"Output file: {fig_path}")
if dry_run:
print("[END FIGURE]")
return
# Convert string "inf", "-inf", and "nan" to their float equivalents
fig_data = _convert_inf_nan_strings(fig_data)
df = pd.DataFrame.from_records(fig_data)
if var_x not in df.columns:
raise ValueError(
f"Cannot find {var_x=!r} in parameter sweep results. "
f"Available variables: {df.columns.tolist()}"
)
if var_y not in df.columns:
raise ValueError(
f"Cannot find {var_y=!r} in parameter sweep results. "
f"Available variables: {df.columns.tolist()}"
)
for k in row_by:
if k not in df.columns:
raise ValueError(
f"Cannot find row_by={k!r} in parameter sweep results. "
f"Available variables: {df.columns.tolist()}"
)
for k in col_by:
if k not in df.columns:
raise ValueError(
f"Cannot find col_by={k!r} in parameter sweep results. "
f"Available variables: {df.columns.tolist()}"
)
for k in curve_by:
if k not in df.columns:
raise ValueError(
f"Cannot find curve_by={k!r} in parameter sweep results. "
f"Available variables: {df.columns.tolist()}"
)
df = filter_by.apply(df)
df = bin_by.apply(df)
# Sort by curve_by columns alphabetically for consistent legend ordering
if curve_by:
df = df.sort_values(by=curve_by)
df["row_group"] = (
pd.concat(
[k + "=" + df[k].astype(str) for k in row_by],
axis=1,
).agg("\n".join, axis=1)
if row_by
else "(All)"
)
df["col_group"] = (
pd.concat(
[k + "=" + df[k].astype(str) for k in col_by],
axis=1,
).agg("\n".join, axis=1)
if col_by
else "(All)"
)
if len(curve_by) <= 3:
hue, style, size, *_ = (*curve_by, None, None, None)
g = sns.relplot(
df,
x=var_x,
y=var_y,
hue=hue,
style=style,
size=size,
markers=True,
errorbar="sd" if error_bars else None,
kind="line",
row="row_group",
col="col_group",
height=fig_height,
)
else:
df["curve_group"] = (
pd.concat(
[k + "=" + df[k].astype(str) for k in curve_by],
axis=1,
).agg("\n".join, axis=1)
if curve_by
else "(All)"
)
g = sns.relplot(
df,
x=var_x,
y=var_y,
hue="curve_group",
markers=True,
errorbar="sd" if error_bars else None,
kind="line",
row="row_group",
col="col_group",
height=fig_height,
)
if row_by and col_by:
g.set_titles("{row_name}\n{col_name}")
elif row_by:
g.set_titles("{row_name}")
elif col_by:
g.set_titles("{col_name}")
else:
g.set_titles("")
if scale_x:
g.set(xscale=scale_x)
if scale_y:
g.set(yscale=scale_y)
g.savefig(fig_path, dpi=fig_dpi)
plt.close(g.figure)
print("[END FIGURE]")
def plot(
output_dir: Path,
fig_dir: Path,
fig_by: list[str],
row_by: list[str],
col_by: list[str],
curve_by: list[str],
*,
var_x: str,
var_y: str,
filter_by: PlotFilters,
bin_by: PlotBinners,
scale_x: str | None,
scale_y: str | None,
dry_run: bool,
fig_name: str = "FIGURE",
error_bars: bool = True,
fig_height: float = 6.4,
fig_dpi: int = 300,
):
all_data = [
run_data
for path in output_dir.rglob("**/summary.json")
for run_data in _json_load_bytes(path)
]
if not all_data:
raise ValueError(f"Did not find any parameter sweep results under {output_dir}")
fig_dir.mkdir(parents=True, exist_ok=True)
fig_groups = full_groupby(
all_data,
key=lambda item: _get_group(item, fig_by),
)
with DummyExecutor() if len(fig_groups) <= 1 else ProcessPoolExecutor() as executor:
# Resolve the iterable to ensure that the workers are run
all(
executor.map(
partial(
_plot_fig,
fig_dir,
row_by=row_by,
col_by=col_by,
curve_by=curve_by,
var_x=var_x,
var_y=var_y,
filter_by=filter_by,
bin_by=bin_by,
scale_x=scale_x,
scale_y=scale_y,
dry_run=dry_run,
fig_name=fig_name,
error_bars=error_bars,
fig_height=fig_height,
fig_dpi=fig_dpi,
),
fig_groups,
)
)
@dataclass
class SweepPlotArgs:
output_dir: Path
fig_dir: Path
fig_by: list[str]
row_by: list[str]
col_by: list[str]
curve_by: list[str]
var_x: str
var_y: str
filter_by: PlotFilters
bin_by: PlotBinners
scale_x: str | None
scale_y: str | None
dry_run: bool
fig_name: str = "FIGURE"
error_bars: bool = True
fig_height: float = 6.4
fig_dpi: int = 300
parser_name: ClassVar[str] = "plot"
parser_help: ClassVar[str] = "Plot performance curves from parameter sweep results."
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
output_dir = Path(args.OUTPUT_DIR)
if not output_dir.exists():
raise ValueError(f"No parameter sweep results under {output_dir}")
curve_by = [] if not args.curve_by else args.curve_by.split(",")
row_by = [] if not args.row_by else args.row_by.split(",")
col_by = [] if not args.col_by else args.col_by.split(",")
fig_by = [] if not args.fig_by else args.fig_by.split(",")
return cls(
output_dir=output_dir,
fig_dir=output_dir / args.fig_dir,
fig_by=fig_by,
row_by=row_by,
col_by=col_by,
curve_by=curve_by,
var_x=args.var_x,
var_y=args.var_y,
filter_by=PlotFilters.parse_str(args.filter_by),
bin_by=PlotBinners.parse_str(args.bin_by),
scale_x=args.scale_x,
scale_y=args.scale_y,
dry_run=args.dry_run,
fig_name=args.fig_name,
error_bars=not args.no_error_bars,
fig_height=args.fig_height,
fig_dpi=args.fig_dpi,
)
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument(
"OUTPUT_DIR",
type=str,
default="results",
help="The directory containing the results to plot, "
"i.e., the `--output-dir` argument to the parameter sweep script.",
)
parser.add_argument(
"--fig-dir",
type=str,
default="",
help="The directory to save the figures, relative to `OUTPUT_DIR`. "
"By default, the same directory is used.",
)
parser.add_argument(
"--fig-by",
type=str,
default="",
help="A comma-separated list of variables, such that a separate figure "
"is created for each combination of these variables.",
)
parser.add_argument(
"--row-by",
type=str,
default="",
help="A comma-separated list of variables, such that a separate row "
"is created for each combination of these variables.",
)
parser.add_argument(
"--col-by",
type=str,
default="",
help="A comma-separated list of variables, such that a separate column "
"is created for each combination of these variables.",
)
parser.add_argument(
"--curve-by",
type=str,
default=None,
help="A comma-separated list of variables, such that a separate curve "
"is created for each combination of these variables.",
)
parser.add_argument(
"--var-x",
type=str,
default="request_throughput",
help="The variable for the x-axis.",
)
parser.add_argument(
"--var-y",
type=str,
default="p99_ttft_ms",
help="The variable for the y-axis",
)
parser.add_argument(
"--filter-by",
type=str,
default="",
help="A comma-separated list of statements indicating values to filter by. "
"This is useful to remove outliers. "
"Example: `max_concurrency<1000,max_num_batched_tokens<=4096` means "
"plot only the points where `max_concurrency` is less than 1000 and "
"`max_num_batched_tokens` is no greater than 4096.",
)
parser.add_argument(
"--bin-by",
type=str,
default="",
help="A comma-separated list of statements indicating values to bin by. "
"This is useful to avoid plotting points that are too close together. "
"Example: `request_throughput%%1` means "
"use a bin size of 1 for the `request_throughput` variable.",
)
parser.add_argument(
"--scale-x",
type=str,
default=None,
help="The scale to use for the x-axis. "
"Currently only accepts string values such as 'log' and 'sqrt'. "
"See also: https://seaborn.pydata.org/generated/seaborn.objects.Plot.scale.html",
)
parser.add_argument(
"--scale-y",
type=str,
default=None,
help="The scale to use for the y-axis. "
"Currently only accepts string values such as 'log' and 'sqrt'. "
"See also: https://seaborn.pydata.org/generated/seaborn.objects.Plot.scale.html",
)
parser.add_argument(
"--fig-name",
type=str,
default="FIGURE",
help="Name prefix for the output figure file. "
"Group data is always appended when present. "
"Default: 'FIGURE'. Example: --fig-name my_performance_plot",
)
parser.add_argument(
"--no-error-bars",
action="store_true",
help="If set, disables error bars on the plot. "
"By default, error bars are shown.",
)
parser.add_argument(
"--fig-height",
type=float,
default=6.4,
help="Height of each subplot in inches. Default: 6.4",
)
parser.add_argument(
"--fig-dpi",
type=int,
default=300,
help="Resolution of the output figure in dots per inch. Default: 300",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="If set, prints the information about each figure to plot, "
"then exits without drawing them.",
)
return parser
def run_main(args: SweepPlotArgs):
return plot(
output_dir=args.output_dir,
fig_dir=args.fig_dir,
fig_by=args.fig_by,
row_by=args.row_by,
col_by=args.col_by,
curve_by=args.curve_by,
var_x=args.var_x,
var_y=args.var_y,
filter_by=args.filter_by,
bin_by=args.bin_by,
scale_x=args.scale_x,
scale_y=args.scale_y,
dry_run=args.dry_run,
fig_name=args.fig_name,
error_bars=args.error_bars,
fig_height=args.fig_height,
fig_dpi=args.fig_dpi,
)
def main(args: argparse.Namespace):
run_main(SweepPlotArgs.from_cli_args(args))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=SweepPlotArgs.parser_help)
SweepPlotArgs.add_cli_args(parser)
main(parser.parse_args())

View File

@@ -0,0 +1,399 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import math
from concurrent.futures import ProcessPoolExecutor
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import ClassVar
from vllm.utils.collection_utils import full_groupby
from vllm.utils.import_utils import PlaceholderModule
from .plot import DummyExecutor, _json_load_bytes
from .utils import sanitize_filename
try:
import matplotlib.pyplot as plt
except ImportError:
plt = PlaceholderModule("matplotlib").placeholder_attr("pyplot")
try:
import pandas as pd
except ImportError:
pd = PlaceholderModule("pandas")
try:
import seaborn as sns
except ImportError:
seaborn = PlaceholderModule("seaborn")
def _first_present(run_data: dict[str, object], keys: list[str]):
for key in keys:
for candidate in {key, key.replace("_", "-"), key.replace("-", "_")}:
if candidate in run_data:
return run_data[candidate]
return None
def _get_numeric(
run_data: dict[str, object],
keys: list[str],
*,
allow_zero: bool = True,
) -> float | None:
value = _first_present(run_data, keys)
if value is None:
return None
try:
numeric = float(value)
except (TypeError, ValueError) as exc:
raise ValueError(
f"Expected numeric value for one of {keys}, "
f"but found {value!r} in {run_data=}"
) from exc
if not allow_zero and numeric == 0:
return None
return numeric
def _infer_user_count(
run_data: dict[str, object],
user_count_var: str | None,
) -> float | None:
candidates = [user_count_var] if user_count_var else []
candidates.extend(["request_rate"])
user_count = _get_numeric(run_data, candidates, allow_zero=False)
if user_count is not None:
return user_count
# Fallback to the observed peak if configured value is missing.
return _get_numeric(run_data, ["max_concurrent_requests"], allow_zero=False)
def _infer_gpu_count(
run_data: dict[str, object],
gpu_count_var: str | None,
) -> float:
direct_candidates = [gpu_count_var] if gpu_count_var else []
direct_gpu_count = _get_numeric(run_data, direct_candidates, allow_zero=False)
if direct_gpu_count:
return direct_gpu_count
tp_size = _get_numeric(run_data, ["tensor_parallel_size", "tp"])
pp_size = _get_numeric(run_data, ["pipeline_parallel_size", "pp"])
dp_size = _get_numeric(run_data, ["data_parallel_size", "dp"])
world_size = 1.0
if tp_size:
world_size *= tp_size
if pp_size:
world_size *= pp_size
if dp_size:
world_size *= dp_size
return world_size
def _get_throughput(
run_data: dict[str, object],
throughput_var: str,
) -> float:
throughput = _get_numeric(run_data, [throughput_var])
if throughput is None:
raise ValueError(
f"Cannot find throughput metric {throughput_var!r} in run data. "
f"Available keys: {sorted(run_data)}"
)
return throughput
def _prepare_records(
all_data: list[dict[str, object]],
*,
user_count_var: str | None,
gpu_count_var: str | None,
) -> tuple[list[dict[str, object]], int]:
prepared = []
skipped_missing_users = 0
for record in all_data:
throughput = _get_throughput(record, "output_throughput")
user_count = _infer_user_count(record, user_count_var)
if user_count is None:
skipped_missing_users += 1
continue
gpu_count = _infer_gpu_count(record, gpu_count_var)
tokens_per_user = throughput / user_count
tokens_per_gpu = throughput / gpu_count
prepared.append(
{
**record,
"tokens_per_user": tokens_per_user,
"tokens_per_gpu": tokens_per_gpu,
"user_count_estimate": user_count,
"gpu_count": gpu_count,
}
)
return prepared, skipped_missing_users
def _pareto_frontier(
df: "pd.DataFrame",
x_col: str,
y_col: str,
*,
epsilon: float = 1e-9,
) -> "pd.DataFrame":
sorted_df = df.sort_values([x_col, y_col], ascending=[False, False])
frontier_indices = []
best_y = -math.inf
for idx, row in sorted_df.iterrows():
y_val = row[y_col]
if y_val >= best_y - epsilon:
frontier_indices.append(idx)
best_y = max(best_y, y_val)
return df.loc[frontier_indices]
def _get_fig_path(
fig_dir: Path,
fig_group: tuple[tuple[str, str], ...],
) -> Path:
parts = ["PARETO"]
if fig_group:
parts.extend(f"{k}={v}" for k, v in fig_group)
filename = sanitize_filename("-".join(parts) + ".png")
return fig_dir / filename
def _plot_fig(
fig_dir: Path,
fig_group_data: tuple[tuple[tuple[str, str], ...], list[dict[str, object]]],
label_by: list[str],
*,
dry_run: bool,
):
fig_group, fig_data = fig_group_data
fig_path = _get_fig_path(fig_dir, fig_group)
print("[BEGIN FIGURE]")
print(f"Group: {dict(fig_group)}")
print(f"Output file: {fig_path}")
if dry_run:
print("[END FIGURE]")
return
df = pd.DataFrame.from_records(fig_data)
df = df.dropna(subset=["tokens_per_user", "tokens_per_gpu"])
if df.empty:
print("No data points available after filtering; skipping.")
print("[END FIGURE]")
return
frontier = _pareto_frontier(df, "tokens_per_user", "tokens_per_gpu")
frontier = frontier.sort_values("tokens_per_user")
fig, ax = plt.subplots()
sns.scatterplot(
data=df,
x="tokens_per_user",
y="tokens_per_gpu",
color="0.5",
alpha=0.6,
ax=ax,
label="All runs",
)
sns.lineplot(
data=frontier,
x="tokens_per_user",
y="tokens_per_gpu",
marker="o",
ax=ax,
label="Pareto frontier",
)
if label_by:
for _, row in frontier.iterrows():
label_parts = []
for key in label_by:
if key in row:
label_parts.append(f"{key}={row[key]}")
if label_parts:
ax.text(
row["tokens_per_user"],
row["tokens_per_gpu"],
"\n".join(label_parts),
fontsize=8,
)
ax.set_xlabel("Tokens/s/user")
ax.set_ylabel("Tokens/s/GPU")
ax.grid(True, linestyle="--", linewidth=0.5, alpha=0.6)
fig.tight_layout()
fig.savefig(fig_path)
plt.close(fig)
print(
f"Plotted {len(df)} points; Pareto frontier size: {len(frontier)}.",
)
print("[END FIGURE]")
def plot_pareto(
output_dir: Path,
user_count_var: str | None,
gpu_count_var: str | None,
label_by: list[str],
*,
dry_run: bool,
):
fig_dir = output_dir / "pareto"
raw_data = [
run_data
for path in output_dir.rglob("**/summary.json")
for run_data in _json_load_bytes(path)
]
if not raw_data:
raise ValueError(f"Did not find any parameter sweep results under {output_dir}")
fig_dir.mkdir(parents=True, exist_ok=True)
prepared_data, skipped_missing_users = _prepare_records(
raw_data,
user_count_var=user_count_var,
gpu_count_var=gpu_count_var,
)
if skipped_missing_users:
print(
f"Skipped {skipped_missing_users} runs without a user count "
"(`max_concurrency` or `max_concurrent_requests`).",
)
if not prepared_data:
raise ValueError(
"No data points with both throughput and user count available "
"to plot Pareto frontier.",
)
fig_groups = full_groupby(
prepared_data,
key=lambda item: tuple(),
)
with DummyExecutor() if len(fig_groups) <= 1 else ProcessPoolExecutor() as executor:
all(
executor.map(
partial(
_plot_fig,
fig_dir,
label_by=label_by,
dry_run=dry_run,
),
fig_groups,
)
)
@dataclass
class SweepPlotParetoArgs:
output_dir: Path
user_count_var: str | None
gpu_count_var: str | None
label_by: list[str]
dry_run: bool
parser_name: ClassVar[str] = "plot_pareto"
parser_help: ClassVar[str] = (
"Plot Pareto frontier between tokens/s/user and tokens/s/GPU "
"from parameter sweep results."
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
output_dir = Path(args.OUTPUT_DIR)
if not output_dir.exists():
raise ValueError(f"No parameter sweep results under {output_dir}")
label_by = [] if not args.label_by else args.label_by.split(",")
return cls(
output_dir=output_dir,
user_count_var=args.user_count_var,
gpu_count_var=args.gpu_count_var,
label_by=label_by,
dry_run=args.dry_run,
)
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser):
parser.add_argument(
"OUTPUT_DIR",
type=str,
default="results",
help="The directory containing the sweep results to plot.",
)
parser.add_argument(
"--user-count-var",
type=str,
default="max_concurrency",
help="Result key that stores concurrent user count. "
"Falls back to max_concurrent_requests if missing.",
)
parser.add_argument(
"--gpu-count-var",
type=str,
default=None,
help="Result key that stores GPU count. "
"If not provided, falls back to num_gpus/gpu_count "
"or tensor_parallel_size * pipeline_parallel_size.",
)
parser.add_argument(
"--label-by",
type=str,
default="max_concurrency,gpu_count",
help="Comma-separated list of fields to annotate on Pareto frontier "
"points.",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="If set, prints the figures to plot without drawing them.",
)
return parser
def run_main(args: SweepPlotParetoArgs):
return plot_pareto(
output_dir=args.output_dir,
user_count_var=args.user_count_var,
gpu_count_var=args.gpu_count_var,
label_by=args.label_by,
dry_run=args.dry_run,
)
def main(args: argparse.Namespace):
run_main(SweepPlotParetoArgs.from_cli_args(args))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=SweepPlotParetoArgs.parser_help)
SweepPlotParetoArgs.add_cli_args(parser)
main(parser.parse_args())

View File

@@ -0,0 +1,498 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import contextlib
import json
import shlex
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import ClassVar
from vllm.utils.import_utils import PlaceholderModule
from .param_sweep import ParameterSweep, ParameterSweepItem
from .server import ServerProcess
from .utils import sanitize_filename
try:
import pandas as pd
except ImportError:
pd = PlaceholderModule("pandas")
@contextlib.contextmanager
def run_server(
serve_cmd: list[str],
after_bench_cmd: list[str],
*,
show_stdout: bool,
serve_overrides: ParameterSweepItem,
dry_run: bool,
server_ready_timeout: int = 300,
):
server_cmd = serve_overrides.apply_to_cmd(serve_cmd)
print("[BEGIN SERVER]")
print(f"Server overrides: {serve_overrides}")
print(f"Server command: {server_cmd}")
if dry_run:
yield None
print("[END SERVER]")
return
with ServerProcess(server_cmd, after_bench_cmd, show_stdout=show_stdout) as server:
server.wait_until_ready(timeout=server_ready_timeout)
yield server
print("[END SERVER]")
def _update_run_data(
run_data: dict[str, object],
serve_overrides: ParameterSweepItem,
bench_overrides: ParameterSweepItem,
run_number: int,
):
run_data["run_number"] = run_number
run_data.update(serve_overrides)
run_data.update(bench_overrides)
return run_data
def run_benchmark(
server: ServerProcess | None,
bench_cmd: list[str],
*,
serve_overrides: ParameterSweepItem,
bench_overrides: ParameterSweepItem,
run_number: int,
output_path: Path,
dry_run: bool,
):
benchmark_cmd = [
*bench_overrides.apply_to_cmd(bench_cmd),
"--percentile-metrics",
"ttft,tpot,itl,e2el",
"--save-result",
"--result-dir",
str(output_path.parent),
"--result-filename",
output_path.name,
]
print("[BEGIN BENCHMARK]")
print(f"Benchmark overrides: {bench_overrides}")
print(f"Run Number: {run_number}")
print(f"Benchmark command: {benchmark_cmd}")
print(f"Output file: {output_path}")
run_data: dict[str, object]
if output_path.exists():
print("Found existing results.")
print("[SKIPPED BENCHMARK]")
with output_path.open("rb") as f:
run_data = json.load(f)
return _update_run_data(
run_data,
serve_overrides,
bench_overrides,
run_number,
)
if server is None:
if not dry_run:
raise ValueError(f"Cannot find results at {output_path}")
print("[END BENCHMARK]")
return None
output_path.parent.mkdir(parents=True, exist_ok=True)
server.run_subcommand(benchmark_cmd)
server.after_bench()
with output_path.open("rb") as f:
run_data = json.load(f)
run_data = _update_run_data(
run_data,
serve_overrides,
bench_overrides,
run_number,
)
with output_path.open("w") as f:
json.dump(run_data, f, indent=4)
print("[END BENCHMARK]")
return run_data
def _get_comb_base_path(
output_dir: Path,
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
):
parts = list[str]()
if serve_comb:
parts.extend(("SERVE-", serve_comb.name))
if bench_comb:
parts.extend(("BENCH-", bench_comb.name))
return output_dir / sanitize_filename("-".join(parts))
def _get_comb_run_path(base_path: Path, run_number: int | None):
if run_number is None:
return base_path / "summary.json"
return base_path / f"run={run_number}.json"
def _comb_needs_server(
serve_comb: ParameterSweepItem,
bench_combs: ParameterSweep,
output_dir: Path,
):
for bench_comb in bench_combs:
base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb)
if not _get_comb_run_path(base_path, run_number=None).exists():
return True
return False
def server_ctx(
serve_cmd: list[str],
after_bench_cmd: list[str],
*,
show_stdout: bool,
serve_comb: ParameterSweepItem,
bench_params: ParameterSweep,
output_dir: Path,
dry_run: bool,
server_ready_timeout: int = 300,
):
if not _comb_needs_server(serve_comb, bench_params, output_dir):
return contextlib.nullcontext()
return run_server(
serve_cmd,
after_bench_cmd,
show_stdout=show_stdout,
serve_overrides=serve_comb,
dry_run=dry_run,
server_ready_timeout=server_ready_timeout,
)
def _comb_is_valid(
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
link_vars: list[tuple[str, str]],
) -> bool:
return all(
serve_key in serve_comb
and bench_key in bench_comb
and serve_comb[serve_key] == bench_comb[bench_key]
for serve_key, bench_key in link_vars
)
def run_comb(
server: ServerProcess | None,
bench_cmd: list[str],
*,
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
base_path: Path,
num_runs: int,
dry_run: bool,
link_vars: list[tuple[str, str]],
):
if not _comb_is_valid(serve_comb, bench_comb, link_vars):
return None
comb_data = list[dict[str, object]]()
for run_number in range(num_runs):
run_data = run_benchmark(
server,
bench_cmd,
serve_overrides=serve_comb,
bench_overrides=bench_comb,
run_number=run_number,
output_path=_get_comb_run_path(base_path, run_number),
dry_run=dry_run,
)
if run_data is not None:
comb_data.append(run_data)
if dry_run:
return None
with _get_comb_run_path(base_path, run_number=None).open("w") as f:
json.dump(comb_data, f, indent=4)
return comb_data
def run_combs(
serve_cmd: list[str],
bench_cmd: list[str],
after_bench_cmd: list[str],
*,
show_stdout: bool,
server_ready_timeout: int,
serve_params: ParameterSweep,
bench_params: ParameterSweep,
output_dir: Path,
num_runs: int,
dry_run: bool,
link_vars: list[tuple[str, str]],
):
all_data = list[dict[str, object]]()
for serve_comb in serve_params:
with server_ctx(
serve_cmd,
after_bench_cmd,
show_stdout=show_stdout,
serve_comb=serve_comb,
bench_params=bench_params,
output_dir=output_dir,
dry_run=dry_run,
server_ready_timeout=server_ready_timeout,
) as server:
for bench_comb in bench_params:
base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb)
comb_data = run_comb(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb,
base_path=base_path,
num_runs=num_runs,
dry_run=dry_run,
link_vars=link_vars,
)
if comb_data is not None:
all_data.extend(comb_data)
if dry_run:
return None
combined_df = pd.DataFrame.from_records(all_data)
combined_df.to_csv(output_dir / "summary.csv")
return combined_df
@dataclass
class SweepServeArgs:
serve_cmd: list[str]
bench_cmd: list[str]
after_bench_cmd: list[str]
show_stdout: bool
server_ready_timeout: int
serve_params: ParameterSweep
bench_params: ParameterSweep
output_dir: Path
num_runs: int
dry_run: bool
resume: str | None
link_vars: list[tuple[str, str]]
parser_name: ClassVar[str] = "serve"
parser_help: ClassVar[str] = "Run vLLM server benchmark under multiple settings."
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
serve_cmd = shlex.split(args.serve_cmd)
bench_cmd = shlex.split(args.bench_cmd)
after_bench_cmd = (
[] if args.after_bench_cmd is None else shlex.split(args.after_bench_cmd)
)
if args.serve_params:
serve_params = ParameterSweep.read_json(args.serve_params)
else:
# i.e.: run serve_cmd without any modification
serve_params = ParameterSweep.from_records([{}])
if args.bench_params:
bench_params = ParameterSweep.read_json(args.bench_params)
else:
# i.e.: run bench_cmd without any modification
bench_params = ParameterSweep.from_records([{}])
link_vars = cls.parse_link_vars(args.link_vars)
num_runs = args.num_runs
if num_runs < 1:
raise ValueError("`num_runs` should be at least 1.")
return cls(
serve_cmd=serve_cmd,
bench_cmd=bench_cmd,
after_bench_cmd=after_bench_cmd,
show_stdout=args.show_stdout,
serve_params=serve_params,
bench_params=bench_params,
output_dir=Path(args.output_dir),
num_runs=num_runs,
dry_run=args.dry_run,
resume=args.resume,
link_vars=link_vars,
server_ready_timeout=args.server_ready_timeout,
)
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument(
"--serve-cmd",
type=str,
required=True,
help="The command used to run the server: `vllm serve ...`",
)
parser.add_argument(
"--bench-cmd",
type=str,
required=True,
help="The command used to run the benchmark: `vllm bench serve ...`",
)
parser.add_argument(
"--after-bench-cmd",
type=str,
default=None,
help="After a benchmark run is complete, invoke this command instead of "
"the default `ServerWrapper.clear_cache()`.",
)
parser.add_argument(
"--show-stdout",
action="store_true",
help="If set, logs the standard output of subcommands. "
"Useful for debugging but can be quite spammy.",
)
parser.add_argument(
"--server-ready-timeout",
type=int,
default=300,
help="Timeout in seconds to wait for the server to become ready.",
)
parser.add_argument(
"--serve-params",
type=str,
default=None,
help="Path to JSON file containing parameter combinations "
"for the `vllm serve` command. Can be either a list of dicts or a dict "
"where keys are benchmark names. "
"If both `serve_params` and `bench_params` are given, "
"this script will iterate over their Cartesian product.",
)
parser.add_argument(
"--bench-params",
type=str,
default=None,
help="Path to JSON file containing parameter combinations "
"for the `vllm bench serve` command. Can be either a list of dicts or "
"a dict where keys are benchmark names. "
"If both `serve_params` and `bench_params` are given, "
"this script will iterate over their Cartesian product.",
)
parser.add_argument(
"-o",
"--output-dir",
type=str,
default="results",
help="The directory to which results are written.",
)
parser.add_argument(
"--num-runs",
type=int,
default=3,
help="Number of runs per parameter combination.",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="If set, prints the commands to run, "
"then exits without executing them.",
)
parser.add_argument(
"--resume",
type=str,
default=None,
help="Set this to the name of a directory under `output_dir` (which is a "
"timestamp) to resume a previous execution of this script, i.e., only run "
"parameter combinations for which there are still no output files.",
)
parser.add_argument(
"--link-vars",
type=str,
default="",
help=(
"Comma-separated list of linked variables between serve and bench, "
"e.g. max_num_seqs=max_concurrency,max_model_len=random_input_len"
),
)
return parser
@staticmethod
def parse_link_vars(s: str) -> list[tuple[str, str]]:
if not s:
return []
pairs = []
for item in s.split(","):
a, b = item.split("=")
pairs.append((a.strip(), b.strip()))
return pairs
def run_main(args: SweepServeArgs):
timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = args.output_dir / timestamp
if args.resume and not output_dir.exists():
raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")
try:
return run_combs(
serve_cmd=args.serve_cmd,
bench_cmd=args.bench_cmd,
after_bench_cmd=args.after_bench_cmd,
show_stdout=args.show_stdout,
server_ready_timeout=args.server_ready_timeout,
serve_params=args.serve_params,
bench_params=args.bench_params,
output_dir=output_dir,
num_runs=args.num_runs,
dry_run=args.dry_run,
link_vars=args.link_vars,
)
except BaseException as exc:
raise RuntimeError(
f"The script was terminated early. Use `--resume {timestamp}` "
f"to continue the script from its last checkpoint."
) from exc
def main(args: argparse.Namespace):
run_main(SweepServeArgs.from_cli_args(args))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=SweepServeArgs.parser_help)
SweepServeArgs.add_cli_args(parser)
main(parser.parse_args())

View File

@@ -0,0 +1,305 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import math
from dataclasses import asdict, dataclass
from datetime import datetime
from pathlib import Path
from typing import ClassVar, Literal, get_args
import numpy as np
from typing_extensions import assert_never
from vllm.utils.import_utils import PlaceholderModule
from .param_sweep import ParameterSweep, ParameterSweepItem
from .serve import (
SweepServeArgs,
_get_comb_base_path,
run_comb,
server_ctx,
)
from .server import ServerProcess
try:
import pandas as pd
except ImportError:
pd = PlaceholderModule("pandas")
SLAVariable = Literal["request_rate", "max_concurrency"]
def _estimate_sla_value(run_data: dict[str, object], sla_variable: SLAVariable):
request_throughput = float(run_data["request_throughput"]) # type: ignore
if sla_variable == "request_rate":
return request_throughput
if sla_variable == "max_concurrency":
mean_latency_ms = float(run_data["mean_e2el_ms"]) # type: ignore
return request_throughput * mean_latency_ms / 1000
assert_never(sla_variable)
def _estimate_sla_avg(runs: list[dict[str, object]], sla_variable: SLAVariable):
return sum(_estimate_sla_value(run, sla_variable) for run in runs) / len(runs)
def run_comb_sla(
server: ServerProcess | None,
bench_cmd: list[str],
*,
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
output_dir: Path,
num_runs: int,
dry_run: bool,
link_vars: list[tuple[str, str]],
sla_variable: SLAVariable,
sla_value: int,
) -> list[dict[str, object]] | None:
bench_comb_sla = bench_comb | {sla_variable: sla_value}
return run_comb(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb_sla,
base_path=_get_comb_base_path(output_dir, serve_comb, bench_comb_sla),
num_runs=num_runs,
dry_run=dry_run,
link_vars=link_vars,
)
def explore_sla(
server: ServerProcess | None,
bench_cmd: list[str],
*,
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
sla_variable: SLAVariable,
sla_iters: int,
output_dir: Path,
num_runs: int,
dry_run: bool,
link_vars: list[tuple[str, str]],
):
print("[SLA START]")
print(f"Serve parameters: {serve_comb.as_text() or '(None)'}")
print(f"Bench parameters: {bench_comb.as_text() or '(None)'}")
print(f"Number of SLA iterations: {sla_iters}")
if sla_iters < 2:
raise ValueError("`sla_iters` should be at least 2")
serial_comb_data = run_comb_sla(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb,
output_dir=output_dir,
num_runs=num_runs,
dry_run=dry_run,
link_vars=link_vars,
sla_variable=sla_variable,
sla_value=1,
)
batch_comb_data = run_comb_sla(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb,
output_dir=output_dir,
num_runs=num_runs,
dry_run=dry_run,
link_vars=link_vars,
sla_variable=sla_variable,
sla_value=int(bench_comb.get("num_prompts", 1000)), # type: ignore
)
if serial_comb_data is None or batch_comb_data is None:
if dry_run:
print("Omitting intermediate SLA iterations.")
print("[SLA END]")
return
serial_sla_value = math.ceil(_estimate_sla_avg(serial_comb_data, sla_variable))
print(f"Serial inference: {sla_variable}={serial_sla_value}")
batch_sla_value = math.floor(_estimate_sla_avg(batch_comb_data, sla_variable))
print(f"Batch inference: {sla_variable}={batch_sla_value}")
# Avoid duplicated runs for intermediate values if the range between
# `serial_sla_value` and `batch_sla_value` is small
inter_sla_values = np.linspace(serial_sla_value, batch_sla_value, sla_iters)[1:-1]
inter_sla_values = sorted(set(map(round, inter_sla_values)))
inter_combs_data: list[dict[str, object]] = []
for inter_sla_value in inter_sla_values:
print(f"Exploring: {sla_variable}={inter_sla_value}")
inter_comb_data = run_comb_sla(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb,
output_dir=output_dir,
num_runs=num_runs,
dry_run=dry_run,
link_vars=link_vars,
sla_variable=sla_variable,
sla_value=inter_sla_value,
)
if inter_comb_data is not None:
inter_combs_data.extend(inter_comb_data)
print("[SLA END]")
return serial_comb_data + inter_combs_data + batch_comb_data
def run_slas(
serve_cmd: list[str],
bench_cmd: list[str],
after_bench_cmd: list[str],
*,
show_stdout: bool,
server_ready_timeout: int,
serve_params: ParameterSweep,
bench_params: ParameterSweep,
sla_variable: SLAVariable,
sla_iters: int,
output_dir: Path,
num_runs: int,
dry_run: bool,
link_vars: list[tuple[str, str]],
):
if any(bench_comb.has_param(sla_variable) for bench_comb in bench_params):
raise ValueError(
f"You should not override `{sla_variable}` in `bench_params` in SLA mode, "
"since it is supposed to be determined automatically."
)
all_data = list[dict[str, object]]()
for serve_comb in serve_params:
with server_ctx(
serve_cmd,
after_bench_cmd,
show_stdout=show_stdout,
server_ready_timeout=server_ready_timeout,
serve_comb=serve_comb,
bench_params=bench_params,
output_dir=output_dir,
dry_run=dry_run,
) as server:
for bench_comb in bench_params:
comb_data = explore_sla(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb,
sla_variable=sla_variable,
sla_iters=sla_iters,
output_dir=output_dir,
num_runs=num_runs,
dry_run=dry_run,
link_vars=link_vars,
)
if comb_data is not None:
all_data.extend(comb_data)
if dry_run:
return None
combined_df = pd.DataFrame.from_records(all_data)
combined_df.to_csv(output_dir / "summary.csv")
return combined_df
@dataclass
class SweepServeSLAArgs(SweepServeArgs):
sla_variable: SLAVariable
sla_iters: int
parser_name: ClassVar[str] = "serve_sla"
parser_help: ClassVar[str] = (
"Explore the latency-throughput space for determining SLAs."
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
# NOTE: Don't use super() as `from_cli_args` calls `cls()`
base_args = SweepServeArgs.from_cli_args(args)
return cls(
**asdict(base_args),
sla_variable=args.sla_variable,
sla_iters=args.sla_iters,
)
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser = super().add_cli_args(parser)
sla_group = parser.add_argument_group("sla options")
sla_group.add_argument(
"--sla-variable",
type=str,
choices=get_args(SLAVariable),
default="request_rate",
help="The variable to adjust in each iteration.",
)
sla_group.add_argument(
"--sla-iters",
type=int,
default=10,
help="Number of iterations used to explore the latency-throughput space. "
"This includes the first two iterations used to interpolate the value of "
"`sla_variable` for remaining iterations.",
)
return parser
def run_main(args: SweepServeSLAArgs):
timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = args.output_dir / timestamp
if args.resume and not output_dir.exists():
raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")
try:
return run_slas(
serve_cmd=args.serve_cmd,
bench_cmd=args.bench_cmd,
after_bench_cmd=args.after_bench_cmd,
show_stdout=args.show_stdout,
server_ready_timeout=args.server_ready_timeout,
serve_params=args.serve_params,
bench_params=args.bench_params,
sla_variable=args.sla_variable,
sla_iters=args.sla_iters,
output_dir=output_dir,
num_runs=args.num_runs,
dry_run=args.dry_run,
link_vars=args.link_vars,
)
except BaseException as exc:
raise RuntimeError(
f"The script was terminated early. Use `--resume {timestamp}` "
f"to continue the script from its last checkpoint."
) from exc
def main(args: argparse.Namespace):
run_main(SweepServeSLAArgs.from_cli_args(args))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=SweepServeSLAArgs.parser_help)
SweepServeSLAArgs.add_cli_args(parser)
main(parser.parse_args())

View File

@@ -0,0 +1,142 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import os
import signal
import subprocess
import time
from types import TracebackType
import requests
from typing_extensions import Self
class ServerProcess:
VLLM_RESET_CACHE_ENDPOINTS = [
"/reset_prefix_cache",
"/reset_mm_cache",
"/reset_encoder_cache",
]
def __init__(
self,
server_cmd: list[str],
after_bench_cmd: list[str],
*,
show_stdout: bool,
) -> None:
super().__init__()
self.server_cmd = server_cmd
self.after_bench_cmd = after_bench_cmd
self.show_stdout = show_stdout
def __enter__(self) -> Self:
self.start()
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
exc_traceback: TracebackType | None,
) -> None:
self.stop()
def start(self):
# Create new process for clean termination
self._server_process = subprocess.Popen(
self.server_cmd,
start_new_session=True,
stdout=None if self.show_stdout else subprocess.DEVNULL,
# Need `VLLM_SERVER_DEV_MODE=1` for `_reset_caches`
env=os.environ | {"VLLM_SERVER_DEV_MODE": "1"},
)
def stop(self):
server_process = self._server_process
if server_process.poll() is None:
# In case only some processes have been terminated
with contextlib.suppress(ProcessLookupError):
# We need to kill both API Server and Engine processes
os.killpg(os.getpgid(server_process.pid), signal.SIGKILL)
def run_subcommand(self, cmd: list[str]):
return subprocess.run(
cmd,
stdout=None if self.show_stdout else subprocess.DEVNULL,
check=True,
)
def after_bench(self) -> None:
if not self.after_bench_cmd:
self.reset_caches()
return
self.run_subcommand(self.after_bench_cmd)
def _get_vllm_server_address(self) -> str:
server_cmd = self.server_cmd
for host_key in ("--host",):
if host_key in server_cmd:
host = server_cmd[server_cmd.index(host_key) + 1]
break
else:
host = "localhost"
for port_key in ("-p", "--port"):
if port_key in server_cmd:
port = int(server_cmd[server_cmd.index(port_key) + 1])
break
else:
port = 8000 # The default value in vllm serve
return f"http://{host}:{port}"
def is_server_ready(self) -> bool:
server_address = self._get_vllm_server_address()
try:
response = requests.get(f"{server_address}/health")
return response.status_code == 200
except requests.RequestException:
return False
def wait_until_ready(self, timeout: int) -> None:
start_time = time.monotonic()
while not self.is_server_ready():
# Check if server process has crashed
if self._server_process.poll() is not None:
returncode = self._server_process.returncode
raise RuntimeError(
f"Server process crashed with return code {returncode}"
)
if time.monotonic() - start_time > timeout:
raise TimeoutError(
f"Server failed to become ready within {timeout} seconds."
)
time.sleep(1)
def reset_caches(self) -> None:
server_cmd = self.server_cmd
# Use `.endswith()` to match `/bin/...`
if server_cmd[0].endswith("vllm"):
server_address = self._get_vllm_server_address()
print(f"Resetting caches at {server_address}")
for endpoint in self.VLLM_RESET_CACHE_ENDPOINTS:
res = requests.post(server_address + endpoint)
res.raise_for_status()
elif server_cmd[0].endswith("infinity_emb"):
if "--vector-disk-cache" in server_cmd:
raise NotImplementedError(
"Infinity server uses caching but does not expose a method "
"to reset the cache"
)
else:
raise NotImplementedError(
f"No implementation of `reset_caches` for `{server_cmd[0]}` server. "
"Please specify a custom command via `--after-bench-cmd`."
)

View File

@@ -0,0 +1,406 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import json
import shlex
import subprocess
from dataclasses import dataclass
from datetime import datetime
from functools import lru_cache
from pathlib import Path
from typing import ClassVar
from vllm.benchmarks.startup import add_cli_args as add_startup_cli_args
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.import_utils import PlaceholderModule
from .param_sweep import ParameterSweep, ParameterSweepItem
from .utils import sanitize_filename
try:
import pandas as pd
except ImportError:
pd = PlaceholderModule("pandas")
@lru_cache(maxsize=1)
def _get_supported_startup_keys() -> set[str]:
parser = FlexibleArgumentParser(add_help=False)
add_startup_cli_args(parser)
supported: set[str] = {"config"}
for action in parser._actions:
if action.dest and action.dest is not argparse.SUPPRESS:
supported.add(action.dest)
for option in action.option_strings:
if option.startswith("--"):
supported.add(option.lstrip("-").replace("-", "_"))
return supported
def _is_supported_param(param_key: str, supported: set[str]) -> bool:
if param_key == "_benchmark_name":
return True
prefix = param_key.split(".", 1)[0]
normalized = prefix.replace("-", "_")
return normalized in supported
def _filter_params(
params: ParameterSweep, *, supported: set[str], strict: bool
) -> ParameterSweep:
filtered = []
for item in params:
kept: dict[str, object] = {}
dropped: list[str] = []
for key, value in item.items():
if _is_supported_param(key, supported):
kept[key] = value
else:
dropped.append(key)
if dropped:
label = item.get("_benchmark_name") or item.as_text()
message = (
"Ignoring unsupported startup params"
f"{' for ' + str(label) if label else ''}: "
f"{', '.join(sorted(dropped))}"
)
if strict:
raise ValueError(message)
print(message)
filtered.append(ParameterSweepItem.from_record(kept))
return ParameterSweep(filtered)
def _update_run_data(
run_data: dict[str, object],
serve_overrides: ParameterSweepItem,
startup_overrides: ParameterSweepItem,
run_number: int,
) -> dict[str, object]:
run_data["run_number"] = run_number
run_data.update(serve_overrides)
run_data.update(startup_overrides)
return run_data
def _strip_arg(cmd: list[str], keys: tuple[str, ...]) -> list[str]:
stripped: list[str] = []
skip_next = False
for arg in cmd:
if skip_next:
skip_next = False
continue
if arg in keys:
skip_next = True
continue
if any(arg.startswith(f"{key}=") for key in keys):
continue
stripped.append(arg)
return stripped
def _apply_output_json(cmd: list[str], output_path: Path) -> list[str]:
keys = ("--output-json", "--output_json")
cmd = _strip_arg(cmd, keys)
return [*cmd, keys[0], str(output_path)]
def _get_comb_base_path(
output_dir: Path,
serve_comb: ParameterSweepItem,
startup_comb: ParameterSweepItem,
) -> Path:
parts = list[str]()
if serve_comb:
parts.extend(("SERVE-", serve_comb.name))
if startup_comb:
parts.extend(("STARTUP-", startup_comb.name))
return output_dir / sanitize_filename("-".join(parts))
def _get_comb_run_path(base_path: Path, run_number: int | None) -> Path:
if run_number is None:
return base_path / "summary.json"
return base_path / f"run={run_number}.json"
def run_benchmark(
startup_cmd: list[str],
*,
serve_overrides: ParameterSweepItem,
startup_overrides: ParameterSweepItem,
run_number: int,
output_path: Path,
show_stdout: bool,
dry_run: bool,
) -> dict[str, object] | None:
cmd = serve_overrides.apply_to_cmd(startup_cmd)
cmd = startup_overrides.apply_to_cmd(cmd)
cmd = _apply_output_json(cmd, output_path)
print("[BEGIN BENCHMARK]")
print(f"Serve overrides: {serve_overrides}")
print(f"Startup overrides: {startup_overrides}")
print(f"Run Number: {run_number}")
print(f"Benchmark command: {cmd}")
print(f"Output file: {output_path}")
if output_path.exists():
print("Found existing results.")
print("[SKIPPED BENCHMARK]")
with output_path.open("r", encoding="utf-8") as f:
run_data = json.load(f)
return _update_run_data(
run_data, serve_overrides, startup_overrides, run_number
)
if dry_run:
print("[END BENCHMARK]")
return None
output_path.parent.mkdir(parents=True, exist_ok=True)
subprocess.run(
cmd,
stdout=None if show_stdout else subprocess.DEVNULL,
check=True,
)
with output_path.open("r", encoding="utf-8") as f:
run_data = json.load(f)
run_data = _update_run_data(
run_data, serve_overrides, startup_overrides, run_number
)
with output_path.open("w", encoding="utf-8") as f:
json.dump(run_data, f, indent=4)
print("[END BENCHMARK]")
return run_data
def run_comb(
startup_cmd: list[str],
*,
serve_comb: ParameterSweepItem,
startup_comb: ParameterSweepItem,
base_path: Path,
num_runs: int,
show_stdout: bool,
dry_run: bool,
) -> list[dict[str, object]] | None:
comb_data = list[dict[str, object]]()
for run_number in range(num_runs):
run_data = run_benchmark(
startup_cmd,
serve_overrides=serve_comb,
startup_overrides=startup_comb,
run_number=run_number,
output_path=_get_comb_run_path(base_path, run_number),
show_stdout=show_stdout,
dry_run=dry_run,
)
if run_data is not None:
comb_data.append(run_data)
if dry_run:
return None
with _get_comb_run_path(base_path, run_number=None).open(
"w", encoding="utf-8"
) as f:
json.dump(comb_data, f, indent=4)
return comb_data
def run_combs(
startup_cmd: list[str],
*,
serve_params: ParameterSweep,
startup_params: ParameterSweep,
output_dir: Path,
num_runs: int,
show_stdout: bool,
dry_run: bool,
) -> "pd.DataFrame | None":
all_data = list[dict[str, object]]()
for serve_comb in serve_params:
for startup_comb in startup_params:
base_path = _get_comb_base_path(output_dir, serve_comb, startup_comb)
comb_data = run_comb(
startup_cmd,
serve_comb=serve_comb,
startup_comb=startup_comb,
base_path=base_path,
num_runs=num_runs,
show_stdout=show_stdout,
dry_run=dry_run,
)
if comb_data is not None:
all_data.extend(comb_data)
if dry_run:
return None
combined_df = pd.DataFrame.from_records(all_data)
combined_df.to_csv(output_dir / "summary.csv")
return combined_df
@dataclass
class SweepStartupArgs:
startup_cmd: list[str]
serve_params: ParameterSweep
startup_params: ParameterSweep
output_dir: Path
num_runs: int
show_stdout: bool
dry_run: bool
resume: str | None
strict_params: bool
parser_name: ClassVar[str] = "startup"
parser_help: ClassVar[str] = (
"Benchmark vLLM startup time over parameter combinations."
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
startup_cmd = shlex.split(args.startup_cmd)
if args.serve_params:
serve_params = ParameterSweep.read_json(args.serve_params)
else:
serve_params = ParameterSweep.from_records([{}])
if args.startup_params:
startup_params = ParameterSweep.read_json(args.startup_params)
else:
startup_params = ParameterSweep.from_records([{}])
supported = _get_supported_startup_keys()
serve_params = _filter_params(
serve_params, supported=supported, strict=args.strict_params
)
startup_params = _filter_params(
startup_params, supported=supported, strict=args.strict_params
)
if args.num_runs < 1:
raise ValueError("`num_runs` should be at least 1.")
return cls(
startup_cmd=startup_cmd,
serve_params=serve_params,
startup_params=startup_params,
output_dir=Path(args.output_dir),
num_runs=args.num_runs,
show_stdout=args.show_stdout,
dry_run=args.dry_run,
resume=args.resume,
strict_params=args.strict_params,
)
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument(
"--startup-cmd",
type=str,
default="vllm bench startup",
help="The command used to run the startup benchmark.",
)
parser.add_argument(
"--serve-params",
type=str,
default=None,
help="Path to JSON file containing parameter combinations "
"for the `vllm serve` command. Only parameters supported by "
"`vllm bench startup` will be applied.",
)
parser.add_argument(
"--startup-params",
type=str,
default=None,
help="Path to JSON file containing parameter combinations "
"for the `vllm bench startup` command.",
)
parser.add_argument(
"-o",
"--output-dir",
type=str,
default="results",
help="The directory to which results are written.",
)
parser.add_argument(
"--num-runs",
type=int,
default=1,
help="Number of runs per parameter combination.",
)
parser.add_argument(
"--show-stdout",
action="store_true",
help="If set, logs the standard output of subcommands.",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="If set, prints the commands to run, "
"then exits without executing them.",
)
parser.add_argument(
"--resume",
type=str,
default=None,
help="Set this to the name of a directory under `output_dir` (which is a "
"timestamp) to resume a previous execution of this script, i.e., only run "
"parameter combinations for which there are still no output files.",
)
parser.add_argument(
"--strict-params",
action="store_true",
help="If set, unknown parameters in sweep files raise an error "
"instead of being ignored.",
)
return parser
def run_main(args: SweepStartupArgs):
timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = args.output_dir / timestamp
if args.resume and not output_dir.exists():
raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")
try:
return run_combs(
startup_cmd=args.startup_cmd,
serve_params=args.serve_params,
startup_params=args.startup_params,
output_dir=output_dir,
num_runs=args.num_runs,
show_stdout=args.show_stdout,
dry_run=args.dry_run,
)
except BaseException as exc:
raise RuntimeError(
f"The script was terminated early. Use `--resume {timestamp}` "
f"to continue the script from its last checkpoint."
) from exc
def main(args: argparse.Namespace):
run_main(SweepStartupArgs.from_cli_args(args))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=SweepStartupArgs.parser_help)
SweepStartupArgs.add_cli_args(parser)
main(parser.parse_args())

View File

@@ -0,0 +1,4 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
def sanitize_filename(filename: str) -> str:
return filename.replace("/", "_").replace("..", "__").strip("'").strip('"')