update
This commit is contained in:
0
vllm/benchmarks/__init__.py
Normal file
0
vllm/benchmarks/__init__.py
Normal file
3453
vllm/benchmarks/datasets.py
Normal file
3453
vllm/benchmarks/datasets.py
Normal file
File diff suppressed because it is too large
Load Diff
172
vllm/benchmarks/latency.py
Normal file
172
vllm/benchmarks/latency.py
Normal file
@@ -0,0 +1,172 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Benchmark the latency of processing a single batch of requests."""
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
|
||||
|
||||
def save_to_pytorch_benchmark_format(
|
||||
args: argparse.Namespace, results: dict[str, Any]
|
||||
) -> None:
|
||||
pt_records = convert_to_pytorch_benchmark_format(
|
||||
args=args,
|
||||
metrics={"latency": results["latencies"]},
|
||||
extra_info={k: results[k] for k in ["avg_latency", "percentiles"]},
|
||||
)
|
||||
if pt_records:
|
||||
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
|
||||
write_to_json(pt_file, pt_records)
|
||||
|
||||
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--input-len", type=int, default=32)
|
||||
parser.add_argument("--output-len", type=int, default=128)
|
||||
parser.add_argument("--batch-size", type=int, default=8)
|
||||
parser.add_argument(
|
||||
"--n",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of generated sequences per prompt.",
|
||||
)
|
||||
parser.add_argument("--use-beam-search", action="store_true")
|
||||
parser.add_argument(
|
||||
"--num-iters-warmup",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of iterations to run for warmup.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-iters", type=int, default=30, help="Number of iterations to run."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
action="store_true",
|
||||
help="profile the generation process of a single batch",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-json",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to save the latency results in JSON format.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-detokenize",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Do not detokenize responses (i.e. do not include "
|
||||
"detokenization time in the latency measurement)"
|
||||
),
|
||||
)
|
||||
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# V1 enables prefix caching by default which skews the latency
|
||||
# numbers. We need to disable prefix caching by default.
|
||||
parser.set_defaults(enable_prefix_caching=False)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
|
||||
# Lazy import to avoid importing LLM when the bench command is not selected.
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# NOTE(woosuk): If the request cannot be processed in a single batch,
|
||||
# the engine will automatically process the request in multiple batches.
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
assert llm.llm_engine.model_config.max_model_len >= (
|
||||
args.input_len + args.output_len
|
||||
), (
|
||||
"Please ensure that max_model_len is greater than"
|
||||
" the sum of input_len and output_len."
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
n=args.n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=args.output_len,
|
||||
detokenize=not args.disable_detokenize,
|
||||
)
|
||||
dummy_prompt_token_ids = np.random.randint(
|
||||
10000, size=(args.batch_size, args.input_len)
|
||||
)
|
||||
dummy_prompts: list[PromptType] = [
|
||||
{"prompt_token_ids": batch} for batch in dummy_prompt_token_ids.tolist()
|
||||
]
|
||||
|
||||
def llm_generate():
|
||||
if not args.use_beam_search:
|
||||
llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False)
|
||||
else:
|
||||
llm.beam_search(
|
||||
dummy_prompts,
|
||||
BeamSearchParams(
|
||||
beam_width=args.n,
|
||||
max_tokens=args.output_len,
|
||||
ignore_eos=True,
|
||||
),
|
||||
)
|
||||
|
||||
def run_to_completion(do_profile: bool = False):
|
||||
if do_profile:
|
||||
llm.start_profile()
|
||||
llm_generate()
|
||||
llm.stop_profile()
|
||||
else:
|
||||
start_time = time.perf_counter()
|
||||
llm_generate()
|
||||
end_time = time.perf_counter()
|
||||
latency = end_time - start_time
|
||||
return latency
|
||||
|
||||
print("Warming up...")
|
||||
for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
|
||||
run_to_completion(do_profile=False)
|
||||
|
||||
if args.profile:
|
||||
profiler_config = engine_args.profiler_config
|
||||
if profiler_config.profiler == "torch":
|
||||
print(
|
||||
"Profiling with torch profiler (results will be saved to"
|
||||
f" {profiler_config.torch_profiler_dir})..."
|
||||
)
|
||||
elif profiler_config.profiler == "cuda":
|
||||
print("Profiling with cuda profiler ...")
|
||||
run_to_completion(do_profile=True)
|
||||
return
|
||||
|
||||
# Benchmark.
|
||||
latencies = []
|
||||
for _ in tqdm(range(args.num_iters), desc="Bench iterations"):
|
||||
latencies.append(run_to_completion(do_profile=False))
|
||||
latencies = np.array(latencies)
|
||||
percentages = [10, 25, 50, 75, 90, 99]
|
||||
percentiles = np.percentile(latencies, percentages)
|
||||
print(f"Avg latency: {np.mean(latencies)} seconds")
|
||||
for percentage, percentile in zip(percentages, percentiles):
|
||||
print(f"{percentage}% percentile latency: {percentile} seconds")
|
||||
|
||||
# Output JSON results if specified
|
||||
if args.output_json:
|
||||
results = {
|
||||
"avg_latency": np.mean(latencies),
|
||||
"latencies": latencies.tolist(),
|
||||
"percentiles": dict(zip(percentages, percentiles.tolist())),
|
||||
}
|
||||
with open(args.output_json, "w") as f:
|
||||
json.dump(results, f, indent=4)
|
||||
save_to_pytorch_benchmark_format(args, results)
|
||||
539
vllm/benchmarks/mm_processor.py
Normal file
539
vllm/benchmarks/mm_processor.py
Normal file
@@ -0,0 +1,539 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
r"""Benchmark multimodal processor latency.
|
||||
|
||||
This benchmark measures the latency of the mm processor module
|
||||
using multimodal prompts from datasets.
|
||||
MM processor stats are automatically enabled.
|
||||
|
||||
Run:
|
||||
vllm bench mm-processor \
|
||||
--model <your_model> \
|
||||
--dataset-name random-mm \
|
||||
--num-prompts 10 \
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
import numpy as np
|
||||
|
||||
from vllm.benchmarks.datasets import (
|
||||
MultiModalConversationDataset,
|
||||
VisionArenaDataset,
|
||||
)
|
||||
from vllm.benchmarks.throughput import get_requests
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.utils.gc_utils import freeze_gc_heap
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pd = PlaceholderModule("pandas")
|
||||
|
||||
if TYPE_CHECKING: # Avoid having to mock during docs build
|
||||
from vllm.v1.engine.llm_engine import LLMEngine
|
||||
else:
|
||||
LLMEngine = object
|
||||
|
||||
|
||||
def get_timing_stats_from_engine(llm_engine: LLMEngine) -> dict[str, dict[str, float]]:
|
||||
"""
|
||||
Get all multimodal timing stats from the LLM engine.
|
||||
|
||||
Collects both preprocessing stats (HF processor, hashing, cache lookup,
|
||||
prompt update) and encoder forward pass timing, merged by request_id.
|
||||
|
||||
Args:
|
||||
llm_engine: The LLM engine (has input_processor and workers).
|
||||
|
||||
Returns:
|
||||
Dictionary mapping request_id to merged stats dict containing
|
||||
both preprocessing and encoder timing metrics.
|
||||
|
||||
Example:
|
||||
{
|
||||
'request-123': {
|
||||
'get_mm_hashes_secs': 0.02,
|
||||
'get_cache_missing_items_secs': 0.01,
|
||||
'apply_hf_processor_secs': 0.45,
|
||||
'merge_mm_kwargs_secs': 0.01,
|
||||
'apply_prompt_updates_secs': 0.03,
|
||||
'preprocessor_total_secs': 0.51,
|
||||
'encoder_forward_secs': 0.23,
|
||||
'num_encoder_calls': 1
|
||||
}
|
||||
}
|
||||
"""
|
||||
observability_config = llm_engine.vllm_config.observability_config
|
||||
if not observability_config or not observability_config.enable_mm_processor_stats:
|
||||
return {}
|
||||
|
||||
renderer = llm_engine.renderer
|
||||
mm_processor_stats = renderer._mm_timing_registry.stat()
|
||||
|
||||
encoder_stats = dict[str, dict[str, float]]()
|
||||
for worker_stats in llm_engine.collective_rpc("get_encoder_timing_stats"):
|
||||
if not worker_stats:
|
||||
continue
|
||||
|
||||
for request_id, stats_dict in worker_stats.items():
|
||||
if request_id not in encoder_stats:
|
||||
encoder_stats[request_id] = dict(stats_dict)
|
||||
else:
|
||||
# Aggregate timing metrics across workers
|
||||
current_time = encoder_stats[request_id].get(
|
||||
"encoder_forward_secs", 0.0
|
||||
)
|
||||
new_time = stats_dict.get("encoder_forward_secs", 0.0)
|
||||
encoder_stats[request_id]["encoder_forward_secs"] = max(
|
||||
current_time, new_time
|
||||
)
|
||||
|
||||
current_calls = encoder_stats[request_id].get("num_encoder_calls", 0)
|
||||
new_calls = stats_dict.get("num_encoder_calls", 0)
|
||||
encoder_stats[request_id]["num_encoder_calls"] = max(
|
||||
current_calls, new_calls
|
||||
)
|
||||
|
||||
merged_stats = dict[str, dict[str, float]]()
|
||||
|
||||
for request_id, prep_dict in mm_processor_stats.items():
|
||||
merged_stats[request_id] = dict(prep_dict)
|
||||
|
||||
for request_id, enc_dict in encoder_stats.items():
|
||||
if request_id in merged_stats:
|
||||
merged_stats[request_id].update(enc_dict)
|
||||
continue
|
||||
|
||||
# In V1 engine, the request_id in encoder_stats has a suffix
|
||||
# appended to the original request_id (which is used in
|
||||
# preprocessing_stats).
|
||||
# We try to strip the suffix to find the matching request.
|
||||
possible_original_id = request_id.rpartition("-")[0]
|
||||
if possible_original_id and possible_original_id in merged_stats:
|
||||
merged_stats[possible_original_id].update(enc_dict)
|
||||
else:
|
||||
merged_stats[request_id] = dict(enc_dict)
|
||||
|
||||
return merged_stats
|
||||
|
||||
|
||||
def collect_mm_processor_stats(llm_engine: LLMEngine) -> dict[str, list[float]]:
|
||||
"""
|
||||
Collect multimodal processor timing stats.
|
||||
Returns a dictionary mapping stage names to lists of timing values (in seconds).
|
||||
"""
|
||||
all_stats = get_timing_stats_from_engine(llm_engine)
|
||||
|
||||
stats_by_stage = defaultdict[str, list[float]](list)
|
||||
|
||||
for stats_dict in all_stats.values():
|
||||
for stat_key, stat_val in stats_dict.items():
|
||||
stats_by_stage[stat_key].append(stat_val)
|
||||
|
||||
return stats_by_stage
|
||||
|
||||
|
||||
def calculate_mm_processor_metrics(
|
||||
stats_by_stage: dict[str, list[float]],
|
||||
selected_percentiles: list[float],
|
||||
*,
|
||||
unit: Literal["us", "ms", "s"] = "ms",
|
||||
) -> dict[str, dict[str, float]]:
|
||||
"""
|
||||
Calculate aggregate metrics from stats by stage.
|
||||
"""
|
||||
unit2mult = {"us": 1000000, "ms": 1000, "s": 1}
|
||||
unit_mult = unit2mult[unit]
|
||||
|
||||
metrics = {}
|
||||
|
||||
for stage, times in stats_by_stage.items():
|
||||
stage_name = stage.replace("_secs", "_" + unit)
|
||||
|
||||
if not times:
|
||||
metrics[stage_name] = {
|
||||
"mean": 0.0,
|
||||
"median": 0.0,
|
||||
"std": 0.0,
|
||||
**{f"p{p}": 0.0 for p in selected_percentiles},
|
||||
}
|
||||
continue
|
||||
|
||||
is_count_metric = stage == "num_encoder_calls"
|
||||
values = times if is_count_metric else [t * unit_mult for t in times]
|
||||
|
||||
metrics[stage_name] = {
|
||||
"mean": float(np.mean(values)),
|
||||
"median": float(np.median(values)),
|
||||
"std": float(np.std(values)),
|
||||
**{f"p{p}": float(np.percentile(values, p)) for p in selected_percentiles},
|
||||
}
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def validate_args(args):
|
||||
"""
|
||||
Validate command-line arguments for mm_processor benchmark.
|
||||
"""
|
||||
if not getattr(args, "tokenizer", None):
|
||||
args.tokenizer = args.model
|
||||
if not hasattr(args, "dataset_path"):
|
||||
args.dataset_path = None
|
||||
if not hasattr(args, "lora_path"):
|
||||
args.lora_path = None
|
||||
if not hasattr(args, "max_loras"):
|
||||
args.max_loras = None
|
||||
|
||||
if args.dataset_name == "hf" and not args.dataset_path:
|
||||
raise ValueError(
|
||||
"--dataset-path is required when using --dataset-name hf. "
|
||||
"For multimodal benchmarking, specify a dataset like "
|
||||
"'lmarena-ai/VisionArena-Chat'."
|
||||
)
|
||||
if args.dataset_name == "hf":
|
||||
supported_mm_datasets = (
|
||||
VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys()
|
||||
| MultiModalConversationDataset.SUPPORTED_DATASET_PATHS
|
||||
)
|
||||
if args.dataset_path not in supported_mm_datasets:
|
||||
raise ValueError(
|
||||
f"{args.dataset_path} is not a supported multimodal dataset. "
|
||||
f"Supported multimodal datasets are: {sorted(supported_mm_datasets)}"
|
||||
)
|
||||
|
||||
|
||||
def benchmark_multimodal_processor(
|
||||
args: argparse.Namespace,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Run the multimodal processor benchmark.
|
||||
"""
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
validate_args(args)
|
||||
|
||||
if args.seed is None:
|
||||
args.seed = 0
|
||||
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
|
||||
tokenizer = llm.get_tokenizer()
|
||||
requests = get_requests(args, tokenizer)
|
||||
|
||||
assert all(
|
||||
llm.llm_engine.model_config.max_model_len
|
||||
>= (request.prompt_len + request.expected_output_len)
|
||||
for request in requests
|
||||
), (
|
||||
"Please ensure that max_model_len is greater than the sum of "
|
||||
"prompt_len and expected_output_len for all requests."
|
||||
)
|
||||
|
||||
prompts = [request.prompt for request in requests]
|
||||
expected_output_lens = [request.expected_output_len for request in requests]
|
||||
|
||||
sampling_params = [
|
||||
SamplingParams(
|
||||
n=1,
|
||||
temperature=0.0,
|
||||
max_tokens=output_len,
|
||||
detokenize=True,
|
||||
)
|
||||
for output_len in expected_output_lens
|
||||
]
|
||||
|
||||
selected_percentiles = [
|
||||
float(p) for p in getattr(args, "metric_percentiles", "99").split(",")
|
||||
]
|
||||
|
||||
freeze_gc_heap()
|
||||
|
||||
num_warmups = getattr(args, "num_warmups", 0)
|
||||
if num_warmups > 0:
|
||||
print(f"Processing {num_warmups} warmup requests...")
|
||||
# Create a temporary args object for warmup requests
|
||||
warmup_args = argparse.Namespace(**vars(args))
|
||||
warmup_args.num_prompts = num_warmups
|
||||
warmup_args.seed += 1
|
||||
warmup_requests = get_requests(warmup_args, tokenizer)
|
||||
warmup_prompts = [req.prompt for req in warmup_requests]
|
||||
warmup_output_lens = [req.expected_output_len for req in warmup_requests]
|
||||
warmup_sampling_params = [
|
||||
SamplingParams(max_tokens=output_len) for output_len in warmup_output_lens
|
||||
]
|
||||
llm.chat(
|
||||
warmup_prompts,
|
||||
warmup_sampling_params,
|
||||
use_tqdm=not getattr(args, "disable_tqdm", False),
|
||||
)
|
||||
|
||||
# Clear stats from warmup requests
|
||||
collect_mm_processor_stats(llm.llm_engine)
|
||||
|
||||
print(f"Processing {len(prompts)} requests...")
|
||||
start_time = time.perf_counter()
|
||||
|
||||
outputs = llm.chat(
|
||||
prompts, sampling_params, use_tqdm=not getattr(args, "disable_tqdm", False)
|
||||
)
|
||||
|
||||
end_time = time.perf_counter()
|
||||
total_time = end_time - start_time
|
||||
|
||||
mm_stats_by_stage = collect_mm_processor_stats(llm.llm_engine)
|
||||
|
||||
if not any(mm_stats_by_stage.values()):
|
||||
print(
|
||||
"\n⚠️ Warning: No MM processor stats found in registry.\n"
|
||||
" This may indicate that:\n"
|
||||
" - No multimodal requests were processed\n"
|
||||
" - Stats were already retrieved (registry is cleared after retrieval)\n"
|
||||
)
|
||||
|
||||
mm_processor_metrics = calculate_mm_processor_metrics(
|
||||
mm_stats_by_stage, selected_percentiles
|
||||
)
|
||||
|
||||
completed = len([o for o in outputs if o.finished])
|
||||
failed = len(outputs) - completed
|
||||
|
||||
e2el_times = []
|
||||
for output in outputs:
|
||||
if not output.finished or output.metrics is None:
|
||||
continue
|
||||
metrics = output.metrics
|
||||
# Calculate E2E latency as: TTFT + (last_token_ts - first_token_ts)
|
||||
if (
|
||||
getattr(metrics, "first_token_latency", None) is not None
|
||||
and getattr(metrics, "last_token_ts", None) is not None
|
||||
and getattr(metrics, "first_token_ts", None) is not None
|
||||
):
|
||||
ttft = metrics.first_token_latency
|
||||
# Decode time is the duration between the first and last token generation
|
||||
decode_time = max(0.0, metrics.last_token_ts - metrics.first_token_ts)
|
||||
e2el_times.append((ttft + decode_time) * 1000)
|
||||
|
||||
if not e2el_times and completed > 0:
|
||||
print(
|
||||
"\n⚠️ Warning: Detailed end-to-end latency metrics not available.\n"
|
||||
" Falling back to average request latency "
|
||||
"(total_time / num_completed_requests).\n"
|
||||
)
|
||||
avg_time_per_request = total_time / completed
|
||||
e2el_times = [avg_time_per_request * 1000] * completed
|
||||
|
||||
if e2el_times:
|
||||
mean_e2el_ms = float(np.mean(e2el_times))
|
||||
median_e2el_ms = float(np.median(e2el_times))
|
||||
std_e2el_ms = float(np.std(e2el_times))
|
||||
percentiles_e2el_ms = [
|
||||
(p, float(np.percentile(e2el_times, p))) for p in selected_percentiles
|
||||
]
|
||||
else:
|
||||
mean_e2el_ms = 0.0
|
||||
median_e2el_ms = 0.0
|
||||
std_e2el_ms = 0.0
|
||||
percentiles_e2el_ms = [(p, 0.0) for p in selected_percentiles]
|
||||
|
||||
encoder_summary = {}
|
||||
if (
|
||||
"num_encoder_calls" in mm_stats_by_stage
|
||||
and mm_stats_by_stage["num_encoder_calls"]
|
||||
):
|
||||
encoder_calls = mm_stats_by_stage["num_encoder_calls"]
|
||||
encoder_summary = {
|
||||
"total_encoder_calls": int(sum(encoder_calls)),
|
||||
"num_requests_with_encoder_calls": len(encoder_calls),
|
||||
}
|
||||
|
||||
benchmark_result = {
|
||||
"completed": completed,
|
||||
"failed": failed,
|
||||
"mean_e2el_ms": mean_e2el_ms,
|
||||
"median_e2el_ms": median_e2el_ms,
|
||||
"std_e2el_ms": std_e2el_ms,
|
||||
"percentiles_e2el_ms": percentiles_e2el_ms,
|
||||
"mm_processor_stats": mm_processor_metrics,
|
||||
"encoder_summary": encoder_summary,
|
||||
}
|
||||
|
||||
return benchmark_result
|
||||
|
||||
|
||||
def add_cli_args(parser: argparse.ArgumentParser) -> None:
|
||||
"""Add CLI arguments for the multimodal processor benchmark."""
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
|
||||
EngineArgs.add_cli_args(parser)
|
||||
|
||||
parser.set_defaults(enable_mm_processor_stats=True)
|
||||
|
||||
parser.add_argument(
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
default="random-mm",
|
||||
choices=["random-mm", "hf"],
|
||||
help="Name of the dataset to benchmark on. Defaults to 'random-mm'.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-prompts",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of prompts to process.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-warmups",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of warmup prompts to process.",
|
||||
)
|
||||
|
||||
from vllm.benchmarks.datasets import (
|
||||
add_random_dataset_base_args,
|
||||
add_random_multimodal_dataset_args,
|
||||
)
|
||||
|
||||
add_random_dataset_base_args(parser)
|
||||
add_random_multimodal_dataset_args(parser)
|
||||
|
||||
# HuggingFace dataset arguments
|
||||
parser.add_argument(
|
||||
"--dataset-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the dataset file or HuggingFace dataset name "
|
||||
"(e.g., 'yale-nlp/MMVU', 'lmarena-ai/VisionArena-Chat').",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-subset",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Subset of the HuggingFace dataset (optional).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-split",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Split of the HuggingFace dataset (e.g., 'train', 'test', 'validation').",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Output length for each request. "
|
||||
"Overrides the default output lengths from the dataset.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output-json",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to save the benchmark results in JSON format.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metric-percentiles",
|
||||
type=str,
|
||||
default="99",
|
||||
help="Comma-separated list of percentiles to calculate (e.g., '50,90,99').",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-tqdm",
|
||||
action="store_true",
|
||||
help="Disable tqdm progress bar.",
|
||||
)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace) -> None:
|
||||
"""Main entry point for the multimodal processor benchmark."""
|
||||
|
||||
print("Starting multimodal processor benchmark...")
|
||||
result = benchmark_multimodal_processor(args)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("Multimodal Processor Benchmark Results")
|
||||
print("=" * 80)
|
||||
|
||||
if "mm_processor_stats" in result:
|
||||
print("\nMM Processor Metrics:")
|
||||
selected_percentiles = [
|
||||
float(p) for p in getattr(args, "metric_percentiles", "99").split(",")
|
||||
]
|
||||
mm_data = []
|
||||
for stage, metrics in result["mm_processor_stats"].items():
|
||||
row = {
|
||||
"Stage": stage,
|
||||
"Mean": f"{metrics['mean']:.2f}",
|
||||
"Median": f"{metrics['median']:.2f}",
|
||||
"Std": f"{metrics['std']:.2f}",
|
||||
}
|
||||
for p in selected_percentiles:
|
||||
row[f"P{p}"] = f"{metrics.get(f'p{p}', 0.0):.2f}"
|
||||
mm_data.append(row)
|
||||
|
||||
mm_df = pd.DataFrame(mm_data)
|
||||
print(mm_df.to_string(index=False))
|
||||
|
||||
if "encoder_summary" in result and result["encoder_summary"]:
|
||||
total_calls = result["encoder_summary"]["total_encoder_calls"]
|
||||
num_requests = result["encoder_summary"]["num_requests_with_encoder_calls"]
|
||||
print(
|
||||
f"\nSummary: {total_calls} total encoder calls "
|
||||
f"across {num_requests} requests."
|
||||
)
|
||||
|
||||
if "mean_e2el_ms" in result:
|
||||
print("\nEnd-to-End Latency (ms):")
|
||||
selected_percentiles = [
|
||||
float(p) for p in getattr(args, "metric_percentiles", "99").split(",")
|
||||
]
|
||||
|
||||
e2el_data = [
|
||||
{"Metric": "Mean", "Value (ms)": f"{result['mean_e2el_ms']:.2f}"},
|
||||
{"Metric": "Median", "Value (ms)": f"{result['median_e2el_ms']:.2f}"},
|
||||
{"Metric": "Std", "Value (ms)": f"{result['std_e2el_ms']:.2f}"},
|
||||
]
|
||||
|
||||
for p in selected_percentiles:
|
||||
percentile_value = next(
|
||||
(val for pct, val in result["percentiles_e2el_ms"] if pct == p),
|
||||
0.0,
|
||||
)
|
||||
e2el_data.append(
|
||||
{
|
||||
"Metric": f"P{p}",
|
||||
"Value (ms)": f"{percentile_value:.2f}",
|
||||
}
|
||||
)
|
||||
|
||||
e2el_df = pd.DataFrame(e2el_data)
|
||||
print(e2el_df.to_string(index=False))
|
||||
|
||||
if args.output_json:
|
||||
result["config"] = {
|
||||
"model": args.model,
|
||||
"num_prompts": args.num_prompts,
|
||||
"input_len": getattr(args, "random_input_len", None),
|
||||
"output_len": getattr(args, "random_output_len", None),
|
||||
}
|
||||
result["timestamp"] = datetime.now().isoformat()
|
||||
|
||||
with open(args.output_json, "w") as f:
|
||||
json.dump(result, f, indent=2)
|
||||
print(f"\nResults saved to {args.output_json}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Benchmark mm processor latency")
|
||||
add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
1816
vllm/benchmarks/serve.py
Normal file
1816
vllm/benchmarks/serve.py
Normal file
File diff suppressed because it is too large
Load Diff
321
vllm/benchmarks/startup.py
Normal file
321
vllm/benchmarks/startup.py
Normal file
@@ -0,0 +1,321 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Benchmark the cold and warm startup time of vLLM models.
|
||||
|
||||
This script measures total startup time (including model loading, compilation,
|
||||
and cache operations) for both cold and warm scenarios:
|
||||
- Cold startup: Fresh start with no caches (temporary cache directories)
|
||||
- Warm startup: Using cached compilation and model info
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.benchmarks.lib.utils import (
|
||||
convert_to_pytorch_benchmark_format,
|
||||
write_to_json,
|
||||
)
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
|
||||
|
||||
@contextmanager
|
||||
def cold_startup():
|
||||
"""
|
||||
Context manager to measure cold startup time:
|
||||
1. Uses a temporary directory for vLLM cache to avoid any pollution
|
||||
between cold startup iterations.
|
||||
2. Uses inductor's fresh_cache to clear torch.compile caches.
|
||||
"""
|
||||
from torch._inductor.utils import fresh_cache
|
||||
|
||||
# Use temporary directory for caching to avoid any pollution between cold startups
|
||||
original_cache_root = os.environ.get("VLLM_CACHE_ROOT")
|
||||
temp_cache_dir = tempfile.mkdtemp(prefix="vllm_startup_bench_cold_")
|
||||
try:
|
||||
os.environ["VLLM_CACHE_ROOT"] = temp_cache_dir
|
||||
with fresh_cache():
|
||||
yield
|
||||
finally:
|
||||
# Clean up temporary cache directory
|
||||
shutil.rmtree(temp_cache_dir, ignore_errors=True)
|
||||
if original_cache_root:
|
||||
os.environ["VLLM_CACHE_ROOT"] = original_cache_root
|
||||
else:
|
||||
os.environ.pop("VLLM_CACHE_ROOT", None)
|
||||
|
||||
|
||||
def run_startup_in_subprocess(engine_args, result_queue):
|
||||
"""
|
||||
Run LLM startup in a subprocess and return timing metrics via a queue.
|
||||
This ensures complete isolation between iterations.
|
||||
"""
|
||||
try:
|
||||
# Import inside the subprocess to avoid issues with forking
|
||||
from vllm import LLM
|
||||
|
||||
# Measure total startup time
|
||||
start_time = time.perf_counter()
|
||||
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
|
||||
total_startup_time = time.perf_counter() - start_time
|
||||
|
||||
# Extract compilation time if available
|
||||
compilation_time = 0.0
|
||||
if hasattr(llm.llm_engine, "vllm_config"):
|
||||
vllm_config = llm.llm_engine.vllm_config
|
||||
if (
|
||||
hasattr(vllm_config, "compilation_config")
|
||||
and vllm_config.compilation_config is not None
|
||||
):
|
||||
compilation_time = vllm_config.compilation_config.compilation_time
|
||||
|
||||
result_queue.put(
|
||||
{
|
||||
"total_startup_time": total_startup_time,
|
||||
"compilation_time": compilation_time,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
result_queue.put(None)
|
||||
result_queue.put(str(e))
|
||||
|
||||
|
||||
def save_to_pytorch_benchmark_format(
|
||||
args: argparse.Namespace, results: dict[str, Any]
|
||||
) -> None:
|
||||
base_name = os.path.splitext(args.output_json)[0]
|
||||
|
||||
cold_startup_records = convert_to_pytorch_benchmark_format(
|
||||
args=args,
|
||||
metrics={
|
||||
"avg_cold_startup_time": [results["avg_cold_startup_time"]],
|
||||
},
|
||||
extra_info={
|
||||
"cold_startup_times": results["cold_startup_times"],
|
||||
"cold_startup_percentiles": results["cold_startup_percentiles"],
|
||||
},
|
||||
)
|
||||
if cold_startup_records:
|
||||
write_to_json(f"{base_name}.cold_startup.pytorch.json", cold_startup_records)
|
||||
|
||||
cold_compilation_records = convert_to_pytorch_benchmark_format(
|
||||
args=args,
|
||||
metrics={
|
||||
"avg_cold_compilation_time": [results["avg_cold_compilation_time"]],
|
||||
},
|
||||
extra_info={
|
||||
"cold_compilation_times": results["cold_compilation_times"],
|
||||
"cold_compilation_percentiles": results["cold_compilation_percentiles"],
|
||||
},
|
||||
)
|
||||
if cold_compilation_records:
|
||||
write_to_json(
|
||||
f"{base_name}.cold_compilation.pytorch.json", cold_compilation_records
|
||||
)
|
||||
|
||||
warm_startup_records = convert_to_pytorch_benchmark_format(
|
||||
args=args,
|
||||
metrics={
|
||||
"avg_warm_startup_time": [results["avg_warm_startup_time"]],
|
||||
},
|
||||
extra_info={
|
||||
"warm_startup_times": results["warm_startup_times"],
|
||||
"warm_startup_percentiles": results["warm_startup_percentiles"],
|
||||
},
|
||||
)
|
||||
if warm_startup_records:
|
||||
write_to_json(f"{base_name}.warm_startup.pytorch.json", warm_startup_records)
|
||||
|
||||
warm_compilation_records = convert_to_pytorch_benchmark_format(
|
||||
args=args,
|
||||
metrics={
|
||||
"avg_warm_compilation_time": [results["avg_warm_compilation_time"]],
|
||||
},
|
||||
extra_info={
|
||||
"warm_compilation_times": results["warm_compilation_times"],
|
||||
"warm_compilation_percentiles": results["warm_compilation_percentiles"],
|
||||
},
|
||||
)
|
||||
if warm_compilation_records:
|
||||
write_to_json(
|
||||
f"{base_name}.warm_compilation.pytorch.json", warm_compilation_records
|
||||
)
|
||||
|
||||
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--num-iters-cold",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Number of cold startup iterations.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-iters-warmup",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of warmup iterations before benchmarking warm startups.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-iters-warm",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Number of warm startup iterations.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-json",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to save the startup time results in JSON format.",
|
||||
)
|
||||
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
return parser
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
# Set multiprocessing start method to 'spawn' for clean process isolation
|
||||
# This ensures each subprocess starts fresh without inheriting state
|
||||
multiprocessing.set_start_method("spawn", force=True)
|
||||
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
|
||||
def create_llm_and_measure_startup():
|
||||
"""
|
||||
Create LLM instance in a subprocess and measure startup time.
|
||||
Returns timing metrics, using subprocess for complete isolation.
|
||||
"""
|
||||
|
||||
# Create a queue for inter-process communication
|
||||
result_queue = multiprocessing.Queue()
|
||||
process = multiprocessing.Process(
|
||||
target=run_startup_in_subprocess,
|
||||
args=(
|
||||
engine_args,
|
||||
result_queue,
|
||||
),
|
||||
)
|
||||
process.start()
|
||||
process.join()
|
||||
|
||||
if not result_queue.empty():
|
||||
result = result_queue.get()
|
||||
if result is None:
|
||||
if not result_queue.empty():
|
||||
error_msg = result_queue.get()
|
||||
raise RuntimeError(f"Subprocess failed: {error_msg}")
|
||||
else:
|
||||
raise RuntimeError("Subprocess failed with unknown error")
|
||||
return result
|
||||
else:
|
||||
raise RuntimeError("Subprocess did not return a result")
|
||||
|
||||
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
|
||||
print("Setting VLLM_ENABLE_V1_MULTIPROCESSING=0 to collect startup metrics.\n")
|
||||
|
||||
print("Measuring cold startup time...\n")
|
||||
cold_startup_times = []
|
||||
cold_compilation_times = []
|
||||
for i in tqdm(range(args.num_iters_cold), desc="Cold startup iterations"):
|
||||
with cold_startup():
|
||||
metrics = create_llm_and_measure_startup()
|
||||
cold_startup_times.append(metrics["total_startup_time"])
|
||||
cold_compilation_times.append(metrics["compilation_time"])
|
||||
|
||||
# Warmup for warm startup
|
||||
print("\nWarming up for warm startup measurement...\n")
|
||||
for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
|
||||
create_llm_and_measure_startup()
|
||||
|
||||
print("\nMeasuring warm startup time...\n")
|
||||
warm_startup_times = []
|
||||
warm_compilation_times = []
|
||||
for i in tqdm(range(args.num_iters_warm), desc="Warm startup iterations"):
|
||||
metrics = create_llm_and_measure_startup()
|
||||
warm_startup_times.append(metrics["total_startup_time"])
|
||||
warm_compilation_times.append(metrics["compilation_time"])
|
||||
|
||||
# Calculate statistics
|
||||
cold_startup_array = np.array(cold_startup_times)
|
||||
cold_compilation_array = np.array(cold_compilation_times)
|
||||
warm_startup_array = np.array(warm_startup_times)
|
||||
warm_compilation_array = np.array(warm_compilation_times)
|
||||
|
||||
avg_cold_startup = np.mean(cold_startup_array)
|
||||
avg_cold_compilation = np.mean(cold_compilation_array)
|
||||
avg_warm_startup = np.mean(warm_startup_array)
|
||||
avg_warm_compilation = np.mean(warm_compilation_array)
|
||||
|
||||
percentages = [10, 25, 50, 75, 90, 99]
|
||||
cold_startup_percentiles = np.percentile(cold_startup_array, percentages)
|
||||
cold_compilation_percentiles = np.percentile(cold_compilation_array, percentages)
|
||||
warm_startup_percentiles = np.percentile(warm_startup_array, percentages)
|
||||
warm_compilation_percentiles = np.percentile(warm_compilation_array, percentages)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("STARTUP TIME BENCHMARK RESULTS")
|
||||
print("=" * 60)
|
||||
|
||||
# Cold startup statistics
|
||||
print("\nCOLD STARTUP:")
|
||||
print(f"Avg total startup time: {avg_cold_startup:.2f} seconds")
|
||||
print(f"Avg compilation time: {avg_cold_compilation:.2f} seconds")
|
||||
print("Startup time percentiles:")
|
||||
for percentage, percentile in zip(percentages, cold_startup_percentiles):
|
||||
print(f" {percentage}%: {percentile:.2f} seconds")
|
||||
print("Compilation time percentiles:")
|
||||
for percentage, percentile in zip(percentages, cold_compilation_percentiles):
|
||||
print(f" {percentage}%: {percentile:.2f} seconds")
|
||||
|
||||
# Warm startup statistics
|
||||
print("\nWARM STARTUP:")
|
||||
print(f"Avg total startup time: {avg_warm_startup:.2f} seconds")
|
||||
print(f"Avg compilation time: {avg_warm_compilation:.2f} seconds")
|
||||
print("Startup time percentiles:")
|
||||
for percentage, percentile in zip(percentages, warm_startup_percentiles):
|
||||
print(f" {percentage}%: {percentile:.2f} seconds")
|
||||
print("Compilation time percentiles:")
|
||||
for percentage, percentile in zip(percentages, warm_compilation_percentiles):
|
||||
print(f" {percentage}%: {percentile:.2f} seconds")
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
# Output JSON results if specified
|
||||
if args.output_json:
|
||||
results = {
|
||||
"avg_cold_startup_time": float(avg_cold_startup),
|
||||
"avg_cold_compilation_time": float(avg_cold_compilation),
|
||||
"cold_startup_times": cold_startup_times,
|
||||
"cold_compilation_times": cold_compilation_times,
|
||||
"cold_startup_percentiles": dict(
|
||||
zip(percentages, cold_startup_percentiles.tolist())
|
||||
),
|
||||
"cold_compilation_percentiles": dict(
|
||||
zip(percentages, cold_compilation_percentiles.tolist())
|
||||
),
|
||||
"avg_warm_startup_time": float(avg_warm_startup),
|
||||
"avg_warm_compilation_time": float(avg_warm_compilation),
|
||||
"warm_startup_times": warm_startup_times,
|
||||
"warm_compilation_times": warm_compilation_times,
|
||||
"warm_startup_percentiles": dict(
|
||||
zip(percentages, warm_startup_percentiles.tolist())
|
||||
),
|
||||
"warm_compilation_percentiles": dict(
|
||||
zip(percentages, warm_compilation_percentiles.tolist())
|
||||
),
|
||||
}
|
||||
with open(args.output_json, "w") as f:
|
||||
json.dump(results, f, indent=4)
|
||||
save_to_pytorch_benchmark_format(args, results)
|
||||
0
vllm/benchmarks/sweep/__init__.py
Normal file
0
vllm/benchmarks/sweep/__init__.py
Normal file
44
vllm/benchmarks/sweep/cli.py
Normal file
44
vllm/benchmarks/sweep/cli.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
|
||||
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG
|
||||
|
||||
from .plot import SweepPlotArgs
|
||||
from .plot import main as plot_main
|
||||
from .plot_pareto import SweepPlotParetoArgs
|
||||
from .plot_pareto import main as plot_pareto_main
|
||||
from .serve import SweepServeArgs
|
||||
from .serve import main as serve_main
|
||||
from .serve_sla import SweepServeSLAArgs
|
||||
from .serve_sla import main as serve_sla_main
|
||||
from .startup import SweepStartupArgs
|
||||
from .startup import main as startup_main
|
||||
|
||||
SUBCOMMANDS = (
|
||||
(SweepServeArgs, serve_main),
|
||||
(SweepServeSLAArgs, serve_sla_main),
|
||||
(SweepStartupArgs, startup_main),
|
||||
(SweepPlotArgs, plot_main),
|
||||
(SweepPlotParetoArgs, plot_pareto_main),
|
||||
)
|
||||
|
||||
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
subparsers = parser.add_subparsers(required=True, dest="sweep_type")
|
||||
|
||||
for cmd, entrypoint in SUBCOMMANDS:
|
||||
cmd_subparser = subparsers.add_parser(
|
||||
cmd.parser_name,
|
||||
description=cmd.parser_help,
|
||||
usage=f"vllm bench sweep {cmd.parser_name} [options]",
|
||||
)
|
||||
cmd_subparser.set_defaults(dispatch_function=entrypoint)
|
||||
cmd.add_cli_args(cmd_subparser)
|
||||
cmd_subparser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format(
|
||||
subcmd=f"sweep {cmd.parser_name}"
|
||||
)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
args.dispatch_function(args)
|
||||
159
vllm/benchmarks/sweep/param_sweep.py
Normal file
159
vllm/benchmarks/sweep/param_sweep.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ParameterSweep(list["ParameterSweepItem"]):
|
||||
@classmethod
|
||||
def read_json(cls, filepath: os.PathLike):
|
||||
with open(filepath, "rb") as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Support both list and dict formats
|
||||
if isinstance(data, dict):
|
||||
return cls.read_from_dict(data)
|
||||
|
||||
return cls.from_records(data)
|
||||
|
||||
@classmethod
|
||||
def read_from_dict(cls, data: dict[str, dict[str, object]]):
|
||||
"""
|
||||
Read parameter sweep from a dict format where keys are names.
|
||||
|
||||
Example:
|
||||
{
|
||||
"experiment1": {"max_tokens": 100, "temperature": 0.7},
|
||||
"experiment2": {"max_tokens": 200, "temperature": 0.9}
|
||||
}
|
||||
"""
|
||||
records = [{"_benchmark_name": name, **params} for name, params in data.items()]
|
||||
return cls.from_records(records)
|
||||
|
||||
@classmethod
|
||||
def from_records(cls, records: list[dict[str, object]]):
|
||||
if not isinstance(records, list):
|
||||
raise TypeError(
|
||||
f"The parameter sweep should be a list of dictionaries, "
|
||||
f"but found type: {type(records)}"
|
||||
)
|
||||
|
||||
# Validate that all _benchmark_name values are unique if provided
|
||||
names = [r["_benchmark_name"] for r in records if "_benchmark_name" in r]
|
||||
if names and len(names) != len(set(names)):
|
||||
duplicates = [name for name in names if names.count(name) > 1]
|
||||
raise ValueError(
|
||||
f"Duplicate _benchmark_name values found: {set(duplicates)}. "
|
||||
f"All _benchmark_name values must be unique."
|
||||
)
|
||||
|
||||
return cls(ParameterSweepItem.from_record(record) for record in records)
|
||||
|
||||
|
||||
class ParameterSweepItem(dict[str, object]):
|
||||
@classmethod
|
||||
def from_record(cls, record: dict[str, object]):
|
||||
if not isinstance(record, dict):
|
||||
raise TypeError(
|
||||
f"Each item in the parameter sweep should be a dictionary, "
|
||||
f"but found type: {type(record)}"
|
||||
)
|
||||
|
||||
return cls(record)
|
||||
|
||||
def __or__(self, other: dict[str, Any]):
|
||||
return type(self)(super().__or__(other))
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""
|
||||
Get the name for this parameter sweep item.
|
||||
|
||||
Returns the '_benchmark_name' field if present, otherwise returns a text
|
||||
representation of all parameters.
|
||||
"""
|
||||
if "_benchmark_name" in self:
|
||||
return str(self["_benchmark_name"])
|
||||
|
||||
return self.as_text(sep="-")
|
||||
|
||||
# In JSON, we prefer "_"
|
||||
def _iter_param_key_candidates(self, param_key: str):
|
||||
# Inner config arguments are not converted by the CLI
|
||||
if "." in param_key:
|
||||
prefix, rest = param_key.split(".", 1)
|
||||
for prefix_candidate in self._iter_param_key_candidates(prefix):
|
||||
yield prefix_candidate + "." + rest
|
||||
|
||||
return
|
||||
|
||||
yield param_key
|
||||
yield param_key.replace("-", "_")
|
||||
yield param_key.replace("_", "-")
|
||||
|
||||
# In CLI, we prefer "-"
|
||||
def _iter_cmd_key_candidates(self, param_key: str):
|
||||
for k in reversed(tuple(self._iter_param_key_candidates(param_key))):
|
||||
yield "--" + k
|
||||
|
||||
def _normalize_cmd_key(self, param_key: str):
|
||||
return next(self._iter_cmd_key_candidates(param_key))
|
||||
|
||||
def has_param(self, param_key: str) -> bool:
|
||||
return any(k in self for k in self._iter_param_key_candidates(param_key))
|
||||
|
||||
def _normalize_cmd_kv_pair(self, k: str, v: object) -> list[str]:
|
||||
"""
|
||||
Normalize a key-value pair into command-line arguments.
|
||||
|
||||
Returns a list containing either:
|
||||
- A single element for boolean flags (e.g., ['--flag'] or ['--flag=true'])
|
||||
- Two elements for key-value pairs (e.g., ['--key', 'value'])
|
||||
"""
|
||||
if isinstance(v, bool):
|
||||
# For nested params (containing "."), use =true/false syntax
|
||||
if "." in k:
|
||||
return [f"{self._normalize_cmd_key(k)}={'true' if v else 'false'}"]
|
||||
else:
|
||||
return [self._normalize_cmd_key(k if v else "no-" + k)]
|
||||
else:
|
||||
return [self._normalize_cmd_key(k), str(v)]
|
||||
|
||||
def apply_to_cmd(self, cmd: list[str]) -> list[str]:
|
||||
cmd = list(cmd)
|
||||
|
||||
for k, v in self.items():
|
||||
# Skip the '_benchmark_name' field, not a parameter
|
||||
if k == "_benchmark_name":
|
||||
continue
|
||||
|
||||
# Serialize dict values as JSON
|
||||
if isinstance(v, dict):
|
||||
v = json.dumps(v)
|
||||
|
||||
for k_candidate in self._iter_cmd_key_candidates(k):
|
||||
try:
|
||||
k_idx = cmd.index(k_candidate)
|
||||
|
||||
# Replace existing parameter
|
||||
normalized = self._normalize_cmd_kv_pair(k, v)
|
||||
if len(normalized) == 1:
|
||||
# Boolean flag
|
||||
cmd[k_idx] = normalized[0]
|
||||
else:
|
||||
# Key-value pair
|
||||
cmd[k_idx] = normalized[0]
|
||||
cmd[k_idx + 1] = normalized[1]
|
||||
|
||||
break
|
||||
except ValueError:
|
||||
continue
|
||||
else:
|
||||
# Add new parameter
|
||||
cmd.extend(self._normalize_cmd_kv_pair(k, v))
|
||||
|
||||
return cmd
|
||||
|
||||
def as_text(self, sep: str = ", ") -> str:
|
||||
return sep.join(f"{k}={v}" for k, v in self.items() if k != "_benchmark_name")
|
||||
683
vllm/benchmarks/sweep/plot.py
Normal file
683
vllm/benchmarks/sweep/plot.py
Normal file
@@ -0,0 +1,683 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from types import TracebackType
|
||||
from typing import ClassVar
|
||||
|
||||
from typing_extensions import Self, override
|
||||
|
||||
from vllm.utils.collection_utils import full_groupby
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
from .utils import sanitize_filename
|
||||
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
except ImportError:
|
||||
plt = PlaceholderModule("matplotlib").placeholder_attr("pyplot")
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pd = PlaceholderModule("pandas")
|
||||
|
||||
try:
|
||||
import seaborn as sns
|
||||
except ImportError:
|
||||
seaborn = PlaceholderModule("seaborn")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotFilterBase(ABC):
|
||||
var: str
|
||||
target: str
|
||||
|
||||
@classmethod
|
||||
def parse_str(cls, s: str):
|
||||
for op_key in PLOT_FILTERS:
|
||||
if op_key in s:
|
||||
key, value = s.split(op_key)
|
||||
return PLOT_FILTERS[op_key](
|
||||
key,
|
||||
value.removeprefix(op_key).strip("'").strip('"'),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid operator for plot filter '{s}'. "
|
||||
f"Valid operators are: {sorted(PLOT_FILTERS)}",
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
"""Applies this filter to a DataFrame."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotEqualTo(PlotFilterBase):
|
||||
@override
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
try:
|
||||
target = float(self.target)
|
||||
except ValueError:
|
||||
target = self.target
|
||||
|
||||
return df[df[self.var] == target]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotNotEqualTo(PlotFilterBase):
|
||||
@override
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
try:
|
||||
target = float(self.target)
|
||||
except ValueError:
|
||||
target = self.target
|
||||
|
||||
return df[df[self.var] != target]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotLessThan(PlotFilterBase):
|
||||
@override
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
return df[df[self.var] < float(self.target)]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotLessThanOrEqualTo(PlotFilterBase):
|
||||
@override
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
return df[df[self.var] <= float(self.target)]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotGreaterThan(PlotFilterBase):
|
||||
@override
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
return df[df[self.var] > float(self.target)]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotGreaterThanOrEqualTo(PlotFilterBase):
|
||||
@override
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
return df[df[self.var] >= float(self.target)]
|
||||
|
||||
|
||||
# NOTE: The ordering is important! Match longer op_keys first
|
||||
PLOT_FILTERS: dict[str, type[PlotFilterBase]] = {
|
||||
"==": PlotEqualTo,
|
||||
"!=": PlotNotEqualTo,
|
||||
"<=": PlotLessThanOrEqualTo,
|
||||
">=": PlotGreaterThanOrEqualTo,
|
||||
"<": PlotLessThan,
|
||||
">": PlotGreaterThan,
|
||||
}
|
||||
|
||||
|
||||
class PlotFilters(list[PlotFilterBase]):
|
||||
@classmethod
|
||||
def parse_str(cls, s: str):
|
||||
if not s:
|
||||
return cls()
|
||||
|
||||
return cls(PlotFilterBase.parse_str(e) for e in s.split(","))
|
||||
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
for item in self:
|
||||
df = item.apply(df)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotBinner:
|
||||
var: str
|
||||
bin_size: float
|
||||
|
||||
@classmethod
|
||||
def parse_str(cls, s: str):
|
||||
for op_key in PLOT_BINNERS:
|
||||
if op_key in s:
|
||||
key, value = s.split(op_key)
|
||||
return PLOT_BINNERS[op_key](key, float(value.removeprefix(op_key)))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid operator for plot binner '{s}'. "
|
||||
f"Valid operators are: {sorted(PLOT_BINNERS)}",
|
||||
)
|
||||
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
"""Applies this binner to a DataFrame."""
|
||||
df = df.copy()
|
||||
df[self.var] = df[self.var] // self.bin_size * self.bin_size
|
||||
return df
|
||||
|
||||
|
||||
PLOT_BINNERS: dict[str, type[PlotBinner]] = {
|
||||
"%": PlotBinner,
|
||||
}
|
||||
|
||||
|
||||
class PlotBinners(list[PlotBinner]):
|
||||
@classmethod
|
||||
def parse_str(cls, s: str):
|
||||
if not s:
|
||||
return cls()
|
||||
|
||||
return cls(PlotBinner.parse_str(e) for e in s.split(","))
|
||||
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
for item in self:
|
||||
df = item.apply(df)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def _json_load_bytes(path: Path) -> list[dict[str, object]]:
|
||||
with path.open("rb") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def _convert_inf_nan_strings(data: list[dict[str, object]]) -> list[dict[str, object]]:
|
||||
"""
|
||||
Convert string values "inf", "-inf", and "nan" to their float equivalents.
|
||||
|
||||
This handles the case where JSON serialization represents inf/nan as strings.
|
||||
"""
|
||||
converted_data = []
|
||||
for record in data:
|
||||
converted_record = {}
|
||||
for key, value in record.items():
|
||||
if isinstance(value, str):
|
||||
if value in ["inf", "-inf", "nan"]:
|
||||
converted_record[key] = float(value)
|
||||
else:
|
||||
converted_record[key] = value
|
||||
else:
|
||||
converted_record[key] = value
|
||||
converted_data.append(converted_record)
|
||||
return converted_data
|
||||
|
||||
|
||||
def _get_metric(run_data: dict[str, object], metric_key: str):
|
||||
try:
|
||||
return run_data[metric_key]
|
||||
except KeyError as exc:
|
||||
raise ValueError(f"Cannot find metric {metric_key!r} in {run_data=}") from exc
|
||||
|
||||
|
||||
def _get_group(run_data: dict[str, object], group_keys: list[str]):
|
||||
return tuple((k, str(_get_metric(run_data, k))) for k in group_keys)
|
||||
|
||||
|
||||
def _get_fig_path(fig_dir: Path, group: tuple[tuple[str, str], ...], fig_name: str):
|
||||
parts = list[str]()
|
||||
|
||||
# Start with figure name (always provided, defaults to "FIGURE")
|
||||
parts.append(fig_name)
|
||||
|
||||
# Always append group data if present
|
||||
if group:
|
||||
parts.extend(f"{k}={v}" for k, v in group)
|
||||
|
||||
return fig_dir / sanitize_filename("-".join(parts) + ".png")
|
||||
|
||||
|
||||
class DummyExecutor:
|
||||
map = map
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
exc_traceback: TracebackType | None,
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def _plot_fig(
|
||||
fig_dir: Path,
|
||||
fig_group_data: tuple[tuple[tuple[str, str], ...], list[dict[str, object]]],
|
||||
row_by: list[str],
|
||||
col_by: list[str],
|
||||
curve_by: list[str],
|
||||
*,
|
||||
var_x: str,
|
||||
var_y: str,
|
||||
filter_by: PlotFilters,
|
||||
bin_by: PlotBinners,
|
||||
scale_x: str | None,
|
||||
scale_y: str | None,
|
||||
dry_run: bool,
|
||||
fig_name: str,
|
||||
error_bars: bool,
|
||||
fig_height: float,
|
||||
fig_dpi: int,
|
||||
):
|
||||
fig_group, fig_data = fig_group_data
|
||||
|
||||
row_groups = full_groupby(
|
||||
fig_data,
|
||||
key=lambda item: _get_group(item, row_by),
|
||||
)
|
||||
num_rows = len(row_groups)
|
||||
num_cols = max(
|
||||
len(full_groupby(row_data, key=lambda item: _get_group(item, col_by)))
|
||||
for _, row_data in row_groups
|
||||
)
|
||||
|
||||
fig_path = _get_fig_path(fig_dir, fig_group, fig_name)
|
||||
|
||||
print("[BEGIN FIGURE]")
|
||||
print(f"Group: {dict(fig_group)}")
|
||||
print(f"Grid: {num_rows} rows x {num_cols} cols")
|
||||
print(f"Output file: {fig_path}")
|
||||
|
||||
if dry_run:
|
||||
print("[END FIGURE]")
|
||||
return
|
||||
|
||||
# Convert string "inf", "-inf", and "nan" to their float equivalents
|
||||
fig_data = _convert_inf_nan_strings(fig_data)
|
||||
df = pd.DataFrame.from_records(fig_data)
|
||||
|
||||
if var_x not in df.columns:
|
||||
raise ValueError(
|
||||
f"Cannot find {var_x=!r} in parameter sweep results. "
|
||||
f"Available variables: {df.columns.tolist()}"
|
||||
)
|
||||
if var_y not in df.columns:
|
||||
raise ValueError(
|
||||
f"Cannot find {var_y=!r} in parameter sweep results. "
|
||||
f"Available variables: {df.columns.tolist()}"
|
||||
)
|
||||
for k in row_by:
|
||||
if k not in df.columns:
|
||||
raise ValueError(
|
||||
f"Cannot find row_by={k!r} in parameter sweep results. "
|
||||
f"Available variables: {df.columns.tolist()}"
|
||||
)
|
||||
for k in col_by:
|
||||
if k not in df.columns:
|
||||
raise ValueError(
|
||||
f"Cannot find col_by={k!r} in parameter sweep results. "
|
||||
f"Available variables: {df.columns.tolist()}"
|
||||
)
|
||||
for k in curve_by:
|
||||
if k not in df.columns:
|
||||
raise ValueError(
|
||||
f"Cannot find curve_by={k!r} in parameter sweep results. "
|
||||
f"Available variables: {df.columns.tolist()}"
|
||||
)
|
||||
|
||||
df = filter_by.apply(df)
|
||||
df = bin_by.apply(df)
|
||||
|
||||
# Sort by curve_by columns alphabetically for consistent legend ordering
|
||||
if curve_by:
|
||||
df = df.sort_values(by=curve_by)
|
||||
|
||||
df["row_group"] = (
|
||||
pd.concat(
|
||||
[k + "=" + df[k].astype(str) for k in row_by],
|
||||
axis=1,
|
||||
).agg("\n".join, axis=1)
|
||||
if row_by
|
||||
else "(All)"
|
||||
)
|
||||
|
||||
df["col_group"] = (
|
||||
pd.concat(
|
||||
[k + "=" + df[k].astype(str) for k in col_by],
|
||||
axis=1,
|
||||
).agg("\n".join, axis=1)
|
||||
if col_by
|
||||
else "(All)"
|
||||
)
|
||||
|
||||
if len(curve_by) <= 3:
|
||||
hue, style, size, *_ = (*curve_by, None, None, None)
|
||||
|
||||
g = sns.relplot(
|
||||
df,
|
||||
x=var_x,
|
||||
y=var_y,
|
||||
hue=hue,
|
||||
style=style,
|
||||
size=size,
|
||||
markers=True,
|
||||
errorbar="sd" if error_bars else None,
|
||||
kind="line",
|
||||
row="row_group",
|
||||
col="col_group",
|
||||
height=fig_height,
|
||||
)
|
||||
else:
|
||||
df["curve_group"] = (
|
||||
pd.concat(
|
||||
[k + "=" + df[k].astype(str) for k in curve_by],
|
||||
axis=1,
|
||||
).agg("\n".join, axis=1)
|
||||
if curve_by
|
||||
else "(All)"
|
||||
)
|
||||
|
||||
g = sns.relplot(
|
||||
df,
|
||||
x=var_x,
|
||||
y=var_y,
|
||||
hue="curve_group",
|
||||
markers=True,
|
||||
errorbar="sd" if error_bars else None,
|
||||
kind="line",
|
||||
row="row_group",
|
||||
col="col_group",
|
||||
height=fig_height,
|
||||
)
|
||||
|
||||
if row_by and col_by:
|
||||
g.set_titles("{row_name}\n{col_name}")
|
||||
elif row_by:
|
||||
g.set_titles("{row_name}")
|
||||
elif col_by:
|
||||
g.set_titles("{col_name}")
|
||||
else:
|
||||
g.set_titles("")
|
||||
|
||||
if scale_x:
|
||||
g.set(xscale=scale_x)
|
||||
if scale_y:
|
||||
g.set(yscale=scale_y)
|
||||
|
||||
g.savefig(fig_path, dpi=fig_dpi)
|
||||
plt.close(g.figure)
|
||||
|
||||
print("[END FIGURE]")
|
||||
|
||||
|
||||
def plot(
|
||||
output_dir: Path,
|
||||
fig_dir: Path,
|
||||
fig_by: list[str],
|
||||
row_by: list[str],
|
||||
col_by: list[str],
|
||||
curve_by: list[str],
|
||||
*,
|
||||
var_x: str,
|
||||
var_y: str,
|
||||
filter_by: PlotFilters,
|
||||
bin_by: PlotBinners,
|
||||
scale_x: str | None,
|
||||
scale_y: str | None,
|
||||
dry_run: bool,
|
||||
fig_name: str = "FIGURE",
|
||||
error_bars: bool = True,
|
||||
fig_height: float = 6.4,
|
||||
fig_dpi: int = 300,
|
||||
):
|
||||
all_data = [
|
||||
run_data
|
||||
for path in output_dir.rglob("**/summary.json")
|
||||
for run_data in _json_load_bytes(path)
|
||||
]
|
||||
|
||||
if not all_data:
|
||||
raise ValueError(f"Did not find any parameter sweep results under {output_dir}")
|
||||
|
||||
fig_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
fig_groups = full_groupby(
|
||||
all_data,
|
||||
key=lambda item: _get_group(item, fig_by),
|
||||
)
|
||||
|
||||
with DummyExecutor() if len(fig_groups) <= 1 else ProcessPoolExecutor() as executor:
|
||||
# Resolve the iterable to ensure that the workers are run
|
||||
all(
|
||||
executor.map(
|
||||
partial(
|
||||
_plot_fig,
|
||||
fig_dir,
|
||||
row_by=row_by,
|
||||
col_by=col_by,
|
||||
curve_by=curve_by,
|
||||
var_x=var_x,
|
||||
var_y=var_y,
|
||||
filter_by=filter_by,
|
||||
bin_by=bin_by,
|
||||
scale_x=scale_x,
|
||||
scale_y=scale_y,
|
||||
dry_run=dry_run,
|
||||
fig_name=fig_name,
|
||||
error_bars=error_bars,
|
||||
fig_height=fig_height,
|
||||
fig_dpi=fig_dpi,
|
||||
),
|
||||
fig_groups,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SweepPlotArgs:
|
||||
output_dir: Path
|
||||
fig_dir: Path
|
||||
fig_by: list[str]
|
||||
row_by: list[str]
|
||||
col_by: list[str]
|
||||
curve_by: list[str]
|
||||
var_x: str
|
||||
var_y: str
|
||||
filter_by: PlotFilters
|
||||
bin_by: PlotBinners
|
||||
scale_x: str | None
|
||||
scale_y: str | None
|
||||
dry_run: bool
|
||||
fig_name: str = "FIGURE"
|
||||
error_bars: bool = True
|
||||
fig_height: float = 6.4
|
||||
fig_dpi: int = 300
|
||||
|
||||
parser_name: ClassVar[str] = "plot"
|
||||
parser_help: ClassVar[str] = "Plot performance curves from parameter sweep results."
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
output_dir = Path(args.OUTPUT_DIR)
|
||||
if not output_dir.exists():
|
||||
raise ValueError(f"No parameter sweep results under {output_dir}")
|
||||
|
||||
curve_by = [] if not args.curve_by else args.curve_by.split(",")
|
||||
row_by = [] if not args.row_by else args.row_by.split(",")
|
||||
col_by = [] if not args.col_by else args.col_by.split(",")
|
||||
fig_by = [] if not args.fig_by else args.fig_by.split(",")
|
||||
|
||||
return cls(
|
||||
output_dir=output_dir,
|
||||
fig_dir=output_dir / args.fig_dir,
|
||||
fig_by=fig_by,
|
||||
row_by=row_by,
|
||||
col_by=col_by,
|
||||
curve_by=curve_by,
|
||||
var_x=args.var_x,
|
||||
var_y=args.var_y,
|
||||
filter_by=PlotFilters.parse_str(args.filter_by),
|
||||
bin_by=PlotBinners.parse_str(args.bin_by),
|
||||
scale_x=args.scale_x,
|
||||
scale_y=args.scale_y,
|
||||
dry_run=args.dry_run,
|
||||
fig_name=args.fig_name,
|
||||
error_bars=not args.no_error_bars,
|
||||
fig_height=args.fig_height,
|
||||
fig_dpi=args.fig_dpi,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
"OUTPUT_DIR",
|
||||
type=str,
|
||||
default="results",
|
||||
help="The directory containing the results to plot, "
|
||||
"i.e., the `--output-dir` argument to the parameter sweep script.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fig-dir",
|
||||
type=str,
|
||||
default="",
|
||||
help="The directory to save the figures, relative to `OUTPUT_DIR`. "
|
||||
"By default, the same directory is used.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fig-by",
|
||||
type=str,
|
||||
default="",
|
||||
help="A comma-separated list of variables, such that a separate figure "
|
||||
"is created for each combination of these variables.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--row-by",
|
||||
type=str,
|
||||
default="",
|
||||
help="A comma-separated list of variables, such that a separate row "
|
||||
"is created for each combination of these variables.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--col-by",
|
||||
type=str,
|
||||
default="",
|
||||
help="A comma-separated list of variables, such that a separate column "
|
||||
"is created for each combination of these variables.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--curve-by",
|
||||
type=str,
|
||||
default=None,
|
||||
help="A comma-separated list of variables, such that a separate curve "
|
||||
"is created for each combination of these variables.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--var-x",
|
||||
type=str,
|
||||
default="request_throughput",
|
||||
help="The variable for the x-axis.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--var-y",
|
||||
type=str,
|
||||
default="p99_ttft_ms",
|
||||
help="The variable for the y-axis",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--filter-by",
|
||||
type=str,
|
||||
default="",
|
||||
help="A comma-separated list of statements indicating values to filter by. "
|
||||
"This is useful to remove outliers. "
|
||||
"Example: `max_concurrency<1000,max_num_batched_tokens<=4096` means "
|
||||
"plot only the points where `max_concurrency` is less than 1000 and "
|
||||
"`max_num_batched_tokens` is no greater than 4096.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bin-by",
|
||||
type=str,
|
||||
default="",
|
||||
help="A comma-separated list of statements indicating values to bin by. "
|
||||
"This is useful to avoid plotting points that are too close together. "
|
||||
"Example: `request_throughput%%1` means "
|
||||
"use a bin size of 1 for the `request_throughput` variable.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale-x",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The scale to use for the x-axis. "
|
||||
"Currently only accepts string values such as 'log' and 'sqrt'. "
|
||||
"See also: https://seaborn.pydata.org/generated/seaborn.objects.Plot.scale.html",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale-y",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The scale to use for the y-axis. "
|
||||
"Currently only accepts string values such as 'log' and 'sqrt'. "
|
||||
"See also: https://seaborn.pydata.org/generated/seaborn.objects.Plot.scale.html",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fig-name",
|
||||
type=str,
|
||||
default="FIGURE",
|
||||
help="Name prefix for the output figure file. "
|
||||
"Group data is always appended when present. "
|
||||
"Default: 'FIGURE'. Example: --fig-name my_performance_plot",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-error-bars",
|
||||
action="store_true",
|
||||
help="If set, disables error bars on the plot. "
|
||||
"By default, error bars are shown.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fig-height",
|
||||
type=float,
|
||||
default=6.4,
|
||||
help="Height of each subplot in inches. Default: 6.4",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fig-dpi",
|
||||
type=int,
|
||||
default=300,
|
||||
help="Resolution of the output figure in dots per inch. Default: 300",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="If set, prints the information about each figure to plot, "
|
||||
"then exits without drawing them.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def run_main(args: SweepPlotArgs):
|
||||
return plot(
|
||||
output_dir=args.output_dir,
|
||||
fig_dir=args.fig_dir,
|
||||
fig_by=args.fig_by,
|
||||
row_by=args.row_by,
|
||||
col_by=args.col_by,
|
||||
curve_by=args.curve_by,
|
||||
var_x=args.var_x,
|
||||
var_y=args.var_y,
|
||||
filter_by=args.filter_by,
|
||||
bin_by=args.bin_by,
|
||||
scale_x=args.scale_x,
|
||||
scale_y=args.scale_y,
|
||||
dry_run=args.dry_run,
|
||||
fig_name=args.fig_name,
|
||||
error_bars=args.error_bars,
|
||||
fig_height=args.fig_height,
|
||||
fig_dpi=args.fig_dpi,
|
||||
)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
run_main(SweepPlotArgs.from_cli_args(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=SweepPlotArgs.parser_help)
|
||||
SweepPlotArgs.add_cli_args(parser)
|
||||
|
||||
main(parser.parse_args())
|
||||
399
vllm/benchmarks/sweep/plot_pareto.py
Normal file
399
vllm/benchmarks/sweep/plot_pareto.py
Normal file
@@ -0,0 +1,399 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import math
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import ClassVar
|
||||
|
||||
from vllm.utils.collection_utils import full_groupby
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
from .plot import DummyExecutor, _json_load_bytes
|
||||
from .utils import sanitize_filename
|
||||
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
except ImportError:
|
||||
plt = PlaceholderModule("matplotlib").placeholder_attr("pyplot")
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pd = PlaceholderModule("pandas")
|
||||
|
||||
try:
|
||||
import seaborn as sns
|
||||
except ImportError:
|
||||
seaborn = PlaceholderModule("seaborn")
|
||||
|
||||
|
||||
def _first_present(run_data: dict[str, object], keys: list[str]):
|
||||
for key in keys:
|
||||
for candidate in {key, key.replace("_", "-"), key.replace("-", "_")}:
|
||||
if candidate in run_data:
|
||||
return run_data[candidate]
|
||||
return None
|
||||
|
||||
|
||||
def _get_numeric(
|
||||
run_data: dict[str, object],
|
||||
keys: list[str],
|
||||
*,
|
||||
allow_zero: bool = True,
|
||||
) -> float | None:
|
||||
value = _first_present(run_data, keys)
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
numeric = float(value)
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise ValueError(
|
||||
f"Expected numeric value for one of {keys}, "
|
||||
f"but found {value!r} in {run_data=}"
|
||||
) from exc
|
||||
|
||||
if not allow_zero and numeric == 0:
|
||||
return None
|
||||
|
||||
return numeric
|
||||
|
||||
|
||||
def _infer_user_count(
|
||||
run_data: dict[str, object],
|
||||
user_count_var: str | None,
|
||||
) -> float | None:
|
||||
candidates = [user_count_var] if user_count_var else []
|
||||
candidates.extend(["request_rate"])
|
||||
user_count = _get_numeric(run_data, candidates, allow_zero=False)
|
||||
if user_count is not None:
|
||||
return user_count
|
||||
|
||||
# Fallback to the observed peak if configured value is missing.
|
||||
return _get_numeric(run_data, ["max_concurrent_requests"], allow_zero=False)
|
||||
|
||||
|
||||
def _infer_gpu_count(
|
||||
run_data: dict[str, object],
|
||||
gpu_count_var: str | None,
|
||||
) -> float:
|
||||
direct_candidates = [gpu_count_var] if gpu_count_var else []
|
||||
direct_gpu_count = _get_numeric(run_data, direct_candidates, allow_zero=False)
|
||||
if direct_gpu_count:
|
||||
return direct_gpu_count
|
||||
|
||||
tp_size = _get_numeric(run_data, ["tensor_parallel_size", "tp"])
|
||||
pp_size = _get_numeric(run_data, ["pipeline_parallel_size", "pp"])
|
||||
dp_size = _get_numeric(run_data, ["data_parallel_size", "dp"])
|
||||
world_size = 1.0
|
||||
if tp_size:
|
||||
world_size *= tp_size
|
||||
if pp_size:
|
||||
world_size *= pp_size
|
||||
if dp_size:
|
||||
world_size *= dp_size
|
||||
|
||||
return world_size
|
||||
|
||||
|
||||
def _get_throughput(
|
||||
run_data: dict[str, object],
|
||||
throughput_var: str,
|
||||
) -> float:
|
||||
throughput = _get_numeric(run_data, [throughput_var])
|
||||
if throughput is None:
|
||||
raise ValueError(
|
||||
f"Cannot find throughput metric {throughput_var!r} in run data. "
|
||||
f"Available keys: {sorted(run_data)}"
|
||||
)
|
||||
|
||||
return throughput
|
||||
|
||||
|
||||
def _prepare_records(
|
||||
all_data: list[dict[str, object]],
|
||||
*,
|
||||
user_count_var: str | None,
|
||||
gpu_count_var: str | None,
|
||||
) -> tuple[list[dict[str, object]], int]:
|
||||
prepared = []
|
||||
skipped_missing_users = 0
|
||||
|
||||
for record in all_data:
|
||||
throughput = _get_throughput(record, "output_throughput")
|
||||
user_count = _infer_user_count(record, user_count_var)
|
||||
if user_count is None:
|
||||
skipped_missing_users += 1
|
||||
continue
|
||||
|
||||
gpu_count = _infer_gpu_count(record, gpu_count_var)
|
||||
tokens_per_user = throughput / user_count
|
||||
tokens_per_gpu = throughput / gpu_count
|
||||
|
||||
prepared.append(
|
||||
{
|
||||
**record,
|
||||
"tokens_per_user": tokens_per_user,
|
||||
"tokens_per_gpu": tokens_per_gpu,
|
||||
"user_count_estimate": user_count,
|
||||
"gpu_count": gpu_count,
|
||||
}
|
||||
)
|
||||
|
||||
return prepared, skipped_missing_users
|
||||
|
||||
|
||||
def _pareto_frontier(
|
||||
df: "pd.DataFrame",
|
||||
x_col: str,
|
||||
y_col: str,
|
||||
*,
|
||||
epsilon: float = 1e-9,
|
||||
) -> "pd.DataFrame":
|
||||
sorted_df = df.sort_values([x_col, y_col], ascending=[False, False])
|
||||
frontier_indices = []
|
||||
best_y = -math.inf
|
||||
|
||||
for idx, row in sorted_df.iterrows():
|
||||
y_val = row[y_col]
|
||||
if y_val >= best_y - epsilon:
|
||||
frontier_indices.append(idx)
|
||||
best_y = max(best_y, y_val)
|
||||
|
||||
return df.loc[frontier_indices]
|
||||
|
||||
|
||||
def _get_fig_path(
|
||||
fig_dir: Path,
|
||||
fig_group: tuple[tuple[str, str], ...],
|
||||
) -> Path:
|
||||
parts = ["PARETO"]
|
||||
if fig_group:
|
||||
parts.extend(f"{k}={v}" for k, v in fig_group)
|
||||
filename = sanitize_filename("-".join(parts) + ".png")
|
||||
return fig_dir / filename
|
||||
|
||||
|
||||
def _plot_fig(
|
||||
fig_dir: Path,
|
||||
fig_group_data: tuple[tuple[tuple[str, str], ...], list[dict[str, object]]],
|
||||
label_by: list[str],
|
||||
*,
|
||||
dry_run: bool,
|
||||
):
|
||||
fig_group, fig_data = fig_group_data
|
||||
fig_path = _get_fig_path(fig_dir, fig_group)
|
||||
|
||||
print("[BEGIN FIGURE]")
|
||||
print(f"Group: {dict(fig_group)}")
|
||||
print(f"Output file: {fig_path}")
|
||||
|
||||
if dry_run:
|
||||
print("[END FIGURE]")
|
||||
return
|
||||
|
||||
df = pd.DataFrame.from_records(fig_data)
|
||||
df = df.dropna(subset=["tokens_per_user", "tokens_per_gpu"])
|
||||
|
||||
if df.empty:
|
||||
print("No data points available after filtering; skipping.")
|
||||
print("[END FIGURE]")
|
||||
return
|
||||
|
||||
frontier = _pareto_frontier(df, "tokens_per_user", "tokens_per_gpu")
|
||||
frontier = frontier.sort_values("tokens_per_user")
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
sns.scatterplot(
|
||||
data=df,
|
||||
x="tokens_per_user",
|
||||
y="tokens_per_gpu",
|
||||
color="0.5",
|
||||
alpha=0.6,
|
||||
ax=ax,
|
||||
label="All runs",
|
||||
)
|
||||
sns.lineplot(
|
||||
data=frontier,
|
||||
x="tokens_per_user",
|
||||
y="tokens_per_gpu",
|
||||
marker="o",
|
||||
ax=ax,
|
||||
label="Pareto frontier",
|
||||
)
|
||||
|
||||
if label_by:
|
||||
for _, row in frontier.iterrows():
|
||||
label_parts = []
|
||||
for key in label_by:
|
||||
if key in row:
|
||||
label_parts.append(f"{key}={row[key]}")
|
||||
if label_parts:
|
||||
ax.text(
|
||||
row["tokens_per_user"],
|
||||
row["tokens_per_gpu"],
|
||||
"\n".join(label_parts),
|
||||
fontsize=8,
|
||||
)
|
||||
|
||||
ax.set_xlabel("Tokens/s/user")
|
||||
ax.set_ylabel("Tokens/s/GPU")
|
||||
ax.grid(True, linestyle="--", linewidth=0.5, alpha=0.6)
|
||||
fig.tight_layout()
|
||||
fig.savefig(fig_path)
|
||||
plt.close(fig)
|
||||
|
||||
print(
|
||||
f"Plotted {len(df)} points; Pareto frontier size: {len(frontier)}.",
|
||||
)
|
||||
print("[END FIGURE]")
|
||||
|
||||
|
||||
def plot_pareto(
|
||||
output_dir: Path,
|
||||
user_count_var: str | None,
|
||||
gpu_count_var: str | None,
|
||||
label_by: list[str],
|
||||
*,
|
||||
dry_run: bool,
|
||||
):
|
||||
fig_dir = output_dir / "pareto"
|
||||
raw_data = [
|
||||
run_data
|
||||
for path in output_dir.rglob("**/summary.json")
|
||||
for run_data in _json_load_bytes(path)
|
||||
]
|
||||
|
||||
if not raw_data:
|
||||
raise ValueError(f"Did not find any parameter sweep results under {output_dir}")
|
||||
|
||||
fig_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
prepared_data, skipped_missing_users = _prepare_records(
|
||||
raw_data,
|
||||
user_count_var=user_count_var,
|
||||
gpu_count_var=gpu_count_var,
|
||||
)
|
||||
|
||||
if skipped_missing_users:
|
||||
print(
|
||||
f"Skipped {skipped_missing_users} runs without a user count "
|
||||
"(`max_concurrency` or `max_concurrent_requests`).",
|
||||
)
|
||||
|
||||
if not prepared_data:
|
||||
raise ValueError(
|
||||
"No data points with both throughput and user count available "
|
||||
"to plot Pareto frontier.",
|
||||
)
|
||||
|
||||
fig_groups = full_groupby(
|
||||
prepared_data,
|
||||
key=lambda item: tuple(),
|
||||
)
|
||||
|
||||
with DummyExecutor() if len(fig_groups) <= 1 else ProcessPoolExecutor() as executor:
|
||||
all(
|
||||
executor.map(
|
||||
partial(
|
||||
_plot_fig,
|
||||
fig_dir,
|
||||
label_by=label_by,
|
||||
dry_run=dry_run,
|
||||
),
|
||||
fig_groups,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SweepPlotParetoArgs:
|
||||
output_dir: Path
|
||||
user_count_var: str | None
|
||||
gpu_count_var: str | None
|
||||
label_by: list[str]
|
||||
dry_run: bool
|
||||
|
||||
parser_name: ClassVar[str] = "plot_pareto"
|
||||
parser_help: ClassVar[str] = (
|
||||
"Plot Pareto frontier between tokens/s/user and tokens/s/GPU "
|
||||
"from parameter sweep results."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
output_dir = Path(args.OUTPUT_DIR)
|
||||
if not output_dir.exists():
|
||||
raise ValueError(f"No parameter sweep results under {output_dir}")
|
||||
|
||||
label_by = [] if not args.label_by else args.label_by.split(",")
|
||||
|
||||
return cls(
|
||||
output_dir=output_dir,
|
||||
user_count_var=args.user_count_var,
|
||||
gpu_count_var=args.gpu_count_var,
|
||||
label_by=label_by,
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"OUTPUT_DIR",
|
||||
type=str,
|
||||
default="results",
|
||||
help="The directory containing the sweep results to plot.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--user-count-var",
|
||||
type=str,
|
||||
default="max_concurrency",
|
||||
help="Result key that stores concurrent user count. "
|
||||
"Falls back to max_concurrent_requests if missing.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpu-count-var",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Result key that stores GPU count. "
|
||||
"If not provided, falls back to num_gpus/gpu_count "
|
||||
"or tensor_parallel_size * pipeline_parallel_size.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--label-by",
|
||||
type=str,
|
||||
default="max_concurrency,gpu_count",
|
||||
help="Comma-separated list of fields to annotate on Pareto frontier "
|
||||
"points.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="If set, prints the figures to plot without drawing them.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def run_main(args: SweepPlotParetoArgs):
|
||||
return plot_pareto(
|
||||
output_dir=args.output_dir,
|
||||
user_count_var=args.user_count_var,
|
||||
gpu_count_var=args.gpu_count_var,
|
||||
label_by=args.label_by,
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
run_main(SweepPlotParetoArgs.from_cli_args(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=SweepPlotParetoArgs.parser_help)
|
||||
SweepPlotParetoArgs.add_cli_args(parser)
|
||||
|
||||
main(parser.parse_args())
|
||||
498
vllm/benchmarks/sweep/serve.py
Normal file
498
vllm/benchmarks/sweep/serve.py
Normal file
@@ -0,0 +1,498 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import contextlib
|
||||
import json
|
||||
import shlex
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import ClassVar
|
||||
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
from .param_sweep import ParameterSweep, ParameterSweepItem
|
||||
from .server import ServerProcess
|
||||
from .utils import sanitize_filename
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pd = PlaceholderModule("pandas")
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def run_server(
|
||||
serve_cmd: list[str],
|
||||
after_bench_cmd: list[str],
|
||||
*,
|
||||
show_stdout: bool,
|
||||
serve_overrides: ParameterSweepItem,
|
||||
dry_run: bool,
|
||||
server_ready_timeout: int = 300,
|
||||
):
|
||||
server_cmd = serve_overrides.apply_to_cmd(serve_cmd)
|
||||
|
||||
print("[BEGIN SERVER]")
|
||||
print(f"Server overrides: {serve_overrides}")
|
||||
print(f"Server command: {server_cmd}")
|
||||
|
||||
if dry_run:
|
||||
yield None
|
||||
print("[END SERVER]")
|
||||
return
|
||||
|
||||
with ServerProcess(server_cmd, after_bench_cmd, show_stdout=show_stdout) as server:
|
||||
server.wait_until_ready(timeout=server_ready_timeout)
|
||||
yield server
|
||||
|
||||
print("[END SERVER]")
|
||||
|
||||
|
||||
def _update_run_data(
|
||||
run_data: dict[str, object],
|
||||
serve_overrides: ParameterSweepItem,
|
||||
bench_overrides: ParameterSweepItem,
|
||||
run_number: int,
|
||||
):
|
||||
run_data["run_number"] = run_number
|
||||
run_data.update(serve_overrides)
|
||||
run_data.update(bench_overrides)
|
||||
|
||||
return run_data
|
||||
|
||||
|
||||
def run_benchmark(
|
||||
server: ServerProcess | None,
|
||||
bench_cmd: list[str],
|
||||
*,
|
||||
serve_overrides: ParameterSweepItem,
|
||||
bench_overrides: ParameterSweepItem,
|
||||
run_number: int,
|
||||
output_path: Path,
|
||||
dry_run: bool,
|
||||
):
|
||||
benchmark_cmd = [
|
||||
*bench_overrides.apply_to_cmd(bench_cmd),
|
||||
"--percentile-metrics",
|
||||
"ttft,tpot,itl,e2el",
|
||||
"--save-result",
|
||||
"--result-dir",
|
||||
str(output_path.parent),
|
||||
"--result-filename",
|
||||
output_path.name,
|
||||
]
|
||||
|
||||
print("[BEGIN BENCHMARK]")
|
||||
print(f"Benchmark overrides: {bench_overrides}")
|
||||
print(f"Run Number: {run_number}")
|
||||
print(f"Benchmark command: {benchmark_cmd}")
|
||||
print(f"Output file: {output_path}")
|
||||
|
||||
run_data: dict[str, object]
|
||||
|
||||
if output_path.exists():
|
||||
print("Found existing results.")
|
||||
print("[SKIPPED BENCHMARK]")
|
||||
|
||||
with output_path.open("rb") as f:
|
||||
run_data = json.load(f)
|
||||
return _update_run_data(
|
||||
run_data,
|
||||
serve_overrides,
|
||||
bench_overrides,
|
||||
run_number,
|
||||
)
|
||||
|
||||
if server is None:
|
||||
if not dry_run:
|
||||
raise ValueError(f"Cannot find results at {output_path}")
|
||||
|
||||
print("[END BENCHMARK]")
|
||||
return None
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
server.run_subcommand(benchmark_cmd)
|
||||
server.after_bench()
|
||||
|
||||
with output_path.open("rb") as f:
|
||||
run_data = json.load(f)
|
||||
|
||||
run_data = _update_run_data(
|
||||
run_data,
|
||||
serve_overrides,
|
||||
bench_overrides,
|
||||
run_number,
|
||||
)
|
||||
|
||||
with output_path.open("w") as f:
|
||||
json.dump(run_data, f, indent=4)
|
||||
|
||||
print("[END BENCHMARK]")
|
||||
|
||||
return run_data
|
||||
|
||||
|
||||
def _get_comb_base_path(
|
||||
output_dir: Path,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
):
|
||||
parts = list[str]()
|
||||
if serve_comb:
|
||||
parts.extend(("SERVE-", serve_comb.name))
|
||||
if bench_comb:
|
||||
parts.extend(("BENCH-", bench_comb.name))
|
||||
|
||||
return output_dir / sanitize_filename("-".join(parts))
|
||||
|
||||
|
||||
def _get_comb_run_path(base_path: Path, run_number: int | None):
|
||||
if run_number is None:
|
||||
return base_path / "summary.json"
|
||||
|
||||
return base_path / f"run={run_number}.json"
|
||||
|
||||
|
||||
def _comb_needs_server(
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_combs: ParameterSweep,
|
||||
output_dir: Path,
|
||||
):
|
||||
for bench_comb in bench_combs:
|
||||
base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb)
|
||||
if not _get_comb_run_path(base_path, run_number=None).exists():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def server_ctx(
|
||||
serve_cmd: list[str],
|
||||
after_bench_cmd: list[str],
|
||||
*,
|
||||
show_stdout: bool,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_params: ParameterSweep,
|
||||
output_dir: Path,
|
||||
dry_run: bool,
|
||||
server_ready_timeout: int = 300,
|
||||
):
|
||||
if not _comb_needs_server(serve_comb, bench_params, output_dir):
|
||||
return contextlib.nullcontext()
|
||||
|
||||
return run_server(
|
||||
serve_cmd,
|
||||
after_bench_cmd,
|
||||
show_stdout=show_stdout,
|
||||
serve_overrides=serve_comb,
|
||||
dry_run=dry_run,
|
||||
server_ready_timeout=server_ready_timeout,
|
||||
)
|
||||
|
||||
|
||||
def _comb_is_valid(
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
link_vars: list[tuple[str, str]],
|
||||
) -> bool:
|
||||
return all(
|
||||
serve_key in serve_comb
|
||||
and bench_key in bench_comb
|
||||
and serve_comb[serve_key] == bench_comb[bench_key]
|
||||
for serve_key, bench_key in link_vars
|
||||
)
|
||||
|
||||
|
||||
def run_comb(
|
||||
server: ServerProcess | None,
|
||||
bench_cmd: list[str],
|
||||
*,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
base_path: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
link_vars: list[tuple[str, str]],
|
||||
):
|
||||
if not _comb_is_valid(serve_comb, bench_comb, link_vars):
|
||||
return None
|
||||
|
||||
comb_data = list[dict[str, object]]()
|
||||
|
||||
for run_number in range(num_runs):
|
||||
run_data = run_benchmark(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_overrides=serve_comb,
|
||||
bench_overrides=bench_comb,
|
||||
run_number=run_number,
|
||||
output_path=_get_comb_run_path(base_path, run_number),
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
if run_data is not None:
|
||||
comb_data.append(run_data)
|
||||
|
||||
if dry_run:
|
||||
return None
|
||||
|
||||
with _get_comb_run_path(base_path, run_number=None).open("w") as f:
|
||||
json.dump(comb_data, f, indent=4)
|
||||
|
||||
return comb_data
|
||||
|
||||
|
||||
def run_combs(
|
||||
serve_cmd: list[str],
|
||||
bench_cmd: list[str],
|
||||
after_bench_cmd: list[str],
|
||||
*,
|
||||
show_stdout: bool,
|
||||
server_ready_timeout: int,
|
||||
serve_params: ParameterSweep,
|
||||
bench_params: ParameterSweep,
|
||||
output_dir: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
link_vars: list[tuple[str, str]],
|
||||
):
|
||||
all_data = list[dict[str, object]]()
|
||||
for serve_comb in serve_params:
|
||||
with server_ctx(
|
||||
serve_cmd,
|
||||
after_bench_cmd,
|
||||
show_stdout=show_stdout,
|
||||
serve_comb=serve_comb,
|
||||
bench_params=bench_params,
|
||||
output_dir=output_dir,
|
||||
dry_run=dry_run,
|
||||
server_ready_timeout=server_ready_timeout,
|
||||
) as server:
|
||||
for bench_comb in bench_params:
|
||||
base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb)
|
||||
|
||||
comb_data = run_comb(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
base_path=base_path,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
link_vars=link_vars,
|
||||
)
|
||||
|
||||
if comb_data is not None:
|
||||
all_data.extend(comb_data)
|
||||
|
||||
if dry_run:
|
||||
return None
|
||||
|
||||
combined_df = pd.DataFrame.from_records(all_data)
|
||||
combined_df.to_csv(output_dir / "summary.csv")
|
||||
|
||||
return combined_df
|
||||
|
||||
|
||||
@dataclass
|
||||
class SweepServeArgs:
|
||||
serve_cmd: list[str]
|
||||
bench_cmd: list[str]
|
||||
after_bench_cmd: list[str]
|
||||
show_stdout: bool
|
||||
server_ready_timeout: int
|
||||
serve_params: ParameterSweep
|
||||
bench_params: ParameterSweep
|
||||
output_dir: Path
|
||||
num_runs: int
|
||||
dry_run: bool
|
||||
resume: str | None
|
||||
link_vars: list[tuple[str, str]]
|
||||
|
||||
parser_name: ClassVar[str] = "serve"
|
||||
parser_help: ClassVar[str] = "Run vLLM server benchmark under multiple settings."
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
serve_cmd = shlex.split(args.serve_cmd)
|
||||
bench_cmd = shlex.split(args.bench_cmd)
|
||||
after_bench_cmd = (
|
||||
[] if args.after_bench_cmd is None else shlex.split(args.after_bench_cmd)
|
||||
)
|
||||
|
||||
if args.serve_params:
|
||||
serve_params = ParameterSweep.read_json(args.serve_params)
|
||||
else:
|
||||
# i.e.: run serve_cmd without any modification
|
||||
serve_params = ParameterSweep.from_records([{}])
|
||||
|
||||
if args.bench_params:
|
||||
bench_params = ParameterSweep.read_json(args.bench_params)
|
||||
else:
|
||||
# i.e.: run bench_cmd without any modification
|
||||
bench_params = ParameterSweep.from_records([{}])
|
||||
|
||||
link_vars = cls.parse_link_vars(args.link_vars)
|
||||
|
||||
num_runs = args.num_runs
|
||||
if num_runs < 1:
|
||||
raise ValueError("`num_runs` should be at least 1.")
|
||||
|
||||
return cls(
|
||||
serve_cmd=serve_cmd,
|
||||
bench_cmd=bench_cmd,
|
||||
after_bench_cmd=after_bench_cmd,
|
||||
show_stdout=args.show_stdout,
|
||||
serve_params=serve_params,
|
||||
bench_params=bench_params,
|
||||
output_dir=Path(args.output_dir),
|
||||
num_runs=num_runs,
|
||||
dry_run=args.dry_run,
|
||||
resume=args.resume,
|
||||
link_vars=link_vars,
|
||||
server_ready_timeout=args.server_ready_timeout,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
"--serve-cmd",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The command used to run the server: `vllm serve ...`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bench-cmd",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The command used to run the benchmark: `vllm bench serve ...`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--after-bench-cmd",
|
||||
type=str,
|
||||
default=None,
|
||||
help="After a benchmark run is complete, invoke this command instead of "
|
||||
"the default `ServerWrapper.clear_cache()`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--show-stdout",
|
||||
action="store_true",
|
||||
help="If set, logs the standard output of subcommands. "
|
||||
"Useful for debugging but can be quite spammy.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--server-ready-timeout",
|
||||
type=int,
|
||||
default=300,
|
||||
help="Timeout in seconds to wait for the server to become ready.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--serve-params",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to JSON file containing parameter combinations "
|
||||
"for the `vllm serve` command. Can be either a list of dicts or a dict "
|
||||
"where keys are benchmark names. "
|
||||
"If both `serve_params` and `bench_params` are given, "
|
||||
"this script will iterate over their Cartesian product.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bench-params",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to JSON file containing parameter combinations "
|
||||
"for the `vllm bench serve` command. Can be either a list of dicts or "
|
||||
"a dict where keys are benchmark names. "
|
||||
"If both `serve_params` and `bench_params` are given, "
|
||||
"this script will iterate over their Cartesian product.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="results",
|
||||
help="The directory to which results are written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-runs",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Number of runs per parameter combination.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="If set, prints the commands to run, "
|
||||
"then exits without executing them.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Set this to the name of a directory under `output_dir` (which is a "
|
||||
"timestamp) to resume a previous execution of this script, i.e., only run "
|
||||
"parameter combinations for which there are still no output files.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--link-vars",
|
||||
type=str,
|
||||
default="",
|
||||
help=(
|
||||
"Comma-separated list of linked variables between serve and bench, "
|
||||
"e.g. max_num_seqs=max_concurrency,max_model_len=random_input_len"
|
||||
),
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
@staticmethod
|
||||
def parse_link_vars(s: str) -> list[tuple[str, str]]:
|
||||
if not s:
|
||||
return []
|
||||
pairs = []
|
||||
for item in s.split(","):
|
||||
a, b = item.split("=")
|
||||
pairs.append((a.strip(), b.strip()))
|
||||
return pairs
|
||||
|
||||
|
||||
def run_main(args: SweepServeArgs):
|
||||
timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_dir = args.output_dir / timestamp
|
||||
|
||||
if args.resume and not output_dir.exists():
|
||||
raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")
|
||||
|
||||
try:
|
||||
return run_combs(
|
||||
serve_cmd=args.serve_cmd,
|
||||
bench_cmd=args.bench_cmd,
|
||||
after_bench_cmd=args.after_bench_cmd,
|
||||
show_stdout=args.show_stdout,
|
||||
server_ready_timeout=args.server_ready_timeout,
|
||||
serve_params=args.serve_params,
|
||||
bench_params=args.bench_params,
|
||||
output_dir=output_dir,
|
||||
num_runs=args.num_runs,
|
||||
dry_run=args.dry_run,
|
||||
link_vars=args.link_vars,
|
||||
)
|
||||
except BaseException as exc:
|
||||
raise RuntimeError(
|
||||
f"The script was terminated early. Use `--resume {timestamp}` "
|
||||
f"to continue the script from its last checkpoint."
|
||||
) from exc
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
run_main(SweepServeArgs.from_cli_args(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=SweepServeArgs.parser_help)
|
||||
SweepServeArgs.add_cli_args(parser)
|
||||
|
||||
main(parser.parse_args())
|
||||
305
vllm/benchmarks/sweep/serve_sla.py
Normal file
305
vllm/benchmarks/sweep/serve_sla.py
Normal file
@@ -0,0 +1,305 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import math
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Literal, get_args
|
||||
|
||||
import numpy as np
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
from .param_sweep import ParameterSweep, ParameterSweepItem
|
||||
from .serve import (
|
||||
SweepServeArgs,
|
||||
_get_comb_base_path,
|
||||
run_comb,
|
||||
server_ctx,
|
||||
)
|
||||
from .server import ServerProcess
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pd = PlaceholderModule("pandas")
|
||||
|
||||
|
||||
SLAVariable = Literal["request_rate", "max_concurrency"]
|
||||
|
||||
|
||||
def _estimate_sla_value(run_data: dict[str, object], sla_variable: SLAVariable):
|
||||
request_throughput = float(run_data["request_throughput"]) # type: ignore
|
||||
if sla_variable == "request_rate":
|
||||
return request_throughput
|
||||
if sla_variable == "max_concurrency":
|
||||
mean_latency_ms = float(run_data["mean_e2el_ms"]) # type: ignore
|
||||
return request_throughput * mean_latency_ms / 1000
|
||||
|
||||
assert_never(sla_variable)
|
||||
|
||||
|
||||
def _estimate_sla_avg(runs: list[dict[str, object]], sla_variable: SLAVariable):
|
||||
return sum(_estimate_sla_value(run, sla_variable) for run in runs) / len(runs)
|
||||
|
||||
|
||||
def run_comb_sla(
|
||||
server: ServerProcess | None,
|
||||
bench_cmd: list[str],
|
||||
*,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
output_dir: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
link_vars: list[tuple[str, str]],
|
||||
sla_variable: SLAVariable,
|
||||
sla_value: int,
|
||||
) -> list[dict[str, object]] | None:
|
||||
bench_comb_sla = bench_comb | {sla_variable: sla_value}
|
||||
|
||||
return run_comb(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb_sla,
|
||||
base_path=_get_comb_base_path(output_dir, serve_comb, bench_comb_sla),
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
link_vars=link_vars,
|
||||
)
|
||||
|
||||
|
||||
def explore_sla(
|
||||
server: ServerProcess | None,
|
||||
bench_cmd: list[str],
|
||||
*,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
sla_variable: SLAVariable,
|
||||
sla_iters: int,
|
||||
output_dir: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
link_vars: list[tuple[str, str]],
|
||||
):
|
||||
print("[SLA START]")
|
||||
print(f"Serve parameters: {serve_comb.as_text() or '(None)'}")
|
||||
print(f"Bench parameters: {bench_comb.as_text() or '(None)'}")
|
||||
print(f"Number of SLA iterations: {sla_iters}")
|
||||
|
||||
if sla_iters < 2:
|
||||
raise ValueError("`sla_iters` should be at least 2")
|
||||
|
||||
serial_comb_data = run_comb_sla(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
output_dir=output_dir,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
link_vars=link_vars,
|
||||
sla_variable=sla_variable,
|
||||
sla_value=1,
|
||||
)
|
||||
batch_comb_data = run_comb_sla(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
output_dir=output_dir,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
link_vars=link_vars,
|
||||
sla_variable=sla_variable,
|
||||
sla_value=int(bench_comb.get("num_prompts", 1000)), # type: ignore
|
||||
)
|
||||
|
||||
if serial_comb_data is None or batch_comb_data is None:
|
||||
if dry_run:
|
||||
print("Omitting intermediate SLA iterations.")
|
||||
print("[SLA END]")
|
||||
|
||||
return
|
||||
|
||||
serial_sla_value = math.ceil(_estimate_sla_avg(serial_comb_data, sla_variable))
|
||||
print(f"Serial inference: {sla_variable}={serial_sla_value}")
|
||||
|
||||
batch_sla_value = math.floor(_estimate_sla_avg(batch_comb_data, sla_variable))
|
||||
print(f"Batch inference: {sla_variable}={batch_sla_value}")
|
||||
|
||||
# Avoid duplicated runs for intermediate values if the range between
|
||||
# `serial_sla_value` and `batch_sla_value` is small
|
||||
inter_sla_values = np.linspace(serial_sla_value, batch_sla_value, sla_iters)[1:-1]
|
||||
inter_sla_values = sorted(set(map(round, inter_sla_values)))
|
||||
|
||||
inter_combs_data: list[dict[str, object]] = []
|
||||
for inter_sla_value in inter_sla_values:
|
||||
print(f"Exploring: {sla_variable}={inter_sla_value}")
|
||||
inter_comb_data = run_comb_sla(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
output_dir=output_dir,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
link_vars=link_vars,
|
||||
sla_variable=sla_variable,
|
||||
sla_value=inter_sla_value,
|
||||
)
|
||||
if inter_comb_data is not None:
|
||||
inter_combs_data.extend(inter_comb_data)
|
||||
|
||||
print("[SLA END]")
|
||||
|
||||
return serial_comb_data + inter_combs_data + batch_comb_data
|
||||
|
||||
|
||||
def run_slas(
|
||||
serve_cmd: list[str],
|
||||
bench_cmd: list[str],
|
||||
after_bench_cmd: list[str],
|
||||
*,
|
||||
show_stdout: bool,
|
||||
server_ready_timeout: int,
|
||||
serve_params: ParameterSweep,
|
||||
bench_params: ParameterSweep,
|
||||
sla_variable: SLAVariable,
|
||||
sla_iters: int,
|
||||
output_dir: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
link_vars: list[tuple[str, str]],
|
||||
):
|
||||
if any(bench_comb.has_param(sla_variable) for bench_comb in bench_params):
|
||||
raise ValueError(
|
||||
f"You should not override `{sla_variable}` in `bench_params` in SLA mode, "
|
||||
"since it is supposed to be determined automatically."
|
||||
)
|
||||
|
||||
all_data = list[dict[str, object]]()
|
||||
for serve_comb in serve_params:
|
||||
with server_ctx(
|
||||
serve_cmd,
|
||||
after_bench_cmd,
|
||||
show_stdout=show_stdout,
|
||||
server_ready_timeout=server_ready_timeout,
|
||||
serve_comb=serve_comb,
|
||||
bench_params=bench_params,
|
||||
output_dir=output_dir,
|
||||
dry_run=dry_run,
|
||||
) as server:
|
||||
for bench_comb in bench_params:
|
||||
comb_data = explore_sla(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
sla_variable=sla_variable,
|
||||
sla_iters=sla_iters,
|
||||
output_dir=output_dir,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
link_vars=link_vars,
|
||||
)
|
||||
|
||||
if comb_data is not None:
|
||||
all_data.extend(comb_data)
|
||||
|
||||
if dry_run:
|
||||
return None
|
||||
|
||||
combined_df = pd.DataFrame.from_records(all_data)
|
||||
combined_df.to_csv(output_dir / "summary.csv")
|
||||
|
||||
return combined_df
|
||||
|
||||
|
||||
@dataclass
|
||||
class SweepServeSLAArgs(SweepServeArgs):
|
||||
sla_variable: SLAVariable
|
||||
sla_iters: int
|
||||
|
||||
parser_name: ClassVar[str] = "serve_sla"
|
||||
parser_help: ClassVar[str] = (
|
||||
"Explore the latency-throughput space for determining SLAs."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
# NOTE: Don't use super() as `from_cli_args` calls `cls()`
|
||||
base_args = SweepServeArgs.from_cli_args(args)
|
||||
|
||||
return cls(
|
||||
**asdict(base_args),
|
||||
sla_variable=args.sla_variable,
|
||||
sla_iters=args.sla_iters,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
parser = super().add_cli_args(parser)
|
||||
|
||||
sla_group = parser.add_argument_group("sla options")
|
||||
sla_group.add_argument(
|
||||
"--sla-variable",
|
||||
type=str,
|
||||
choices=get_args(SLAVariable),
|
||||
default="request_rate",
|
||||
help="The variable to adjust in each iteration.",
|
||||
)
|
||||
sla_group.add_argument(
|
||||
"--sla-iters",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of iterations used to explore the latency-throughput space. "
|
||||
"This includes the first two iterations used to interpolate the value of "
|
||||
"`sla_variable` for remaining iterations.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def run_main(args: SweepServeSLAArgs):
|
||||
timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_dir = args.output_dir / timestamp
|
||||
|
||||
if args.resume and not output_dir.exists():
|
||||
raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")
|
||||
|
||||
try:
|
||||
return run_slas(
|
||||
serve_cmd=args.serve_cmd,
|
||||
bench_cmd=args.bench_cmd,
|
||||
after_bench_cmd=args.after_bench_cmd,
|
||||
show_stdout=args.show_stdout,
|
||||
server_ready_timeout=args.server_ready_timeout,
|
||||
serve_params=args.serve_params,
|
||||
bench_params=args.bench_params,
|
||||
sla_variable=args.sla_variable,
|
||||
sla_iters=args.sla_iters,
|
||||
output_dir=output_dir,
|
||||
num_runs=args.num_runs,
|
||||
dry_run=args.dry_run,
|
||||
link_vars=args.link_vars,
|
||||
)
|
||||
except BaseException as exc:
|
||||
raise RuntimeError(
|
||||
f"The script was terminated early. Use `--resume {timestamp}` "
|
||||
f"to continue the script from its last checkpoint."
|
||||
) from exc
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
run_main(SweepServeSLAArgs.from_cli_args(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=SweepServeSLAArgs.parser_help)
|
||||
SweepServeSLAArgs.add_cli_args(parser)
|
||||
|
||||
main(parser.parse_args())
|
||||
142
vllm/benchmarks/sweep/server.py
Normal file
142
vllm/benchmarks/sweep/server.py
Normal file
@@ -0,0 +1,142 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import time
|
||||
from types import TracebackType
|
||||
|
||||
import requests
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
class ServerProcess:
|
||||
VLLM_RESET_CACHE_ENDPOINTS = [
|
||||
"/reset_prefix_cache",
|
||||
"/reset_mm_cache",
|
||||
"/reset_encoder_cache",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_cmd: list[str],
|
||||
after_bench_cmd: list[str],
|
||||
*,
|
||||
show_stdout: bool,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.server_cmd = server_cmd
|
||||
self.after_bench_cmd = after_bench_cmd
|
||||
self.show_stdout = show_stdout
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
exc_traceback: TracebackType | None,
|
||||
) -> None:
|
||||
self.stop()
|
||||
|
||||
def start(self):
|
||||
# Create new process for clean termination
|
||||
self._server_process = subprocess.Popen(
|
||||
self.server_cmd,
|
||||
start_new_session=True,
|
||||
stdout=None if self.show_stdout else subprocess.DEVNULL,
|
||||
# Need `VLLM_SERVER_DEV_MODE=1` for `_reset_caches`
|
||||
env=os.environ | {"VLLM_SERVER_DEV_MODE": "1"},
|
||||
)
|
||||
|
||||
def stop(self):
|
||||
server_process = self._server_process
|
||||
|
||||
if server_process.poll() is None:
|
||||
# In case only some processes have been terminated
|
||||
with contextlib.suppress(ProcessLookupError):
|
||||
# We need to kill both API Server and Engine processes
|
||||
os.killpg(os.getpgid(server_process.pid), signal.SIGKILL)
|
||||
|
||||
def run_subcommand(self, cmd: list[str]):
|
||||
return subprocess.run(
|
||||
cmd,
|
||||
stdout=None if self.show_stdout else subprocess.DEVNULL,
|
||||
check=True,
|
||||
)
|
||||
|
||||
def after_bench(self) -> None:
|
||||
if not self.after_bench_cmd:
|
||||
self.reset_caches()
|
||||
return
|
||||
|
||||
self.run_subcommand(self.after_bench_cmd)
|
||||
|
||||
def _get_vllm_server_address(self) -> str:
|
||||
server_cmd = self.server_cmd
|
||||
|
||||
for host_key in ("--host",):
|
||||
if host_key in server_cmd:
|
||||
host = server_cmd[server_cmd.index(host_key) + 1]
|
||||
break
|
||||
else:
|
||||
host = "localhost"
|
||||
|
||||
for port_key in ("-p", "--port"):
|
||||
if port_key in server_cmd:
|
||||
port = int(server_cmd[server_cmd.index(port_key) + 1])
|
||||
break
|
||||
else:
|
||||
port = 8000 # The default value in vllm serve
|
||||
|
||||
return f"http://{host}:{port}"
|
||||
|
||||
def is_server_ready(self) -> bool:
|
||||
server_address = self._get_vllm_server_address()
|
||||
try:
|
||||
response = requests.get(f"{server_address}/health")
|
||||
return response.status_code == 200
|
||||
except requests.RequestException:
|
||||
return False
|
||||
|
||||
def wait_until_ready(self, timeout: int) -> None:
|
||||
start_time = time.monotonic()
|
||||
while not self.is_server_ready():
|
||||
# Check if server process has crashed
|
||||
if self._server_process.poll() is not None:
|
||||
returncode = self._server_process.returncode
|
||||
raise RuntimeError(
|
||||
f"Server process crashed with return code {returncode}"
|
||||
)
|
||||
if time.monotonic() - start_time > timeout:
|
||||
raise TimeoutError(
|
||||
f"Server failed to become ready within {timeout} seconds."
|
||||
)
|
||||
time.sleep(1)
|
||||
|
||||
def reset_caches(self) -> None:
|
||||
server_cmd = self.server_cmd
|
||||
|
||||
# Use `.endswith()` to match `/bin/...`
|
||||
if server_cmd[0].endswith("vllm"):
|
||||
server_address = self._get_vllm_server_address()
|
||||
print(f"Resetting caches at {server_address}")
|
||||
|
||||
for endpoint in self.VLLM_RESET_CACHE_ENDPOINTS:
|
||||
res = requests.post(server_address + endpoint)
|
||||
res.raise_for_status()
|
||||
elif server_cmd[0].endswith("infinity_emb"):
|
||||
if "--vector-disk-cache" in server_cmd:
|
||||
raise NotImplementedError(
|
||||
"Infinity server uses caching but does not expose a method "
|
||||
"to reset the cache"
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"No implementation of `reset_caches` for `{server_cmd[0]}` server. "
|
||||
"Please specify a custom command via `--after-bench-cmd`."
|
||||
)
|
||||
406
vllm/benchmarks/sweep/startup.py
Normal file
406
vllm/benchmarks/sweep/startup.py
Normal file
@@ -0,0 +1,406 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import json
|
||||
import shlex
|
||||
import subprocess
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import ClassVar
|
||||
|
||||
from vllm.benchmarks.startup import add_cli_args as add_startup_cli_args
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
from .param_sweep import ParameterSweep, ParameterSweepItem
|
||||
from .utils import sanitize_filename
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pd = PlaceholderModule("pandas")
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_supported_startup_keys() -> set[str]:
|
||||
parser = FlexibleArgumentParser(add_help=False)
|
||||
add_startup_cli_args(parser)
|
||||
|
||||
supported: set[str] = {"config"}
|
||||
for action in parser._actions:
|
||||
if action.dest and action.dest is not argparse.SUPPRESS:
|
||||
supported.add(action.dest)
|
||||
for option in action.option_strings:
|
||||
if option.startswith("--"):
|
||||
supported.add(option.lstrip("-").replace("-", "_"))
|
||||
|
||||
return supported
|
||||
|
||||
|
||||
def _is_supported_param(param_key: str, supported: set[str]) -> bool:
|
||||
if param_key == "_benchmark_name":
|
||||
return True
|
||||
prefix = param_key.split(".", 1)[0]
|
||||
normalized = prefix.replace("-", "_")
|
||||
return normalized in supported
|
||||
|
||||
|
||||
def _filter_params(
|
||||
params: ParameterSweep, *, supported: set[str], strict: bool
|
||||
) -> ParameterSweep:
|
||||
filtered = []
|
||||
for item in params:
|
||||
kept: dict[str, object] = {}
|
||||
dropped: list[str] = []
|
||||
for key, value in item.items():
|
||||
if _is_supported_param(key, supported):
|
||||
kept[key] = value
|
||||
else:
|
||||
dropped.append(key)
|
||||
|
||||
if dropped:
|
||||
label = item.get("_benchmark_name") or item.as_text()
|
||||
message = (
|
||||
"Ignoring unsupported startup params"
|
||||
f"{' for ' + str(label) if label else ''}: "
|
||||
f"{', '.join(sorted(dropped))}"
|
||||
)
|
||||
if strict:
|
||||
raise ValueError(message)
|
||||
print(message)
|
||||
|
||||
filtered.append(ParameterSweepItem.from_record(kept))
|
||||
|
||||
return ParameterSweep(filtered)
|
||||
|
||||
|
||||
def _update_run_data(
|
||||
run_data: dict[str, object],
|
||||
serve_overrides: ParameterSweepItem,
|
||||
startup_overrides: ParameterSweepItem,
|
||||
run_number: int,
|
||||
) -> dict[str, object]:
|
||||
run_data["run_number"] = run_number
|
||||
run_data.update(serve_overrides)
|
||||
run_data.update(startup_overrides)
|
||||
return run_data
|
||||
|
||||
|
||||
def _strip_arg(cmd: list[str], keys: tuple[str, ...]) -> list[str]:
|
||||
stripped: list[str] = []
|
||||
skip_next = False
|
||||
for arg in cmd:
|
||||
if skip_next:
|
||||
skip_next = False
|
||||
continue
|
||||
if arg in keys:
|
||||
skip_next = True
|
||||
continue
|
||||
if any(arg.startswith(f"{key}=") for key in keys):
|
||||
continue
|
||||
stripped.append(arg)
|
||||
return stripped
|
||||
|
||||
|
||||
def _apply_output_json(cmd: list[str], output_path: Path) -> list[str]:
|
||||
keys = ("--output-json", "--output_json")
|
||||
cmd = _strip_arg(cmd, keys)
|
||||
return [*cmd, keys[0], str(output_path)]
|
||||
|
||||
|
||||
def _get_comb_base_path(
|
||||
output_dir: Path,
|
||||
serve_comb: ParameterSweepItem,
|
||||
startup_comb: ParameterSweepItem,
|
||||
) -> Path:
|
||||
parts = list[str]()
|
||||
if serve_comb:
|
||||
parts.extend(("SERVE-", serve_comb.name))
|
||||
if startup_comb:
|
||||
parts.extend(("STARTUP-", startup_comb.name))
|
||||
return output_dir / sanitize_filename("-".join(parts))
|
||||
|
||||
|
||||
def _get_comb_run_path(base_path: Path, run_number: int | None) -> Path:
|
||||
if run_number is None:
|
||||
return base_path / "summary.json"
|
||||
return base_path / f"run={run_number}.json"
|
||||
|
||||
|
||||
def run_benchmark(
|
||||
startup_cmd: list[str],
|
||||
*,
|
||||
serve_overrides: ParameterSweepItem,
|
||||
startup_overrides: ParameterSweepItem,
|
||||
run_number: int,
|
||||
output_path: Path,
|
||||
show_stdout: bool,
|
||||
dry_run: bool,
|
||||
) -> dict[str, object] | None:
|
||||
cmd = serve_overrides.apply_to_cmd(startup_cmd)
|
||||
cmd = startup_overrides.apply_to_cmd(cmd)
|
||||
cmd = _apply_output_json(cmd, output_path)
|
||||
|
||||
print("[BEGIN BENCHMARK]")
|
||||
print(f"Serve overrides: {serve_overrides}")
|
||||
print(f"Startup overrides: {startup_overrides}")
|
||||
print(f"Run Number: {run_number}")
|
||||
print(f"Benchmark command: {cmd}")
|
||||
print(f"Output file: {output_path}")
|
||||
|
||||
if output_path.exists():
|
||||
print("Found existing results.")
|
||||
print("[SKIPPED BENCHMARK]")
|
||||
|
||||
with output_path.open("r", encoding="utf-8") as f:
|
||||
run_data = json.load(f)
|
||||
return _update_run_data(
|
||||
run_data, serve_overrides, startup_overrides, run_number
|
||||
)
|
||||
|
||||
if dry_run:
|
||||
print("[END BENCHMARK]")
|
||||
return None
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
subprocess.run(
|
||||
cmd,
|
||||
stdout=None if show_stdout else subprocess.DEVNULL,
|
||||
check=True,
|
||||
)
|
||||
|
||||
with output_path.open("r", encoding="utf-8") as f:
|
||||
run_data = json.load(f)
|
||||
|
||||
run_data = _update_run_data(
|
||||
run_data, serve_overrides, startup_overrides, run_number
|
||||
)
|
||||
|
||||
with output_path.open("w", encoding="utf-8") as f:
|
||||
json.dump(run_data, f, indent=4)
|
||||
|
||||
print("[END BENCHMARK]")
|
||||
return run_data
|
||||
|
||||
|
||||
def run_comb(
|
||||
startup_cmd: list[str],
|
||||
*,
|
||||
serve_comb: ParameterSweepItem,
|
||||
startup_comb: ParameterSweepItem,
|
||||
base_path: Path,
|
||||
num_runs: int,
|
||||
show_stdout: bool,
|
||||
dry_run: bool,
|
||||
) -> list[dict[str, object]] | None:
|
||||
comb_data = list[dict[str, object]]()
|
||||
for run_number in range(num_runs):
|
||||
run_data = run_benchmark(
|
||||
startup_cmd,
|
||||
serve_overrides=serve_comb,
|
||||
startup_overrides=startup_comb,
|
||||
run_number=run_number,
|
||||
output_path=_get_comb_run_path(base_path, run_number),
|
||||
show_stdout=show_stdout,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
if run_data is not None:
|
||||
comb_data.append(run_data)
|
||||
|
||||
if dry_run:
|
||||
return None
|
||||
|
||||
with _get_comb_run_path(base_path, run_number=None).open(
|
||||
"w", encoding="utf-8"
|
||||
) as f:
|
||||
json.dump(comb_data, f, indent=4)
|
||||
|
||||
return comb_data
|
||||
|
||||
|
||||
def run_combs(
|
||||
startup_cmd: list[str],
|
||||
*,
|
||||
serve_params: ParameterSweep,
|
||||
startup_params: ParameterSweep,
|
||||
output_dir: Path,
|
||||
num_runs: int,
|
||||
show_stdout: bool,
|
||||
dry_run: bool,
|
||||
) -> "pd.DataFrame | None":
|
||||
all_data = list[dict[str, object]]()
|
||||
for serve_comb in serve_params:
|
||||
for startup_comb in startup_params:
|
||||
base_path = _get_comb_base_path(output_dir, serve_comb, startup_comb)
|
||||
comb_data = run_comb(
|
||||
startup_cmd,
|
||||
serve_comb=serve_comb,
|
||||
startup_comb=startup_comb,
|
||||
base_path=base_path,
|
||||
num_runs=num_runs,
|
||||
show_stdout=show_stdout,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
if comb_data is not None:
|
||||
all_data.extend(comb_data)
|
||||
|
||||
if dry_run:
|
||||
return None
|
||||
|
||||
combined_df = pd.DataFrame.from_records(all_data)
|
||||
combined_df.to_csv(output_dir / "summary.csv")
|
||||
return combined_df
|
||||
|
||||
|
||||
@dataclass
|
||||
class SweepStartupArgs:
|
||||
startup_cmd: list[str]
|
||||
serve_params: ParameterSweep
|
||||
startup_params: ParameterSweep
|
||||
output_dir: Path
|
||||
num_runs: int
|
||||
show_stdout: bool
|
||||
dry_run: bool
|
||||
resume: str | None
|
||||
strict_params: bool
|
||||
|
||||
parser_name: ClassVar[str] = "startup"
|
||||
parser_help: ClassVar[str] = (
|
||||
"Benchmark vLLM startup time over parameter combinations."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
startup_cmd = shlex.split(args.startup_cmd)
|
||||
|
||||
if args.serve_params:
|
||||
serve_params = ParameterSweep.read_json(args.serve_params)
|
||||
else:
|
||||
serve_params = ParameterSweep.from_records([{}])
|
||||
|
||||
if args.startup_params:
|
||||
startup_params = ParameterSweep.read_json(args.startup_params)
|
||||
else:
|
||||
startup_params = ParameterSweep.from_records([{}])
|
||||
|
||||
supported = _get_supported_startup_keys()
|
||||
serve_params = _filter_params(
|
||||
serve_params, supported=supported, strict=args.strict_params
|
||||
)
|
||||
startup_params = _filter_params(
|
||||
startup_params, supported=supported, strict=args.strict_params
|
||||
)
|
||||
|
||||
if args.num_runs < 1:
|
||||
raise ValueError("`num_runs` should be at least 1.")
|
||||
|
||||
return cls(
|
||||
startup_cmd=startup_cmd,
|
||||
serve_params=serve_params,
|
||||
startup_params=startup_params,
|
||||
output_dir=Path(args.output_dir),
|
||||
num_runs=args.num_runs,
|
||||
show_stdout=args.show_stdout,
|
||||
dry_run=args.dry_run,
|
||||
resume=args.resume,
|
||||
strict_params=args.strict_params,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
"--startup-cmd",
|
||||
type=str,
|
||||
default="vllm bench startup",
|
||||
help="The command used to run the startup benchmark.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--serve-params",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to JSON file containing parameter combinations "
|
||||
"for the `vllm serve` command. Only parameters supported by "
|
||||
"`vllm bench startup` will be applied.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--startup-params",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to JSON file containing parameter combinations "
|
||||
"for the `vllm bench startup` command.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="results",
|
||||
help="The directory to which results are written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-runs",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of runs per parameter combination.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--show-stdout",
|
||||
action="store_true",
|
||||
help="If set, logs the standard output of subcommands.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="If set, prints the commands to run, "
|
||||
"then exits without executing them.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Set this to the name of a directory under `output_dir` (which is a "
|
||||
"timestamp) to resume a previous execution of this script, i.e., only run "
|
||||
"parameter combinations for which there are still no output files.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--strict-params",
|
||||
action="store_true",
|
||||
help="If set, unknown parameters in sweep files raise an error "
|
||||
"instead of being ignored.",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def run_main(args: SweepStartupArgs):
|
||||
timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_dir = args.output_dir / timestamp
|
||||
|
||||
if args.resume and not output_dir.exists():
|
||||
raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")
|
||||
|
||||
try:
|
||||
return run_combs(
|
||||
startup_cmd=args.startup_cmd,
|
||||
serve_params=args.serve_params,
|
||||
startup_params=args.startup_params,
|
||||
output_dir=output_dir,
|
||||
num_runs=args.num_runs,
|
||||
show_stdout=args.show_stdout,
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
except BaseException as exc:
|
||||
raise RuntimeError(
|
||||
f"The script was terminated early. Use `--resume {timestamp}` "
|
||||
f"to continue the script from its last checkpoint."
|
||||
) from exc
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
run_main(SweepStartupArgs.from_cli_args(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=SweepStartupArgs.parser_help)
|
||||
SweepStartupArgs.add_cli_args(parser)
|
||||
main(parser.parse_args())
|
||||
4
vllm/benchmarks/sweep/utils.py
Normal file
4
vllm/benchmarks/sweep/utils.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
def sanitize_filename(filename: str) -> str:
|
||||
return filename.replace("/", "_").replace("..", "__").strip("'").strip('"')
|
||||
946
vllm/benchmarks/throughput.py
Normal file
946
vllm/benchmarks/throughput.py
Normal file
@@ -0,0 +1,946 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Benchmark offline inference throughput."""
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import warnings
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import uvloop
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
|
||||
|
||||
from vllm.benchmarks.datasets import (
|
||||
AIMODataset,
|
||||
BurstGPTDataset,
|
||||
ConversationDataset,
|
||||
InstructCoderDataset,
|
||||
MultiModalConversationDataset,
|
||||
PrefixRepetitionRandomDataset,
|
||||
RandomDataset,
|
||||
RandomDatasetForReranking,
|
||||
RandomMultiModalDataset,
|
||||
SampleRequest,
|
||||
ShareGPTDataset,
|
||||
SonnetDataset,
|
||||
VisionArenaDataset,
|
||||
add_random_dataset_base_args,
|
||||
add_random_multimodal_dataset_args,
|
||||
)
|
||||
from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||
from vllm.inputs import TextPrompt, TokensPrompt
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
from vllm.tokenizers import TokenizerLike, get_tokenizer
|
||||
from vllm.utils.async_utils import merge_async_iterators
|
||||
|
||||
|
||||
def run_vllm(
|
||||
requests: list[SampleRequest],
|
||||
n: int,
|
||||
engine_args: EngineArgs,
|
||||
do_profile: bool,
|
||||
disable_detokenize: bool = False,
|
||||
) -> tuple[float, list[RequestOutput] | None]:
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
assert all(
|
||||
llm.llm_engine.model_config.max_model_len
|
||||
>= (request.prompt_len + request.expected_output_len)
|
||||
for request in requests
|
||||
), (
|
||||
"Please ensure that max_model_len is greater than the sum of"
|
||||
" prompt_len and expected_output_len for all requests."
|
||||
)
|
||||
# Add the requests to the engine.
|
||||
prompts: list[TextPrompt | TokensPrompt] = []
|
||||
sampling_params: list[SamplingParams] = []
|
||||
for request in requests:
|
||||
prompt = (
|
||||
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"])
|
||||
if "prompt_token_ids" in request.prompt
|
||||
else TextPrompt(prompt=request.prompt)
|
||||
)
|
||||
if request.multi_modal_data:
|
||||
assert isinstance(request.multi_modal_data, dict)
|
||||
prompt["multi_modal_data"] = request.multi_modal_data
|
||||
prompts.append(prompt)
|
||||
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
detokenize=not disable_detokenize,
|
||||
)
|
||||
)
|
||||
lora_requests: list[LoRARequest] | None = None
|
||||
if engine_args.enable_lora:
|
||||
lora_requests = [request.lora_request for request in requests]
|
||||
|
||||
use_beam_search = False
|
||||
|
||||
outputs = None
|
||||
if not use_beam_search:
|
||||
start = time.perf_counter()
|
||||
if do_profile:
|
||||
llm.start_profile()
|
||||
outputs = llm.generate(
|
||||
prompts, sampling_params, lora_request=lora_requests, use_tqdm=True
|
||||
)
|
||||
if do_profile:
|
||||
llm.stop_profile()
|
||||
end = time.perf_counter()
|
||||
else:
|
||||
assert lora_requests is None, "BeamSearch API does not support LoRA"
|
||||
prompts = [request.prompt for request in requests]
|
||||
# output_len should be the same for all requests.
|
||||
output_len = requests[0].expected_output_len
|
||||
for request in requests:
|
||||
assert request.expected_output_len == output_len
|
||||
start = time.perf_counter()
|
||||
if do_profile:
|
||||
llm.start_profile()
|
||||
llm.beam_search(
|
||||
prompts,
|
||||
BeamSearchParams(
|
||||
beam_width=n,
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
),
|
||||
)
|
||||
if do_profile:
|
||||
llm.stop_profile()
|
||||
end = time.perf_counter()
|
||||
return end - start, outputs
|
||||
|
||||
|
||||
def run_vllm_chat(
|
||||
requests: list[SampleRequest],
|
||||
n: int,
|
||||
engine_args: EngineArgs,
|
||||
do_profile: bool,
|
||||
disable_detokenize: bool = False,
|
||||
) -> tuple[float, list[RequestOutput]]:
|
||||
"""
|
||||
Run vLLM chat benchmark. This function is recommended ONLY for benchmarking
|
||||
multimodal models as it properly handles multimodal inputs and chat
|
||||
formatting. For non-multimodal models, use run_vllm() instead.
|
||||
"""
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
|
||||
assert all(
|
||||
llm.llm_engine.model_config.max_model_len
|
||||
>= (request.prompt_len + request.expected_output_len)
|
||||
for request in requests
|
||||
), (
|
||||
"Please ensure that max_model_len is greater than the sum of "
|
||||
"prompt_len and expected_output_len for all requests."
|
||||
)
|
||||
|
||||
prompts = []
|
||||
sampling_params: list[SamplingParams] = []
|
||||
for request in requests:
|
||||
prompts.append(request.prompt)
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
detokenize=not disable_detokenize,
|
||||
)
|
||||
)
|
||||
start = time.perf_counter()
|
||||
if do_profile:
|
||||
llm.start_profile()
|
||||
outputs = llm.chat(prompts, sampling_params, use_tqdm=True)
|
||||
if do_profile:
|
||||
llm.stop_profile()
|
||||
end = time.perf_counter()
|
||||
return end - start, outputs
|
||||
|
||||
|
||||
async def run_vllm_async(
|
||||
requests: list[SampleRequest],
|
||||
n: int,
|
||||
engine_args: AsyncEngineArgs,
|
||||
do_profile: bool,
|
||||
disable_frontend_multiprocessing: bool = False,
|
||||
disable_detokenize: bool = False,
|
||||
) -> float:
|
||||
from vllm import SamplingParams
|
||||
from vllm.entrypoints.openai.api_server import (
|
||||
build_async_engine_client_from_engine_args,
|
||||
)
|
||||
|
||||
async with build_async_engine_client_from_engine_args(
|
||||
engine_args,
|
||||
disable_frontend_multiprocessing=disable_frontend_multiprocessing,
|
||||
) as llm:
|
||||
model_config = llm.model_config
|
||||
assert all(
|
||||
model_config.max_model_len
|
||||
>= (request.prompt_len + request.expected_output_len)
|
||||
for request in requests
|
||||
), (
|
||||
"Please ensure that max_model_len is greater than the sum of"
|
||||
" prompt_len and expected_output_len for all requests."
|
||||
)
|
||||
|
||||
# Add the requests to the engine.
|
||||
prompts: list[TextPrompt | TokensPrompt] = []
|
||||
sampling_params: list[SamplingParams] = []
|
||||
lora_requests: list[LoRARequest | None] = []
|
||||
for request in requests:
|
||||
prompt = (
|
||||
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"])
|
||||
if "prompt_token_ids" in request.prompt
|
||||
else TextPrompt(prompt=request.prompt)
|
||||
)
|
||||
|
||||
if request.multi_modal_data:
|
||||
assert isinstance(request.multi_modal_data, dict)
|
||||
prompt["multi_modal_data"] = request.multi_modal_data
|
||||
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
detokenize=not disable_detokenize,
|
||||
)
|
||||
)
|
||||
prompts.append(prompt)
|
||||
lora_requests.append(request.lora_request)
|
||||
|
||||
generators = []
|
||||
start = time.perf_counter()
|
||||
if do_profile:
|
||||
await llm.start_profile()
|
||||
for i, (prompt, sp, lr) in enumerate(
|
||||
zip(prompts, sampling_params, lora_requests)
|
||||
):
|
||||
generator = llm.generate(prompt, sp, lora_request=lr, request_id=f"test{i}")
|
||||
generators.append(generator)
|
||||
all_gens = merge_async_iterators(*generators)
|
||||
async for i, res in all_gens:
|
||||
pass
|
||||
if do_profile:
|
||||
await llm.stop_profile()
|
||||
end = time.perf_counter()
|
||||
return end - start
|
||||
|
||||
|
||||
def run_hf(
|
||||
requests: list[SampleRequest],
|
||||
model: str,
|
||||
tokenizer: TokenizerLike,
|
||||
n: int,
|
||||
max_batch_size: int,
|
||||
trust_remote_code: bool,
|
||||
disable_detokenize: bool = False,
|
||||
) -> float:
|
||||
assert isinstance(tokenizer, PreTrainedTokenizerBase), (
|
||||
"the hf backend only supports HF tokenizers"
|
||||
)
|
||||
llm = AutoModelForCausalLM.from_pretrained(
|
||||
model, dtype=torch.float16, trust_remote_code=trust_remote_code
|
||||
)
|
||||
if llm.config.model_type == "llama":
|
||||
# To enable padding in the HF backend.
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
llm = llm.cuda()
|
||||
|
||||
pbar = tqdm(total=len(requests))
|
||||
start = time.perf_counter()
|
||||
batch: list[str] = []
|
||||
max_prompt_len = 0
|
||||
max_output_len = 0
|
||||
for i in range(len(requests)):
|
||||
prompt = requests[i].prompt
|
||||
prompt_len = requests[i].prompt_len
|
||||
output_len = requests[i].expected_output_len
|
||||
# Add the prompt to the batch.
|
||||
batch.append(prompt)
|
||||
max_prompt_len = max(max_prompt_len, prompt_len)
|
||||
max_output_len = max(max_output_len, output_len)
|
||||
if len(batch) < max_batch_size and i != len(requests) - 1:
|
||||
# Check if we can add more requests to the batch.
|
||||
next_prompt_len = requests[i + 1].prompt_len
|
||||
next_output_len = requests[i + 1].expected_output_len
|
||||
if (
|
||||
max(max_prompt_len, next_prompt_len)
|
||||
+ max(max_output_len, next_output_len)
|
||||
) <= 2048:
|
||||
# We can add more requests to the batch.
|
||||
continue
|
||||
|
||||
# Generate the sequences.
|
||||
input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids
|
||||
llm_outputs = llm.generate(
|
||||
input_ids=input_ids.cuda(),
|
||||
do_sample=True,
|
||||
num_return_sequences=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
use_cache=True,
|
||||
max_new_tokens=max_output_len,
|
||||
)
|
||||
if not disable_detokenize:
|
||||
# Include the decoding time.
|
||||
tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
|
||||
pbar.update(len(batch))
|
||||
|
||||
# Clear the batch.
|
||||
batch = []
|
||||
max_prompt_len = 0
|
||||
max_output_len = 0
|
||||
end = time.perf_counter()
|
||||
return end - start
|
||||
|
||||
|
||||
def save_to_pytorch_benchmark_format(
|
||||
args: argparse.Namespace, results: dict[str, Any]
|
||||
) -> None:
|
||||
pt_records = convert_to_pytorch_benchmark_format(
|
||||
args=args,
|
||||
metrics={
|
||||
"requests_per_second": [results["requests_per_second"]],
|
||||
"tokens_per_second": [results["tokens_per_second"]],
|
||||
},
|
||||
extra_info={
|
||||
k: results[k] for k in ["elapsed_time", "num_requests", "total_num_tokens"]
|
||||
},
|
||||
)
|
||||
if pt_records:
|
||||
# Don't use json suffix here as we don't want CI to pick it up
|
||||
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
|
||||
write_to_json(pt_file, pt_records)
|
||||
|
||||
|
||||
def get_requests(args, tokenizer):
|
||||
# Common parameters for all dataset types.
|
||||
common_kwargs = {
|
||||
"dataset_path": args.dataset_path,
|
||||
"random_seed": args.seed,
|
||||
}
|
||||
sample_kwargs = {
|
||||
"tokenizer": tokenizer,
|
||||
"lora_path": args.lora_path,
|
||||
"max_loras": args.max_loras,
|
||||
"num_requests": args.num_prompts,
|
||||
}
|
||||
|
||||
if args.dataset_name == "random" or (
|
||||
args.dataset_path is None
|
||||
and args.dataset_name not in {"prefix_repetition", "random-mm", "random-rerank"}
|
||||
):
|
||||
sample_kwargs["range_ratio"] = args.random_range_ratio
|
||||
# prefer random_* arguments, fall back to regular arguments
|
||||
random_prefix_len = getattr(args, "random_prefix_len", None)
|
||||
sample_kwargs["prefix_len"] = (
|
||||
random_prefix_len if random_prefix_len is not None else args.prefix_len
|
||||
)
|
||||
random_input_len = getattr(args, "random_input_len", None)
|
||||
sample_kwargs["input_len"] = (
|
||||
random_input_len if random_input_len is not None else args.input_len
|
||||
)
|
||||
random_output_len = getattr(args, "random_output_len", None)
|
||||
sample_kwargs["output_len"] = (
|
||||
random_output_len if random_output_len is not None else args.output_len
|
||||
)
|
||||
dataset_cls = RandomDataset
|
||||
elif args.dataset_name == "sharegpt":
|
||||
dataset_cls = ShareGPTDataset
|
||||
if args.backend == "vllm-chat":
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
if args.output_len is not None:
|
||||
sample_kwargs["output_len"] = args.output_len
|
||||
elif args.dataset_name == "sonnet":
|
||||
assert tokenizer.chat_template or tokenizer.default_chat_template, (
|
||||
"Tokenizer/model must have chat template for sonnet dataset."
|
||||
)
|
||||
dataset_cls = SonnetDataset
|
||||
sample_kwargs["prefix_len"] = args.prefix_len
|
||||
sample_kwargs["return_prompt_formatted"] = True
|
||||
if args.input_len is not None:
|
||||
sample_kwargs["input_len"] = args.input_len
|
||||
if args.output_len is not None:
|
||||
sample_kwargs["output_len"] = args.output_len
|
||||
elif args.dataset_name == "burstgpt":
|
||||
dataset_cls = BurstGPTDataset
|
||||
elif args.dataset_name == "hf":
|
||||
if args.output_len is not None:
|
||||
sample_kwargs["output_len"] = args.output_len
|
||||
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = VisionArenaDataset
|
||||
common_kwargs["dataset_subset"] = None
|
||||
common_kwargs["dataset_split"] = "train"
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = InstructCoderDataset
|
||||
common_kwargs["dataset_split"] = "train"
|
||||
elif args.dataset_path in MultiModalConversationDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = MultiModalConversationDataset
|
||||
common_kwargs["dataset_subset"] = args.hf_subset
|
||||
common_kwargs["dataset_split"] = args.hf_split
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = ConversationDataset
|
||||
common_kwargs["dataset_subset"] = args.hf_subset
|
||||
common_kwargs["dataset_split"] = args.hf_split
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = AIMODataset
|
||||
common_kwargs["dataset_subset"] = None
|
||||
common_kwargs["dataset_split"] = "train"
|
||||
elif args.dataset_name == "prefix_repetition":
|
||||
dataset_cls = PrefixRepetitionRandomDataset
|
||||
sample_kwargs["prefix_len"] = args.prefix_repetition_prefix_len
|
||||
sample_kwargs["suffix_len"] = args.prefix_repetition_suffix_len
|
||||
sample_kwargs["num_prefixes"] = args.prefix_repetition_num_prefixes
|
||||
sample_kwargs["output_len"] = args.prefix_repetition_output_len
|
||||
elif args.dataset_name == "random-mm":
|
||||
dataset_cls = RandomMultiModalDataset
|
||||
# prefer random_* arguments, fall back to regular arguments
|
||||
random_input_len = getattr(args, "random_input_len", None)
|
||||
sample_kwargs["input_len"] = (
|
||||
random_input_len
|
||||
if random_input_len is not None
|
||||
else getattr(args, "input_len", None)
|
||||
)
|
||||
random_output_len = getattr(args, "random_output_len", None)
|
||||
sample_kwargs["output_len"] = (
|
||||
random_output_len
|
||||
if random_output_len is not None
|
||||
else getattr(args, "output_len", None)
|
||||
)
|
||||
sample_kwargs["base_items_per_request"] = getattr(
|
||||
args, "random_mm_base_items_per_request", None
|
||||
)
|
||||
sample_kwargs["num_mm_items_range_ratio"] = getattr(
|
||||
args, "random_mm_num_mm_items_range_ratio", None
|
||||
)
|
||||
sample_kwargs["limit_mm_per_prompt"] = getattr(
|
||||
args, "random_mm_limit_mm_per_prompt", None
|
||||
)
|
||||
sample_kwargs["bucket_config"] = getattr(args, "random_mm_bucket_config", None)
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
random_prefix_len = getattr(args, "random_prefix_len", None)
|
||||
prefix_len = getattr(args, "prefix_len", None)
|
||||
sample_kwargs["prefix_len"] = (
|
||||
random_prefix_len if random_prefix_len is not None else prefix_len
|
||||
)
|
||||
sample_kwargs["range_ratio"] = args.random_range_ratio
|
||||
elif args.dataset_name == "random-rerank":
|
||||
dataset_cls = RandomDatasetForReranking
|
||||
# prefer random_* arguments, fall back to regular arguments
|
||||
random_input_len = getattr(args, "random_input_len", None)
|
||||
sample_kwargs["input_len"] = (
|
||||
random_input_len
|
||||
if random_input_len is not None
|
||||
else getattr(args, "input_len", None)
|
||||
)
|
||||
random_output_len = getattr(args, "random_output_len", None)
|
||||
sample_kwargs["output_len"] = (
|
||||
random_output_len
|
||||
if random_output_len is not None
|
||||
else getattr(args, "output_len", None)
|
||||
)
|
||||
sample_kwargs["batchsize"] = getattr(args, "random_batch_size", 1)
|
||||
sample_kwargs["is_reranker"] = not getattr(args, "no_reranker", False)
|
||||
sample_kwargs["range_ratio"] = args.random_range_ratio
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset name: {args.dataset_name}")
|
||||
# Remove None values
|
||||
sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None}
|
||||
requests = dataset_cls(**common_kwargs).sample(**sample_kwargs)
|
||||
requests = filter_requests_for_dp(requests, args.data_parallel_size)
|
||||
return requests
|
||||
|
||||
|
||||
def filter_requests_for_dp(requests, data_parallel_size):
|
||||
# Note(zhuohan): The way we get data_parallel_rank is hacky and only
|
||||
# works for external launcher mode. Should be cleaned up and deprecated
|
||||
# in the future with a better vLLM distributed process design.
|
||||
if data_parallel_size == 1:
|
||||
return requests
|
||||
|
||||
global_rank = int(os.environ["RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
data_parallel_rank = global_rank // (world_size // data_parallel_size)
|
||||
return [
|
||||
r
|
||||
for i, r in enumerate(requests)
|
||||
if i % data_parallel_size == data_parallel_rank
|
||||
]
|
||||
|
||||
|
||||
def validate_args(args):
|
||||
"""
|
||||
Validate command-line arguments.
|
||||
"""
|
||||
|
||||
# === Deprecation and Defaulting ===
|
||||
if args.dataset is not None:
|
||||
warnings.warn(
|
||||
"The '--dataset' argument will be deprecated in the next release. "
|
||||
"Please use '--dataset-name' and '--dataset-path' instead.",
|
||||
stacklevel=2,
|
||||
)
|
||||
args.dataset_path = args.dataset
|
||||
|
||||
if not getattr(args, "tokenizer", None):
|
||||
args.tokenizer = args.model
|
||||
|
||||
# === Backend Validation ===
|
||||
valid_backends = {"vllm", "hf", "mii", "vllm-chat"}
|
||||
if args.backend not in valid_backends:
|
||||
raise ValueError(f"Unsupported backend: {args.backend}")
|
||||
|
||||
# === Dataset Configuration ===
|
||||
if (
|
||||
not args.dataset
|
||||
and not args.dataset_path
|
||||
and args.dataset_name not in {"prefix_repetition"}
|
||||
):
|
||||
print("When dataset path is not set, it will default to random dataset")
|
||||
args.dataset_name = "random"
|
||||
random_input_len = getattr(args, "random_input_len", None)
|
||||
if args.input_len is None and random_input_len is None:
|
||||
raise ValueError(
|
||||
"Either --input-len or --random-input-len must be provided "
|
||||
"for a random dataset"
|
||||
)
|
||||
|
||||
# === Dataset Name Specific Checks ===
|
||||
# --hf-subset and --hf-split: only used
|
||||
# when dataset_name is 'hf'
|
||||
if args.dataset_name != "hf" and (
|
||||
getattr(args, "hf_subset", None) is not None
|
||||
or getattr(args, "hf_split", None) is not None
|
||||
):
|
||||
warnings.warn(
|
||||
"--hf-subset and --hf-split will be ignored \
|
||||
since --dataset-name is not 'hf'.",
|
||||
stacklevel=2,
|
||||
)
|
||||
elif args.dataset_name == "hf":
|
||||
if args.dataset_path in (
|
||||
VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys()
|
||||
| MultiModalConversationDataset.SUPPORTED_DATASET_PATHS
|
||||
| ConversationDataset.SUPPORTED_DATASET_PATHS
|
||||
):
|
||||
assert args.backend == "vllm-chat", (
|
||||
f"{args.dataset_path} needs to use vllm-chat as the backend."
|
||||
)
|
||||
elif args.dataset_path in (
|
||||
InstructCoderDataset.SUPPORTED_DATASET_PATHS
|
||||
| AIMODataset.SUPPORTED_DATASET_PATHS
|
||||
):
|
||||
assert args.backend == "vllm", (
|
||||
f"{args.dataset_path} needs to use vllm as the backend."
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"{args.dataset_path} is not supported by hf dataset.")
|
||||
|
||||
# --random-range-ratio: only used when dataset_name is 'random',
|
||||
# 'random-mm', or 'random-rerank'
|
||||
if (
|
||||
args.dataset_name not in {"random", "random-mm", "random-rerank"}
|
||||
and args.random_range_ratio is not None
|
||||
):
|
||||
warnings.warn(
|
||||
"--random-range-ratio will be ignored since \
|
||||
--dataset-name is not 'random', 'random-mm', or 'random-rerank'.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# --random-batch-size: only used when dataset_name is 'random-rerank'
|
||||
if (
|
||||
args.dataset_name != "random-rerank"
|
||||
and getattr(args, "random_batch_size", None) is not None
|
||||
) and args.random_batch_size != 1:
|
||||
warnings.warn(
|
||||
"--random-batch-size will be ignored since \
|
||||
--dataset-name is not 'random-rerank'.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# --no-reranker: only used when dataset_name is 'random-rerank'
|
||||
if args.dataset_name != "random-rerank" and getattr(args, "no_reranker", False):
|
||||
warnings.warn(
|
||||
"--no-reranker will be ignored since \
|
||||
--dataset-name is not 'random-rerank'.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# --prefix-len: only used when dataset_name is 'random', 'random-mm',
|
||||
# 'sonnet', or not set.
|
||||
if (
|
||||
args.dataset_name not in {"random", "random-mm", "sonnet", None}
|
||||
and args.prefix_len is not None
|
||||
):
|
||||
warnings.warn(
|
||||
"--prefix-len will be ignored since --dataset-name\
|
||||
is not 'random', 'random-mm', 'sonnet', or not set.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# === Random Dataset Argument Conflict Detection ===
|
||||
# Check for conflicts between regular and random arguments when using
|
||||
# random datasets
|
||||
if args.dataset_name in {"random", "random-mm", "random-rerank"}:
|
||||
random_input_len = getattr(args, "random_input_len", None)
|
||||
random_output_len = getattr(args, "random_output_len", None)
|
||||
random_prefix_len = getattr(args, "random_prefix_len", None)
|
||||
|
||||
if args.input_len is not None and random_input_len is not None:
|
||||
warnings.warn(
|
||||
"Both --input-len and --random-input-len are specified. "
|
||||
"The random version (--random-input-len) will be preferred "
|
||||
"in this run.",
|
||||
stacklevel=2,
|
||||
)
|
||||
if args.output_len is not None and random_output_len is not None:
|
||||
warnings.warn(
|
||||
"Both --output-len and --random-output-len are specified. "
|
||||
"The random version (--random-output-len) will be preferred "
|
||||
"in this run.",
|
||||
stacklevel=2,
|
||||
)
|
||||
if args.prefix_len is not None and random_prefix_len is not None:
|
||||
warnings.warn(
|
||||
"Both --prefix-len and --random-prefix-len are specified. "
|
||||
"The random version (--random-prefix-len) will be preferred "
|
||||
"in this run.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# === LoRA Settings ===
|
||||
if getattr(args, "enable_lora", False) and args.backend != "vllm":
|
||||
raise ValueError("LoRA benchmarking is only supported for vLLM backend")
|
||||
if getattr(args, "enable_lora", False) and args.lora_path is None:
|
||||
raise ValueError("LoRA path must be provided when enable_lora is True")
|
||||
|
||||
# === Backend-specific Validations ===
|
||||
if args.backend == "hf" and args.hf_max_batch_size is None:
|
||||
raise ValueError("HF max batch size is required for HF backend")
|
||||
if args.backend != "hf" and args.hf_max_batch_size is not None:
|
||||
raise ValueError("HF max batch size is only for HF backend.")
|
||||
|
||||
if (
|
||||
args.backend in {"hf", "mii"}
|
||||
and getattr(args, "quantization", None) is not None
|
||||
):
|
||||
raise ValueError("Quantization is only for vLLM backend.")
|
||||
|
||||
if args.backend == "mii" and args.dtype != "auto":
|
||||
raise ValueError("dtype must be auto for MII backend.")
|
||||
if args.backend == "mii" and args.n != 1:
|
||||
raise ValueError("n must be 1 for MII backend.")
|
||||
if args.backend == "mii" and args.tokenizer != args.model:
|
||||
raise ValueError("Tokenizer must be the same as the model for MII backend.")
|
||||
|
||||
if args.data_parallel_size > 1 and (
|
||||
args.distributed_executor_backend != "external_launcher" or args.async_engine
|
||||
):
|
||||
# --data-parallel is not supported fully.
|
||||
# Old issue: https://github.com/vllm-project/vllm/issues/16222
|
||||
# Currently we only support data parallel with external launcher
|
||||
# mode (i.e., launch with toruchrun).
|
||||
raise ValueError(
|
||||
"Data parallel is only supported with external launcher mode "
|
||||
"with synchronous engine in offline benchmark, "
|
||||
"please use benchmark serving instead"
|
||||
)
|
||||
|
||||
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
type=str,
|
||||
choices=["vllm", "hf", "mii", "vllm-chat"],
|
||||
default="vllm",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
choices=[
|
||||
"sharegpt",
|
||||
"random",
|
||||
"sonnet",
|
||||
"burstgpt",
|
||||
"hf",
|
||||
"prefix_repetition",
|
||||
"random-mm",
|
||||
"random-rerank",
|
||||
],
|
||||
help="Name of the dataset to benchmark on.",
|
||||
default="sharegpt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the ShareGPT dataset, will be deprecated in\
|
||||
the next release. The dataset is expected to "
|
||||
"be a json in form of list[dict[..., conversations: "
|
||||
"list[dict[..., value: <prompt_or_response>]]]]",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-path", type=str, default=None, help="Path to the dataset"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Input prompt length for each request",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Output length for each request. Overrides the "
|
||||
"output length from the dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n", type=int, default=1, help="Number of generated sequences per prompt."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-prompts", type=int, default=1000, help="Number of prompts to process."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-max-batch-size",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum batch size for HF backend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-json",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to save the throughput results in JSON format.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--async-engine",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use vLLM async engine rather than LLM class.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-frontend-multiprocessing",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Disable decoupled async engine frontend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-detokenize",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Do not detokenize the response (i.e. do not include "
|
||||
"detokenization time in the measurement)"
|
||||
),
|
||||
)
|
||||
# LoRA
|
||||
parser.add_argument(
|
||||
"--lora-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the lora adapters to use. This can be an absolute path, "
|
||||
"a relative path, or a Hugging Face model identifier.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefix-len",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of fixed prefix tokens before the random "
|
||||
"context in a request (default: 0).",
|
||||
)
|
||||
|
||||
# hf dtaset
|
||||
parser.add_argument(
|
||||
"--hf-subset",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Subset of the HF dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-split",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Split of the HF dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use vLLM Profiling. --profiler-config must be provided on the server.",
|
||||
)
|
||||
|
||||
# prefix repetition dataset
|
||||
parser.add_argument(
|
||||
"--prefix-repetition-prefix-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of prefix tokens per request, used only for prefix "
|
||||
"repetition dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefix-repetition-suffix-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of suffix tokens per request, used only for prefix "
|
||||
"repetition dataset. Total input length is prefix_len + suffix_len.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefix-repetition-num-prefixes",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of prefixes to generate, used only for prefix repetition "
|
||||
"dataset. Prompts per prefix is num_requests // num_prefixes.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefix-repetition-output-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of output tokens per request, used only for prefix "
|
||||
"repetition dataset.",
|
||||
)
|
||||
|
||||
# (random, random-mm, random-rerank)
|
||||
add_random_dataset_base_args(parser)
|
||||
add_random_multimodal_dataset_args(parser)
|
||||
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
validate_args(args)
|
||||
if args.seed is None:
|
||||
args.seed = 0
|
||||
random.seed(args.seed)
|
||||
# Sample the requests.
|
||||
if (
|
||||
args.backend == "hf" or args.backend == "mii"
|
||||
) and args.tokenizer_mode == "auto":
|
||||
# mistral_common tokenizer is only supported on vllm and vllm-chat backends;
|
||||
# for hf and mii backends, we use hf tokenizer
|
||||
args.tokenizer_mode = "hf"
|
||||
tokenizer = get_tokenizer(
|
||||
args.tokenizer,
|
||||
tokenizer_mode=args.tokenizer_mode,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
)
|
||||
requests = get_requests(args, tokenizer)
|
||||
is_multi_modal = any(request.multi_modal_data is not None for request in requests)
|
||||
request_outputs: list[RequestOutput] | None = None
|
||||
if args.backend == "vllm":
|
||||
if args.async_engine:
|
||||
elapsed_time = uvloop.run(
|
||||
run_vllm_async(
|
||||
requests,
|
||||
args.n,
|
||||
AsyncEngineArgs.from_cli_args(args),
|
||||
disable_frontend_multiprocessing=args.disable_frontend_multiprocessing,
|
||||
disable_detokenize=args.disable_detokenize,
|
||||
do_profile=args.profile,
|
||||
)
|
||||
)
|
||||
else:
|
||||
elapsed_time, request_outputs = run_vllm(
|
||||
requests,
|
||||
args.n,
|
||||
EngineArgs.from_cli_args(args),
|
||||
disable_detokenize=args.disable_detokenize,
|
||||
do_profile=args.profile,
|
||||
)
|
||||
elif args.backend == "hf":
|
||||
assert args.tensor_parallel_size == 1
|
||||
if args.profile:
|
||||
raise NotImplementedError("Profiling not implemented yet for backend='hf'.")
|
||||
elapsed_time = run_hf(
|
||||
requests,
|
||||
args.model,
|
||||
tokenizer,
|
||||
args.n,
|
||||
args.hf_max_batch_size,
|
||||
args.trust_remote_code,
|
||||
args.disable_detokenize,
|
||||
)
|
||||
elif args.backend == "vllm-chat":
|
||||
elapsed_time, request_outputs = run_vllm_chat(
|
||||
requests,
|
||||
args.n,
|
||||
EngineArgs.from_cli_args(args),
|
||||
disable_detokenize=args.disable_detokenize,
|
||||
do_profile=args.profile,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown backend: {args.backend}")
|
||||
|
||||
if request_outputs:
|
||||
# Note: with the vllm and vllm-chat backends,
|
||||
# we have request_outputs, which we use to count tokens.
|
||||
total_prompt_tokens = 0
|
||||
total_output_tokens = 0
|
||||
for ro in request_outputs:
|
||||
if not isinstance(ro, RequestOutput):
|
||||
continue
|
||||
total_prompt_tokens += (
|
||||
len(ro.prompt_token_ids) if ro.prompt_token_ids else 0
|
||||
)
|
||||
total_output_tokens += sum(len(o.token_ids) for o in ro.outputs if o)
|
||||
total_num_tokens = total_prompt_tokens + total_output_tokens
|
||||
else:
|
||||
total_num_tokens = sum(r.prompt_len + r.expected_output_len for r in requests)
|
||||
total_output_tokens = sum(r.expected_output_len for r in requests)
|
||||
total_prompt_tokens = total_num_tokens - total_output_tokens
|
||||
|
||||
if is_multi_modal and args.backend != "vllm-chat":
|
||||
print(
|
||||
"\033[91mWARNING\033[0m: Multi-modal request with "
|
||||
f"{args.backend} backend detected. The "
|
||||
"following metrics are not accurate because image tokens are not"
|
||||
" counted. See vllm-project/vllm/issues/9778 for details."
|
||||
)
|
||||
# TODO(vllm-project/vllm/issues/9778): Count multi-modal token length.
|
||||
# vllm-chat backend counts the image tokens now
|
||||
|
||||
print(
|
||||
f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
||||
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
|
||||
f"{total_output_tokens / elapsed_time:.2f} output tokens/s"
|
||||
)
|
||||
print(f"Total num prompt tokens: {total_prompt_tokens}")
|
||||
print(f"Total num output tokens: {total_output_tokens}")
|
||||
|
||||
# Output JSON results if specified
|
||||
if args.output_json:
|
||||
results = {
|
||||
"elapsed_time": elapsed_time,
|
||||
"num_requests": len(requests),
|
||||
"total_num_tokens": total_num_tokens,
|
||||
"requests_per_second": len(requests) / elapsed_time,
|
||||
"tokens_per_second": total_num_tokens / elapsed_time,
|
||||
}
|
||||
with open(args.output_json, "w") as f:
|
||||
json.dump(results, f, indent=4)
|
||||
save_to_pytorch_benchmark_format(args, results)
|
||||
Reference in New Issue
Block a user