Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -10,14 +10,14 @@ from .plot_pareto import SweepPlotParetoArgs
|
||||
from .plot_pareto import main as plot_pareto_main
|
||||
from .serve import SweepServeArgs
|
||||
from .serve import main as serve_main
|
||||
from .serve_sla import SweepServeSLAArgs
|
||||
from .serve_sla import main as serve_sla_main
|
||||
from .serve_workload import SweepServeWorkloadArgs
|
||||
from .serve_workload import main as serve_workload_main
|
||||
from .startup import SweepStartupArgs
|
||||
from .startup import main as startup_main
|
||||
|
||||
SUBCOMMANDS = (
|
||||
(SweepServeArgs, serve_main),
|
||||
(SweepServeSLAArgs, serve_sla_main),
|
||||
(SweepServeWorkloadArgs, serve_workload_main),
|
||||
(SweepStartupArgs, startup_main),
|
||||
(SweepPlotArgs, plot_main),
|
||||
(SweepPlotParetoArgs, plot_pareto_main),
|
||||
|
||||
@@ -324,6 +324,11 @@ def _plot_fig(
|
||||
df = filter_by.apply(df)
|
||||
df = bin_by.apply(df)
|
||||
|
||||
if len(df) == 0:
|
||||
print(f"No data to plot. Filters: {filter_by}")
|
||||
print("[END FIGURE]")
|
||||
return
|
||||
|
||||
# Sort by curve_by columns alphabetically for consistent legend ordering
|
||||
if curve_by:
|
||||
df = df.sort_values(by=curve_by)
|
||||
@@ -494,7 +499,7 @@ class SweepPlotArgs:
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
output_dir = Path(args.OUTPUT_DIR)
|
||||
output_dir = Path(args.EXPERIMENT_DIR)
|
||||
if not output_dir.exists():
|
||||
raise ValueError(f"No parameter sweep results under {output_dir}")
|
||||
|
||||
@@ -526,11 +531,9 @@ class SweepPlotArgs:
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
"OUTPUT_DIR",
|
||||
"EXPERIMENT_DIR",
|
||||
type=str,
|
||||
default="results",
|
||||
help="The directory containing the results to plot, "
|
||||
"i.e., the `--output-dir` argument to the parameter sweep script.",
|
||||
help="The directory containing the sweep results to plot.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fig-dir",
|
||||
@@ -570,13 +573,13 @@ class SweepPlotArgs:
|
||||
parser.add_argument(
|
||||
"--var-x",
|
||||
type=str,
|
||||
default="request_throughput",
|
||||
default="total_token_throughput",
|
||||
help="The variable for the x-axis.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--var-y",
|
||||
type=str,
|
||||
default="p99_ttft_ms",
|
||||
default="median_ttft_ms",
|
||||
help="The variable for the y-axis",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
@@ -325,7 +325,7 @@ class SweepPlotParetoArgs:
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
output_dir = Path(args.OUTPUT_DIR)
|
||||
output_dir = Path(args.EXPERIMENT_DIR)
|
||||
if not output_dir.exists():
|
||||
raise ValueError(f"No parameter sweep results under {output_dir}")
|
||||
|
||||
@@ -342,9 +342,8 @@ class SweepPlotParetoArgs:
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"OUTPUT_DIR",
|
||||
"EXPERIMENT_DIR",
|
||||
type=str,
|
||||
default="results",
|
||||
help="The directory containing the sweep results to plot.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
@@ -4,6 +4,7 @@ import argparse
|
||||
import contextlib
|
||||
import json
|
||||
import shlex
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
@@ -135,17 +136,21 @@ def run_benchmark(
|
||||
|
||||
|
||||
def _get_comb_base_path(
|
||||
output_dir: Path,
|
||||
experiment_dir: Path,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
*,
|
||||
extra_parts: tuple[str, ...] = (),
|
||||
):
|
||||
parts = list[str]()
|
||||
if serve_comb:
|
||||
parts.extend(("SERVE-", serve_comb.name))
|
||||
if bench_comb:
|
||||
parts.extend(("BENCH-", bench_comb.name))
|
||||
if extra_parts:
|
||||
parts.extend(extra_parts)
|
||||
|
||||
return output_dir / sanitize_filename("-".join(parts))
|
||||
return experiment_dir / sanitize_filename("-".join(parts))
|
||||
|
||||
|
||||
def _get_comb_run_path(base_path: Path, run_number: int | None):
|
||||
@@ -158,10 +163,10 @@ def _get_comb_run_path(base_path: Path, run_number: int | None):
|
||||
def _comb_needs_server(
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_combs: ParameterSweep,
|
||||
output_dir: Path,
|
||||
experiment_dir: Path,
|
||||
):
|
||||
for bench_comb in bench_combs:
|
||||
base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb)
|
||||
base_path = _get_comb_base_path(experiment_dir, serve_comb, bench_comb)
|
||||
if not _get_comb_run_path(base_path, run_number=None).exists():
|
||||
return True
|
||||
|
||||
@@ -175,11 +180,11 @@ def server_ctx(
|
||||
show_stdout: bool,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_params: ParameterSweep,
|
||||
output_dir: Path,
|
||||
experiment_dir: Path,
|
||||
dry_run: bool,
|
||||
server_ready_timeout: int = 300,
|
||||
):
|
||||
if not _comb_needs_server(serve_comb, bench_params, output_dir):
|
||||
if not _comb_needs_server(serve_comb, bench_params, experiment_dir):
|
||||
return contextlib.nullcontext()
|
||||
|
||||
return run_server(
|
||||
@@ -211,10 +216,10 @@ def run_comb(
|
||||
*,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
link_vars: list[tuple[str, str]],
|
||||
base_path: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
link_vars: list[tuple[str, str]],
|
||||
):
|
||||
if not _comb_is_valid(serve_comb, bench_comb, link_vars):
|
||||
return None
|
||||
@@ -253,10 +258,10 @@ def run_combs(
|
||||
server_ready_timeout: int,
|
||||
serve_params: ParameterSweep,
|
||||
bench_params: ParameterSweep,
|
||||
output_dir: Path,
|
||||
link_vars: list[tuple[str, str]],
|
||||
experiment_dir: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
link_vars: list[tuple[str, str]],
|
||||
):
|
||||
all_data = list[dict[str, object]]()
|
||||
for serve_comb in serve_params:
|
||||
@@ -266,22 +271,22 @@ def run_combs(
|
||||
show_stdout=show_stdout,
|
||||
serve_comb=serve_comb,
|
||||
bench_params=bench_params,
|
||||
output_dir=output_dir,
|
||||
experiment_dir=experiment_dir,
|
||||
dry_run=dry_run,
|
||||
server_ready_timeout=server_ready_timeout,
|
||||
) as server:
|
||||
for bench_comb in bench_params:
|
||||
base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb)
|
||||
base_path = _get_comb_base_path(experiment_dir, serve_comb, bench_comb)
|
||||
|
||||
comb_data = run_comb(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
link_vars=link_vars,
|
||||
base_path=base_path,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
link_vars=link_vars,
|
||||
)
|
||||
|
||||
if comb_data is not None:
|
||||
@@ -291,7 +296,7 @@ def run_combs(
|
||||
return None
|
||||
|
||||
combined_df = pd.DataFrame.from_records(all_data)
|
||||
combined_df.to_csv(output_dir / "summary.csv")
|
||||
combined_df.to_csv(experiment_dir / "summary.csv")
|
||||
|
||||
return combined_df
|
||||
|
||||
@@ -305,11 +310,12 @@ class SweepServeArgs:
|
||||
server_ready_timeout: int
|
||||
serve_params: ParameterSweep
|
||||
bench_params: ParameterSweep
|
||||
link_vars: list[tuple[str, str]]
|
||||
output_dir: Path
|
||||
experiment_name: str
|
||||
num_runs: int
|
||||
dry_run: bool
|
||||
resume: str | None
|
||||
link_vars: list[tuple[str, str]]
|
||||
resume: bool
|
||||
|
||||
parser_name: ClassVar[str] = "serve"
|
||||
parser_help: ClassVar[str] = "Run vLLM server benchmark under multiple settings."
|
||||
@@ -336,6 +342,11 @@ class SweepServeArgs:
|
||||
|
||||
link_vars = cls.parse_link_vars(args.link_vars)
|
||||
|
||||
if args.experiment_name:
|
||||
experiment_name = args.experiment_name
|
||||
else:
|
||||
experiment_name = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
num_runs = args.num_runs
|
||||
if num_runs < 1:
|
||||
raise ValueError("`num_runs` should be at least 1.")
|
||||
@@ -347,11 +358,12 @@ class SweepServeArgs:
|
||||
show_stdout=args.show_stdout,
|
||||
serve_params=serve_params,
|
||||
bench_params=bench_params,
|
||||
link_vars=link_vars,
|
||||
output_dir=Path(args.output_dir),
|
||||
experiment_name=experiment_name,
|
||||
num_runs=num_runs,
|
||||
dry_run=args.dry_run,
|
||||
resume=args.resume,
|
||||
link_vars=link_vars,
|
||||
server_ready_timeout=args.server_ready_timeout,
|
||||
)
|
||||
|
||||
@@ -388,6 +400,7 @@ class SweepServeArgs:
|
||||
default=300,
|
||||
help="Timeout in seconds to wait for the server to become ready.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--serve-params",
|
||||
type=str,
|
||||
@@ -398,6 +411,16 @@ class SweepServeArgs:
|
||||
"If both `serve_params` and `bench_params` are given, "
|
||||
"this script will iterate over their Cartesian product.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--link-vars",
|
||||
type=str,
|
||||
default="",
|
||||
help=(
|
||||
"Comma-separated list of linked variables between serve and bench, "
|
||||
"e.g. max_num_seqs=max_concurrency,max_model_len=random_input_len"
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bench-params",
|
||||
type=str,
|
||||
@@ -413,7 +436,15 @@ class SweepServeArgs:
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="results",
|
||||
help="The directory to which results are written.",
|
||||
help="The main directory to which results are written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--experiment-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The name of this experiment (defaults to current timestamp). "
|
||||
"Results will be stored under `output_dir/experiment_name`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-runs",
|
||||
@@ -429,21 +460,10 @@ class SweepServeArgs:
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Set this to the name of a directory under `output_dir` (which is a "
|
||||
"timestamp) to resume a previous execution of this script, i.e., only run "
|
||||
"parameter combinations for which there are still no output files.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--link-vars",
|
||||
type=str,
|
||||
default="",
|
||||
help=(
|
||||
"Comma-separated list of linked variables between serve and bench, "
|
||||
"e.g. max_num_seqs=max_concurrency,max_model_len=random_input_len"
|
||||
),
|
||||
action="store_true",
|
||||
help="Resume a previous execution of this script, i.e., only run "
|
||||
"parameter combinations for which there are still no output files "
|
||||
"under `output_dir/experiment_name`.",
|
||||
)
|
||||
|
||||
return parser
|
||||
@@ -458,33 +478,52 @@ class SweepServeArgs:
|
||||
pairs.append((a.strip(), b.strip()))
|
||||
return pairs
|
||||
|
||||
def resolve_experiment_dir(self) -> Path:
|
||||
experiment_dir = self.output_dir / self.experiment_name
|
||||
|
||||
if self.resume:
|
||||
if not experiment_dir.exists():
|
||||
raise ValueError(f"Cannot resume from non-existent {experiment_dir=}")
|
||||
else:
|
||||
if experiment_dir.exists():
|
||||
raise ValueError(f"Cannot overwrite existing {experiment_dir=}")
|
||||
|
||||
return experiment_dir
|
||||
|
||||
@contextmanager
|
||||
def run_ctx(self, experiment_dir: Path):
|
||||
if self.dry_run:
|
||||
yield
|
||||
print(f"Experiment will be saved at: {experiment_dir}")
|
||||
return
|
||||
|
||||
try:
|
||||
yield
|
||||
print(f"Experiment has been saved at: {experiment_dir}")
|
||||
except BaseException as exc:
|
||||
raise RuntimeError(
|
||||
"The script was terminated early. Use `--resume` "
|
||||
"to continue the script from its last checkpoint."
|
||||
) from exc
|
||||
|
||||
|
||||
def run_main(args: SweepServeArgs):
|
||||
timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_dir = args.output_dir / timestamp
|
||||
experiment_dir = args.resolve_experiment_dir()
|
||||
|
||||
if args.resume and not output_dir.exists():
|
||||
raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")
|
||||
|
||||
try:
|
||||
with args.run_ctx(experiment_dir):
|
||||
return run_combs(
|
||||
serve_cmd=args.serve_cmd,
|
||||
bench_cmd=args.bench_cmd,
|
||||
link_vars=args.link_vars,
|
||||
after_bench_cmd=args.after_bench_cmd,
|
||||
show_stdout=args.show_stdout,
|
||||
server_ready_timeout=args.server_ready_timeout,
|
||||
serve_params=args.serve_params,
|
||||
bench_params=args.bench_params,
|
||||
output_dir=output_dir,
|
||||
experiment_dir=experiment_dir,
|
||||
num_runs=args.num_runs,
|
||||
dry_run=args.dry_run,
|
||||
link_vars=args.link_vars,
|
||||
)
|
||||
except BaseException as exc:
|
||||
raise RuntimeError(
|
||||
f"The script was terminated early. Use `--resume {timestamp}` "
|
||||
f"to continue the script from its last checkpoint."
|
||||
) from exc
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
|
||||
@@ -1,305 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import math
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Literal, get_args
|
||||
|
||||
import numpy as np
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
from .param_sweep import ParameterSweep, ParameterSweepItem
|
||||
from .serve import (
|
||||
SweepServeArgs,
|
||||
_get_comb_base_path,
|
||||
run_comb,
|
||||
server_ctx,
|
||||
)
|
||||
from .server import ServerProcess
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pd = PlaceholderModule("pandas")
|
||||
|
||||
|
||||
SLAVariable = Literal["request_rate", "max_concurrency"]
|
||||
|
||||
|
||||
def _estimate_sla_value(run_data: dict[str, object], sla_variable: SLAVariable):
|
||||
request_throughput = float(run_data["request_throughput"]) # type: ignore
|
||||
if sla_variable == "request_rate":
|
||||
return request_throughput
|
||||
if sla_variable == "max_concurrency":
|
||||
mean_latency_ms = float(run_data["mean_e2el_ms"]) # type: ignore
|
||||
return request_throughput * mean_latency_ms / 1000
|
||||
|
||||
assert_never(sla_variable)
|
||||
|
||||
|
||||
def _estimate_sla_avg(runs: list[dict[str, object]], sla_variable: SLAVariable):
|
||||
return sum(_estimate_sla_value(run, sla_variable) for run in runs) / len(runs)
|
||||
|
||||
|
||||
def run_comb_sla(
|
||||
server: ServerProcess | None,
|
||||
bench_cmd: list[str],
|
||||
*,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
output_dir: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
link_vars: list[tuple[str, str]],
|
||||
sla_variable: SLAVariable,
|
||||
sla_value: int,
|
||||
) -> list[dict[str, object]] | None:
|
||||
bench_comb_sla = bench_comb | {sla_variable: sla_value}
|
||||
|
||||
return run_comb(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb_sla,
|
||||
base_path=_get_comb_base_path(output_dir, serve_comb, bench_comb_sla),
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
link_vars=link_vars,
|
||||
)
|
||||
|
||||
|
||||
def explore_sla(
|
||||
server: ServerProcess | None,
|
||||
bench_cmd: list[str],
|
||||
*,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
sla_variable: SLAVariable,
|
||||
sla_iters: int,
|
||||
output_dir: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
link_vars: list[tuple[str, str]],
|
||||
):
|
||||
print("[SLA START]")
|
||||
print(f"Serve parameters: {serve_comb.as_text() or '(None)'}")
|
||||
print(f"Bench parameters: {bench_comb.as_text() or '(None)'}")
|
||||
print(f"Number of SLA iterations: {sla_iters}")
|
||||
|
||||
if sla_iters < 2:
|
||||
raise ValueError("`sla_iters` should be at least 2")
|
||||
|
||||
serial_comb_data = run_comb_sla(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
output_dir=output_dir,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
link_vars=link_vars,
|
||||
sla_variable=sla_variable,
|
||||
sla_value=1,
|
||||
)
|
||||
batch_comb_data = run_comb_sla(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
output_dir=output_dir,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
link_vars=link_vars,
|
||||
sla_variable=sla_variable,
|
||||
sla_value=int(bench_comb.get("num_prompts", 1000)), # type: ignore
|
||||
)
|
||||
|
||||
if serial_comb_data is None or batch_comb_data is None:
|
||||
if dry_run:
|
||||
print("Omitting intermediate SLA iterations.")
|
||||
print("[SLA END]")
|
||||
|
||||
return
|
||||
|
||||
serial_sla_value = math.ceil(_estimate_sla_avg(serial_comb_data, sla_variable))
|
||||
print(f"Serial inference: {sla_variable}={serial_sla_value}")
|
||||
|
||||
batch_sla_value = math.floor(_estimate_sla_avg(batch_comb_data, sla_variable))
|
||||
print(f"Batch inference: {sla_variable}={batch_sla_value}")
|
||||
|
||||
# Avoid duplicated runs for intermediate values if the range between
|
||||
# `serial_sla_value` and `batch_sla_value` is small
|
||||
inter_sla_values = np.linspace(serial_sla_value, batch_sla_value, sla_iters)[1:-1]
|
||||
inter_sla_values = sorted(set(map(round, inter_sla_values)))
|
||||
|
||||
inter_combs_data: list[dict[str, object]] = []
|
||||
for inter_sla_value in inter_sla_values:
|
||||
print(f"Exploring: {sla_variable}={inter_sla_value}")
|
||||
inter_comb_data = run_comb_sla(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
output_dir=output_dir,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
link_vars=link_vars,
|
||||
sla_variable=sla_variable,
|
||||
sla_value=inter_sla_value,
|
||||
)
|
||||
if inter_comb_data is not None:
|
||||
inter_combs_data.extend(inter_comb_data)
|
||||
|
||||
print("[SLA END]")
|
||||
|
||||
return serial_comb_data + inter_combs_data + batch_comb_data
|
||||
|
||||
|
||||
def run_slas(
|
||||
serve_cmd: list[str],
|
||||
bench_cmd: list[str],
|
||||
after_bench_cmd: list[str],
|
||||
*,
|
||||
show_stdout: bool,
|
||||
server_ready_timeout: int,
|
||||
serve_params: ParameterSweep,
|
||||
bench_params: ParameterSweep,
|
||||
sla_variable: SLAVariable,
|
||||
sla_iters: int,
|
||||
output_dir: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
link_vars: list[tuple[str, str]],
|
||||
):
|
||||
if any(bench_comb.has_param(sla_variable) for bench_comb in bench_params):
|
||||
raise ValueError(
|
||||
f"You should not override `{sla_variable}` in `bench_params` in SLA mode, "
|
||||
"since it is supposed to be determined automatically."
|
||||
)
|
||||
|
||||
all_data = list[dict[str, object]]()
|
||||
for serve_comb in serve_params:
|
||||
with server_ctx(
|
||||
serve_cmd,
|
||||
after_bench_cmd,
|
||||
show_stdout=show_stdout,
|
||||
server_ready_timeout=server_ready_timeout,
|
||||
serve_comb=serve_comb,
|
||||
bench_params=bench_params,
|
||||
output_dir=output_dir,
|
||||
dry_run=dry_run,
|
||||
) as server:
|
||||
for bench_comb in bench_params:
|
||||
comb_data = explore_sla(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
sla_variable=sla_variable,
|
||||
sla_iters=sla_iters,
|
||||
output_dir=output_dir,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
link_vars=link_vars,
|
||||
)
|
||||
|
||||
if comb_data is not None:
|
||||
all_data.extend(comb_data)
|
||||
|
||||
if dry_run:
|
||||
return None
|
||||
|
||||
combined_df = pd.DataFrame.from_records(all_data)
|
||||
combined_df.to_csv(output_dir / "summary.csv")
|
||||
|
||||
return combined_df
|
||||
|
||||
|
||||
@dataclass
|
||||
class SweepServeSLAArgs(SweepServeArgs):
|
||||
sla_variable: SLAVariable
|
||||
sla_iters: int
|
||||
|
||||
parser_name: ClassVar[str] = "serve_sla"
|
||||
parser_help: ClassVar[str] = (
|
||||
"Explore the latency-throughput space for determining SLAs."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
# NOTE: Don't use super() as `from_cli_args` calls `cls()`
|
||||
base_args = SweepServeArgs.from_cli_args(args)
|
||||
|
||||
return cls(
|
||||
**asdict(base_args),
|
||||
sla_variable=args.sla_variable,
|
||||
sla_iters=args.sla_iters,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
parser = super().add_cli_args(parser)
|
||||
|
||||
sla_group = parser.add_argument_group("sla options")
|
||||
sla_group.add_argument(
|
||||
"--sla-variable",
|
||||
type=str,
|
||||
choices=get_args(SLAVariable),
|
||||
default="request_rate",
|
||||
help="The variable to adjust in each iteration.",
|
||||
)
|
||||
sla_group.add_argument(
|
||||
"--sla-iters",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of iterations used to explore the latency-throughput space. "
|
||||
"This includes the first two iterations used to interpolate the value of "
|
||||
"`sla_variable` for remaining iterations.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def run_main(args: SweepServeSLAArgs):
|
||||
timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_dir = args.output_dir / timestamp
|
||||
|
||||
if args.resume and not output_dir.exists():
|
||||
raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")
|
||||
|
||||
try:
|
||||
return run_slas(
|
||||
serve_cmd=args.serve_cmd,
|
||||
bench_cmd=args.bench_cmd,
|
||||
after_bench_cmd=args.after_bench_cmd,
|
||||
show_stdout=args.show_stdout,
|
||||
server_ready_timeout=args.server_ready_timeout,
|
||||
serve_params=args.serve_params,
|
||||
bench_params=args.bench_params,
|
||||
sla_variable=args.sla_variable,
|
||||
sla_iters=args.sla_iters,
|
||||
output_dir=output_dir,
|
||||
num_runs=args.num_runs,
|
||||
dry_run=args.dry_run,
|
||||
link_vars=args.link_vars,
|
||||
)
|
||||
except BaseException as exc:
|
||||
raise RuntimeError(
|
||||
f"The script was terminated early. Use `--resume {timestamp}` "
|
||||
f"to continue the script from its last checkpoint."
|
||||
) from exc
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
run_main(SweepServeSLAArgs.from_cli_args(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=SweepServeSLAArgs.parser_help)
|
||||
SweepServeSLAArgs.add_cli_args(parser)
|
||||
|
||||
main(parser.parse_args())
|
||||
328
vllm/benchmarks/sweep/serve_workload.py
Normal file
328
vllm/benchmarks/sweep/serve_workload.py
Normal file
@@ -0,0 +1,328 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import math
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Literal, get_args
|
||||
|
||||
import numpy as np
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.benchmarks.datasets import DEFAULT_NUM_PROMPTS
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
from .param_sweep import ParameterSweep, ParameterSweepItem
|
||||
from .serve import (
|
||||
SweepServeArgs,
|
||||
_get_comb_base_path,
|
||||
run_comb,
|
||||
server_ctx,
|
||||
)
|
||||
from .server import ServerProcess
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pd = PlaceholderModule("pandas")
|
||||
|
||||
|
||||
WorkloadVariable = Literal["request_rate", "max_concurrency"]
|
||||
|
||||
|
||||
def _estimate_workload_value(
|
||||
run_data: dict[str, object],
|
||||
workload_var: WorkloadVariable,
|
||||
):
|
||||
request_throughput = float(run_data["request_throughput"]) # type: ignore
|
||||
if workload_var == "request_rate":
|
||||
return request_throughput
|
||||
if workload_var == "max_concurrency":
|
||||
mean_latency_ms = float(run_data["mean_e2el_ms"]) # type: ignore
|
||||
return request_throughput * mean_latency_ms / 1000
|
||||
|
||||
assert_never(workload_var)
|
||||
|
||||
|
||||
def _estimate_workload_avg(
|
||||
runs: list[dict[str, object]],
|
||||
workload_var: WorkloadVariable,
|
||||
):
|
||||
total = sum(_estimate_workload_value(run, workload_var) for run in runs)
|
||||
return total / len(runs)
|
||||
|
||||
|
||||
def run_comb_workload(
|
||||
server: ServerProcess | None,
|
||||
bench_cmd: list[str],
|
||||
*,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
link_vars: list[tuple[str, str]],
|
||||
experiment_dir: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
workload_var: WorkloadVariable,
|
||||
workload_value: int,
|
||||
) -> list[dict[str, object]] | None:
|
||||
bench_comb_workload = bench_comb | {workload_var: workload_value}
|
||||
|
||||
return run_comb(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb_workload,
|
||||
link_vars=link_vars,
|
||||
base_path=_get_comb_base_path(
|
||||
experiment_dir,
|
||||
serve_comb,
|
||||
bench_comb,
|
||||
extra_parts=("WL-", f"{workload_var}={workload_value}"),
|
||||
),
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
|
||||
def explore_comb_workloads(
|
||||
server: ServerProcess | None,
|
||||
bench_cmd: list[str],
|
||||
*,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
link_vars: list[tuple[str, str]],
|
||||
workload_var: WorkloadVariable,
|
||||
workload_iters: int,
|
||||
experiment_dir: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
):
|
||||
print("[WL START]")
|
||||
print(f"Serve parameters: {serve_comb.as_text() or '(None)'}")
|
||||
print(f"Bench parameters: {bench_comb.as_text() or '(None)'}")
|
||||
print(f"Number of workload iterations: {workload_iters}")
|
||||
|
||||
if workload_iters < 2:
|
||||
raise ValueError("`workload_iters` should be at least 2")
|
||||
|
||||
dataset_size = DEFAULT_NUM_PROMPTS
|
||||
if "num_prompts" in bench_comb:
|
||||
dataset_size = int(bench_comb["num_prompts"]) # type: ignore
|
||||
else:
|
||||
for i, arg in enumerate(bench_cmd):
|
||||
if arg == "--num-prompts" and i + 1 < len(bench_cmd):
|
||||
dataset_size = int(bench_cmd[i + 1])
|
||||
break
|
||||
elif arg.startswith("--num-prompts="):
|
||||
dataset_size = int(arg.split("=", 1)[1])
|
||||
break
|
||||
|
||||
print(f"Dataset size: {dataset_size}")
|
||||
|
||||
serial_workload_data = run_comb_workload(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb | {"max_concurrency": 1},
|
||||
link_vars=link_vars,
|
||||
experiment_dir=experiment_dir,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
workload_var=workload_var,
|
||||
workload_value=1,
|
||||
)
|
||||
batch_workload_data = run_comb_workload(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb | {"max_concurrency": dataset_size},
|
||||
link_vars=link_vars,
|
||||
experiment_dir=experiment_dir,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
workload_var=workload_var,
|
||||
workload_value=dataset_size,
|
||||
)
|
||||
|
||||
if serial_workload_data is None or batch_workload_data is None:
|
||||
if dry_run:
|
||||
print("Omitting intermediate Workload iterations.")
|
||||
print("[WL END]")
|
||||
|
||||
return
|
||||
|
||||
serial_workload_value = math.ceil(
|
||||
_estimate_workload_avg(serial_workload_data, workload_var)
|
||||
)
|
||||
print(f"Serial inference: {workload_var}={serial_workload_value}")
|
||||
|
||||
batch_workload_value = math.floor(
|
||||
_estimate_workload_avg(batch_workload_data, workload_var)
|
||||
)
|
||||
print(f"Batch inference: {workload_var}={batch_workload_value}")
|
||||
|
||||
# Avoid duplicated runs for intermediate values if the range between
|
||||
# `serial_workload_value` and `batch_workload_value` is small
|
||||
inter_workload_values = np.linspace(
|
||||
serial_workload_value, batch_workload_value, workload_iters
|
||||
)[1:-1]
|
||||
inter_workload_values = sorted(set(map(round, inter_workload_values)))
|
||||
|
||||
inter_workloads_data: list[dict[str, object]] = []
|
||||
for inter_workload_value in inter_workload_values:
|
||||
print(f"Exploring: {workload_var}={inter_workload_value}")
|
||||
inter_workload_data = run_comb_workload(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
link_vars=link_vars,
|
||||
experiment_dir=experiment_dir,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
workload_var=workload_var,
|
||||
workload_value=inter_workload_value,
|
||||
)
|
||||
if inter_workload_data is not None:
|
||||
inter_workloads_data.extend(inter_workload_data)
|
||||
|
||||
print("[WL END]")
|
||||
|
||||
return serial_workload_data + inter_workloads_data + batch_workload_data
|
||||
|
||||
|
||||
def explore_combs_workloads(
|
||||
serve_cmd: list[str],
|
||||
bench_cmd: list[str],
|
||||
after_bench_cmd: list[str],
|
||||
*,
|
||||
show_stdout: bool,
|
||||
server_ready_timeout: int,
|
||||
serve_params: ParameterSweep,
|
||||
bench_params: ParameterSweep,
|
||||
link_vars: list[tuple[str, str]],
|
||||
workload_var: WorkloadVariable,
|
||||
workload_iters: int,
|
||||
experiment_dir: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
):
|
||||
if any(bench_comb.has_param(workload_var) for bench_comb in bench_params):
|
||||
raise ValueError(
|
||||
f"You should not override `{workload_var}` in `bench_params` "
|
||||
"since it is supposed to be explored automatically."
|
||||
)
|
||||
|
||||
all_data = list[dict[str, object]]()
|
||||
for serve_comb in serve_params:
|
||||
with server_ctx(
|
||||
serve_cmd,
|
||||
after_bench_cmd,
|
||||
show_stdout=show_stdout,
|
||||
server_ready_timeout=server_ready_timeout,
|
||||
serve_comb=serve_comb,
|
||||
bench_params=bench_params,
|
||||
experiment_dir=experiment_dir,
|
||||
dry_run=dry_run,
|
||||
) as server:
|
||||
for bench_comb in bench_params:
|
||||
comb_data = explore_comb_workloads(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
link_vars=link_vars,
|
||||
workload_var=workload_var,
|
||||
workload_iters=workload_iters,
|
||||
experiment_dir=experiment_dir,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
if comb_data is not None:
|
||||
all_data.extend(comb_data)
|
||||
|
||||
if dry_run:
|
||||
return None
|
||||
|
||||
combined_df = pd.DataFrame.from_records(all_data)
|
||||
combined_df.to_csv(experiment_dir / "summary.csv")
|
||||
|
||||
return combined_df
|
||||
|
||||
|
||||
@dataclass
|
||||
class SweepServeWorkloadArgs(SweepServeArgs):
|
||||
workload_var: WorkloadVariable
|
||||
workload_iters: int
|
||||
|
||||
parser_name: ClassVar[str] = "serve_workload"
|
||||
parser_help: ClassVar[str] = (
|
||||
"Explore the latency-throughput tradeoff for different workload levels."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
# NOTE: Don't use super() as `from_cli_args` calls `cls()`
|
||||
base_args = SweepServeArgs.from_cli_args(args)
|
||||
|
||||
return cls(
|
||||
**asdict(base_args),
|
||||
workload_var=args.workload_var,
|
||||
workload_iters=args.workload_iters,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
parser = super().add_cli_args(parser)
|
||||
|
||||
workload_group = parser.add_argument_group("workload options")
|
||||
workload_group.add_argument(
|
||||
"--workload-var",
|
||||
type=str,
|
||||
choices=get_args(WorkloadVariable),
|
||||
default="request_rate",
|
||||
help="The variable to adjust in each iteration.",
|
||||
)
|
||||
workload_group.add_argument(
|
||||
"--workload-iters",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of workload levels to explore. "
|
||||
"This includes the first two iterations used to interpolate the value of "
|
||||
"`workload_var` for remaining iterations.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def run_main(args: SweepServeWorkloadArgs):
|
||||
experiment_dir = args.resolve_experiment_dir()
|
||||
|
||||
with args.run_ctx(experiment_dir):
|
||||
return explore_combs_workloads(
|
||||
serve_cmd=args.serve_cmd,
|
||||
bench_cmd=args.bench_cmd,
|
||||
after_bench_cmd=args.after_bench_cmd,
|
||||
show_stdout=args.show_stdout,
|
||||
server_ready_timeout=args.server_ready_timeout,
|
||||
serve_params=args.serve_params,
|
||||
bench_params=args.bench_params,
|
||||
link_vars=args.link_vars,
|
||||
workload_var=args.workload_var,
|
||||
workload_iters=args.workload_iters,
|
||||
experiment_dir=experiment_dir,
|
||||
num_runs=args.num_runs,
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
run_main(SweepServeWorkloadArgs.from_cli_args(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=SweepServeWorkloadArgs.parser_help)
|
||||
SweepServeWorkloadArgs.add_cli_args(parser)
|
||||
|
||||
main(parser.parse_args())
|
||||
@@ -4,6 +4,7 @@ import argparse
|
||||
import json
|
||||
import shlex
|
||||
import subprocess
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
@@ -111,7 +112,7 @@ def _apply_output_json(cmd: list[str], output_path: Path) -> list[str]:
|
||||
|
||||
|
||||
def _get_comb_base_path(
|
||||
output_dir: Path,
|
||||
experiment_dir: Path,
|
||||
serve_comb: ParameterSweepItem,
|
||||
startup_comb: ParameterSweepItem,
|
||||
) -> Path:
|
||||
@@ -120,7 +121,8 @@ def _get_comb_base_path(
|
||||
parts.extend(("SERVE-", serve_comb.name))
|
||||
if startup_comb:
|
||||
parts.extend(("STARTUP-", startup_comb.name))
|
||||
return output_dir / sanitize_filename("-".join(parts))
|
||||
|
||||
return experiment_dir / sanitize_filename("-".join(parts))
|
||||
|
||||
|
||||
def _get_comb_run_path(base_path: Path, run_number: int | None) -> Path:
|
||||
@@ -225,7 +227,7 @@ def run_combs(
|
||||
*,
|
||||
serve_params: ParameterSweep,
|
||||
startup_params: ParameterSweep,
|
||||
output_dir: Path,
|
||||
experiment_dir: Path,
|
||||
num_runs: int,
|
||||
show_stdout: bool,
|
||||
dry_run: bool,
|
||||
@@ -233,7 +235,7 @@ def run_combs(
|
||||
all_data = list[dict[str, object]]()
|
||||
for serve_comb in serve_params:
|
||||
for startup_comb in startup_params:
|
||||
base_path = _get_comb_base_path(output_dir, serve_comb, startup_comb)
|
||||
base_path = _get_comb_base_path(experiment_dir, serve_comb, startup_comb)
|
||||
comb_data = run_comb(
|
||||
startup_cmd,
|
||||
serve_comb=serve_comb,
|
||||
@@ -250,7 +252,7 @@ def run_combs(
|
||||
return None
|
||||
|
||||
combined_df = pd.DataFrame.from_records(all_data)
|
||||
combined_df.to_csv(output_dir / "summary.csv")
|
||||
combined_df.to_csv(experiment_dir / "summary.csv")
|
||||
return combined_df
|
||||
|
||||
|
||||
@@ -260,11 +262,11 @@ class SweepStartupArgs:
|
||||
serve_params: ParameterSweep
|
||||
startup_params: ParameterSweep
|
||||
output_dir: Path
|
||||
experiment_name: str
|
||||
num_runs: int
|
||||
show_stdout: bool
|
||||
dry_run: bool
|
||||
resume: str | None
|
||||
strict_params: bool
|
||||
resume: bool
|
||||
|
||||
parser_name: ClassVar[str] = "startup"
|
||||
parser_help: ClassVar[str] = (
|
||||
@@ -286,13 +288,19 @@ class SweepStartupArgs:
|
||||
startup_params = ParameterSweep.from_records([{}])
|
||||
|
||||
supported = _get_supported_startup_keys()
|
||||
strict_params = args.strict_params
|
||||
serve_params = _filter_params(
|
||||
serve_params, supported=supported, strict=args.strict_params
|
||||
serve_params, supported=supported, strict=strict_params
|
||||
)
|
||||
startup_params = _filter_params(
|
||||
startup_params, supported=supported, strict=args.strict_params
|
||||
startup_params, supported=supported, strict=strict_params
|
||||
)
|
||||
|
||||
if args.experiment_name:
|
||||
experiment_name = args.experiment_name
|
||||
else:
|
||||
experiment_name = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
if args.num_runs < 1:
|
||||
raise ValueError("`num_runs` should be at least 1.")
|
||||
|
||||
@@ -301,11 +309,11 @@ class SweepStartupArgs:
|
||||
serve_params=serve_params,
|
||||
startup_params=startup_params,
|
||||
output_dir=Path(args.output_dir),
|
||||
experiment_name=experiment_name,
|
||||
num_runs=args.num_runs,
|
||||
show_stdout=args.show_stdout,
|
||||
dry_run=args.dry_run,
|
||||
resume=args.resume,
|
||||
strict_params=args.strict_params,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -316,6 +324,7 @@ class SweepStartupArgs:
|
||||
default="vllm bench startup",
|
||||
help="The command used to run the startup benchmark.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--serve-params",
|
||||
type=str,
|
||||
@@ -331,12 +340,27 @@ class SweepStartupArgs:
|
||||
help="Path to JSON file containing parameter combinations "
|
||||
"for the `vllm bench startup` command.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--strict-params",
|
||||
action="store_true",
|
||||
help="If set, unknown parameters in sweep files raise an error "
|
||||
"instead of being ignored.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="results",
|
||||
help="The directory to which results are written.",
|
||||
help="The main directory to which results are written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--experiment-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The name of this experiment (defaults to current timestamp). "
|
||||
"Results will be stored under `output_dir/experiment_name`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-runs",
|
||||
@@ -357,43 +381,56 @@ class SweepStartupArgs:
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Set this to the name of a directory under `output_dir` (which is a "
|
||||
"timestamp) to resume a previous execution of this script, i.e., only run "
|
||||
"parameter combinations for which there are still no output files.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--strict-params",
|
||||
action="store_true",
|
||||
help="If set, unknown parameters in sweep files raise an error "
|
||||
"instead of being ignored.",
|
||||
help="Resume a previous execution of this script, i.e., only run "
|
||||
"parameter combinations for which there are still no output files "
|
||||
"under `output_dir/experiment_name`.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
def resolve_experiment_dir(self) -> Path:
|
||||
experiment_dir = self.output_dir / self.experiment_name
|
||||
|
||||
if self.resume:
|
||||
if not experiment_dir.exists():
|
||||
raise ValueError(f"Cannot resume from non-existent {experiment_dir=}")
|
||||
else:
|
||||
if experiment_dir.exists():
|
||||
raise ValueError(f"Cannot overwrite existing {experiment_dir=}")
|
||||
|
||||
return experiment_dir
|
||||
|
||||
@contextmanager
|
||||
def run_ctx(self, experiment_dir: Path):
|
||||
if self.dry_run:
|
||||
yield
|
||||
print(f"Experiment will be saved at: {experiment_dir}")
|
||||
return
|
||||
|
||||
try:
|
||||
yield
|
||||
print(f"Experiment has been saved at: {experiment_dir}")
|
||||
except BaseException as exc:
|
||||
raise RuntimeError(
|
||||
"The script was terminated early. Use `--resume` "
|
||||
"to continue the script from its last checkpoint."
|
||||
) from exc
|
||||
|
||||
|
||||
def run_main(args: SweepStartupArgs):
|
||||
timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_dir = args.output_dir / timestamp
|
||||
experiment_dir = args.resolve_experiment_dir()
|
||||
|
||||
if args.resume and not output_dir.exists():
|
||||
raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")
|
||||
|
||||
try:
|
||||
with args.run_ctx(experiment_dir):
|
||||
return run_combs(
|
||||
startup_cmd=args.startup_cmd,
|
||||
serve_params=args.serve_params,
|
||||
startup_params=args.startup_params,
|
||||
output_dir=output_dir,
|
||||
experiment_dir=experiment_dir,
|
||||
num_runs=args.num_runs,
|
||||
show_stdout=args.show_stdout,
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
except BaseException as exc:
|
||||
raise RuntimeError(
|
||||
f"The script was terminated early. Use `--resume {timestamp}` "
|
||||
f"to continue the script from its last checkpoint."
|
||||
) from exc
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
|
||||
Reference in New Issue
Block a user