Files
2026-04-09 11:23:47 +08:00

306 lines
9.2 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import math
from dataclasses import asdict, dataclass
from datetime import datetime
from pathlib import Path
from typing import ClassVar, Literal, get_args
import numpy as np
from typing_extensions import assert_never
from vllm.utils.import_utils import PlaceholderModule
from .param_sweep import ParameterSweep, ParameterSweepItem
from .serve import (
SweepServeArgs,
_get_comb_base_path,
run_comb,
server_ctx,
)
from .server import ServerProcess
try:
import pandas as pd
except ImportError:
pd = PlaceholderModule("pandas")
SLAVariable = Literal["request_rate", "max_concurrency"]
def _estimate_sla_value(run_data: dict[str, object], sla_variable: SLAVariable):
request_throughput = float(run_data["request_throughput"]) # type: ignore
if sla_variable == "request_rate":
return request_throughput
if sla_variable == "max_concurrency":
mean_latency_ms = float(run_data["mean_e2el_ms"]) # type: ignore
return request_throughput * mean_latency_ms / 1000
assert_never(sla_variable)
def _estimate_sla_avg(runs: list[dict[str, object]], sla_variable: SLAVariable):
return sum(_estimate_sla_value(run, sla_variable) for run in runs) / len(runs)
def run_comb_sla(
server: ServerProcess | None,
bench_cmd: list[str],
*,
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
output_dir: Path,
num_runs: int,
dry_run: bool,
link_vars: list[tuple[str, str]],
sla_variable: SLAVariable,
sla_value: int,
) -> list[dict[str, object]] | None:
bench_comb_sla = bench_comb | {sla_variable: sla_value}
return run_comb(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb_sla,
base_path=_get_comb_base_path(output_dir, serve_comb, bench_comb_sla),
num_runs=num_runs,
dry_run=dry_run,
link_vars=link_vars,
)
def explore_sla(
server: ServerProcess | None,
bench_cmd: list[str],
*,
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
sla_variable: SLAVariable,
sla_iters: int,
output_dir: Path,
num_runs: int,
dry_run: bool,
link_vars: list[tuple[str, str]],
):
print("[SLA START]")
print(f"Serve parameters: {serve_comb.as_text() or '(None)'}")
print(f"Bench parameters: {bench_comb.as_text() or '(None)'}")
print(f"Number of SLA iterations: {sla_iters}")
if sla_iters < 2:
raise ValueError("`sla_iters` should be at least 2")
serial_comb_data = run_comb_sla(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb,
output_dir=output_dir,
num_runs=num_runs,
dry_run=dry_run,
link_vars=link_vars,
sla_variable=sla_variable,
sla_value=1,
)
batch_comb_data = run_comb_sla(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb,
output_dir=output_dir,
num_runs=num_runs,
dry_run=dry_run,
link_vars=link_vars,
sla_variable=sla_variable,
sla_value=int(bench_comb.get("num_prompts", 1000)), # type: ignore
)
if serial_comb_data is None or batch_comb_data is None:
if dry_run:
print("Omitting intermediate SLA iterations.")
print("[SLA END]")
return
serial_sla_value = math.ceil(_estimate_sla_avg(serial_comb_data, sla_variable))
print(f"Serial inference: {sla_variable}={serial_sla_value}")
batch_sla_value = math.floor(_estimate_sla_avg(batch_comb_data, sla_variable))
print(f"Batch inference: {sla_variable}={batch_sla_value}")
# Avoid duplicated runs for intermediate values if the range between
# `serial_sla_value` and `batch_sla_value` is small
inter_sla_values = np.linspace(serial_sla_value, batch_sla_value, sla_iters)[1:-1]
inter_sla_values = sorted(set(map(round, inter_sla_values)))
inter_combs_data: list[dict[str, object]] = []
for inter_sla_value in inter_sla_values:
print(f"Exploring: {sla_variable}={inter_sla_value}")
inter_comb_data = run_comb_sla(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb,
output_dir=output_dir,
num_runs=num_runs,
dry_run=dry_run,
link_vars=link_vars,
sla_variable=sla_variable,
sla_value=inter_sla_value,
)
if inter_comb_data is not None:
inter_combs_data.extend(inter_comb_data)
print("[SLA END]")
return serial_comb_data + inter_combs_data + batch_comb_data
def run_slas(
serve_cmd: list[str],
bench_cmd: list[str],
after_bench_cmd: list[str],
*,
show_stdout: bool,
server_ready_timeout: int,
serve_params: ParameterSweep,
bench_params: ParameterSweep,
sla_variable: SLAVariable,
sla_iters: int,
output_dir: Path,
num_runs: int,
dry_run: bool,
link_vars: list[tuple[str, str]],
):
if any(bench_comb.has_param(sla_variable) for bench_comb in bench_params):
raise ValueError(
f"You should not override `{sla_variable}` in `bench_params` in SLA mode, "
"since it is supposed to be determined automatically."
)
all_data = list[dict[str, object]]()
for serve_comb in serve_params:
with server_ctx(
serve_cmd,
after_bench_cmd,
show_stdout=show_stdout,
server_ready_timeout=server_ready_timeout,
serve_comb=serve_comb,
bench_params=bench_params,
output_dir=output_dir,
dry_run=dry_run,
) as server:
for bench_comb in bench_params:
comb_data = explore_sla(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb,
sla_variable=sla_variable,
sla_iters=sla_iters,
output_dir=output_dir,
num_runs=num_runs,
dry_run=dry_run,
link_vars=link_vars,
)
if comb_data is not None:
all_data.extend(comb_data)
if dry_run:
return None
combined_df = pd.DataFrame.from_records(all_data)
combined_df.to_csv(output_dir / "summary.csv")
return combined_df
@dataclass
class SweepServeSLAArgs(SweepServeArgs):
sla_variable: SLAVariable
sla_iters: int
parser_name: ClassVar[str] = "serve_sla"
parser_help: ClassVar[str] = (
"Explore the latency-throughput space for determining SLAs."
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
# NOTE: Don't use super() as `from_cli_args` calls `cls()`
base_args = SweepServeArgs.from_cli_args(args)
return cls(
**asdict(base_args),
sla_variable=args.sla_variable,
sla_iters=args.sla_iters,
)
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser = super().add_cli_args(parser)
sla_group = parser.add_argument_group("sla options")
sla_group.add_argument(
"--sla-variable",
type=str,
choices=get_args(SLAVariable),
default="request_rate",
help="The variable to adjust in each iteration.",
)
sla_group.add_argument(
"--sla-iters",
type=int,
default=10,
help="Number of iterations used to explore the latency-throughput space. "
"This includes the first two iterations used to interpolate the value of "
"`sla_variable` for remaining iterations.",
)
return parser
def run_main(args: SweepServeSLAArgs):
timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = args.output_dir / timestamp
if args.resume and not output_dir.exists():
raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")
try:
return run_slas(
serve_cmd=args.serve_cmd,
bench_cmd=args.bench_cmd,
after_bench_cmd=args.after_bench_cmd,
show_stdout=args.show_stdout,
server_ready_timeout=args.server_ready_timeout,
serve_params=args.serve_params,
bench_params=args.bench_params,
sla_variable=args.sla_variable,
sla_iters=args.sla_iters,
output_dir=output_dir,
num_runs=args.num_runs,
dry_run=args.dry_run,
link_vars=args.link_vars,
)
except BaseException as exc:
raise RuntimeError(
f"The script was terminated early. Use `--resume {timestamp}` "
f"to continue the script from its last checkpoint."
) from exc
def main(args: argparse.Namespace):
run_main(SweepServeSLAArgs.from_cli_args(args))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=SweepServeSLAArgs.parser_help)
SweepServeSLAArgs.add_cli_args(parser)
main(parser.parse_args())