This commit is contained in:
root
2026-04-09 11:19:36 +08:00
parent 809cecae09
commit 8082d5f4b2
2579 changed files with 3675 additions and 0 deletions

View File

Binary file not shown.

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,172 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Benchmark the latency of processing a single batch of requests."""
import argparse
import dataclasses
import json
import os
import time
from typing import Any
import numpy as np
from tqdm import tqdm
import vllm.envs as envs
from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptType
from vllm.sampling_params import BeamSearchParams
def save_to_pytorch_benchmark_format(
args: argparse.Namespace, results: dict[str, Any]
) -> None:
pt_records = convert_to_pytorch_benchmark_format(
args=args,
metrics={"latency": results["latencies"]},
extra_info={k: results[k] for k in ["avg_latency", "percentiles"]},
)
if pt_records:
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
write_to_json(pt_file, pt_records)
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument("--input-len", type=int, default=32)
parser.add_argument("--output-len", type=int, default=128)
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument(
"--n",
type=int,
default=1,
help="Number of generated sequences per prompt.",
)
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument(
"--num-iters-warmup",
type=int,
default=10,
help="Number of iterations to run for warmup.",
)
parser.add_argument(
"--num-iters", type=int, default=30, help="Number of iterations to run."
)
parser.add_argument(
"--profile",
action="store_true",
help="profile the generation process of a single batch",
)
parser.add_argument(
"--output-json",
type=str,
default=None,
help="Path to save the latency results in JSON format.",
)
parser.add_argument(
"--disable-detokenize",
action="store_true",
help=(
"Do not detokenize responses (i.e. do not include "
"detokenization time in the latency measurement)"
),
)
parser = EngineArgs.add_cli_args(parser)
# V1 enables prefix caching by default which skews the latency
# numbers. We need to disable prefix caching by default.
parser.set_defaults(enable_prefix_caching=False)
def main(args: argparse.Namespace):
if args.profile and not envs.VLLM_TORCH_PROFILER_DIR:
raise OSError(
"The environment variable 'VLLM_TORCH_PROFILER_DIR' is not set. "
"Please set it to a valid path to use torch profiler."
)
engine_args = EngineArgs.from_cli_args(args)
# Lazy import to avoid importing LLM when the bench command is not selected.
from vllm import LLM, SamplingParams
# NOTE(woosuk): If the request cannot be processed in a single batch,
# the engine will automatically process the request in multiple batches.
llm = LLM(**dataclasses.asdict(engine_args))
assert llm.llm_engine.model_config.max_model_len >= (
args.input_len + args.output_len
), (
"Please ensure that max_model_len is greater than"
" the sum of input_len and output_len."
)
sampling_params = SamplingParams(
n=args.n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=args.output_len,
detokenize=not args.disable_detokenize,
)
dummy_prompt_token_ids = np.random.randint(
10000, size=(args.batch_size, args.input_len)
)
dummy_prompts: list[PromptType] = [
{"prompt_token_ids": batch} for batch in dummy_prompt_token_ids.tolist()
]
def llm_generate():
if not args.use_beam_search:
llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False)
else:
llm.beam_search(
dummy_prompts,
BeamSearchParams(
beam_width=args.n,
max_tokens=args.output_len,
ignore_eos=True,
),
)
def run_to_completion(profile_dir: str | None = None):
if profile_dir:
llm.start_profile()
llm_generate()
llm.stop_profile()
else:
start_time = time.perf_counter()
llm_generate()
end_time = time.perf_counter()
latency = end_time - start_time
return latency
print("Warming up...")
for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
run_to_completion(profile_dir=None)
if args.profile:
profile_dir = envs.VLLM_TORCH_PROFILER_DIR
print(f"Profiling (results will be saved to '{profile_dir}')...")
run_to_completion(profile_dir=profile_dir)
return
# Benchmark.
latencies = []
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
latencies.append(run_to_completion(profile_dir=None))
latencies = np.array(latencies)
percentages = [10, 25, 50, 75, 90, 99]
percentiles = np.percentile(latencies, percentages)
print(f"Avg latency: {np.mean(latencies)} seconds")
for percentage, percentile in zip(percentages, percentiles):
print(f"{percentage}% percentile latency: {percentile} seconds")
# Output JSON results if specified
if args.output_json:
results = {
"avg_latency": np.mean(latencies),
"latencies": latencies.tolist(),
"percentiles": dict(zip(percentages, percentiles.tolist())),
}
with open(args.output_json, "w") as f:
json.dump(results, f, indent=4)
save_to_pytorch_benchmark_format(args, results)

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Benchmark library utilities."""

View File

@@ -0,0 +1,777 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""The request function for API endpoints."""
import io
import json
import os
import sys
import time
import traceback
from collections.abc import Awaitable
from dataclasses import dataclass, field
from typing import Any, Literal, Protocol
import aiohttp
import regex as re
from tqdm.asyncio import tqdm
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
class StreamedResponseHandler:
"""Handles streaming HTTP responses by accumulating chunks until complete
messages are available."""
def __init__(self):
self.buffer = ""
def add_chunk(self, chunk_bytes: bytes) -> list[str]:
"""Add a chunk of bytes to the buffer and return any complete
messages."""
chunk_str = chunk_bytes.decode("utf-8")
self.buffer += chunk_str
messages = []
# Split by double newlines (SSE message separator)
while "\n\n" in self.buffer:
message, self.buffer = self.buffer.split("\n\n", 1)
message = message.strip()
if message:
messages.append(message)
# if self.buffer is not empty, check if it is a complete message
# by removing data: prefix and check if it is a valid JSON
if self.buffer.startswith("data: "):
message_content = self.buffer.removeprefix("data: ").strip()
if message_content == "[DONE]":
messages.append(self.buffer.strip())
self.buffer = ""
elif message_content:
try:
json.loads(message_content)
messages.append(self.buffer.strip())
self.buffer = ""
except json.JSONDecodeError:
# Incomplete JSON, wait for more chunks.
pass
return messages
@dataclass
class RequestFuncInput:
"""The input for the request function."""
prompt: str | list[str]
api_url: str
prompt_len: int
output_len: int
model: str
model_name: str | None = None
logprobs: int | None = None
extra_headers: dict | None = None
extra_body: dict | None = None
multi_modal_content: dict | list[dict] | None = None
ignore_eos: bool = False
language: str | None = None
request_id: str | None = None
@dataclass
class RequestFuncOutput:
"""The output of the request function including metrics."""
generated_text: str = ""
success: bool = False
latency: float = 0.0
output_tokens: int = 0
ttft: float = 0.0 # Time to first token
itl: list[float] = field(default_factory=list) # list of inter-token latencies
tpot: float = 0.0 # avg next-token latencies
prompt_len: int = 0
error: str = ""
start_time: float = 0.0
class RequestFunc(Protocol):
def __call__(
self,
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: tqdm | None = None,
) -> Awaitable[RequestFuncOutput]: ...
def _validate_api_url(
api_url: str,
api_name: str,
expected_suffixes: str | set[str],
) -> None:
if isinstance(expected_suffixes, str):
expected_suffixes = {expected_suffixes}
expected_suffixes = {*expected_suffixes, "profile"}
if not api_url.endswith(tuple(expected_suffixes)):
raise ValueError(f"{api_name} URL must end with one of: {expected_suffixes}.")
def _update_payload_common(
payload: dict[str, Any],
request_func_input: RequestFuncInput,
) -> None:
if request_func_input.ignore_eos:
payload["ignore_eos"] = request_func_input.ignore_eos
if request_func_input.extra_body:
payload.update(request_func_input.extra_body)
def _update_headers_common(
headers: dict[str, Any],
request_func_input: RequestFuncInput,
) -> None:
if request_func_input.extra_headers:
headers |= request_func_input.extra_headers
if request_func_input.request_id:
headers["x-request-id"] = request_func_input.request_id
async def async_request_openai_completions(
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: tqdm | None = None,
) -> RequestFuncOutput:
"""The async request function for the OpenAI Completions API.
Args:
request_func_input: The input for the request function.
pbar: The progress bar to display the progress.
Returns:
The output of the request function.
"""
api_url = request_func_input.api_url
_validate_api_url(api_url, "OpenAI Completions API", "completions")
payload = {
"model": request_func_input.model_name
if request_func_input.model_name
else request_func_input.model,
"prompt": request_func_input.prompt,
"temperature": 0.0,
"repetition_penalty": 1.0,
"max_tokens": request_func_input.output_len,
"logprobs": request_func_input.logprobs,
"stream": True,
"stream_options": {
"include_usage": True,
},
}
_update_payload_common(payload, request_func_input)
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
}
_update_headers_common(headers, request_func_input)
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
generated_text = ""
st = time.perf_counter()
output.start_time = st
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload, headers=headers) as response:
if response.status == 200:
first_chunk_received = False
handler = StreamedResponseHandler()
async for chunk_bytes in response.content.iter_any():
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
messages = handler.add_chunk(chunk_bytes)
for message in messages:
# NOTE: SSE comments (often used as pings) start with
# a colon. These are not JSON data payload and should
# be skipped.
if message.startswith(":"):
continue
chunk = message.removeprefix("data: ")
if chunk != "[DONE]":
data = json.loads(chunk)
# NOTE: Some completion API might have a last
# usage summary response without a token so we
# want to check a token was generated
if choices := data.get("choices"):
# Note that text could be empty here
# e.g. for special tokens
text = choices[0].get("text")
timestamp = time.perf_counter()
# First token
if not first_chunk_received:
first_chunk_received = True
ttft = time.perf_counter() - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp = timestamp
generated_text += text or ""
elif usage := data.get("usage"):
output.output_tokens = usage.get("completion_tokens")
if first_chunk_received:
output.success = True
else:
output.success = False
output.error = (
"Never received a valid chunk to calculate TTFT."
"This response will be marked as failed!"
)
output.generated_text = generated_text
output.latency = most_recent_timestamp - st
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar:
pbar.update(1)
return output
def _get_chat_content(
request_func_input: RequestFuncInput,
mm_position: Literal["first", "last"] = "last",
) -> list[dict[str, Any]]:
text_contents = [{"type": "text", "text": request_func_input.prompt}]
mm_contents = []
if request_func_input.multi_modal_content:
mm_content = request_func_input.multi_modal_content
if isinstance(mm_content, list):
mm_contents.extend(request_func_input.multi_modal_content)
elif isinstance(mm_content, dict):
mm_contents.append(request_func_input.multi_modal_content)
else:
raise TypeError(
"multi_modal_content must be a dict or list[dict] for openai-chat"
)
if mm_position == "first":
return mm_contents + text_contents
return text_contents + mm_contents
async def async_request_openai_chat_completions(
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: tqdm | None = None,
mm_position: Literal["first", "last"] = "last",
) -> RequestFuncOutput:
api_url = request_func_input.api_url
_validate_api_url(api_url, "OpenAI Chat Completions API", "chat/completions")
content = _get_chat_content(request_func_input, mm_position=mm_position)
payload = {
"model": request_func_input.model_name
if request_func_input.model_name
else request_func_input.model,
"messages": [
{"role": "user", "content": content},
],
"temperature": 0.0,
"max_completion_tokens": request_func_input.output_len,
"stream": True,
"stream_options": {
"include_usage": True,
},
}
_update_payload_common(payload, request_func_input)
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
}
_update_headers_common(headers, request_func_input)
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
generated_text = ""
ttft = 0.0
st = time.perf_counter()
output.start_time = st
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload, headers=headers) as response:
if response.status == 200:
handler = StreamedResponseHandler()
async for chunk_bytes in response.content.iter_any():
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
messages = handler.add_chunk(chunk_bytes)
for message in messages:
# NOTE: SSE comments (often used as pings) start with
# a colon. These are not JSON data payload and should
# be skipped.
if message.startswith(":"):
continue
chunk = message.removeprefix("data: ")
if chunk != "[DONE]":
timestamp = time.perf_counter()
data = json.loads(chunk)
if choices := data.get("choices"):
content = choices[0]["delta"].get("content")
# First token
if ttft == 0.0:
ttft = timestamp - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(timestamp - most_recent_timestamp)
generated_text += content or ""
elif usage := data.get("usage"):
output.output_tokens = usage.get("completion_tokens")
most_recent_timestamp = timestamp
output.generated_text = generated_text
output.success = True
output.latency = most_recent_timestamp - st
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar:
pbar.update(1)
return output
async def async_request_openai_audio(
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: tqdm | None = None,
) -> RequestFuncOutput:
# Lazy import without PlaceholderModule to avoid vllm dep.
import soundfile
api_url = request_func_input.api_url
_validate_api_url(api_url, "OpenAI Audio API", {"transcriptions", "translations"})
content = [{"type": "text", "text": request_func_input.prompt}]
payload = {
"model": request_func_input.model_name
if request_func_input.model_name
else request_func_input.model,
"temperature": 0.0,
"max_completion_tokens": request_func_input.output_len,
"stream": True,
"language": "en",
# Flattened due to multipart/form-data
"stream_include_usage": True,
"stream_continuous_usage_stats": True,
}
_update_payload_common(payload, request_func_input)
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
}
_update_headers_common(headers, request_func_input)
# Send audio file
def to_bytes(y, sr):
buffer = io.BytesIO()
soundfile.write(buffer, y, sr, format="WAV")
buffer.seek(0)
return buffer
mm_audio = request_func_input.multi_modal_content
if not isinstance(mm_audio, dict) or "audio" not in mm_audio:
raise TypeError("multi_modal_content must be a dict containing 'audio'")
with to_bytes(*mm_audio["audio"]) as f:
form = aiohttp.FormData()
form.add_field("file", f, content_type="audio/wav")
for key, value in payload.items():
form.add_field(key, str(value))
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
generated_text = ""
ttft = 0.0
st = time.perf_counter()
output.start_time = st
most_recent_timestamp = st
try:
async with session.post(
url=api_url, data=form, headers=headers
) as response:
if response.status == 200:
handler = StreamedResponseHandler()
async for chunk_bytes in response.content.iter_any():
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
messages = handler.add_chunk(chunk_bytes)
for message in messages:
chunk = message.decode("utf-8").removeprefix("data: ")
if chunk != "[DONE]":
timestamp = time.perf_counter()
data = json.loads(chunk)
if choices := data.get("choices"):
content = choices[0]["delta"].get("content")
# First token
if ttft == 0.0:
ttft = timestamp - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(
timestamp - most_recent_timestamp
)
generated_text += content or ""
elif usage := data.get("usage"):
output.output_tokens = usage.get(
"completion_tokens"
)
most_recent_timestamp = timestamp
output.generated_text = generated_text
output.success = True
output.latency = most_recent_timestamp - st
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar:
pbar.update(1)
return output
async def _run_pooling_request(
session: aiohttp.ClientSession,
api_url: str,
payload: dict[str, Any],
headers: dict[str, Any],
pbar: tqdm | None = None,
) -> RequestFuncOutput:
output = RequestFuncOutput()
st = time.perf_counter()
output.start_time = st
try:
async with session.post(url=api_url, headers=headers, json=payload) as response:
if response.status == 200:
output.ttft = output.latency = time.perf_counter() - st
if payload.get("encoding_format", "float") == "bytes":
metadata = json.loads(response.headers["metadata"])
usage = metadata.get("usage", {})
else:
data = await response.json()
usage = data.get("usage", {})
output.success = True
output.generated_text = ""
output.prompt_len = usage.get("prompt_tokens", 0)
else:
output.success = False
output.error = response.reason or ""
except Exception as e:
output.success = False
output.error = str(e)
if pbar:
pbar.update(1)
return output
async def async_request_openai_embeddings(
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: tqdm | None = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
_validate_api_url(api_url, "OpenAI Embeddings API", "embeddings")
payload = {
"model": request_func_input.model_name
if request_func_input.model_name
else request_func_input.model,
"input": request_func_input.prompt,
# Many embedding models have short context length,
# this is to avoid dropping some of the requests.
"truncate_prompt_tokens": -1,
}
_update_payload_common(payload, request_func_input)
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
}
_update_headers_common(headers, request_func_input)
return await _run_pooling_request(
session,
api_url,
payload=payload,
headers=headers,
pbar=pbar,
)
async def async_request_vllm_rerank(
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: tqdm | None = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
_validate_api_url(api_url, "vLLM score API", "rerank")
assert (
isinstance(request_func_input.prompt, list)
and len(request_func_input.prompt) > 1
)
payload = {
"model": request_func_input.model_name
if request_func_input.model_name
else request_func_input.model,
"query": request_func_input.prompt[0],
"documents": request_func_input.prompt[1:],
# Many reranker models have short context length,
# this is to avoid dropping some of the requests.
"truncate_prompt_tokens": -1,
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
}
_update_headers_common(headers, request_func_input)
return await _run_pooling_request(
session,
api_url,
payload=payload,
headers=headers,
pbar=pbar,
)
async def async_request_openai_embeddings_chat(
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: tqdm | None = None,
mm_position: Literal["first", "last"] = "last",
) -> RequestFuncOutput:
api_url = request_func_input.api_url
_validate_api_url(api_url, "OpenAI Embeddings API", "embeddings")
content = _get_chat_content(request_func_input, mm_position=mm_position)
payload = {
"model": request_func_input.model_name
if request_func_input.model_name
else request_func_input.model,
"messages": [
{"role": "user", "content": content},
],
# Many embedding models have short context length,
# this is to avoid dropping some of the requests.
"truncate_prompt_tokens": -1,
}
_update_payload_common(payload, request_func_input)
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
}
_update_headers_common(headers, request_func_input)
return await _run_pooling_request(
session,
api_url,
payload=payload,
headers=headers,
pbar=pbar,
)
def _try_extract_request_idx(request_func_input: RequestFuncInput):
if request_func_input.request_id:
match = re.search(r"(\d+)$", request_func_input.request_id)
if match:
try:
return int(match.group(1))
except ValueError:
pass
return None
def _preprocess_clip(request_func_input: RequestFuncInput):
if request_func_input.multi_modal_content:
# Image input
request_func_input.prompt = ""
def _preprocess_vlm2vec(request_func_input: RequestFuncInput):
if request_func_input.multi_modal_content:
request_idx = _try_extract_request_idx(request_func_input)
# Adjust the ratio manually if needed.
use_image_only_prompt = request_idx is None or request_idx % 2 == 0
if use_image_only_prompt:
# Image input
request_func_input.prompt = "Represent the given image."
else:
# Text+Image input
request_func_input.prompt = (
f"Represent the given image with the following question: "
f"{request_func_input.prompt}"
)
async def async_request_openai_embeddings_clip(
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: tqdm | None = None,
) -> RequestFuncOutput:
_preprocess_clip(request_func_input)
return await async_request_openai_embeddings_chat(
request_func_input,
session,
pbar=pbar,
)
async def async_request_openai_embeddings_vlm2vec(
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: tqdm | None = None,
) -> RequestFuncOutput:
_preprocess_vlm2vec(request_func_input)
return await async_request_openai_embeddings_chat(
request_func_input,
session,
pbar=pbar,
mm_position="first",
)
async def async_request_infinity_embeddings(
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: tqdm | None = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
_validate_api_url(api_url, "Infinity Embeddings API", "embeddings")
payload = {
"model": request_func_input.model_name
if request_func_input.model_name
else request_func_input.model,
}
if request_func_input.prompt:
payload["input"] = request_func_input.prompt
else:
mm_content = request_func_input.multi_modal_content
assert isinstance(mm_content, dict)
mm_type = mm_content["type"]
payload["input"] = mm_content[mm_type]["url"]
payload["modality"] = mm_type.split("_", 1)[0]
_update_payload_common(payload, request_func_input)
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
}
_update_headers_common(headers, request_func_input)
return await _run_pooling_request(
session,
api_url,
payload=payload,
headers=headers,
pbar=pbar,
)
async def async_request_infinity_embeddings_clip(
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: tqdm | None = None,
) -> RequestFuncOutput:
_preprocess_clip(request_func_input)
return await async_request_infinity_embeddings(
request_func_input,
session,
pbar=pbar,
)
# TODO: Add more request functions for different API protocols.
ASYNC_REQUEST_FUNCS: dict[str, RequestFunc] = {
"vllm": async_request_openai_completions,
"openai": async_request_openai_completions,
"openai-chat": async_request_openai_chat_completions,
"openai-audio": async_request_openai_audio,
"openai-embeddings": async_request_openai_embeddings,
"openai-embeddings-chat": async_request_openai_embeddings_chat,
"openai-embeddings-clip": async_request_openai_embeddings_clip,
"openai-embeddings-vlm2vec": async_request_openai_embeddings_vlm2vec,
# Infinity embedding server: https://github.com/michaelfeil/infinity
"infinity-embeddings": async_request_infinity_embeddings,
"infinity-embeddings-clip": async_request_infinity_embeddings_clip,
# (Infinity embedding server does not support vlm2vec)
"vllm-rerank": async_request_vllm_rerank,
}
OPENAI_COMPATIBLE_BACKENDS = [
k
for k, v in ASYNC_REQUEST_FUNCS.items()
if v in (async_request_openai_completions, async_request_openai_chat_completions)
]

View File

@@ -0,0 +1,72 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utilities for checking endpoint readiness."""
import asyncio
import time
import aiohttp
from tqdm.asyncio import tqdm
from .endpoint_request_func import RequestFunc, RequestFuncInput, RequestFuncOutput
async def wait_for_endpoint(
request_func: RequestFunc,
test_input: RequestFuncInput,
session: aiohttp.ClientSession,
timeout_seconds: int = 600,
retry_interval: int = 5,
) -> RequestFuncOutput:
"""
Wait for an endpoint to become available before starting benchmarks.
Args:
request_func: The async request function to call
test_input: The RequestFuncInput to test with
timeout_seconds: Maximum time to wait in seconds (default: 10 minutes)
retry_interval: Time between retries in seconds (default: 5 seconds)
Returns:
RequestFuncOutput: The successful response
Raises:
ValueError: If the endpoint doesn't become available within the timeout
"""
deadline = time.perf_counter() + timeout_seconds
output = RequestFuncOutput(success=False)
print(f"Waiting for endpoint to become up in {timeout_seconds} seconds")
with tqdm(
total=timeout_seconds,
bar_format="{desc} |{bar}| {elapsed} elapsed, {remaining} remaining",
unit="s",
) as pbar:
while True:
# update progress bar
remaining = deadline - time.perf_counter()
elapsed = timeout_seconds - remaining
update_amount = min(elapsed - pbar.n, timeout_seconds - pbar.n)
pbar.update(update_amount)
pbar.refresh()
if remaining <= 0:
pbar.close()
break
# ping the endpoint using request_func
try:
output = await request_func(
request_func_input=test_input, session=session
)
if output.success:
pbar.close()
return output
except aiohttp.ClientConnectorError:
pass
# retry after a delay
sleep_duration = min(retry_interval, remaining)
if sleep_duration > 0:
await asyncio.sleep(sleep_duration)
return output

View File

@@ -0,0 +1,79 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import json
import math
import os
from typing import Any
def convert_to_pytorch_benchmark_format(
args: argparse.Namespace, metrics: dict[str, list], extra_info: dict[str, Any]
) -> list:
"""
Save the benchmark results in the format used by PyTorch OSS benchmark with
on metric per record
https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database
"""
records = []
if not os.environ.get("SAVE_TO_PYTORCH_BENCHMARK_FORMAT", False):
return records
for name, benchmark_values in metrics.items():
record = {
"benchmark": {
"name": "vLLM benchmark",
"extra_info": {
"args": vars(args),
},
},
"model": {
"name": args.model,
},
"metric": {
"name": name,
"benchmark_values": benchmark_values,
"extra_info": extra_info,
},
}
tp = record["benchmark"]["extra_info"]["args"].get("tensor_parallel_size")
# Save tensor_parallel_size parameter if it's part of the metadata
if not tp and "tensor_parallel_size" in extra_info:
record["benchmark"]["extra_info"]["args"]["tensor_parallel_size"] = (
extra_info["tensor_parallel_size"]
)
records.append(record)
return records
class InfEncoder(json.JSONEncoder):
def clear_inf(self, o: Any):
if isinstance(o, dict):
return {
str(k)
if not isinstance(k, (str, int, float, bool, type(None)))
else k: self.clear_inf(v)
for k, v in o.items()
}
elif isinstance(o, list):
return [self.clear_inf(v) for v in o]
elif isinstance(o, float) and math.isinf(o):
return "inf"
return o
def iterencode(self, o: Any, *args, **kwargs) -> Any:
return super().iterencode(self.clear_inf(o), *args, **kwargs)
def write_to_json(filename: str, records: list) -> None:
with open(filename, "w") as f:
json.dump(
records,
f,
cls=InfEncoder,
default=lambda o: f"<{type(o).__name__} is not JSON serializable>",
)

1531
vllm_old/benchmarks/serve.py Normal file

File diff suppressed because it is too large Load Diff

View File

View File

@@ -0,0 +1,38 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG
from .plot import SweepPlotArgs
from .plot import main as plot_main
from .serve import SweepServeArgs
from .serve import main as serve_main
from .serve_sla import SweepServeSLAArgs
from .serve_sla import main as serve_sla_main
SUBCOMMANDS = (
(SweepServeArgs, serve_main),
(SweepServeSLAArgs, serve_sla_main),
(SweepPlotArgs, plot_main),
)
def add_cli_args(parser: argparse.ArgumentParser):
subparsers = parser.add_subparsers(required=True, dest="sweep_type")
for cmd, entrypoint in SUBCOMMANDS:
cmd_subparser = subparsers.add_parser(
cmd.parser_name,
description=cmd.parser_help,
usage=f"vllm bench sweep {cmd.parser_name} [options]",
)
cmd_subparser.set_defaults(dispatch_function=entrypoint)
cmd.add_cli_args(cmd_subparser)
cmd_subparser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format(
subcmd=f"sweep {cmd.parser_name}"
)
def main(args: argparse.Namespace):
args.dispatch_function(args)

View File

@@ -0,0 +1,91 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import os
from typing import Any
class ParameterSweep(list["ParameterSweepItem"]):
@classmethod
def read_json(cls, filepath: os.PathLike):
with open(filepath, "rb") as f:
records = json.load(f)
return cls.from_records(records)
@classmethod
def from_records(cls, records: list[dict[str, object]]):
if not isinstance(records, list):
raise TypeError(
f"The parameter sweep should be a list of dictionaries, "
f"but found type: {type(records)}"
)
return cls(ParameterSweepItem.from_record(record) for record in records)
class ParameterSweepItem(dict[str, object]):
@classmethod
def from_record(cls, record: dict[str, object]):
if not isinstance(record, dict):
raise TypeError(
f"Each item in the parameter sweep should be a dictionary, "
f"but found type: {type(record)}"
)
return cls(record)
def __or__(self, other: dict[str, Any]):
return type(self)(super().__or__(other))
# In JSON, we prefer "_"
def _iter_param_key_candidates(self, param_key: str):
# Inner config arguments are not converted by the CLI
if "." in param_key:
prefix, rest = param_key.split(".", 1)
for prefix_candidate in self._iter_param_key_candidates(prefix):
yield prefix_candidate + "." + rest
return
yield param_key
yield param_key.replace("-", "_")
yield param_key.replace("_", "-")
# In CLI, we prefer "-"
def _iter_cmd_key_candidates(self, param_key: str):
for k in reversed(tuple(self._iter_param_key_candidates(param_key))):
yield "--" + k
def _normalize_cmd_key(self, param_key: str):
return next(self._iter_cmd_key_candidates(param_key))
def has_param(self, param_key: str) -> bool:
return any(k in self for k in self._iter_param_key_candidates(param_key))
def apply_to_cmd(self, cmd: list[str]) -> list[str]:
cmd = list(cmd)
for k, v in self.items():
for k_candidate in self._iter_cmd_key_candidates(k):
try:
k_idx = cmd.index(k_candidate)
if isinstance(v, bool):
cmd[k_idx] = self._normalize_cmd_key(k if v else "no-" + k)
else:
cmd[k_idx + 1] = str(v)
break
except ValueError:
continue
else:
if isinstance(v, bool):
cmd.append(self._normalize_cmd_key(k if v else "no-" + k))
else:
cmd.extend([self._normalize_cmd_key(k), str(v)])
return cmd
def as_text(self, sep: str = ", ") -> str:
return sep.join(f"{k}={v}" for k, v in self.items())

View File

@@ -0,0 +1,580 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import json
from abc import ABC, abstractmethod
from concurrent.futures import ProcessPoolExecutor
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from types import TracebackType
from typing import ClassVar
from typing_extensions import Self, override
from vllm.utils.collection_utils import full_groupby
from vllm.utils.import_utils import PlaceholderModule
from .utils import sanitize_filename
try:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
except ImportError:
plt = PlaceholderModule("matplotlib").placeholder_attr("pyplot")
pd = PlaceholderModule("pandas")
seaborn = PlaceholderModule("seaborn")
@dataclass
class PlotFilterBase(ABC):
var: str
target: str
@classmethod
def parse_str(cls, s: str):
for op_key in PLOT_FILTERS:
if op_key in s:
key, value = s.split(op_key)
return PLOT_FILTERS[op_key](
key,
value.removeprefix(op_key).strip("'").strip('"'),
)
else:
raise ValueError(
f"Invalid operator for plot filter '{s}'. "
f"Valid operators are: {sorted(PLOT_FILTERS)}",
)
@abstractmethod
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
"""Applies this filter to a DataFrame."""
raise NotImplementedError
@dataclass
class PlotEqualTo(PlotFilterBase):
@override
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
try:
target = float(self.target)
except ValueError:
target = self.target
return df[df[self.var] == target]
@dataclass
class PlotLessThan(PlotFilterBase):
@override
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
return df[df[self.var] < float(self.target)]
@dataclass
class PlotLessThanOrEqualTo(PlotFilterBase):
@override
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
return df[df[self.var] <= float(self.target)]
@dataclass
class PlotGreaterThan(PlotFilterBase):
@override
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
return df[df[self.var] > float(self.target)]
@dataclass
class PlotGreaterThanOrEqualTo(PlotFilterBase):
@override
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
return df[df[self.var] >= float(self.target)]
# NOTE: The ordering is important! Match longer op_keys first
PLOT_FILTERS: dict[str, type[PlotFilterBase]] = {
"==": PlotEqualTo,
"<=": PlotLessThanOrEqualTo,
">=": PlotGreaterThanOrEqualTo,
"<": PlotLessThan,
">": PlotGreaterThan,
}
class PlotFilters(list[PlotFilterBase]):
@classmethod
def parse_str(cls, s: str):
if not s:
return cls()
return cls(PlotFilterBase.parse_str(e) for e in s.split(","))
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
for item in self:
df = item.apply(df)
return df
@dataclass
class PlotBinner:
var: str
bin_size: float
@classmethod
def parse_str(cls, s: str):
for op_key in PLOT_BINNERS:
if op_key in s:
key, value = s.split(op_key)
return PLOT_BINNERS[op_key](key, float(value.removeprefix(op_key)))
else:
raise ValueError(
f"Invalid operator for plot binner '{s}'. "
f"Valid operators are: {sorted(PLOT_BINNERS)}",
)
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
"""Applies this binner to a DataFrame."""
df = df.copy()
df[self.var] = df[self.var] // self.bin_size * self.bin_size
return df
PLOT_BINNERS: dict[str, type[PlotBinner]] = {
"%": PlotBinner,
}
class PlotBinners(list[PlotBinner]):
@classmethod
def parse_str(cls, s: str):
if not s:
return cls()
return cls(PlotBinner.parse_str(e) for e in s.split(","))
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
for item in self:
df = item.apply(df)
return df
def _json_load_bytes(path: Path) -> list[dict[str, object]]:
with path.open("rb") as f:
return json.load(f)
def _get_metric(run_data: dict[str, object], metric_key: str):
try:
return run_data[metric_key]
except KeyError as exc:
raise ValueError(f"Cannot find metric {metric_key!r} in {run_data=}") from exc
def _get_group(run_data: dict[str, object], group_keys: list[str]):
return tuple((k, str(_get_metric(run_data, k))) for k in group_keys)
def _get_fig_path(fig_dir: Path, group: tuple[tuple[str, str], ...]):
parts = list[str]()
if group:
parts.extend(("FIGURE-", *(f"{k}={v}" for k, v in group)))
else:
parts.append("figure")
return fig_dir / sanitize_filename("-".join(parts) + ".png")
class DummyExecutor:
map = map
def __enter__(self) -> Self:
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
exc_traceback: TracebackType | None,
) -> None:
return None
def _plot_fig(
fig_dir: Path,
fig_group_data: tuple[tuple[tuple[str, str], ...], list[dict[str, object]]],
row_by: list[str],
col_by: list[str],
curve_by: list[str],
*,
var_x: str,
var_y: str,
filter_by: PlotFilters,
bin_by: PlotBinners,
scale_x: str | None,
scale_y: str | None,
dry_run: bool,
):
fig_group, fig_data = fig_group_data
row_groups = full_groupby(
fig_data,
key=lambda item: _get_group(item, row_by),
)
num_rows = len(row_groups)
num_cols = max(
len(full_groupby(row_data, key=lambda item: _get_group(item, col_by)))
for _, row_data in row_groups
)
fig_path = _get_fig_path(fig_dir, fig_group)
print("[BEGIN FIGURE]")
print(f"Group: {dict(fig_group)}")
print(f"Grid: {num_rows} rows x {num_cols} cols")
print(f"Output file: {fig_path}")
if dry_run:
print("[END FIGURE]")
return
df = pd.DataFrame.from_records(fig_data)
if var_x not in df.columns:
raise ValueError(
f"Cannot find {var_x=!r} in parameter sweep results. "
f"Available variables: {df.columns.tolist()}"
)
if var_y not in df.columns:
raise ValueError(
f"Cannot find {var_y=!r} in parameter sweep results. "
f"Available variables: {df.columns.tolist()}"
)
for k in row_by:
if k not in df.columns:
raise ValueError(
f"Cannot find row_by={k!r} in parameter sweep results. "
f"Available variables: {df.columns.tolist()}"
)
for k in col_by:
if k not in df.columns:
raise ValueError(
f"Cannot find col_by={k!r} in parameter sweep results. "
f"Available variables: {df.columns.tolist()}"
)
for k in curve_by:
if k not in df.columns:
raise ValueError(
f"Cannot find curve_by={k!r} in parameter sweep results. "
f"Available variables: {df.columns.tolist()}"
)
df = filter_by.apply(df)
df = bin_by.apply(df)
df["row_group"] = (
pd.concat(
[k + "=" + df[k].astype(str) for k in row_by],
axis=1,
).agg("\n".join, axis=1)
if row_by
else "(All)"
)
df["col_group"] = (
pd.concat(
[k + "=" + df[k].astype(str) for k in col_by],
axis=1,
).agg("\n".join, axis=1)
if col_by
else "(All)"
)
g = sns.FacetGrid(df, row="row_group", col="col_group")
if row_by and col_by:
g.set_titles("{row_name}\n{col_name}")
elif row_by:
g.set_titles("{row_name}")
elif col_by:
g.set_titles("{col_name}")
else:
g.set_titles("")
if scale_x:
g.set(xscale=scale_x)
if scale_y:
g.set(yscale=scale_y)
if len(curve_by) <= 3:
hue, style, size, *_ = (*curve_by, None, None, None)
g.map_dataframe(
sns.lineplot,
x=var_x,
y=var_y,
hue=hue,
style=style,
size=size,
markers=True,
)
g.add_legend(title=hue)
else:
df["curve_group"] = (
pd.concat(
[k + "=" + df[k].astype(str) for k in curve_by],
axis=1,
).agg("\n".join, axis=1)
if curve_by
else "(All)"
)
g.map_dataframe(
sns.lineplot,
x=var_x,
y=var_y,
hue="curve_group",
markers=True,
)
g.add_legend()
g.savefig(fig_path)
plt.close(g.figure)
print("[END FIGURE]")
def plot(
output_dir: Path,
fig_dir: Path,
fig_by: list[str],
row_by: list[str],
col_by: list[str],
curve_by: list[str],
*,
var_x: str,
var_y: str,
filter_by: PlotFilters,
bin_by: PlotBinners,
scale_x: str | None,
scale_y: str | None,
dry_run: bool,
):
all_data = [
run_data
for path in output_dir.rglob("**/summary.json")
for run_data in _json_load_bytes(path)
]
if not all_data:
raise ValueError(f"Did not find any parameter sweep results under {output_dir}")
fig_dir.mkdir(parents=True, exist_ok=True)
fig_groups = full_groupby(
all_data,
key=lambda item: _get_group(item, fig_by),
)
with DummyExecutor() if len(fig_groups) <= 1 else ProcessPoolExecutor() as executor:
# Resolve the iterable to ensure that the workers are run
all(
executor.map(
partial(
_plot_fig,
fig_dir,
row_by=row_by,
col_by=col_by,
curve_by=curve_by,
var_x=var_x,
var_y=var_y,
filter_by=filter_by,
bin_by=bin_by,
scale_x=scale_x,
scale_y=scale_y,
dry_run=dry_run,
),
fig_groups,
)
)
@dataclass
class SweepPlotArgs:
output_dir: Path
fig_dir: Path
fig_by: list[str]
row_by: list[str]
col_by: list[str]
curve_by: list[str]
var_x: str
var_y: str
filter_by: PlotFilters
bin_by: PlotBinners
scale_x: str | None
scale_y: str | None
dry_run: bool
parser_name: ClassVar[str] = "plot"
parser_help: ClassVar[str] = "Plot performance curves from parameter sweep results."
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
output_dir = Path(args.OUTPUT_DIR)
if not output_dir.exists():
raise ValueError(f"No parameter sweep results under {output_dir}")
curve_by = [] if not args.curve_by else args.curve_by.split(",")
row_by = [] if not args.row_by else args.row_by.split(",")
col_by = [] if not args.col_by else args.col_by.split(",")
fig_by = [] if not args.fig_by else args.fig_by.split(",")
return cls(
output_dir=output_dir,
fig_dir=output_dir / args.fig_dir,
fig_by=fig_by,
row_by=row_by,
col_by=col_by,
curve_by=curve_by,
var_x=args.var_x,
var_y=args.var_y,
filter_by=PlotFilters.parse_str(args.filter_by),
bin_by=PlotBinners.parse_str(args.bin_by),
scale_x=args.scale_x,
scale_y=args.scale_y,
dry_run=args.dry_run,
)
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument(
"OUTPUT_DIR",
type=str,
default="results",
help="The directory containing the results to plot, "
"i.e., the `--output-dir` argument to the parameter sweep script.",
)
parser.add_argument(
"--fig-dir",
type=str,
default="",
help="The directory to save the figures, relative to `OUTPUT_DIR`. "
"By default, the same directory is used.",
)
parser.add_argument(
"--fig-by",
type=str,
default="",
help="A comma-separated list of variables, such that a separate figure "
"is created for each combination of these variables.",
)
parser.add_argument(
"--row-by",
type=str,
default="",
help="A comma-separated list of variables, such that a separate row "
"is created for each combination of these variables.",
)
parser.add_argument(
"--col-by",
type=str,
default="",
help="A comma-separated list of variables, such that a separate column "
"is created for each combination of these variables.",
)
parser.add_argument(
"--curve-by",
type=str,
default=None,
help="A comma-separated list of variables, such that a separate curve "
"is created for each combination of these variables.",
)
parser.add_argument(
"--var-x",
type=str,
default="request_throughput",
help="The variable for the x-axis.",
)
parser.add_argument(
"--var-y",
type=str,
default="p99_e2el_ms",
help="The variable for the y-axis",
)
parser.add_argument(
"--filter-by",
type=str,
default="",
help="A comma-separated list of statements indicating values to filter by. "
"This is useful to remove outliers. "
"Example: `max_concurrency<1000,max_num_batched_tokens<=4096` means "
"plot only the points where `max_concurrency` is less than 1000 and "
"`max_num_batched_tokens` is no greater than 4096.",
)
parser.add_argument(
"--bin-by",
type=str,
default="",
help="A comma-separated list of statements indicating values to bin by. "
"This is useful to avoid plotting points that are too close together. "
"Example: `request_throughput%%1` means "
"use a bin size of 1 for the `request_throughput` variable.",
)
parser.add_argument(
"--scale-x",
type=str,
default=None,
help="The scale to use for the x-axis. "
"Currently only accepts string values such as 'log' and 'sqrt'. "
"See also: https://seaborn.pydata.org/generated/seaborn.objects.Plot.scale.html",
)
parser.add_argument(
"--scale-y",
type=str,
default=None,
help="The scale to use for the y-axis. "
"Currently only accepts string values such as 'log' and 'sqrt'. "
"See also: https://seaborn.pydata.org/generated/seaborn.objects.Plot.scale.html",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="If set, prints the information about each figure to plot, "
"then exits without drawing them.",
)
return parser
def run_main(args: SweepPlotArgs):
return plot(
output_dir=args.output_dir,
fig_dir=args.fig_dir,
fig_by=args.fig_by,
row_by=args.row_by,
col_by=args.col_by,
curve_by=args.curve_by,
var_x=args.var_x,
var_y=args.var_y,
filter_by=args.filter_by,
bin_by=args.bin_by,
scale_x=args.scale_x,
scale_y=args.scale_y,
dry_run=args.dry_run,
)
def main(args: argparse.Namespace):
run_main(SweepPlotArgs.from_cli_args(args))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=SweepPlotArgs.parser_help)
SweepPlotArgs.add_cli_args(parser)
main(parser.parse_args())

View File

@@ -0,0 +1,416 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import contextlib
import json
import shlex
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import ClassVar
from vllm.utils.import_utils import PlaceholderModule
from .param_sweep import ParameterSweep, ParameterSweepItem
from .server import ServerProcess
from .utils import sanitize_filename
try:
import pandas as pd
except ImportError:
pd = PlaceholderModule("pandas")
@contextlib.contextmanager
def run_server(
serve_cmd: list[str],
after_bench_cmd: list[str],
*,
show_stdout: bool,
serve_overrides: ParameterSweepItem,
dry_run: bool,
):
server_cmd = serve_overrides.apply_to_cmd(serve_cmd)
print("[BEGIN SERVER]")
print(f"Server overrides: {serve_overrides}")
print(f"Server command: {server_cmd}")
if dry_run:
yield None
print("[END SERVER]")
return
with ServerProcess(server_cmd, after_bench_cmd, show_stdout=show_stdout) as server:
yield server
print("[END SERVER]")
def _update_run_data(
run_data: dict[str, object],
serve_overrides: ParameterSweepItem,
bench_overrides: ParameterSweepItem,
run_number: int,
):
run_data["run_number"] = run_number
run_data.update(serve_overrides)
run_data.update(bench_overrides)
return run_data
def run_benchmark(
server: ServerProcess | None,
bench_cmd: list[str],
*,
serve_overrides: ParameterSweepItem,
bench_overrides: ParameterSweepItem,
run_number: int,
output_path: Path,
dry_run: bool,
):
benchmark_cmd = [
*bench_overrides.apply_to_cmd(bench_cmd),
"--percentile-metrics",
"ttft,tpot,itl,e2el",
"--save-result",
"--result-dir",
str(output_path.parent),
"--result-filename",
output_path.name,
]
print("[BEGIN BENCHMARK]")
print(f"Benchmark overrides: {bench_overrides}")
print(f"Run Number: {run_number}")
print(f"Benchmark command: {benchmark_cmd}")
print(f"Output file: {output_path}")
run_data: dict[str, object]
if output_path.exists():
print("Found existing results. Skipping.")
with output_path.open("rb") as f:
run_data = json.load(f)
return _update_run_data(
run_data,
serve_overrides,
bench_overrides,
run_number,
)
if server is None:
if not dry_run:
raise ValueError(f"Cannot find results at {output_path}")
print("[END BENCHMARK]")
return None
output_path.parent.mkdir(parents=True, exist_ok=True)
server.run_subcommand(benchmark_cmd)
server.after_bench()
with output_path.open("rb") as f:
run_data = json.load(f)
run_data = _update_run_data(
run_data,
serve_overrides,
bench_overrides,
run_number,
)
with output_path.open("w") as f:
json.dump(run_data, f, indent=4)
print("[END BENCHMARK]")
return run_data
def _get_comb_base_path(
output_dir: Path,
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
):
parts = list[str]()
if serve_comb:
parts.extend(("SERVE-", serve_comb.as_text(sep="-")))
if bench_comb:
parts.extend(("BENCH-", bench_comb.as_text(sep="-")))
return output_dir / sanitize_filename("-".join(parts))
def _get_comb_run_path(base_path: Path, run_number: int | None):
if run_number is None:
return base_path / "summary.json"
return base_path / f"run={run_number}.json"
def _comb_needs_server(
serve_comb: ParameterSweepItem,
bench_combs: ParameterSweep,
output_dir: Path,
):
for bench_comb in bench_combs:
base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb)
if not _get_comb_run_path(base_path, run_number=None).exists():
return True
return False
def run_comb(
server: ServerProcess | None,
bench_cmd: list[str],
*,
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
base_path: Path,
num_runs: int,
dry_run: bool,
):
comb_data = list[dict[str, object]]()
for run_number in range(num_runs):
run_data = run_benchmark(
server,
bench_cmd,
serve_overrides=serve_comb,
bench_overrides=bench_comb,
run_number=run_number,
output_path=_get_comb_run_path(base_path, run_number),
dry_run=dry_run,
)
if run_data is not None:
comb_data.append(run_data)
if dry_run:
return None
with _get_comb_run_path(base_path, run_number=None).open("w") as f:
json.dump(comb_data, f, indent=4)
return comb_data
def run_combs(
serve_cmd: list[str],
bench_cmd: list[str],
after_bench_cmd: list[str],
*,
show_stdout: bool,
serve_params: ParameterSweep,
bench_params: ParameterSweep,
output_dir: Path,
num_runs: int,
dry_run: bool,
):
all_data = list[dict[str, object]]()
for serve_comb in serve_params:
with (
run_server(
serve_cmd,
after_bench_cmd,
show_stdout=show_stdout,
serve_overrides=serve_comb,
dry_run=dry_run,
)
if _comb_needs_server(serve_comb, bench_params, output_dir)
else contextlib.nullcontext()
) as server:
for bench_comb in bench_params:
base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb)
comb_data = run_comb(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb,
base_path=base_path,
num_runs=num_runs,
dry_run=dry_run,
)
if comb_data is not None:
all_data.extend(comb_data)
if dry_run:
return None
combined_df = pd.DataFrame.from_records(all_data)
combined_df.to_csv(output_dir / "summary.csv")
return combined_df
@dataclass
class SweepServeArgs:
serve_cmd: list[str]
bench_cmd: list[str]
after_bench_cmd: list[str]
show_stdout: bool
serve_params: ParameterSweep
bench_params: ParameterSweep
output_dir: Path
num_runs: int
dry_run: bool
resume: str | None
parser_name: ClassVar[str] = "serve"
parser_help: ClassVar[str] = "Run vLLM server benchmark under multiple settings."
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
serve_cmd = shlex.split(args.serve_cmd)
bench_cmd = shlex.split(args.bench_cmd)
after_bench_cmd = (
[] if args.after_bench_cmd is None else shlex.split(args.after_bench_cmd)
)
if args.serve_params:
serve_params = ParameterSweep.read_json(args.serve_params)
else:
# i.e.: run serve_cmd without any modification
serve_params = ParameterSweep.from_records([{}])
if args.bench_params:
bench_params = ParameterSweep.read_json(args.bench_params)
else:
# i.e.: run bench_cmd without any modification
bench_params = ParameterSweep.from_records([{}])
num_runs = args.num_runs
if num_runs < 1:
raise ValueError("`num_runs` should be at least 1.")
return cls(
serve_cmd=serve_cmd,
bench_cmd=bench_cmd,
after_bench_cmd=after_bench_cmd,
show_stdout=args.show_stdout,
serve_params=serve_params,
bench_params=bench_params,
output_dir=Path(args.output_dir),
num_runs=num_runs,
dry_run=args.dry_run,
resume=args.resume,
)
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument(
"--serve-cmd",
type=str,
required=True,
help="The command used to run the server: `vllm serve ...`",
)
parser.add_argument(
"--bench-cmd",
type=str,
required=True,
help="The command used to run the benchmark: `vllm bench serve ...`",
)
parser.add_argument(
"--after-bench-cmd",
type=str,
default=None,
help="After a benchmark run is complete, invoke this command instead of "
"the default `ServerWrapper.clear_cache()`.",
)
parser.add_argument(
"--show-stdout",
action="store_true",
help="If set, logs the standard output of subcommands. "
"Useful for debugging but can be quite spammy.",
)
parser.add_argument(
"--serve-params",
type=str,
default=None,
help="Path to JSON file containing a list of parameter combinations "
"for the `vllm serve` command. "
"If both `serve_params` and `bench_params` are given, "
"this script will iterate over their Cartesian product.",
)
parser.add_argument(
"--bench-params",
type=str,
default=None,
help="Path to JSON file containing a list of parameter combinations "
"for the `vllm bench serve` command. "
"If both `serve_params` and `bench_params` are given, "
"this script will iterate over their Cartesian product.",
)
parser.add_argument(
"-o",
"--output-dir",
type=str,
default="results",
help="The directory to which results are written.",
)
parser.add_argument(
"--num-runs",
type=int,
default=3,
help="Number of runs per parameter combination.",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="If set, prints the commands to run, "
"then exits without executing them.",
)
parser.add_argument(
"--resume",
type=str,
default=None,
help="Set this to the name of a directory under `output_dir` (which is a "
"timestamp) to resume a previous execution of this script, i.e., only run "
"parameter combinations for which there are still no output files.",
)
return parser
def run_main(args: SweepServeArgs):
timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = args.output_dir / timestamp
if args.resume and not output_dir.exists():
raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")
try:
return run_combs(
serve_cmd=args.serve_cmd,
bench_cmd=args.bench_cmd,
after_bench_cmd=args.after_bench_cmd,
show_stdout=args.show_stdout,
serve_params=args.serve_params,
bench_params=args.bench_params,
output_dir=output_dir,
num_runs=args.num_runs,
dry_run=args.dry_run,
)
except BaseException as exc:
raise RuntimeError(
f"The script was terminated early. Use `--resume {timestamp}` "
f"to continue the script from its last checkpoint."
) from exc
def main(args: argparse.Namespace):
run_main(SweepServeArgs.from_cli_args(args))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=SweepServeArgs.parser_help)
SweepServeArgs.add_cli_args(parser)
main(parser.parse_args())

View File

@@ -0,0 +1,492 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import contextlib
import json
import math
from dataclasses import asdict, dataclass
from datetime import datetime
from pathlib import Path
from typing import ClassVar, Literal, get_args
from typing_extensions import assert_never
from vllm.utils.import_utils import PlaceholderModule
from .param_sweep import ParameterSweep, ParameterSweepItem
from .serve import SweepServeArgs, run_benchmark, run_server
from .server import ServerProcess
from .sla_sweep import SLASweep, SLASweepItem
from .utils import sanitize_filename
try:
import pandas as pd
except ImportError:
pd = PlaceholderModule("pandas")
def _get_sla_base_path(
output_dir: Path,
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
):
parts = list[str]()
if serve_comb:
parts.extend(("SERVE-", serve_comb.as_text(sep="-")))
if bench_comb:
parts.extend(("BENCH-", bench_comb.as_text(sep="-")))
return output_dir / sanitize_filename("-".join(parts))
def _get_sla_iter_path(
base_path: Path,
sla_comb: SLASweepItem,
sla_variable: str,
sla_value: int | None,
):
if sla_value is None:
prefix = sla_comb.as_text(sep="-")
return base_path / f"SLA--{prefix}.json"
return base_path / f"{sla_variable}={sla_value}"
def _get_sla_run_path(iter_path: Path, run_number: int | None):
if run_number is None:
return iter_path / "summary.json"
return iter_path / f"run={run_number}.json"
def _sla_needs_server(
serve_comb: ParameterSweepItem,
bench_combs: ParameterSweep,
sla_combs: SLASweep,
sla_variable: str,
output_dir: Path,
):
for bench_comb in bench_combs:
base_path = _get_sla_base_path(output_dir, serve_comb, bench_comb)
for sla_comb in sla_combs:
if not _get_sla_iter_path(
base_path,
sla_comb,
sla_variable,
sla_value=None,
).exists():
return True
return False
def run_sla(
server: ServerProcess | None,
bench_cmd: list[str],
*,
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
iter_path: Path,
num_runs: int,
dry_run: bool,
):
iter_data = list[dict[str, object]]()
for run_number in range(num_runs):
run_data = run_benchmark(
server,
bench_cmd,
serve_overrides=serve_comb,
bench_overrides=bench_comb,
run_number=run_number,
output_path=_get_sla_run_path(iter_path, run_number),
dry_run=dry_run,
)
if run_data is not None:
iter_data.append(run_data)
if dry_run:
return None
with _get_sla_run_path(iter_path, run_number=None).open("w") as f:
json.dump(iter_data, f, indent=4)
return iter_data
SLAVariable = Literal["request_rate", "max_concurrency"]
def _estimate_sla_value(run_data: dict[str, object], sla_variable: SLAVariable):
request_throughput = float(run_data["request_throughput"]) # type: ignore
if sla_variable == "request_rate":
return request_throughput
if sla_variable == "max_concurrency":
mean_latency_ms = float(run_data["mean_e2el_ms"]) # type: ignore
return request_throughput * mean_latency_ms / 1000
assert_never(sla_variable)
def _estimate_sla_bounds(
server: ServerProcess | None,
bench_cmd: list[str],
*,
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
sla_comb: SLASweepItem,
base_path: Path,
num_runs: int,
dry_run: bool,
sla_variable: SLAVariable,
init_value: int,
max_value: int,
):
sla_data = list[dict[str, object]]()
max_passing: int = 0
min_failing: int = 0
val: int = init_value
assert val > 0
while True:
print(f"Testing {sla_variable}: {val} req/s")
iter_data = run_sla(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb | {sla_variable: val},
iter_path=_get_sla_iter_path(base_path, sla_comb, sla_variable, val),
num_runs=num_runs,
dry_run=dry_run,
)
assert iter_data is not None
sla_data.extend(iter_data)
iter_data_mean = {
k: sum(float(run_data[k]) for run_data in iter_data) / len(iter_data) # type: ignore
for k in sla_comb
}
sla_results = [
criterion.print_and_validate(iter_data_mean, k)
for k, criterion in sla_comb.items()
]
if all(sla_results):
print("SLA criteria are met.")
max_passing = val
val *= 2
else:
print("SLA criteria are not met.")
min_failing = val
break
if val >= max_value:
break
return sla_data, (max_passing, min_failing)
def _find_sla_value(
server: ServerProcess | None,
bench_cmd: list[str],
*,
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
sla_comb: SLASweepItem,
base_path: Path,
num_runs: int,
dry_run: bool,
sla_variable: SLAVariable,
min_value: int,
max_value: int,
):
sla_data = list[dict[str, object]]()
left: int = min_value
right: int = max_value
while True:
val = (left + right) // 2
print(f"Testing {sla_variable}: {val} req/s")
iter_data = run_sla(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb | {sla_variable: val},
iter_path=_get_sla_iter_path(base_path, sla_comb, sla_variable, val),
num_runs=num_runs,
dry_run=dry_run,
)
assert iter_data is not None
sla_data.extend(iter_data)
iter_data_mean = {
k: sum(float(run_data[k]) for run_data in iter_data) / len(iter_data) # type: ignore
for k in sla_comb
}
sla_results = [
criterion.print_and_validate(iter_data_mean, k)
for k, criterion in sla_comb.items()
]
if all(sla_results):
print("SLA criteria are met.")
left = val
else:
print("SLA criteria are not met.")
right = val
if right - left <= 1:
break
return sla_data, left
def search_sla(
server: ServerProcess | None,
bench_cmd: list[str],
*,
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
sla_comb: SLASweepItem,
sla_variable: SLAVariable,
sla_inf_value: int = 65536, # The value that represents infinite QPS
base_path: Path,
num_runs: int,
dry_run: bool,
):
print("[SLA START]")
print(f"SLA criteria: {sla_comb.as_text()}")
sla_data_0 = run_sla(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb | {sla_variable: sla_inf_value},
iter_path=_get_sla_iter_path(base_path, sla_comb, sla_variable, sla_inf_value),
num_runs=num_runs,
dry_run=dry_run,
)
if sla_data_0 is None:
assert dry_run
print("Omitting SLA search.")
print("[SLA END]")
return None
sla_init_value = math.ceil(
sum(_estimate_sla_value(item, sla_variable) for item in sla_data_0)
/ len(sla_data_0)
)
print(f"Initial {sla_variable} to search: {sla_init_value} req/s.")
sla_data_1, (sla_min, sla_max) = _estimate_sla_bounds(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb,
sla_comb=sla_comb,
base_path=base_path,
num_runs=num_runs,
dry_run=dry_run,
sla_variable=sla_variable,
init_value=sla_init_value,
max_value=sla_inf_value,
)
print(f"Range of {sla_variable} to search: [{sla_min}, {sla_max}] req/s.")
sla_data_2, sla_value = _find_sla_value(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb,
sla_comb=sla_comb,
base_path=base_path,
num_runs=num_runs,
dry_run=dry_run,
sla_variable=sla_variable,
min_value=sla_min,
max_value=sla_max,
)
sla_data = sla_data_0 + sla_data_1 + sla_data_2
print(f"Maximum {sla_variable} for SLA: {sla_value} req/s.")
with _get_sla_iter_path(
base_path,
sla_comb,
sla_variable,
sla_value=None,
).open("w") as f:
json.dump(sla_data, f, indent=4)
print("[SLA END]")
return sla_data
def run_slas(
serve_cmd: list[str],
bench_cmd: list[str],
after_bench_cmd: list[str],
*,
show_stdout: bool,
serve_params: ParameterSweep,
bench_params: ParameterSweep,
sla_params: SLASweep,
sla_variable: SLAVariable,
output_dir: Path,
num_runs: int,
dry_run: bool,
):
if any(bench_comb.has_param(sla_variable) for bench_comb in bench_params):
raise ValueError(
f"You should not override `{sla_variable}` in `bench_params` in SLA mode, "
"since it is supposed to be determined automatically."
)
all_data = list[dict[str, object]]()
for serve_comb in serve_params:
with (
run_server(
serve_cmd,
after_bench_cmd,
show_stdout=show_stdout,
serve_overrides=serve_comb,
dry_run=dry_run,
)
if _sla_needs_server(
serve_comb,
bench_params,
sla_params,
sla_variable,
output_dir,
)
else contextlib.nullcontext()
) as server:
for bench_comb in bench_params:
for sla_comb in sla_params:
base_path = _get_sla_base_path(output_dir, serve_comb, bench_comb)
comb_data = search_sla(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb,
sla_comb=sla_comb,
sla_variable=sla_variable,
base_path=base_path,
num_runs=num_runs,
dry_run=dry_run,
)
if comb_data is not None:
all_data.extend(comb_data)
if dry_run:
return None
combined_df = pd.DataFrame.from_records(all_data)
combined_df.to_csv(output_dir / "summary.csv")
return combined_df
@dataclass
class SweepServeSLAArgs(SweepServeArgs):
sla_params: SLASweep
sla_variable: SLAVariable
parser_name: ClassVar[str] = "serve_sla"
parser_help: ClassVar[str] = "Tune a variable to meet SLAs under multiple settings."
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
# NOTE: Don't use super() as `from_cli_args` calls `cls()`
base_args = SweepServeArgs.from_cli_args(args)
if args.sla_params:
sla_params = SLASweep.read_json(args.sla_params)
else:
sla_params = SLASweep.from_records([])
return cls(
**asdict(base_args),
sla_params=sla_params,
sla_variable=args.sla_variable,
)
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser = super().add_cli_args(parser)
sla_group = parser.add_argument_group("sla options")
sla_group.add_argument(
"--sla-params",
type=str,
required=True,
help="Path to JSON file containing a list of SLA constraints to satisfy. "
'Each constraint is expressed in `{"<KEY>": "<OP><VALUE>"}` format, '
'e.g.: `{"p99_e2el_ms": "<=500"}` means that '
"the E2E latency should be less than 500ms 99%% of the time. "
"Setting this option runs this script in SLA mode, which searches for "
"the maximum `sla_variable` that satisfies the constraints for "
"each combination of `serve_params`, `bench_params`, and `sla_params`.",
)
sla_group.add_argument(
"--sla-variable",
type=str,
choices=get_args(SLAVariable),
default="request_rate",
help="Whether to tune request rate or maximum concurrency to satisfy "
"the SLA constraints.",
)
return parser
def run_main(args: SweepServeSLAArgs):
timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = args.output_dir / timestamp
if args.resume and not output_dir.exists():
raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")
try:
return run_slas(
serve_cmd=args.serve_cmd,
bench_cmd=args.bench_cmd,
after_bench_cmd=args.after_bench_cmd,
show_stdout=args.show_stdout,
serve_params=args.serve_params,
bench_params=args.bench_params,
sla_params=args.sla_params,
sla_variable=args.sla_variable,
output_dir=output_dir,
num_runs=args.num_runs,
dry_run=args.dry_run,
)
except BaseException as exc:
raise RuntimeError(
f"The script was terminated early. Use `--resume {timestamp}` "
f"to continue the script from its last checkpoint."
) from exc
def main(args: argparse.Namespace):
run_main(SweepServeSLAArgs.from_cli_args(args))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=SweepServeSLAArgs.parser_help)
SweepServeSLAArgs.add_cli_args(parser)
main(parser.parse_args())

View File

@@ -0,0 +1,114 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import os
import signal
import subprocess
from types import TracebackType
import requests
from typing_extensions import Self
class ServerProcess:
def __init__(
self,
server_cmd: list[str],
after_bench_cmd: list[str],
*,
show_stdout: bool,
) -> None:
super().__init__()
self.server_cmd = server_cmd
self.after_bench_cmd = after_bench_cmd
self.show_stdout = show_stdout
def __enter__(self) -> Self:
self.start()
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
exc_traceback: TracebackType | None,
) -> None:
self.stop()
def start(self):
# Create new process for clean termination
self._server_process = subprocess.Popen(
self.server_cmd,
start_new_session=True,
stdout=None if self.show_stdout else subprocess.DEVNULL,
# Need `VLLM_SERVER_DEV_MODE=1` for `_reset_caches`
env=os.environ | {"VLLM_SERVER_DEV_MODE": "1"},
)
def stop(self):
server_process = self._server_process
if server_process.poll() is None:
# In case only some processes have been terminated
with contextlib.suppress(ProcessLookupError):
# We need to kill both API Server and Engine processes
os.killpg(os.getpgid(server_process.pid), signal.SIGKILL)
def run_subcommand(self, cmd: list[str]):
return subprocess.run(
cmd,
stdout=None if self.show_stdout else subprocess.DEVNULL,
check=True,
)
def after_bench(self) -> None:
if not self.after_bench_cmd:
self.reset_caches()
return
self.run_subcommand(self.after_bench_cmd)
def _get_vllm_server_address(self) -> str:
server_cmd = self.server_cmd
for host_key in ("--host",):
if host_key in server_cmd:
host = server_cmd[server_cmd.index(host_key) + 1]
break
else:
host = "localhost"
for port_key in ("-p", "--port"):
if port_key in server_cmd:
port = int(server_cmd[server_cmd.index(port_key) + 1])
break
else:
port = 8000 # The default value in vllm serve
return f"http://{host}:{port}"
def reset_caches(self) -> None:
server_cmd = self.server_cmd
# Use `.endswith()` to match `/bin/...`
if server_cmd[0].endswith("vllm"):
server_address = self._get_vllm_server_address()
print(f"Resetting caches at {server_address}")
res = requests.post(f"{server_address}/reset_prefix_cache")
res.raise_for_status()
res = requests.post(f"{server_address}/reset_mm_cache")
res.raise_for_status()
elif server_cmd[0].endswith("infinity_emb"):
if "--vector-disk-cache" in server_cmd:
raise NotImplementedError(
"Infinity server uses caching but does not expose a method "
"to reset the cache"
)
else:
raise NotImplementedError(
f"No implementation of `reset_caches` for `{server_cmd[0]}` server. "
"Please specify a custom command via `--after-bench-cmd`."
)

View File

@@ -0,0 +1,132 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing_extensions import override
@dataclass
class SLACriterionBase(ABC):
target: float
@abstractmethod
def validate(self, actual: float) -> bool:
"""Return `True` if this criterion is met; otherwise `False`."""
raise NotImplementedError
@abstractmethod
def format_cond(self, lhs: str) -> str:
raise NotImplementedError
def print_and_validate(
self,
metrics: dict[str, float],
metrics_key: str,
) -> bool:
metric = metrics[metrics_key]
result = self.validate(metric)
cond = self.format_cond(f"{metrics_key} = {metric:.2f}")
print(f"Validating SLA: {cond} | " + ("PASSED" if result else "FAILED"))
return result
@dataclass
class SLALessThan(SLACriterionBase):
@override
def validate(self, actual: float) -> bool:
return actual < self.target
@override
def format_cond(self, lhs: str) -> str:
return f"{lhs}<{self.target:.2f}"
@dataclass
class SLALessThanOrEqualTo(SLACriterionBase):
@override
def validate(self, actual: float) -> bool:
return actual <= self.target
@override
def format_cond(self, lhs: str) -> str:
return f"{lhs}<={self.target:.2f}"
@dataclass
class SLAGreaterThan(SLACriterionBase):
@override
def validate(self, actual: float) -> bool:
return actual > self.target
@override
def format_cond(self, lhs: str) -> str:
return f"{lhs}>{self.target:.2f}"
@dataclass
class SLAGreaterThanOrEqualTo(SLACriterionBase):
@override
def validate(self, actual: float) -> bool:
return actual >= self.target
@override
def format_cond(self, lhs: str) -> str:
return f"{lhs}>={self.target:.2f}"
# NOTE: The ordering is important! Match longer op_keys first
SLA_CRITERIA: dict[str, type[SLACriterionBase]] = {
"<=": SLALessThanOrEqualTo,
">=": SLAGreaterThanOrEqualTo,
"<": SLALessThan,
">": SLAGreaterThan,
}
class SLASweep(list["SLASweepItem"]):
@classmethod
def read_json(cls, filepath: os.PathLike):
with open(filepath, "rb") as f:
records = json.load(f)
return cls.from_records(records)
@classmethod
def from_records(cls, records: list[dict[str, str]]):
if not isinstance(records, list):
raise TypeError(
f"The SLA sweep should be a list of dictionaries, "
f"but found type: {type(records)}"
)
return cls(SLASweepItem.from_record(record) for record in records)
class SLASweepItem(dict[str, SLACriterionBase]):
@classmethod
def from_record(cls, record: dict[str, str]):
sla_criteria: dict[str, SLACriterionBase] = {}
for metric_key, metric_value in record.items():
for op_key in SLA_CRITERIA:
if metric_value.startswith(op_key):
sla_criteria[metric_key] = SLA_CRITERIA[op_key](
float(metric_value.removeprefix(op_key))
)
break
else:
raise ValueError(
f"Invalid operator for "
f"SLA constraint '{metric_key}={metric_value}'. "
f"Valid operators are: {sorted(SLA_CRITERIA)}",
)
return cls(sla_criteria)
def as_text(self, sep: str = ", ") -> str:
return sep.join(v.format_cond(k) for k, v in self.items())

View File

@@ -0,0 +1,4 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
def sanitize_filename(filename: str) -> str:
return filename.replace("/", "_").replace("..", "__").strip("'").strip('"')

View File

@@ -0,0 +1,799 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Benchmark offline inference throughput."""
import argparse
import dataclasses
import json
import os
import random
import time
import warnings
from typing import Any
import torch
import uvloop
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase
from vllm.benchmarks.datasets import (
AIMODataset,
BurstGPTDataset,
ConversationDataset,
InstructCoderDataset,
MultiModalConversationDataset,
PrefixRepetitionRandomDataset,
RandomDataset,
SampleRequest,
ShareGPTDataset,
SonnetDataset,
VisionArenaDataset,
)
from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.inputs import TextPrompt, TokensPrompt
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams
from vllm.utils.async_utils import merge_async_iterators
def run_vllm(
requests: list[SampleRequest],
n: int,
engine_args: EngineArgs,
do_profile: bool,
disable_detokenize: bool = False,
) -> tuple[float, list[RequestOutput] | None]:
from vllm import LLM, SamplingParams
llm = LLM(**dataclasses.asdict(engine_args))
assert all(
llm.llm_engine.model_config.max_model_len
>= (request.prompt_len + request.expected_output_len)
for request in requests
), (
"Please ensure that max_model_len is greater than the sum of"
" prompt_len and expected_output_len for all requests."
)
# Add the requests to the engine.
prompts: list[TextPrompt | TokensPrompt] = []
sampling_params: list[SamplingParams] = []
for request in requests:
prompt = (
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"])
if "prompt_token_ids" in request.prompt
else TextPrompt(prompt=request.prompt)
)
if request.multi_modal_data:
assert isinstance(request.multi_modal_data, dict)
prompt["multi_modal_data"] = request.multi_modal_data
prompts.append(prompt)
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=request.expected_output_len,
detokenize=not disable_detokenize,
)
)
lora_requests: list[LoRARequest] | None = None
if engine_args.enable_lora:
lora_requests = [request.lora_request for request in requests]
use_beam_search = False
outputs = None
if not use_beam_search:
start = time.perf_counter()
if do_profile:
llm.start_profile()
outputs = llm.generate(
prompts, sampling_params, lora_request=lora_requests, use_tqdm=True
)
if do_profile:
llm.stop_profile()
end = time.perf_counter()
else:
assert lora_requests is None, "BeamSearch API does not support LoRA"
prompts = [request.prompt for request in requests]
# output_len should be the same for all requests.
output_len = requests[0].expected_output_len
for request in requests:
assert request.expected_output_len == output_len
start = time.perf_counter()
if do_profile:
llm.start_profile()
llm.beam_search(
prompts,
BeamSearchParams(
beam_width=n,
max_tokens=output_len,
ignore_eos=True,
),
)
if do_profile:
llm.stop_profile()
end = time.perf_counter()
return end - start, outputs
def run_vllm_chat(
requests: list[SampleRequest],
n: int,
engine_args: EngineArgs,
do_profile: bool,
disable_detokenize: bool = False,
) -> tuple[float, list[RequestOutput]]:
"""
Run vLLM chat benchmark. This function is recommended ONLY for benchmarking
multimodal models as it properly handles multimodal inputs and chat
formatting. For non-multimodal models, use run_vllm() instead.
"""
from vllm import LLM, SamplingParams
llm = LLM(**dataclasses.asdict(engine_args))
assert all(
llm.llm_engine.model_config.max_model_len
>= (request.prompt_len + request.expected_output_len)
for request in requests
), (
"Please ensure that max_model_len is greater than the sum of "
"prompt_len and expected_output_len for all requests."
)
prompts = []
sampling_params: list[SamplingParams] = []
for request in requests:
prompts.append(request.prompt)
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=request.expected_output_len,
detokenize=not disable_detokenize,
)
)
start = time.perf_counter()
if do_profile:
llm.start_profile()
outputs = llm.chat(prompts, sampling_params, use_tqdm=True)
if do_profile:
llm.stop_profile()
end = time.perf_counter()
return end - start, outputs
async def run_vllm_async(
requests: list[SampleRequest],
n: int,
engine_args: AsyncEngineArgs,
do_profile: bool,
disable_frontend_multiprocessing: bool = False,
disable_detokenize: bool = False,
) -> float:
from vllm import SamplingParams
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args,
)
async with build_async_engine_client_from_engine_args(
engine_args,
disable_frontend_multiprocessing=disable_frontend_multiprocessing,
) as llm:
model_config = llm.model_config
assert all(
model_config.max_model_len
>= (request.prompt_len + request.expected_output_len)
for request in requests
), (
"Please ensure that max_model_len is greater than the sum of"
" prompt_len and expected_output_len for all requests."
)
# Add the requests to the engine.
prompts: list[TextPrompt | TokensPrompt] = []
sampling_params: list[SamplingParams] = []
lora_requests: list[LoRARequest | None] = []
for request in requests:
prompt = (
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"])
if "prompt_token_ids" in request.prompt
else TextPrompt(prompt=request.prompt)
)
if request.multi_modal_data:
assert isinstance(request.multi_modal_data, dict)
prompt["multi_modal_data"] = request.multi_modal_data
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=request.expected_output_len,
detokenize=not disable_detokenize,
)
)
prompts.append(prompt)
lora_requests.append(request.lora_request)
generators = []
start = time.perf_counter()
if do_profile:
await llm.start_profile()
for i, (prompt, sp, lr) in enumerate(
zip(prompts, sampling_params, lora_requests)
):
generator = llm.generate(prompt, sp, lora_request=lr, request_id=f"test{i}")
generators.append(generator)
all_gens = merge_async_iterators(*generators)
async for i, res in all_gens:
pass
if do_profile:
await llm.stop_profile()
end = time.perf_counter()
return end - start
def run_hf(
requests: list[SampleRequest],
model: str,
tokenizer: PreTrainedTokenizerBase,
n: int,
max_batch_size: int,
trust_remote_code: bool,
disable_detokenize: bool = False,
) -> float:
llm = AutoModelForCausalLM.from_pretrained(
model, dtype=torch.float16, trust_remote_code=trust_remote_code
)
if llm.config.model_type == "llama":
# To enable padding in the HF backend.
tokenizer.pad_token = tokenizer.eos_token
llm = llm.cuda()
pbar = tqdm(total=len(requests))
start = time.perf_counter()
batch: list[str] = []
max_prompt_len = 0
max_output_len = 0
for i in range(len(requests)):
prompt = requests[i].prompt
prompt_len = requests[i].prompt_len
output_len = requests[i].expected_output_len
# Add the prompt to the batch.
batch.append(prompt)
max_prompt_len = max(max_prompt_len, prompt_len)
max_output_len = max(max_output_len, output_len)
if len(batch) < max_batch_size and i != len(requests) - 1:
# Check if we can add more requests to the batch.
next_prompt_len = requests[i + 1].prompt_len
next_output_len = requests[i + 1].expected_output_len
if (
max(max_prompt_len, next_prompt_len)
+ max(max_output_len, next_output_len)
) <= 2048:
# We can add more requests to the batch.
continue
# Generate the sequences.
input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids
llm_outputs = llm.generate(
input_ids=input_ids.cuda(),
do_sample=True,
num_return_sequences=n,
temperature=1.0,
top_p=1.0,
use_cache=True,
max_new_tokens=max_output_len,
)
if not disable_detokenize:
# Include the decoding time.
tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
pbar.update(len(batch))
# Clear the batch.
batch = []
max_prompt_len = 0
max_output_len = 0
end = time.perf_counter()
return end - start
def save_to_pytorch_benchmark_format(
args: argparse.Namespace, results: dict[str, Any]
) -> None:
pt_records = convert_to_pytorch_benchmark_format(
args=args,
metrics={
"requests_per_second": [results["requests_per_second"]],
"tokens_per_second": [results["tokens_per_second"]],
},
extra_info={
k: results[k] for k in ["elapsed_time", "num_requests", "total_num_tokens"]
},
)
if pt_records:
# Don't use json suffix here as we don't want CI to pick it up
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
write_to_json(pt_file, pt_records)
def get_requests(args, tokenizer):
# Common parameters for all dataset types.
common_kwargs = {
"dataset_path": args.dataset_path,
"random_seed": args.seed,
}
sample_kwargs = {
"tokenizer": tokenizer,
"lora_path": args.lora_path,
"max_loras": args.max_loras,
"num_requests": args.num_prompts,
"input_len": args.input_len,
"output_len": args.output_len,
}
if args.dataset_path is None or args.dataset_name == "random":
sample_kwargs["range_ratio"] = args.random_range_ratio
sample_kwargs["prefix_len"] = args.prefix_len
dataset_cls = RandomDataset
elif args.dataset_name == "sharegpt":
dataset_cls = ShareGPTDataset
if args.backend == "vllm-chat":
sample_kwargs["enable_multimodal_chat"] = True
elif args.dataset_name == "sonnet":
assert tokenizer.chat_template or tokenizer.default_chat_template, (
"Tokenizer/model must have chat template for sonnet dataset."
)
dataset_cls = SonnetDataset
sample_kwargs["prefix_len"] = args.prefix_len
sample_kwargs["return_prompt_formatted"] = True
elif args.dataset_name == "burstgpt":
dataset_cls = BurstGPTDataset
elif args.dataset_name == "hf":
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
dataset_cls = VisionArenaDataset
common_kwargs["dataset_subset"] = None
common_kwargs["dataset_split"] = "train"
sample_kwargs["enable_multimodal_chat"] = True
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
dataset_cls = InstructCoderDataset
common_kwargs["dataset_split"] = "train"
elif args.dataset_path in MultiModalConversationDataset.SUPPORTED_DATASET_PATHS:
dataset_cls = MultiModalConversationDataset
common_kwargs["dataset_subset"] = args.hf_subset
common_kwargs["dataset_split"] = args.hf_split
sample_kwargs["enable_multimodal_chat"] = True
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
dataset_cls = ConversationDataset
common_kwargs["dataset_subset"] = args.hf_subset
common_kwargs["dataset_split"] = args.hf_split
sample_kwargs["enable_multimodal_chat"] = True
elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
dataset_cls = AIMODataset
common_kwargs["dataset_subset"] = None
common_kwargs["dataset_split"] = "train"
elif args.dataset_name == "prefix_repetition":
dataset_cls = PrefixRepetitionRandomDataset
sample_kwargs["prefix_len"] = args.prefix_repetition_prefix_len
sample_kwargs["suffix_len"] = args.prefix_repetition_suffix_len
sample_kwargs["num_prefixes"] = args.prefix_repetition_num_prefixes
sample_kwargs["output_len"] = args.prefix_repetition_output_len
else:
raise ValueError(f"Unknown dataset name: {args.dataset_name}")
# Remove None values
sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None}
requests = dataset_cls(**common_kwargs).sample(**sample_kwargs)
requests = filter_requests_for_dp(requests, args.data_parallel_size)
return requests
def filter_requests_for_dp(requests, data_parallel_size):
# Note(zhuohan): The way we get data_parallel_rank is hacky and only
# works for external launcher mode. Should be cleaned up and deprecated
# in the future with a better vLLM distributed process design.
if data_parallel_size == 1:
return requests
global_rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
data_parallel_rank = global_rank // (world_size // data_parallel_size)
return [
r
for i, r in enumerate(requests)
if i % data_parallel_size == data_parallel_rank
]
def validate_args(args):
"""
Validate command-line arguments.
"""
# === Deprecation and Defaulting ===
if args.dataset is not None:
warnings.warn(
"The '--dataset' argument will be deprecated in the next release. "
"Please use '--dataset-name' and '--dataset-path' instead.",
stacklevel=2,
)
args.dataset_path = args.dataset
if not getattr(args, "tokenizer", None):
args.tokenizer = args.model
# === Backend Validation ===
valid_backends = {"vllm", "hf", "mii", "vllm-chat"}
if args.backend not in valid_backends:
raise ValueError(f"Unsupported backend: {args.backend}")
# === Dataset Configuration ===
if (
not args.dataset
and not args.dataset_path
and args.dataset_name not in {"prefix_repetition"}
):
print("When dataset path is not set, it will default to random dataset")
args.dataset_name = "random"
if args.input_len is None:
raise ValueError("input_len must be provided for a random dataset")
# === Dataset Name Specific Checks ===
# --hf-subset and --hf-split: only used
# when dataset_name is 'hf'
if args.dataset_name != "hf" and (
getattr(args, "hf_subset", None) is not None
or getattr(args, "hf_split", None) is not None
):
warnings.warn(
"--hf-subset and --hf-split will be ignored \
since --dataset-name is not 'hf'.",
stacklevel=2,
)
elif args.dataset_name == "hf":
if args.dataset_path in (
VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys()
| MultiModalConversationDataset.SUPPORTED_DATASET_PATHS
| ConversationDataset.SUPPORTED_DATASET_PATHS
):
assert args.backend == "vllm-chat", (
f"{args.dataset_path} needs to use vllm-chat as the backend."
)
elif args.dataset_path in (
InstructCoderDataset.SUPPORTED_DATASET_PATHS
| AIMODataset.SUPPORTED_DATASET_PATHS
):
assert args.backend == "vllm", (
f"{args.dataset_path} needs to use vllm as the backend."
)
else:
raise ValueError(f"{args.dataset_path} is not supported by hf dataset.")
# --random-range-ratio: only used when dataset_name is 'random'
if args.dataset_name != "random" and args.random_range_ratio is not None:
warnings.warn(
"--random-range-ratio will be ignored since \
--dataset-name is not 'random'.",
stacklevel=2,
)
# --prefix-len: only used when dataset_name is 'random', 'sonnet', or not
# set.
if (
args.dataset_name not in {"random", "sonnet", None}
and args.prefix_len is not None
):
warnings.warn(
"--prefix-len will be ignored since --dataset-name\
is not 'random', 'sonnet', or not set.",
stacklevel=2,
)
# === LoRA Settings ===
if getattr(args, "enable_lora", False) and args.backend != "vllm":
raise ValueError("LoRA benchmarking is only supported for vLLM backend")
if getattr(args, "enable_lora", False) and args.lora_path is None:
raise ValueError("LoRA path must be provided when enable_lora is True")
# === Backend-specific Validations ===
if args.backend == "hf" and args.hf_max_batch_size is None:
raise ValueError("HF max batch size is required for HF backend")
if args.backend != "hf" and args.hf_max_batch_size is not None:
raise ValueError("HF max batch size is only for HF backend.")
if (
args.backend in {"hf", "mii"}
and getattr(args, "quantization", None) is not None
):
raise ValueError("Quantization is only for vLLM backend.")
if args.backend == "mii" and args.dtype != "auto":
raise ValueError("dtype must be auto for MII backend.")
if args.backend == "mii" and args.n != 1:
raise ValueError("n must be 1 for MII backend.")
if args.backend == "mii" and args.tokenizer != args.model:
raise ValueError("Tokenizer must be the same as the model for MII backend.")
if args.data_parallel_size > 1 and (
args.distributed_executor_backend != "external_launcher" or args.async_engine
):
# --data-parallel is not supported fully.
# Old issue: https://github.com/vllm-project/vllm/issues/16222
# Currently we only support data parallel with external launcher
# mode (i.e., launch with toruchrun).
raise ValueError(
"Data parallel is only supported with external launcher mode "
"with synchronous engine in offline benchmark, "
"please use benchmark serving instead"
)
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--backend",
type=str,
choices=["vllm", "hf", "mii", "vllm-chat"],
default="vllm",
)
parser.add_argument(
"--dataset-name",
type=str,
choices=["sharegpt", "random", "sonnet", "burstgpt", "hf", "prefix_repetition"],
help="Name of the dataset to benchmark on.",
default="sharegpt",
)
parser.add_argument(
"--dataset",
type=str,
default=None,
help="Path to the ShareGPT dataset, will be deprecated in\
the next release. The dataset is expected to "
"be a json in form of list[dict[..., conversations: "
"list[dict[..., value: <prompt_or_response>]]]]",
)
parser.add_argument(
"--dataset-path", type=str, default=None, help="Path to the dataset"
)
parser.add_argument(
"--input-len",
type=int,
default=None,
help="Input prompt length for each request",
)
parser.add_argument(
"--output-len",
type=int,
default=None,
help="Output length for each request. Overrides the "
"output length from the dataset.",
)
parser.add_argument(
"--n", type=int, default=1, help="Number of generated sequences per prompt."
)
parser.add_argument(
"--num-prompts", type=int, default=1000, help="Number of prompts to process."
)
parser.add_argument(
"--hf-max-batch-size",
type=int,
default=None,
help="Maximum batch size for HF backend.",
)
parser.add_argument(
"--output-json",
type=str,
default=None,
help="Path to save the throughput results in JSON format.",
)
parser.add_argument(
"--async-engine",
action="store_true",
default=False,
help="Use vLLM async engine rather than LLM class.",
)
parser.add_argument(
"--disable-frontend-multiprocessing",
action="store_true",
default=False,
help="Disable decoupled async engine frontend.",
)
parser.add_argument(
"--disable-detokenize",
action="store_true",
help=(
"Do not detokenize the response (i.e. do not include "
"detokenization time in the measurement)"
),
)
# LoRA
parser.add_argument(
"--lora-path",
type=str,
default=None,
help="Path to the lora adapters to use. This can be an absolute path, "
"a relative path, or a Hugging Face model identifier.",
)
parser.add_argument(
"--prefix-len",
type=int,
default=0,
help="Number of fixed prefix tokens before the random "
"context in a request (default: 0).",
)
# random dataset
parser.add_argument(
"--random-range-ratio",
type=float,
default=0.0,
help="Range ratio for sampling input/output length, "
"used only for RandomDataset. Must be in the range [0, 1) to define "
"a symmetric sampling range "
"[length * (1 - range_ratio), length * (1 + range_ratio)].",
)
# hf dtaset
parser.add_argument(
"--hf-subset", type=str, default=None, help="Subset of the HF dataset."
)
parser.add_argument(
"--hf-split", type=str, default=None, help="Split of the HF dataset."
)
parser.add_argument(
"--profile",
action="store_true",
default=False,
help="Use Torch Profiler. The env variable "
"VLLM_TORCH_PROFILER_DIR must be set to enable profiler.",
)
# prefix repetition dataset
prefix_repetition_group = parser.add_argument_group(
"prefix repetition dataset options"
)
prefix_repetition_group.add_argument(
"--prefix-repetition-prefix-len",
type=int,
default=None,
help="Number of prefix tokens per request, used only for prefix "
"repetition dataset.",
)
prefix_repetition_group.add_argument(
"--prefix-repetition-suffix-len",
type=int,
default=None,
help="Number of suffix tokens per request, used only for prefix "
"repetition dataset. Total input length is prefix_len + suffix_len.",
)
prefix_repetition_group.add_argument(
"--prefix-repetition-num-prefixes",
type=int,
default=None,
help="Number of prefixes to generate, used only for prefix repetition "
"dataset. Prompts per prefix is num_requests // num_prefixes.",
)
prefix_repetition_group.add_argument(
"--prefix-repetition-output-len",
type=int,
default=None,
help="Number of output tokens per request, used only for prefix "
"repetition dataset.",
)
parser = AsyncEngineArgs.add_cli_args(parser)
def main(args: argparse.Namespace):
if args.tokenizer is None:
args.tokenizer = args.model
validate_args(args)
if args.seed is None:
args.seed = 0
random.seed(args.seed)
# Sample the requests.
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer, trust_remote_code=args.trust_remote_code
)
requests = get_requests(args, tokenizer)
is_multi_modal = any(request.multi_modal_data is not None for request in requests)
request_outputs: list[RequestOutput] | None = None
if args.backend == "vllm":
if args.async_engine:
elapsed_time = uvloop.run(
run_vllm_async(
requests,
args.n,
AsyncEngineArgs.from_cli_args(args),
disable_frontend_multiprocessing=args.disable_frontend_multiprocessing,
disable_detokenize=args.disable_detokenize,
do_profile=args.profile,
)
)
else:
elapsed_time, request_outputs = run_vllm(
requests,
args.n,
EngineArgs.from_cli_args(args),
disable_detokenize=args.disable_detokenize,
do_profile=args.profile,
)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
if args.profile:
raise NotImplementedError("Profiling not implemented yet for backend='hf'.")
elapsed_time = run_hf(
requests,
args.model,
tokenizer,
args.n,
args.hf_max_batch_size,
args.trust_remote_code,
args.disable_detokenize,
)
elif args.backend == "vllm-chat":
elapsed_time, request_outputs = run_vllm_chat(
requests,
args.n,
EngineArgs.from_cli_args(args),
disable_detokenize=args.disable_detokenize,
do_profile=args.profile,
)
else:
raise ValueError(f"Unknown backend: {args.backend}")
if request_outputs:
# Note: with the vllm and vllm-chat backends,
# we have request_outputs, which we use to count tokens.
total_prompt_tokens = 0
total_output_tokens = 0
for ro in request_outputs:
if not isinstance(ro, RequestOutput):
continue
total_prompt_tokens += (
len(ro.prompt_token_ids) if ro.prompt_token_ids else 0
)
total_output_tokens += sum(len(o.token_ids) for o in ro.outputs if o)
total_num_tokens = total_prompt_tokens + total_output_tokens
else:
total_num_tokens = sum(r.prompt_len + r.expected_output_len for r in requests)
total_output_tokens = sum(r.expected_output_len for r in requests)
total_prompt_tokens = total_num_tokens - total_output_tokens
if is_multi_modal and args.backend != "vllm-chat":
print(
"\033[91mWARNING\033[0m: Multi-modal request with "
f"{args.backend} backend detected. The "
"following metrics are not accurate because image tokens are not"
" counted. See vllm-project/vllm/issues/9778 for details."
)
# TODO(vllm-project/vllm/issues/9778): Count multi-modal token length.
# vllm-chat backend counts the image tokens now
print(
f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
f"{total_output_tokens / elapsed_time:.2f} output tokens/s"
)
print(f"Total num prompt tokens: {total_prompt_tokens}")
print(f"Total num output tokens: {total_output_tokens}")
# Output JSON results if specified
if args.output_json:
results = {
"elapsed_time": elapsed_time,
"num_requests": len(requests),
"total_num_tokens": total_num_tokens,
"requests_per_second": len(requests) / elapsed_time,
"tokens_per_second": total_num_tokens / elapsed_time,
}
with open(args.output_json, "w") as f:
json.dump(results, f, indent=4)
save_to_pytorch_benchmark_format(args, results)