Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -34,6 +34,7 @@ from collections.abc import AsyncGenerator, Iterable
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Any, Literal
import aiohttp
@@ -1183,6 +1184,49 @@ def save_to_pytorch_benchmark_format(
write_to_json(pt_file, pt_records)
def compute_result_filename(
args: argparse.Namespace,
model_id: str,
label: str,
current_dt: str,
) -> str | None:
"""Compute the result filename based on benchmark configuration.
Args:
args: Command line arguments containing result configuration
model_id: The model identifier
label: The benchmark label
current_dt: Current datetime string
Returns:
The computed filename path or None if no result saving is requested
"""
if not (args.plot_timeline or args.save_result or args.append_result):
return None
base_model_id = model_id.split("/")[-1]
max_concurrency_str = (
f"-concurrency{args.max_concurrency}"
if args.max_concurrency is not None
else ""
)
label = label or args.backend
if args.ramp_up_strategy is not None:
file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
else:
file_name = f"{label}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
if args.result_filename:
file_name = args.result_filename
if args.result_dir:
os.makedirs(args.result_dir, exist_ok=True)
file_name = os.path.join(args.result_dir, file_name)
return file_name
def add_cli_args(parser: argparse.ArgumentParser):
add_dataset_parser(parser)
parser.add_argument(
@@ -1277,6 +1321,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
- "slow" will always use the slow tokenizer.\n
- "mistral" will always use the tokenizer from `mistral_common`.\n
- "deepseek_v32" will always use the tokenizer from `deepseek_v32`.\n
- "qwen_vl" will always use the tokenizer from `qwen_vl`.\n
- Other custom values can be supported via plugins.""",
)
parser.add_argument("--use-beam-search", action="store_true")
@@ -1535,6 +1580,30 @@ def add_cli_args(parser: argparse.ArgumentParser):
"connecting to servers with self-signed certificates.",
)
parser.add_argument(
"--plot-timeline",
action="store_true",
help="Generate an HTML timeline plot showing request execution. "
"The plot will be saved alongside the results JSON file.",
)
parser.add_argument(
"--timeline-itl-thresholds",
type=float,
nargs=2,
default=[25.0, 50.0],
metavar=("THRESHOLD1", "THRESHOLD2"),
help="ITL thresholds in milliseconds for timeline plot coloring. "
"Specify two values to categorize inter-token latencies into three groups: "
"below first threshold (green), between thresholds (orange), "
"and above second threshold (red). Default: 25 50 (milliseconds).",
)
parser.add_argument(
"--plot-dataset-stats",
action="store_true",
help="Generate a matplotlib figure with dataset statistics showing "
"prompt tokens, output tokens, and combined token distributions.",
)
def main(args: argparse.Namespace) -> dict[str, Any]:
return asyncio.run(main_async(args))
@@ -1770,6 +1839,86 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
# Merge with benchmark result
result_json = {**result_json, **benchmark_result}
# Compute file_name once before using it for plots or saving results
file_name = compute_result_filename(args, model_id, label, current_dt)
# Generate timeline plot if requested
if args.plot_timeline:
try:
from vllm.benchmarks.plot import generate_timeline_plot
# Prepare per-request data for timeline
per_request_data = []
start_times = benchmark_result.get("start_times", [])
ttfts = benchmark_result.get("ttfts", [])
itls = benchmark_result.get("itls", [])
input_lens = benchmark_result.get("input_lens", [])
output_lens = benchmark_result.get("output_lens", [])
if start_times and ttfts and itls:
for i in range(len(start_times)):
# Calculate latency as ttft + sum of all itls
latency = ttfts[i] + sum(itls[i]) if itls[i] else ttfts[i]
per_request_data.append(
{
"start_time": start_times[i],
"ttft": ttfts[i],
"itl": itls[i],
"latency": latency,
"prompt_len": input_lens[i],
"output_tokens": output_lens[i],
}
)
timeline_path = Path(file_name).with_suffix(".timeline.html")
# Convert thresholds from milliseconds to seconds
itl_thresholds_sec = [t / 1000.0 for t in args.timeline_itl_thresholds]
generate_timeline_plot(
per_request_data, timeline_path, itl_thresholds=itl_thresholds_sec
)
else:
warnings.warn(
"Timeline plot requires detailed metrics. "
"Ensure the benchmark completed successfully.",
stacklevel=2,
)
except Exception as e:
warnings.warn(f"Failed to generate timeline plot: {e}", stacklevel=2)
# Generate dataset statistics plot if requested
if args.plot_dataset_stats:
try:
from vllm.benchmarks.plot import generate_dataset_stats_plot
# Prepare per-request data for dataset stats
per_request_data = []
input_lens = benchmark_result.get("input_lens", [])
output_lens = benchmark_result.get("output_lens", [])
if input_lens and output_lens:
for req_input_len, req_output_len in zip(input_lens, output_lens):
per_request_data.append(
{
"prompt_len": req_input_len,
"output_tokens": req_output_len,
}
)
stats_path = Path(file_name).with_suffix(".dataset_stats.png")
generate_dataset_stats_plot(per_request_data, stats_path)
else:
warnings.warn(
"Dataset statistics plot requires input and "
"output length data. Ensure the benchmark completed "
"successfully.",
stacklevel=2,
)
except Exception as e:
warnings.warn(
f"Failed to generate dataset statistics plot: {e}", stacklevel=2
)
if not args.save_detailed:
# Remove fields with too many data points
for field in [
@@ -1786,24 +1935,8 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
if field in benchmark_result:
del benchmark_result[field]
# Save to file
# Save to file
if args.save_result or args.append_result:
base_model_id = model_id.split("/")[-1]
max_concurrency_str = (
f"-concurrency{args.max_concurrency}"
if args.max_concurrency is not None
else ""
)
label = label or args.backend
if args.ramp_up_strategy is not None:
file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
else:
file_name = f"{label}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
if args.result_filename:
file_name = args.result_filename
if args.result_dir:
os.makedirs(args.result_dir, exist_ok=True)
file_name = os.path.join(args.result_dir, file_name)
with open(
file_name, mode="a+" if args.append_result else "w", encoding="utf-8"
) as outfile: