Files
enginex-mlu370-vllm/vllm-v0.6.2/benchmarks/benchmark_serving_concurrency.py
2026-02-04 17:22:39 +08:00

709 lines
27 KiB
Python

"""Benchmark online serving throughput.
On the server side, run one of the following commands:
vLLM OpenAI API server
vllm serve <your_model> \
--swap-space 16 \
--disable-log-requests
(TGI backend)
./launch_tgi_server.sh <your_model> <max_batch_total_tokens>
On the client side, run:
python benchmarks/benchmark_serving.py \
--backend <backend> \
--model <your_model> \
--dataset-name sharegpt \
--dataset-path <path to dataset> \
--request-rate <request_rate> \ # By default <request_rate> is inf
--num-prompts <num_prompts> # By default <num_prompts> is 1000
when using tgi backend, add
--endpoint /generate_stream
to the end of the command above.
"""
import argparse
import asyncio
import json
import os
import random
import time
import warnings
from dataclasses import dataclass
from datetime import datetime
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
import numpy as np
from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
RequestFuncOutput)
from tqdm.asyncio import tqdm
from transformers import PreTrainedTokenizerBase
try:
from vllm.transformers_utils.tokenizer import get_tokenizer
except ImportError:
from backend_request_func import get_tokenizer
try:
from vllm.utils import FlexibleArgumentParser
except ImportError:
from argparse import ArgumentParser as FlexibleArgumentParser
from concurrent_executor import (ConcurrentExecutor, MluRequestFuncOutput)
from benchmark_serving import (BenchmarkMetrics,
sample_sharegpt_requests,
sample_random_requests,
sample_sonnet_requests)
@dataclass
class MluBenchmarkMetrics(BenchmarkMetrics):
# time_in_queue: first_scheduled_time - arrival_time
mean_time_in_queue_ms: float
std_time_in_queue_ms: float
median_time_in_queue_ms: float
percentiles_time_in_queue_ms: List[Tuple[float, float]]
# time_schedule: sum(all schedule step times)
mean_time_schedule_ms: float
std_time_schedule_ms: float
median_time_schedule_ms: float
percentiles_time_schedule_ms: List[Tuple[float, float]]
# ttft: first_token_time - arrival_time
mean_time_ttft_ms: float
std_time_ttft_ms: float
median_time_ttft_ms: float
percentiles_time_ttft_ms: List[Tuple[float, float]]
# e2e: finished_time - arrival_time
mean_time_e2e_ms: float
std_time_e2e_ms: float
median_time_e2e_ms: float
percentiles_time_e2e_ms: List[Tuple[float, float]]
# tpot: (finished_time - first_token_time) / (output_len - 1)
mean_time_tpot_ms: float
std_time_tpot_ms: float
median_time_tpot_ms: float
percentiles_time_tpot_ms: List[Tuple[float, float]]
prompt_tokens: int # server received total tokens
completion_tokens: int # all generated tokens in server
server_output_throughput: float # server output throughput
server_total_token_throughput: float # server total throughput
def calculate_metrics(
input_requests: List[Tuple[str, int, int]],
outputs: List[RequestFuncOutput],
dur_s: float,
tokenizer: PreTrainedTokenizerBase,
selected_percentile_metrics: List[str],
selected_percentiles: List[float],
) -> Tuple[BenchmarkMetrics, List[int]]:
actual_output_lens: List[int] = []
total_input = 0
completed = 0
itls: List[float] = []
tpots: List[float] = []
ttfts: List[float] = []
e2els: List[float] = []
time_in_queues: List[float] = []
time_schedules: List[float] = []
time_ttfts: List[float] = []
time_e2es: List[float] = []
time_tpots: List[float] = []
prompt_tokens: List[int] = []
completion_tokens: List[int] = []
for i in range(len(outputs)):
if outputs[i].success:
# We use the tokenizer to count the number of output tokens for all
# serving backends instead of looking at len(outputs[i].itl) since
# multiple output tokens may be bundled together
# Note : this may inflate the output token count slightly
output_len = len(
tokenizer(outputs[i].generated_text,
add_special_tokens=False).input_ids)
actual_output_lens.append(output_len)
total_input += input_requests[i][1]
if output_len > 1:
tpots.append(
(outputs[i].latency - outputs[i].ttft) / (output_len - 1))
itls += outputs[i].itl
ttfts.append(outputs[i].ttft)
e2els.append(outputs[i].latency)
completed += 1
# Collect metric from server
time_in_queues.append(outputs[i].metric["time_in_queue"])
time_schedules.append(outputs[i].metric["scheduler_time"])
time_ttfts.append(outputs[i].metric["first_token_time"] - outputs[i].metric["arrival_time"])
time_e2es.append(outputs[i].metric["finished_time"] - outputs[i].metric["arrival_time"])
if outputs[i].usage["completion_tokens"] > 1:
time_tpots.append(
(outputs[i].metric["finished_time"] - outputs[i].metric["first_token_time"]) /
(outputs[i].usage["completion_tokens"] - 1))
prompt_tokens.append(outputs[i].usage["prompt_tokens"])
completion_tokens.append(outputs[i].usage["completion_tokens"])
else:
actual_output_lens.append(0)
if completed == 0:
warnings.warn(
"All requests failed. This is likely due to a misconfiguration "
"on the benchmark arguments.",
stacklevel=2)
metrics = MluBenchmarkMetrics(
completed=completed,
total_input=total_input,
total_output=sum(actual_output_lens),
request_throughput=completed / dur_s,
output_throughput=sum(actual_output_lens) / dur_s,
total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s,
mean_ttft_ms=np.mean(ttfts or 0) *
1000, # ttfts is empty if streaming is not supported by backend
std_ttft_ms=np.std(ttfts or 0) * 1000,
median_ttft_ms=np.median(ttfts or 0) * 1000,
percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000)
for p in selected_percentiles],
mean_tpot_ms=np.mean(tpots or 0) * 1000,
std_tpot_ms=np.std(tpots or 0) * 1000,
median_tpot_ms=np.median(tpots or 0) * 1000,
percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000)
for p in selected_percentiles],
mean_itl_ms=np.mean(itls or 0) * 1000,
std_itl_ms=np.std(itls or 0) * 1000,
median_itl_ms=np.median(itls or 0) * 1000,
percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000)
for p in selected_percentiles],
mean_e2el_ms=np.median(e2els or 0) * 1000,
std_e2el_ms=np.std(e2els or 0) * 1000,
median_e2el_ms=np.mean(e2els or 0) * 1000,
percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000)
for p in selected_percentiles],
mean_time_in_queue_ms=np.mean(time_in_queues or 0) * 1000,
std_time_in_queue_ms=np.std(time_in_queues or 0) * 1000,
median_time_in_queue_ms=np.median(time_in_queues or 0) * 1000,
percentiles_time_in_queue_ms=[(p, np.percentile(time_in_queues or 0, p) * 1000)
for p in selected_percentiles],
mean_time_schedule_ms=np.mean(time_schedules or 0) * 1000,
std_time_schedule_ms=np.std(time_schedules or 0) * 1000,
median_time_schedule_ms=np.median(time_schedules or 0) * 1000,
percentiles_time_schedule_ms=[(p, np.percentile(time_schedules or 0, p) * 1000)
for p in selected_percentiles],
mean_time_ttft_ms=np.mean(time_ttfts or 0) * 1000,
std_time_ttft_ms=np.std(time_ttfts or 0) * 1000,
median_time_ttft_ms=np.median(time_ttfts or 0) * 1000,
percentiles_time_ttft_ms=[(p, np.percentile(time_ttfts or 0, p) * 1000)
for p in selected_percentiles],
mean_time_e2e_ms=np.mean(time_e2es or 0) * 1000,
std_time_e2e_ms=np.std(time_e2es or 0) * 1000,
median_time_e2e_ms=np.median(time_e2es or 0) * 1000,
percentiles_time_e2e_ms=[(p, np.percentile(time_e2es or 0, p) * 1000)
for p in selected_percentiles],
mean_time_tpot_ms=np.mean(time_tpots or 0) * 1000,
std_time_tpot_ms=np.std(time_tpots or 0) * 1000,
median_time_tpot_ms=np.median(time_tpots or 0) * 1000,
percentiles_time_tpot_ms=[(p, np.percentile(time_tpots or 0, p) * 1000)
for p in selected_percentiles],
prompt_tokens=sum(prompt_tokens),
completion_tokens=sum(completion_tokens),
server_output_throughput=sum(completion_tokens) / dur_s,
server_total_token_throughput=(sum(prompt_tokens) + sum(completion_tokens)) / dur_s,
)
return metrics, actual_output_lens
async def benchmark(
backend: str,
api_url: str,
model_id: str,
tokenizer: PreTrainedTokenizerBase,
input_requests: List[Tuple[str, int, int]],
logprobs: Optional[int],
best_of: int,
use_beam_search: bool,
disable_tqdm: bool,
selected_percentile_metrics: List[str],
selected_percentiles: List[str],
concurrency_num: int,
ignore_eos: bool,
):
assert backend == "vllm", f"Only support vllm backend at concurrent mode."
assert concurrency_num >= 1, f"The concurrency_num must greater than 0, but got {concurrency_num}."
pbar = None if disable_tqdm else tqdm(total=len(input_requests), desc="Infer")
# Run serving with concurrent mode,
# use 'concurrency' to control reqeust num
executor = ConcurrentExecutor(concurrency_num=concurrency_num,
input_requests=input_requests)
# Config pyload
executor.config_pyload(model=model_id,
api_url=api_url,
logprobs=logprobs,
best_of=best_of,
use_beam_search=use_beam_search,
include_usage=True,
ignore_eos=ignore_eos)
benchmark_start_time = time.perf_counter()
# Execute with concurrent mode
outputs: List[MluRequestFuncOutput] = executor.run(pbar=pbar)
if pbar is not None:
pbar.close()
benchmark_duration = time.perf_counter() - benchmark_start_time
metrics, actual_output_lens = calculate_metrics(
input_requests=input_requests,
outputs=outputs,
dur_s=benchmark_duration,
tokenizer=tokenizer,
selected_percentile_metrics=selected_percentile_metrics,
selected_percentiles=selected_percentiles,
)
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
print("{s:{c}^{n}}".format(s=' Client Metrics ', n=50, c='#'))
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
print("{:<40} {:<10.2f}".format("Benchmark duration (s):",
benchmark_duration))
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
print("{:<40} {:<10}".format("Total generated tokens:",
metrics.total_output))
print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
metrics.request_throughput))
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
metrics.output_throughput))
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
metrics.total_token_throughput))
result = {
"duration": benchmark_duration,
"completed": metrics.completed,
"total_input_tokens": metrics.total_input,
"total_output_tokens": metrics.total_output,
"request_throughput": metrics.request_throughput,
"output_throughput": metrics.output_throughput,
"total_token_throughput": metrics.total_token_throughput,
"input_lens": [output.prompt_len for output in outputs],
"output_lens": actual_output_lens,
"ttfts": [output.ttft for output in outputs],
"itls": [output.itl for output in outputs],
"generated_texts": [output.generated_text for output in outputs],
"errors": [output.error for output in outputs],
}
def process_one_metric(
# E.g., "ttft"
metric_attribute_name: str,
# E.g., "TTFT"
metric_name: str,
# E.g., "Time to First Token"
metric_header: str,
):
# This function print and add statistics of the specified
# metric.
if metric_attribute_name not in selected_percentile_metrics:
return
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
print("{:<40} {:<10.2f}".format(
f"Mean {metric_name} (ms):",
getattr(metrics, f"mean_{metric_attribute_name}_ms")))
print("{:<40} {:<10.2f}".format(
f"Median {metric_name} (ms):",
getattr(metrics, f"median_{metric_attribute_name}_ms")))
result[f"mean_{metric_attribute_name}_ms"] = getattr(
metrics, f"mean_{metric_attribute_name}_ms")
result[f"median_{metric_attribute_name}_ms"] = getattr(
metrics, f"median_{metric_attribute_name}_ms")
result[f"std_{metric_attribute_name}_ms"] = getattr(
metrics, f"std_{metric_attribute_name}_ms")
for p, value in getattr(metrics,
f"percentiles_{metric_attribute_name}_ms"):
p_word = str(int(p)) if int(p) == p else str(p)
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):",
value))
result[f"p{p_word}_{metric_attribute_name}_ms"] = value
process_one_metric("ttft", "TTFT", "Time to First Token")
process_one_metric("tpot", "TPOT",
"Time per Output Token (excl. 1st token)")
process_one_metric("itl", "ITL", "Inter-token Latency")
process_one_metric("e2el", "E2EL", "End-to-end Latency")
print("{s:{c}^{n}}".format(s=' Server Metrics ', n=50, c='#'))
print("{:<40} {:<10}".format("Total input tokens:",
metrics.prompt_tokens))
print("{:<40} {:<10}".format("Total generated tokens:",
metrics.completion_tokens))
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
metrics.server_output_throughput))
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
metrics.server_total_token_throughput))
process_one_metric("time_in_queue", "IQL", "In-Queue Latency")
process_one_metric("time_schedule", "SL", "Schedule Latency")
process_one_metric("time_ttft", "STTFT", "Time to First Token")
process_one_metric("time_tpot", "STPOT", "Time per Output Token")
process_one_metric("time_e2e", "SE2EL", "End-to-end Latency")
print("=" * 50)
return result
def main(args: argparse.Namespace):
print(args)
random.seed(args.seed)
np.random.seed(args.seed)
backend = args.backend
model_id = args.model
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
if args.base_url is not None:
api_url = f"{args.base_url}{args.endpoint}"
base_url = f"{args.base_url}"
else:
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
base_url = f"http://{args.host}:{args.port}"
tokenizer = get_tokenizer(tokenizer_id,
trust_remote_code=args.trust_remote_code)
if args.dataset is not None:
warnings.warn(
"The '--dataset' argument will be deprecated in the next "
"release. Please use '--dataset-name' and "
"'--dataset-path' in the future runs.",
stacklevel=2)
input_requests = sample_sharegpt_requests(
dataset_path=args.dataset,
num_requests=args.num_prompts,
tokenizer=tokenizer,
fixed_output_len=args.sharegpt_output_len,
)
elif args.dataset_name == "sharegpt":
input_requests = sample_sharegpt_requests(
dataset_path=args.dataset_path,
num_requests=args.num_prompts,
tokenizer=tokenizer,
fixed_output_len=args.sharegpt_output_len,
)
elif args.dataset_name == "sonnet":
# Do not format the prompt, pass to message directly
if args.backend == "openai-chat":
input_requests = sample_sonnet_requests(
dataset_path=args.dataset_path,
num_requests=args.num_prompts,
input_len=args.sonnet_input_len,
output_len=args.sonnet_output_len,
prefix_len=args.sonnet_prefix_len,
tokenizer=tokenizer,
)
input_requests = [(prompt, prompt_len, output_len)
for prompt, prompt_formatted, prompt_len,
output_len in input_requests]
else:
assert (
tokenizer.chat_template or tokenizer.default_chat_template
), "Tokenizer/model must have chat template for sonnet dataset."
input_requests = sample_sonnet_requests(
dataset_path=args.dataset_path,
num_requests=args.num_prompts,
input_len=args.sonnet_input_len,
output_len=args.sonnet_output_len,
prefix_len=args.sonnet_prefix_len,
tokenizer=tokenizer,
)
input_requests = [(prompt_formatted, prompt_len, output_len)
for prompt, prompt_formatted, prompt_len,
output_len in input_requests]
elif args.dataset_name == "random":
input_requests = sample_random_requests(
prefix_len=args.random_prefix_len,
input_len=args.random_input_len,
output_len=args.random_output_len,
num_prompts=args.num_prompts,
range_ratio=args.random_range_ratio,
tokenizer=tokenizer,
)
else:
raise ValueError(f"Unknown dataset: {args.dataset_name}")
benchmark_result = asyncio.run(
benchmark(
backend=backend,
api_url=api_url,
model_id=model_id,
tokenizer=tokenizer,
input_requests=input_requests,
logprobs=args.logprobs,
best_of=args.best_of,
use_beam_search=args.use_beam_search,
disable_tqdm=args.disable_tqdm,
selected_percentile_metrics=args.percentile_metrics.split(","),
selected_percentiles=[
float(p) for p in args.metric_percentiles.split(",")
],
concurrency_num=args.concurrency_num,
ignore_eos=args.ignore_eos,
))
# Save config and results to json
if args.save_result:
result_json: Dict[str, Any] = {}
# Setup
current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
result_json["date"] = current_dt
result_json["backend"] = backend
result_json["model_id"] = model_id
result_json["tokenizer_id"] = tokenizer_id
result_json["best_of"] = args.best_of
result_json["use_beam_search"] = args.use_beam_search
result_json["num_prompts"] = args.num_prompts
# Metadata
if args.metadata:
for item in args.metadata:
if "=" in item:
kvstring = item.split("=")
result_json[kvstring[0].strip()] = kvstring[1].strip()
else:
raise ValueError(
"Invalid metadata format. Please use KEY=VALUE format."
)
# Traffic
result_json["request_rate"] = (
args.request_rate if args.request_rate < float("inf") else "inf")
# Merge with benchmark result
result_json = {**result_json, **benchmark_result}
# Save to file
base_model_id = model_id.split("/")[-1]
file_name = f"{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" #noqa
if args.result_filename:
file_name = args.result_filename
if args.result_dir:
file_name = os.path.join(args.result_dir, file_name)
with open(file_name, "w") as outfile:
json.dump(result_json, outfile)
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description="Benchmark the online serving throughput.")
parser.add_argument(
"--backend",
type=str,
default="vllm",
choices=list(ASYNC_REQUEST_FUNCS.keys()),
)
parser.add_argument(
"--base-url",
type=str,
default=None,
help="Server or API base url if not using http host and port.",
)
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument(
"--endpoint",
type=str,
default="/v1/completions",
help="API endpoint.",
)
parser.add_argument(
"--dataset",
type=str,
default=None,
help="Path to the ShareGPT dataset, will be deprecated in the "
"next release.",
)
parser.add_argument(
"--dataset-name",
type=str,
default="sharegpt",
choices=["sharegpt", "sonnet", "random"],
help="Name of the dataset to benchmark on.",
)
parser.add_argument("--dataset-path",
type=str,
default=None,
help="Path to the dataset.")
parser.add_argument(
"--model",
type=str,
required=True,
help="Name of the model.",
)
parser.add_argument(
"--tokenizer",
type=str,
help=
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
)
parser.add_argument(
"--best-of",
type=int,
default=1,
help="Generates `best_of` sequences per prompt and "
"returns the best one.",
)
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument(
"--num-prompts",
type=int,
default=1000,
help="Number of prompts to process.",
)
parser.add_argument(
"--sharegpt-output-len",
type=int,
default=None,
help="Output length for each request. Overrides the output length "
"from the ShareGPT dataset.")
parser.add_argument(
"--sonnet-input-len",
type=int,
default=550,
help=
"Number of input tokens per request, used only for sonnet dataset.",
)
parser.add_argument(
"--sonnet-output-len",
type=int,
default=150,
help=
"Number of output tokens per request, used only for sonnet dataset.",
)
parser.add_argument(
"--logprobs",
type=int,
default=None,
help=("Number of logprobs-per-token to compute & return as part of "
"the request. If unspecified, then either (1) if beam search "
"is disabled, no logprobs are computed & a single dummy "
"logprob is returned for each token; or (2) if beam search "
"is enabled 1 logprob per token is computed"),
)
parser.add_argument(
"--sonnet-prefix-len",
type=int,
default=200,
help=
"Number of prefix tokens per request, used only for sonnet dataset.",
)
parser.add_argument(
"--random-input-len",
type=int,
default=1024,
help=
"Number of input tokens per request, used only for random sampling.",
)
parser.add_argument(
"--random-output-len",
type=int,
default=128,
help=
"Number of output tokens per request, used only for random sampling.",
)
parser.add_argument(
"--random-range-ratio",
type=float,
default=1.0,
help="Range of sampled ratio of input/output length, "
"used only for random sampling.",
)
parser.add_argument(
"--random-prefix-len",
type=int,
default=0,
help="Number of fixed prefix tokens before random "
" context. The length range of context in a random "
" request is [random-prefix-len, "
" random-prefix-len + random-prefix-len * random-range-ratio).")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument(
"--trust-remote-code",
action="store_true",
help="Trust remote code from huggingface",
)
parser.add_argument(
"--disable-tqdm",
action="store_true",
help="Specify to disable tqdm progress bar.",
)
parser.add_argument(
"--save-result",
action="store_true",
help="Specify to save benchmark results to a json file",
)
parser.add_argument(
"--metadata",
metavar="KEY=VALUE",
nargs="*",
help="Key-value pairs (e.g, --metadata version=0.3.3 tp=1) "
"for metadata of this run to be saved in the result JSON file "
"for record keeping purposes.",
)
parser.add_argument(
"--result-dir",
type=str,
default=None,
help="Specify directory to save benchmark json results."
"If not specified, results are saved in the current directory.",
)
parser.add_argument(
"--result-filename",
type=str,
default=None,
help="Specify the filename to save benchmark json results."
"If not specified, results will be saved in "
"{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json"
" format.",
)
parser.add_argument(
"--percentile-metrics",
type=str,
default="ttft,tpot,itl,e2el,time_in_queue,time_schedule,time_ttft,time_e2e,time_tpot",
help="Comma-seperated list of selected metrics to report percentils. "
"This argument specifies the metrics to report percentiles. "
"Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". "
"Default value is \"ttft,tpot,itl,e2el\".")
parser.add_argument(
"--metric-percentiles",
type=str,
default="99",
help="Comma-seperated list of percentiles for selected metrics. "
"To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". "
"Default value is \"99\". "
"Use \"--percentile-metrics\" to select metrics.",
)
parser.add_argument(
"--concurrency-num",
type=int,
default=1,
help="Number of concurrency in client. If this is 1, "
"then 'request_rate' with be enable. "
"Otherwise, we run serving test with concurrent mode.",
)
parser.add_argument("--ignore-eos",
action='store_true',
help='If true, vllm server with decode until reach max_output_len.')
args = parser.parse_args()
main(args)