Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -31,6 +31,7 @@ from tempfile import NamedTemporaryFile
from typing import Any, cast
import numpy as np
from huggingface_hub import snapshot_download
from PIL import Image
from typing_extensions import deprecated
@@ -60,6 +61,8 @@ except ImportError:
logger = logging.getLogger(__name__)
DEFAULT_NUM_PROMPTS = 1000
# -----------------------------------------------------------------------------
# Data Classes
# -----------------------------------------------------------------------------
@@ -303,9 +306,11 @@ def process_image(image: Any) -> Mapping[str, Any]:
a JPEG in memory. - Encodes the JPEG data as a base64 string. - Returns
a dictionary with the image as a base64 data URL.
3. String input: - Treats the string as a URL or local file path. -
Prepends "file://" if the string doesn't start with "http://" or
"file://". - Returns a dictionary with the image URL.
3. String input: - Treats the string as a URL, local file path, or base64
encoded data. - If string starts with "data:image/", treats as base64.
- If string starts with "http://", "https://", or "file://", treats as URL.
- Otherwise treats as local file path and prepends "file://".
- Returns a dictionary with the image URL or base64 data.
Raises:
ValueError: If the input is not a supported type.
@@ -325,14 +330,14 @@ def process_image(image: Any) -> Mapping[str, Any]:
if isinstance(image, str):
image_url = (
image
if image.startswith(("http://", "https://", "file://"))
if image.startswith(("http://", "https://", "file://", "data:image/"))
else f"file://{image}"
)
return {"type": "image_url", "image_url": {"url": image_url}}
raise ValueError(
f"Invalid image input {image}. Must be a PIL.Image.Image"
" or str or dictionary with raw image bytes."
f"Invalid image input {image}. Must be a PIL.Image.Image, "
"str (URL, file path, or base64 data URL), or dictionary with raw image bytes."
)
@@ -1338,7 +1343,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
parser.add_argument(
"--num-prompts",
type=int,
default=1000,
default=DEFAULT_NUM_PROMPTS,
help="Number of prompts to process.",
)
parser.add_argument(
@@ -2676,6 +2681,14 @@ class MMVUDataset(HuggingFaceDataset):
+ (" ".join(f"{k}.{v}" for k, v in x["choices"].items())),
}
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self._remote_path_root = (
f"https://huggingface.co/datasets/{self.hf_name}/resolve/main"
)
self._local_path_root = snapshot_download(self.hf_name, repo_type="dataset")
def sample(
self,
tokenizer: TokenizerLike,
@@ -2698,7 +2711,9 @@ class MMVUDataset(HuggingFaceDataset):
break
prompt = parser_fn(item)
mm_content = process_video(item["video"])
mm_content = process_video(
item["video"].replace(self._remote_path_root, self._local_path_root)
)
prompt_len = len(tokenizer.encode(prompt))
if enable_multimodal_chat:
# Note: when chat is enabled the request prompt_len is no longer

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Benchmark library utilities."""

View File

@@ -0,0 +1,802 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""The request function for API endpoints."""
import io
import json
import os
import sys
import time
import traceback
from collections.abc import Awaitable
from dataclasses import dataclass, field
from typing import Any, Literal, Protocol
import aiohttp
import regex as re
from tqdm.asyncio import tqdm
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
class StreamedResponseHandler:
"""Handles streaming HTTP responses by accumulating chunks until complete
messages are available."""
def __init__(self):
self.buffer = ""
def add_chunk(self, chunk_bytes: bytes) -> list[str]:
"""Add a chunk of bytes to the buffer and return any complete
messages."""
chunk_str = chunk_bytes.decode("utf-8")
self.buffer += chunk_str
messages = []
# Split by double newlines (SSE message separator)
while "\n\n" in self.buffer:
message, self.buffer = self.buffer.split("\n\n", 1)
message = message.strip()
if message:
messages.append(message)
# if self.buffer is not empty, check if it is a complete message
# by removing data: prefix and check if it is a valid JSON
if self.buffer.startswith("data: "):
message_content = self.buffer.removeprefix("data: ").strip()
if message_content == "[DONE]":
messages.append(self.buffer.strip())
self.buffer = ""
elif message_content:
try:
json.loads(message_content)
messages.append(self.buffer.strip())
self.buffer = ""
except json.JSONDecodeError:
# Incomplete JSON, wait for more chunks.
pass
return messages
@dataclass
class RequestFuncInput:
"""The input for the request function."""
prompt: str | list[str]
api_url: str
prompt_len: int
output_len: int
model: str
model_name: str | None = None
logprobs: int | None = None
extra_headers: dict | None = None
extra_body: dict | None = None
multi_modal_content: dict | list[dict] | None = None
ignore_eos: bool = False
language: str | None = None
request_id: str | None = None
@dataclass
class RequestFuncOutput:
"""The output of the request function including metrics."""
generated_text: str = ""
success: bool = False
latency: float = 0.0
output_tokens: int = 0
ttft: float = 0.0 # Time to first token
itl: list[float] = field(default_factory=list) # list of inter-token latencies
tpot: float = 0.0 # avg next-token latencies
prompt_len: int = 0
error: str = ""
start_time: float = 0.0
input_audio_duration: float = 0.0 # in seconds
class RequestFunc(Protocol):
def __call__(
self,
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: tqdm | None = None,
) -> Awaitable[RequestFuncOutput]: ...
def _validate_api_url(
api_url: str,
api_name: str,
expected_suffixes: str | set[str],
) -> None:
if isinstance(expected_suffixes, str):
expected_suffixes = {expected_suffixes}
expected_suffixes = {*expected_suffixes, "profile"}
if not api_url.endswith(tuple(expected_suffixes)):
raise ValueError(f"{api_name} URL must end with one of: {expected_suffixes}.")
def _update_payload_common(
payload: dict[str, Any],
request_func_input: RequestFuncInput,
) -> None:
if request_func_input.ignore_eos:
payload["ignore_eos"] = request_func_input.ignore_eos
if request_func_input.extra_body:
payload.update(request_func_input.extra_body)
def _update_headers_common(
headers: dict[str, Any],
request_func_input: RequestFuncInput,
) -> None:
if request_func_input.extra_headers:
headers |= request_func_input.extra_headers
if request_func_input.request_id:
headers["x-request-id"] = request_func_input.request_id
def _get_headers(content_type: str | None = None) -> dict[str, str]:
headers = {}
if content_type:
headers["Content-Type"] = content_type
api_key = os.environ.get("OPENAI_API_KEY")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
return headers
async def async_request_openai_completions(
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: tqdm | None = None,
) -> RequestFuncOutput:
"""The async request function for the OpenAI Completions API.
Args:
request_func_input: The input for the request function.
pbar: The progress bar to display the progress.
Returns:
The output of the request function.
"""
api_url = request_func_input.api_url
_validate_api_url(api_url, "OpenAI Completions API", "completions")
payload = {
"model": request_func_input.model_name
if request_func_input.model_name
else request_func_input.model,
"prompt": request_func_input.prompt,
"repetition_penalty": 1.0,
"max_tokens": request_func_input.output_len,
"logprobs": request_func_input.logprobs,
"stream": True,
"stream_options": {
"include_usage": True,
},
}
_update_payload_common(payload, request_func_input)
headers = _get_headers()
_update_headers_common(headers, request_func_input)
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
generated_text = ""
st = time.perf_counter()
output.start_time = st
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload, headers=headers) as response:
if response.status == 200:
first_chunk_received = False
handler = StreamedResponseHandler()
async for chunk_bytes in response.content.iter_any():
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
messages = handler.add_chunk(chunk_bytes)
for message in messages:
# NOTE: SSE comments (often used as pings) start with
# a colon. These are not JSON data payload and should
# be skipped.
if message.startswith(":"):
continue
chunk = message.removeprefix("data: ")
if chunk != "[DONE]":
data = json.loads(chunk)
# NOTE: Some completion API might have a last
# usage summary response without a token so we
# want to check a token was generated
if choices := data.get("choices"):
# Note that text could be empty here
# e.g. for special tokens
text = choices[0].get("text")
timestamp = time.perf_counter()
# First token
if not first_chunk_received:
first_chunk_received = True
ttft = time.perf_counter() - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp = timestamp
generated_text += text or ""
elif usage := data.get("usage"):
output.output_tokens = usage.get("completion_tokens")
if first_chunk_received:
output.success = True
else:
output.success = False
output.error = (
"Never received a valid chunk to calculate TTFT."
"This response will be marked as failed!"
)
output.generated_text = generated_text
output.latency = most_recent_timestamp - st
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar:
pbar.update(1)
return output
def _get_chat_content(
request_func_input: RequestFuncInput,
mm_position: Literal["first", "last"] = "last",
) -> list[dict[str, Any]]:
text_contents = [{"type": "text", "text": request_func_input.prompt}]
mm_contents = []
if request_func_input.multi_modal_content:
mm_content = request_func_input.multi_modal_content
if isinstance(mm_content, list):
mm_contents.extend(request_func_input.multi_modal_content)
elif isinstance(mm_content, dict):
mm_contents.append(request_func_input.multi_modal_content)
else:
raise TypeError(
"multi_modal_content must be a dict or list[dict] for openai-chat"
)
if mm_position == "first":
return mm_contents + text_contents
return text_contents + mm_contents
async def async_request_openai_chat_completions(
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: tqdm | None = None,
mm_position: Literal["first", "last"] = "last",
) -> RequestFuncOutput:
api_url = request_func_input.api_url
_validate_api_url(api_url, "OpenAI Chat Completions API", "chat/completions")
content = _get_chat_content(request_func_input, mm_position=mm_position)
payload = {
"model": request_func_input.model_name
if request_func_input.model_name
else request_func_input.model,
"messages": [
{"role": "user", "content": content},
],
"max_completion_tokens": request_func_input.output_len,
"stream": True,
"stream_options": {
"include_usage": True,
},
}
_update_payload_common(payload, request_func_input)
headers = _get_headers("application/json")
_update_headers_common(headers, request_func_input)
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
generated_text = ""
ttft = 0.0
st = time.perf_counter()
output.start_time = st
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload, headers=headers) as response:
if response.status == 200:
handler = StreamedResponseHandler()
async for chunk_bytes in response.content.iter_any():
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
messages = handler.add_chunk(chunk_bytes)
for message in messages:
# NOTE: SSE comments (often used as pings) start with
# a colon. These are not JSON data payload and should
# be skipped.
if message.startswith(":"):
continue
chunk = message.removeprefix("data: ")
if chunk != "[DONE]":
timestamp = time.perf_counter()
data = json.loads(chunk)
if choices := data.get("choices"):
content = choices[0]["delta"].get("content")
# First token
if ttft == 0.0:
ttft = timestamp - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(timestamp - most_recent_timestamp)
generated_text += content or ""
elif usage := data.get("usage"):
output.output_tokens = usage.get("completion_tokens")
most_recent_timestamp = timestamp
output.generated_text = generated_text
output.success = True
output.latency = most_recent_timestamp - st
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar:
pbar.update(1)
return output
async def async_request_openai_audio(
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: tqdm | None = None,
) -> RequestFuncOutput:
# Lazy import without PlaceholderModule to avoid vllm dep.
import soundfile
api_url = request_func_input.api_url
_validate_api_url(api_url, "OpenAI Audio API", {"transcriptions", "translations"})
content = [{"type": "text", "text": request_func_input.prompt}]
payload = {
"model": request_func_input.model_name
if request_func_input.model_name
else request_func_input.model,
"max_completion_tokens": request_func_input.output_len,
"stream": True,
"language": "en",
# Flattened due to multipart/form-data
"stream_include_usage": True,
"stream_continuous_usage_stats": True,
}
_update_payload_common(payload, request_func_input)
headers = _get_headers()
_update_headers_common(headers, request_func_input)
# Send audio file
def to_bytes(y, sr):
buffer = io.BytesIO()
soundfile.write(buffer, y, sr, format="WAV")
buffer.seek(0)
return buffer
mm_audio = request_func_input.multi_modal_content
if not isinstance(mm_audio, dict) or "audio" not in mm_audio:
raise TypeError("multi_modal_content must be a dict containing 'audio'")
with to_bytes(*mm_audio["audio"]) as f:
form = aiohttp.FormData()
form.add_field("file", f, content_type="audio/wav")
for key, value in payload.items():
form.add_field(key, str(value))
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
output.input_audio_duration = soundfile.info(f).duration
f.seek(0)
generated_text = ""
ttft = 0.0
st = time.perf_counter()
output.start_time = st
most_recent_timestamp = st
try:
async with session.post(
url=api_url, data=form, headers=headers
) as response:
if response.status == 200:
handler = StreamedResponseHandler()
async for chunk_bytes in response.content.iter_any():
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
messages = handler.add_chunk(chunk_bytes)
for message in messages:
if type(message) is bytes:
message = message.decode("utf-8")
chunk = message.removeprefix("data: ")
if chunk != "[DONE]":
timestamp = time.perf_counter()
data = json.loads(chunk)
if choices := data.get("choices"):
content = choices[0]["delta"].get("content")
# First token
if ttft == 0.0:
ttft = timestamp - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(
timestamp - most_recent_timestamp
)
generated_text += content or ""
elif usage := data.get("usage"):
output.output_tokens = usage.get(
"completion_tokens"
)
most_recent_timestamp = timestamp
output.generated_text = generated_text
output.success = True
output.latency = most_recent_timestamp - st
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar:
pbar.update(1)
return output
async def _run_pooling_request(
session: aiohttp.ClientSession,
api_url: str,
payload: dict[str, Any],
headers: dict[str, Any],
pbar: tqdm | None = None,
) -> RequestFuncOutput:
output = RequestFuncOutput()
st = time.perf_counter()
output.start_time = st
try:
async with session.post(url=api_url, headers=headers, json=payload) as response:
if response.status == 200:
output.ttft = output.latency = time.perf_counter() - st
if payload.get("encoding_format", "float") == "bytes":
metadata = json.loads(response.headers["metadata"])
usage = metadata.get("usage", {})
else:
data = await response.json()
usage = data.get("usage", {})
output.success = True
output.generated_text = ""
output.prompt_len = usage.get("prompt_tokens", 0)
else:
output.success = False
output.error = response.reason or ""
except Exception as e:
output.success = False
output.error = str(e)
if pbar:
pbar.update(1)
return output
async def async_request_openai_embeddings(
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: tqdm | None = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
_validate_api_url(api_url, "OpenAI Embeddings API", "embeddings")
payload = {
"model": request_func_input.model_name
if request_func_input.model_name
else request_func_input.model,
"input": request_func_input.prompt,
# Many embedding models have short context length,
# this is to avoid dropping some of the requests.
"truncate_prompt_tokens": -1,
}
_update_payload_common(payload, request_func_input)
headers = _get_headers("application/json")
_update_headers_common(headers, request_func_input)
return await _run_pooling_request(
session,
api_url,
payload=payload,
headers=headers,
pbar=pbar,
)
async def async_request_vllm_rerank(
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: tqdm | None = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
_validate_api_url(api_url, "vLLM score API", "rerank")
assert (
isinstance(request_func_input.prompt, list)
and len(request_func_input.prompt) > 1
)
payload = {
"model": request_func_input.model_name
if request_func_input.model_name
else request_func_input.model,
"query": request_func_input.prompt[0],
"documents": request_func_input.prompt[1:],
# Many reranker models have short context length,
# this is to avoid dropping some of the requests.
"truncate_prompt_tokens": -1,
}
headers = _get_headers("application/json")
_update_headers_common(headers, request_func_input)
return await _run_pooling_request(
session,
api_url,
payload=payload,
headers=headers,
pbar=pbar,
)
async def async_request_openai_embeddings_chat(
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: tqdm | None = None,
mm_position: Literal["first", "last"] = "last",
) -> RequestFuncOutput:
api_url = request_func_input.api_url
_validate_api_url(api_url, "OpenAI Embeddings API", "embeddings")
content = _get_chat_content(request_func_input, mm_position=mm_position)
payload = {
"model": request_func_input.model_name
if request_func_input.model_name
else request_func_input.model,
"messages": [
{"role": "user", "content": content},
],
# Many embedding models have short context length,
# this is to avoid dropping some of the requests.
"truncate_prompt_tokens": -1,
}
_update_payload_common(payload, request_func_input)
headers = _get_headers("application/json")
_update_headers_common(headers, request_func_input)
return await _run_pooling_request(
session,
api_url,
payload=payload,
headers=headers,
pbar=pbar,
)
def _try_extract_request_idx(request_func_input: RequestFuncInput):
if request_func_input.request_id:
match = re.search(r"(\d+)$", request_func_input.request_id)
if match:
try:
return int(match.group(1))
except ValueError:
pass
return None
def _preprocess_clip(request_func_input: RequestFuncInput):
if request_func_input.multi_modal_content:
# Image input
request_func_input.prompt = ""
def _preprocess_vlm2vec(request_func_input: RequestFuncInput):
if request_func_input.multi_modal_content:
request_idx = _try_extract_request_idx(request_func_input)
# Adjust the ratio manually if needed.
use_image_only_prompt = request_idx is None or request_idx % 2 == 0
if use_image_only_prompt:
# Image input
request_func_input.prompt = "Represent the given image."
else:
# Text+Image input
request_func_input.prompt = (
f"Represent the given image with the following question: "
f"{request_func_input.prompt}"
)
async def async_request_openai_embeddings_clip(
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: tqdm | None = None,
) -> RequestFuncOutput:
_preprocess_clip(request_func_input)
return await async_request_openai_embeddings_chat(
request_func_input,
session,
pbar=pbar,
)
async def async_request_openai_embeddings_vlm2vec(
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: tqdm | None = None,
) -> RequestFuncOutput:
_preprocess_vlm2vec(request_func_input)
return await async_request_openai_embeddings_chat(
request_func_input,
session,
pbar=pbar,
mm_position="first",
)
async def async_request_infinity_embeddings(
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: tqdm | None = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
_validate_api_url(api_url, "Infinity Embeddings API", "embeddings")
payload = {
"model": request_func_input.model_name
if request_func_input.model_name
else request_func_input.model,
}
if request_func_input.prompt:
payload["input"] = request_func_input.prompt
else:
mm_content = request_func_input.multi_modal_content
assert isinstance(mm_content, dict)
mm_type = mm_content["type"]
payload["input"] = mm_content[mm_type]["url"]
payload["modality"] = mm_type.split("_", 1)[0]
_update_payload_common(payload, request_func_input)
headers = _get_headers("application/json")
_update_headers_common(headers, request_func_input)
return await _run_pooling_request(
session,
api_url,
payload=payload,
headers=headers,
pbar=pbar,
)
async def async_request_infinity_embeddings_clip(
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: tqdm | None = None,
) -> RequestFuncOutput:
_preprocess_clip(request_func_input)
return await async_request_infinity_embeddings(
request_func_input,
session,
pbar=pbar,
)
async def async_request_vllm_pooling(
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: tqdm | None = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
_validate_api_url(api_url, "vLLM Pooling API", "pooling")
payload = {
"model": request_func_input.model_name
if request_func_input.model_name
else request_func_input.model,
"truncate_prompt_tokens": -1,
}
payload = payload | request_func_input.prompt
_update_payload_common(payload, request_func_input)
headers = _get_headers("application/json")
_update_headers_common(headers, request_func_input)
return await _run_pooling_request(
session,
api_url,
payload=payload,
headers=headers,
pbar=pbar,
)
# TODO: Add more request functions for different API protocols.
ASYNC_REQUEST_FUNCS: dict[str, RequestFunc] = {
"vllm": async_request_openai_completions,
"openai": async_request_openai_completions,
"openai-chat": async_request_openai_chat_completions,
"openai-audio": async_request_openai_audio,
"openai-embeddings": async_request_openai_embeddings,
"openai-embeddings-chat": async_request_openai_embeddings_chat,
"openai-embeddings-clip": async_request_openai_embeddings_clip,
"openai-embeddings-vlm2vec": async_request_openai_embeddings_vlm2vec,
# Infinity embedding server: https://github.com/michaelfeil/infinity
"infinity-embeddings": async_request_infinity_embeddings,
"infinity-embeddings-clip": async_request_infinity_embeddings_clip,
# (Infinity embedding server does not support vlm2vec)
"vllm-pooling": async_request_vllm_pooling,
"vllm-rerank": async_request_vllm_rerank,
}
OPENAI_COMPATIBLE_BACKENDS = [
k
for k, v in ASYNC_REQUEST_FUNCS.items()
if v in (async_request_openai_completions, async_request_openai_chat_completions)
]

View File

@@ -0,0 +1,79 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utilities for checking endpoint readiness."""
import asyncio
import time
import aiohttp
from tqdm.asyncio import tqdm
from vllm.logger import init_logger
from .endpoint_request_func import RequestFunc, RequestFuncInput, RequestFuncOutput
logger = init_logger(__name__)
async def wait_for_endpoint(
request_func: RequestFunc,
test_input: RequestFuncInput,
session: aiohttp.ClientSession,
timeout_seconds: int = 600,
retry_interval: int = 5,
) -> RequestFuncOutput:
"""
Wait for an endpoint to become available before starting benchmarks.
Args:
request_func: The async request function to call
test_input: The RequestFuncInput to test with
timeout_seconds: Maximum time to wait in seconds (default: 10 minutes)
retry_interval: Time between retries in seconds (default: 5 seconds)
Returns:
RequestFuncOutput: The successful response
Raises:
ValueError: If the endpoint doesn't become available within the timeout
"""
deadline = time.perf_counter() + timeout_seconds
output = RequestFuncOutput(success=False)
print(f"Waiting for endpoint to become up in {timeout_seconds} seconds")
with tqdm(
total=timeout_seconds,
bar_format="{desc} |{bar}| {elapsed} elapsed, {remaining} remaining",
unit="s",
) as pbar:
while True:
# update progress bar
remaining = deadline - time.perf_counter()
elapsed = timeout_seconds - remaining
update_amount = min(elapsed - pbar.n, timeout_seconds - pbar.n)
pbar.update(update_amount)
pbar.refresh()
if remaining <= 0:
pbar.close()
break
# ping the endpoint using request_func
try:
output = await request_func(
request_func_input=test_input, session=session
)
if output.success:
pbar.close()
return output
else:
err_last_line = str(output.error).rstrip().rsplit("\n", 1)[-1]
logger.warning("Endpoint is not ready. Error='%s'", err_last_line)
except aiohttp.ClientConnectorError:
pass
# retry after a delay
sleep_duration = min(retry_interval, remaining)
if sleep_duration > 0:
await asyncio.sleep(sleep_duration)
return output

View File

@@ -0,0 +1,131 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import json
import math
import os
from contextlib import contextmanager
from typing import Any
def extract_field(
args: argparse.Namespace, extra_info: dict[str, Any], field_name: str
) -> str:
if field_name in extra_info:
return extra_info[field_name]
v = args
# For example, args.compilation_config.mode
for nested_field in field_name.split("."):
if not hasattr(v, nested_field):
return ""
v = getattr(v, nested_field)
return v
def use_compile(args: argparse.Namespace, extra_info: dict[str, Any]) -> bool:
"""
Check if the benchmark is run with torch.compile
"""
return not (
extract_field(args, extra_info, "compilation_config.mode") == "0"
or "eager" in getattr(args, "output_json", "")
or "eager" in getattr(args, "result_filename", "")
)
def convert_to_pytorch_benchmark_format(
args: argparse.Namespace, metrics: dict[str, list], extra_info: dict[str, Any]
) -> list:
"""
Save the benchmark results in the format used by PyTorch OSS benchmark with
on metric per record
https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database
"""
records = []
if not os.environ.get("SAVE_TO_PYTORCH_BENCHMARK_FORMAT", False):
return records
for name, benchmark_values in metrics.items():
if not isinstance(benchmark_values, list):
raise TypeError(
f"benchmark_values for metric '{name}' must be a list, "
f"but got {type(benchmark_values).__name__}"
)
record = {
"benchmark": {
"name": "vLLM benchmark",
"extra_info": {
"args": vars(args),
"compilation_config.mode": extract_field(
args, extra_info, "compilation_config.mode"
),
"optimization_level": extract_field(
args, extra_info, "optimization_level"
),
# A boolean field used by vLLM benchmark HUD dashboard
"use_compile": use_compile(args, extra_info),
},
},
"model": {
"name": args.model,
},
"metric": {
"name": name,
"benchmark_values": benchmark_values,
"extra_info": extra_info,
},
}
tp = record["benchmark"]["extra_info"]["args"].get("tensor_parallel_size")
# Save tensor_parallel_size parameter if it's part of the metadata
if not tp and "tensor_parallel_size" in extra_info:
record["benchmark"]["extra_info"]["args"]["tensor_parallel_size"] = (
extra_info["tensor_parallel_size"]
)
records.append(record)
return records
class InfEncoder(json.JSONEncoder):
def clear_inf(self, o: Any):
if isinstance(o, dict):
return {
str(k)
if not isinstance(k, (str, int, float, bool, type(None)))
else k: self.clear_inf(v)
for k, v in o.items()
}
elif isinstance(o, list):
return [self.clear_inf(v) for v in o]
elif isinstance(o, float) and math.isinf(o):
return "inf"
return o
def iterencode(self, o: Any, *args, **kwargs) -> Any:
return super().iterencode(self.clear_inf(o), *args, **kwargs)
def write_to_json(filename: str, records: list) -> None:
with open(filename, "w") as f:
json.dump(
records,
f,
cls=InfEncoder,
default=lambda o: f"<{type(o).__name__} is not JSON serializable>",
)
@contextmanager
def default_vllm_config():
"""Set a default VllmConfig for cases that directly test CustomOps or pathways
that use get_current_vllm_config() outside of a full engine context.
"""
from vllm.config import VllmConfig, set_current_vllm_config
with set_current_vllm_config(VllmConfig()):
yield

316
vllm/benchmarks/plot.py Normal file
View File

@@ -0,0 +1,316 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Generate plots for benchmark results."""
from pathlib import Path
from typing import Any
from vllm.utils.import_utils import PlaceholderModule
try:
import plotly.express as px
import plotly.io as pio
except ImportError:
_plotly = PlaceholderModule("plotly")
px = _plotly.placeholder_attr("express")
pio = _plotly.placeholder_attr("io")
try:
import matplotlib.pyplot as plt
except ImportError:
_matplotlib = PlaceholderModule("matplotlib")
plt = _matplotlib.placeholder_attr("pyplot")
def generate_timeline_plot(
results: list[dict[str, Any]],
output_path: Path,
colors: list[str] | None = None,
itl_thresholds: list[float] | None = None,
labels: list[str] | None = None,
) -> None:
"""
Generate an HTML timeline plot from benchmark results.
Args:
results: List of per-request result dictionaries containing:
- start_time: Request start time (seconds)
- ttft: Time to first token (seconds)
- itl: List of inter-token latencies (seconds)
- latency: Total request latency (seconds)
- prompt_len: Number of prompt tokens
- output_tokens: Number of output tokens
output_path: Path where the HTML file will be saved
colors: List of colors for ITL categories (default: green, orange, red, black)
itl_thresholds: ITL thresholds in seconds (default: [1.0, 4.0, 6.0])
labels: Labels for ITL categories (default based on thresholds)
"""
# Set defaults
if colors is None:
colors = ["#109618", "#FF7F0E", "#D62728"]
if itl_thresholds is None:
itl_thresholds = [0.025, 0.050]
if labels is None:
labels = [
f"ITL < {itl_thresholds[0] * 1000:.0f}ms",
f"{itl_thresholds[0] * 1000:.0f}ms ≤ ITL < {itl_thresholds[1] * 1000:.0f}ms", # noqa
f"ITL ≥ {itl_thresholds[1] * 1000:.0f}ms",
]
labels_colors = {"TTFT": "#636EFA", **dict(zip(labels, colors))}
labels_order = ["TTFT"] + labels
timeline_data = construct_timeline_data(results, itl_thresholds, labels)
if not timeline_data:
print("No timeline data to plot")
return
# Create the plot
fig = px.timeline(
timeline_data,
x_start="start",
x_end="end",
y="request_id",
color="type",
color_discrete_map=labels_colors,
category_orders={"type": labels_order},
hover_data=[
"prompt_tokens",
"output_tokens",
"req_start_time",
"req_finish_time",
"segment_start",
"segment_end",
"duration",
],
)
# Customize hover template to show only time without date
fig.update_traces(
hovertemplate="<b>%{y}</b><br>"
"Type: %{fullData.name}<br>"
"Start: %{customdata[4]}<br>"
"End: %{customdata[5]}<br>"
"Duration: %{customdata[6]}<br>"
"Prompt Tokens: %{customdata[0]}<br>"
"Output Tokens: %{customdata[1]}<br>"
"Request Start Time: %{customdata[2]}<br>"
"Request End Time: %{customdata[3]}<br>"
"<extra></extra>"
)
fig.update_yaxes(autorange="reversed")
fig.update_layout(
xaxis_title="Time",
yaxis_title="Request ID",
showlegend=True,
)
# Save to HTML
pio.write_html(fig, str(output_path))
print(f"Timeline plot saved to: {output_path}")
def construct_timeline_data(
requests_data: list[dict[str, Any]],
itl_thresholds: list[float],
labels: list[str],
) -> list[dict[str, Any]]:
"""
Construct timeline data from request results.
Args:
requests_data: List of per-request result dictionaries
itl_thresholds: ITL thresholds in seconds
labels: Labels for ITL categories
Returns:
List of timeline segments for plotting
"""
def tostr(sec_time: float) -> str:
"""Convert seconds to HH:MM:SS.mmm format."""
h = int(sec_time // 3600)
assert h < 100, "time seems to last more than 100 hours"
m = int((sec_time % 3600) // 60)
s = sec_time % 60
return f"{h:02d}:{m:02d}:{s:06.3f}"
def itl_type(itl: float) -> str:
"""Categorize ITL based on thresholds."""
if itl < itl_thresholds[0]:
return labels[0]
elif itl < itl_thresholds[1]:
return labels[1]
else:
return labels[2]
# Find the earliest start time to use as t0
t0 = None
for request in requests_data:
start_time = request.get("start_time")
if start_time is not None and (t0 is None or start_time < t0):
t0 = start_time
if t0 is None:
return []
timeline_data = []
for i, request in enumerate(requests_data):
start_time = request.get("start_time")
ttft = request.get("ttft")
itl = request.get("itl", [])
latency = request.get("latency")
prompt_len = request.get("prompt_len", 0)
output_tokens = request.get("output_tokens", 0)
# Skip requests without required data
if start_time is None or ttft is None or latency is None:
continue
# Normalize start time
start_time = start_time - t0
start_time_str = tostr(start_time)
# TTFT segment
ttft_end = start_time + ttft
ttft_end_str = tostr(ttft_end)
timeline_data.append(
{
"request_id": f"Req {i}",
"start": start_time_str,
"end": ttft_end_str,
"type": "TTFT",
"prompt_tokens": prompt_len,
"output_tokens": output_tokens,
"req_start_time": tostr(start_time),
"req_finish_time": tostr(start_time + latency),
"segment_start": start_time_str,
"segment_end": ttft_end_str,
"duration": f"{ttft:.3f}s",
}
)
# ITL segments
prev_time = ttft_end
prev_time_str = ttft_end_str
for itl_value in itl:
itl_end = prev_time + itl_value
itl_end_str = tostr(itl_end)
timeline_data.append(
{
"request_id": f"Req {i}",
"start": prev_time_str,
"end": itl_end_str,
"type": itl_type(itl_value),
"prompt_tokens": prompt_len,
"output_tokens": output_tokens,
"req_start_time": tostr(start_time),
"req_finish_time": tostr(start_time + latency),
"segment_start": prev_time_str,
"segment_end": itl_end_str,
"duration": f"{itl_value:.3f}s",
}
)
prev_time = itl_end
prev_time_str = itl_end_str
return timeline_data
def generate_dataset_stats_plot(
results: list[dict[str, Any]],
output_path: Path,
) -> None:
"""
Generate a matplotlib figure with dataset statistics.
Creates a figure with 4 subplots:
- Top-left: Prompt tokens distribution (histogram)
- Top-right: Output tokens distribution (histogram)
- Bottom-left: Prompt+output tokens distribution (histogram)
- Bottom-right: Stacked bar chart (request_id vs tokens)
Args:
results: List of per-request result dictionaries containing:
- prompt_len: Number of prompt tokens
- output_tokens: Number of output tokens
output_path: Path where the figure will be saved
"""
# Extract data
prompt_tokens = []
output_tokens = []
total_tokens = []
for request in results:
prompt_len = request.get("prompt_len", 0)
output_len = request.get("output_tokens", 0)
prompt_tokens.append(prompt_len)
output_tokens.append(output_len)
total_tokens.append(prompt_len + output_len)
if not prompt_tokens:
print("No data available for dataset statistics plot")
return
# Create figure with 4 subplots
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(14, 10))
# Top-left: Prompt tokens distribution
ax1.hist(prompt_tokens, bins=30, color="steelblue", edgecolor="black", alpha=0.7)
ax1.set_xlabel("Prompt Tokens")
ax1.set_ylabel("Frequency")
ax1.set_title("Prompt Tokens Distribution")
ax1.grid(True, alpha=0.3)
# Top-right: Output tokens distribution
ax2.hist(output_tokens, bins=30, color="coral", edgecolor="black", alpha=0.7)
ax2.set_xlabel("Output Tokens")
ax2.set_ylabel("Frequency")
ax2.set_title("Output Tokens Distribution")
ax2.grid(True, alpha=0.3)
# Bottom-left: Prompt+output tokens distribution
ax3.hist(
total_tokens, bins=30, color="mediumseagreen", edgecolor="black", alpha=0.7
)
ax3.set_xlabel("Total Tokens (Prompt + Output)")
ax3.set_ylabel("Frequency")
ax3.set_title("Total Tokens Distribution")
ax3.grid(True, alpha=0.3)
# Bottom-right: Stacked bar chart
request_ids = list(range(len(prompt_tokens)))
ax4.bar(
request_ids, prompt_tokens, label="Prompt Tokens", color="steelblue", alpha=0.7
)
ax4.bar(
request_ids,
output_tokens,
bottom=prompt_tokens,
label="Output Tokens",
color="coral",
alpha=0.7,
)
ax4.set_xlabel("Request ID")
ax4.set_ylabel("Tokens")
ax4.set_title("Tokens per Request (Stacked)")
ax4.legend()
ax4.grid(True, alpha=0.3, axis="y")
# Adjust layout to prevent overlap
plt.tight_layout()
# Save figure
plt.savefig(str(output_path), dpi=150, bbox_inches="tight")
plt.close(fig)
print(f"Dataset statistics plot saved to: {output_path}")

View File

@@ -34,6 +34,7 @@ from collections.abc import AsyncGenerator, Iterable
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Any, Literal
import aiohttp
@@ -1183,6 +1184,49 @@ def save_to_pytorch_benchmark_format(
write_to_json(pt_file, pt_records)
def compute_result_filename(
args: argparse.Namespace,
model_id: str,
label: str,
current_dt: str,
) -> str | None:
"""Compute the result filename based on benchmark configuration.
Args:
args: Command line arguments containing result configuration
model_id: The model identifier
label: The benchmark label
current_dt: Current datetime string
Returns:
The computed filename path or None if no result saving is requested
"""
if not (args.plot_timeline or args.save_result or args.append_result):
return None
base_model_id = model_id.split("/")[-1]
max_concurrency_str = (
f"-concurrency{args.max_concurrency}"
if args.max_concurrency is not None
else ""
)
label = label or args.backend
if args.ramp_up_strategy is not None:
file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
else:
file_name = f"{label}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
if args.result_filename:
file_name = args.result_filename
if args.result_dir:
os.makedirs(args.result_dir, exist_ok=True)
file_name = os.path.join(args.result_dir, file_name)
return file_name
def add_cli_args(parser: argparse.ArgumentParser):
add_dataset_parser(parser)
parser.add_argument(
@@ -1277,6 +1321,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
- "slow" will always use the slow tokenizer.\n
- "mistral" will always use the tokenizer from `mistral_common`.\n
- "deepseek_v32" will always use the tokenizer from `deepseek_v32`.\n
- "qwen_vl" will always use the tokenizer from `qwen_vl`.\n
- Other custom values can be supported via plugins.""",
)
parser.add_argument("--use-beam-search", action="store_true")
@@ -1535,6 +1580,30 @@ def add_cli_args(parser: argparse.ArgumentParser):
"connecting to servers with self-signed certificates.",
)
parser.add_argument(
"--plot-timeline",
action="store_true",
help="Generate an HTML timeline plot showing request execution. "
"The plot will be saved alongside the results JSON file.",
)
parser.add_argument(
"--timeline-itl-thresholds",
type=float,
nargs=2,
default=[25.0, 50.0],
metavar=("THRESHOLD1", "THRESHOLD2"),
help="ITL thresholds in milliseconds for timeline plot coloring. "
"Specify two values to categorize inter-token latencies into three groups: "
"below first threshold (green), between thresholds (orange), "
"and above second threshold (red). Default: 25 50 (milliseconds).",
)
parser.add_argument(
"--plot-dataset-stats",
action="store_true",
help="Generate a matplotlib figure with dataset statistics showing "
"prompt tokens, output tokens, and combined token distributions.",
)
def main(args: argparse.Namespace) -> dict[str, Any]:
return asyncio.run(main_async(args))
@@ -1770,6 +1839,86 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
# Merge with benchmark result
result_json = {**result_json, **benchmark_result}
# Compute file_name once before using it for plots or saving results
file_name = compute_result_filename(args, model_id, label, current_dt)
# Generate timeline plot if requested
if args.plot_timeline:
try:
from vllm.benchmarks.plot import generate_timeline_plot
# Prepare per-request data for timeline
per_request_data = []
start_times = benchmark_result.get("start_times", [])
ttfts = benchmark_result.get("ttfts", [])
itls = benchmark_result.get("itls", [])
input_lens = benchmark_result.get("input_lens", [])
output_lens = benchmark_result.get("output_lens", [])
if start_times and ttfts and itls:
for i in range(len(start_times)):
# Calculate latency as ttft + sum of all itls
latency = ttfts[i] + sum(itls[i]) if itls[i] else ttfts[i]
per_request_data.append(
{
"start_time": start_times[i],
"ttft": ttfts[i],
"itl": itls[i],
"latency": latency,
"prompt_len": input_lens[i],
"output_tokens": output_lens[i],
}
)
timeline_path = Path(file_name).with_suffix(".timeline.html")
# Convert thresholds from milliseconds to seconds
itl_thresholds_sec = [t / 1000.0 for t in args.timeline_itl_thresholds]
generate_timeline_plot(
per_request_data, timeline_path, itl_thresholds=itl_thresholds_sec
)
else:
warnings.warn(
"Timeline plot requires detailed metrics. "
"Ensure the benchmark completed successfully.",
stacklevel=2,
)
except Exception as e:
warnings.warn(f"Failed to generate timeline plot: {e}", stacklevel=2)
# Generate dataset statistics plot if requested
if args.plot_dataset_stats:
try:
from vllm.benchmarks.plot import generate_dataset_stats_plot
# Prepare per-request data for dataset stats
per_request_data = []
input_lens = benchmark_result.get("input_lens", [])
output_lens = benchmark_result.get("output_lens", [])
if input_lens and output_lens:
for req_input_len, req_output_len in zip(input_lens, output_lens):
per_request_data.append(
{
"prompt_len": req_input_len,
"output_tokens": req_output_len,
}
)
stats_path = Path(file_name).with_suffix(".dataset_stats.png")
generate_dataset_stats_plot(per_request_data, stats_path)
else:
warnings.warn(
"Dataset statistics plot requires input and "
"output length data. Ensure the benchmark completed "
"successfully.",
stacklevel=2,
)
except Exception as e:
warnings.warn(
f"Failed to generate dataset statistics plot: {e}", stacklevel=2
)
if not args.save_detailed:
# Remove fields with too many data points
for field in [
@@ -1786,24 +1935,8 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
if field in benchmark_result:
del benchmark_result[field]
# Save to file
# Save to file
if args.save_result or args.append_result:
base_model_id = model_id.split("/")[-1]
max_concurrency_str = (
f"-concurrency{args.max_concurrency}"
if args.max_concurrency is not None
else ""
)
label = label or args.backend
if args.ramp_up_strategy is not None:
file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
else:
file_name = f"{label}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
if args.result_filename:
file_name = args.result_filename
if args.result_dir:
os.makedirs(args.result_dir, exist_ok=True)
file_name = os.path.join(args.result_dir, file_name)
with open(
file_name, mode="a+" if args.append_result else "w", encoding="utf-8"
) as outfile:

View File

@@ -10,14 +10,14 @@ from .plot_pareto import SweepPlotParetoArgs
from .plot_pareto import main as plot_pareto_main
from .serve import SweepServeArgs
from .serve import main as serve_main
from .serve_sla import SweepServeSLAArgs
from .serve_sla import main as serve_sla_main
from .serve_workload import SweepServeWorkloadArgs
from .serve_workload import main as serve_workload_main
from .startup import SweepStartupArgs
from .startup import main as startup_main
SUBCOMMANDS = (
(SweepServeArgs, serve_main),
(SweepServeSLAArgs, serve_sla_main),
(SweepServeWorkloadArgs, serve_workload_main),
(SweepStartupArgs, startup_main),
(SweepPlotArgs, plot_main),
(SweepPlotParetoArgs, plot_pareto_main),

View File

@@ -324,6 +324,11 @@ def _plot_fig(
df = filter_by.apply(df)
df = bin_by.apply(df)
if len(df) == 0:
print(f"No data to plot. Filters: {filter_by}")
print("[END FIGURE]")
return
# Sort by curve_by columns alphabetically for consistent legend ordering
if curve_by:
df = df.sort_values(by=curve_by)
@@ -494,7 +499,7 @@ class SweepPlotArgs:
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
output_dir = Path(args.OUTPUT_DIR)
output_dir = Path(args.EXPERIMENT_DIR)
if not output_dir.exists():
raise ValueError(f"No parameter sweep results under {output_dir}")
@@ -526,11 +531,9 @@ class SweepPlotArgs:
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument(
"OUTPUT_DIR",
"EXPERIMENT_DIR",
type=str,
default="results",
help="The directory containing the results to plot, "
"i.e., the `--output-dir` argument to the parameter sweep script.",
help="The directory containing the sweep results to plot.",
)
parser.add_argument(
"--fig-dir",
@@ -570,13 +573,13 @@ class SweepPlotArgs:
parser.add_argument(
"--var-x",
type=str,
default="request_throughput",
default="total_token_throughput",
help="The variable for the x-axis.",
)
parser.add_argument(
"--var-y",
type=str,
default="p99_ttft_ms",
default="median_ttft_ms",
help="The variable for the y-axis",
)
parser.add_argument(

View File

@@ -325,7 +325,7 @@ class SweepPlotParetoArgs:
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
output_dir = Path(args.OUTPUT_DIR)
output_dir = Path(args.EXPERIMENT_DIR)
if not output_dir.exists():
raise ValueError(f"No parameter sweep results under {output_dir}")
@@ -342,9 +342,8 @@ class SweepPlotParetoArgs:
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser):
parser.add_argument(
"OUTPUT_DIR",
"EXPERIMENT_DIR",
type=str,
default="results",
help="The directory containing the sweep results to plot.",
)
parser.add_argument(

View File

@@ -4,6 +4,7 @@ import argparse
import contextlib
import json
import shlex
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
@@ -135,17 +136,21 @@ def run_benchmark(
def _get_comb_base_path(
output_dir: Path,
experiment_dir: Path,
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
*,
extra_parts: tuple[str, ...] = (),
):
parts = list[str]()
if serve_comb:
parts.extend(("SERVE-", serve_comb.name))
if bench_comb:
parts.extend(("BENCH-", bench_comb.name))
if extra_parts:
parts.extend(extra_parts)
return output_dir / sanitize_filename("-".join(parts))
return experiment_dir / sanitize_filename("-".join(parts))
def _get_comb_run_path(base_path: Path, run_number: int | None):
@@ -158,10 +163,10 @@ def _get_comb_run_path(base_path: Path, run_number: int | None):
def _comb_needs_server(
serve_comb: ParameterSweepItem,
bench_combs: ParameterSweep,
output_dir: Path,
experiment_dir: Path,
):
for bench_comb in bench_combs:
base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb)
base_path = _get_comb_base_path(experiment_dir, serve_comb, bench_comb)
if not _get_comb_run_path(base_path, run_number=None).exists():
return True
@@ -175,11 +180,11 @@ def server_ctx(
show_stdout: bool,
serve_comb: ParameterSweepItem,
bench_params: ParameterSweep,
output_dir: Path,
experiment_dir: Path,
dry_run: bool,
server_ready_timeout: int = 300,
):
if not _comb_needs_server(serve_comb, bench_params, output_dir):
if not _comb_needs_server(serve_comb, bench_params, experiment_dir):
return contextlib.nullcontext()
return run_server(
@@ -211,10 +216,10 @@ def run_comb(
*,
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
link_vars: list[tuple[str, str]],
base_path: Path,
num_runs: int,
dry_run: bool,
link_vars: list[tuple[str, str]],
):
if not _comb_is_valid(serve_comb, bench_comb, link_vars):
return None
@@ -253,10 +258,10 @@ def run_combs(
server_ready_timeout: int,
serve_params: ParameterSweep,
bench_params: ParameterSweep,
output_dir: Path,
link_vars: list[tuple[str, str]],
experiment_dir: Path,
num_runs: int,
dry_run: bool,
link_vars: list[tuple[str, str]],
):
all_data = list[dict[str, object]]()
for serve_comb in serve_params:
@@ -266,22 +271,22 @@ def run_combs(
show_stdout=show_stdout,
serve_comb=serve_comb,
bench_params=bench_params,
output_dir=output_dir,
experiment_dir=experiment_dir,
dry_run=dry_run,
server_ready_timeout=server_ready_timeout,
) as server:
for bench_comb in bench_params:
base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb)
base_path = _get_comb_base_path(experiment_dir, serve_comb, bench_comb)
comb_data = run_comb(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb,
link_vars=link_vars,
base_path=base_path,
num_runs=num_runs,
dry_run=dry_run,
link_vars=link_vars,
)
if comb_data is not None:
@@ -291,7 +296,7 @@ def run_combs(
return None
combined_df = pd.DataFrame.from_records(all_data)
combined_df.to_csv(output_dir / "summary.csv")
combined_df.to_csv(experiment_dir / "summary.csv")
return combined_df
@@ -305,11 +310,12 @@ class SweepServeArgs:
server_ready_timeout: int
serve_params: ParameterSweep
bench_params: ParameterSweep
link_vars: list[tuple[str, str]]
output_dir: Path
experiment_name: str
num_runs: int
dry_run: bool
resume: str | None
link_vars: list[tuple[str, str]]
resume: bool
parser_name: ClassVar[str] = "serve"
parser_help: ClassVar[str] = "Run vLLM server benchmark under multiple settings."
@@ -336,6 +342,11 @@ class SweepServeArgs:
link_vars = cls.parse_link_vars(args.link_vars)
if args.experiment_name:
experiment_name = args.experiment_name
else:
experiment_name = datetime.now().strftime("%Y%m%d_%H%M%S")
num_runs = args.num_runs
if num_runs < 1:
raise ValueError("`num_runs` should be at least 1.")
@@ -347,11 +358,12 @@ class SweepServeArgs:
show_stdout=args.show_stdout,
serve_params=serve_params,
bench_params=bench_params,
link_vars=link_vars,
output_dir=Path(args.output_dir),
experiment_name=experiment_name,
num_runs=num_runs,
dry_run=args.dry_run,
resume=args.resume,
link_vars=link_vars,
server_ready_timeout=args.server_ready_timeout,
)
@@ -388,6 +400,7 @@ class SweepServeArgs:
default=300,
help="Timeout in seconds to wait for the server to become ready.",
)
parser.add_argument(
"--serve-params",
type=str,
@@ -398,6 +411,16 @@ class SweepServeArgs:
"If both `serve_params` and `bench_params` are given, "
"this script will iterate over their Cartesian product.",
)
parser.add_argument(
"--link-vars",
type=str,
default="",
help=(
"Comma-separated list of linked variables between serve and bench, "
"e.g. max_num_seqs=max_concurrency,max_model_len=random_input_len"
),
)
parser.add_argument(
"--bench-params",
type=str,
@@ -413,7 +436,15 @@ class SweepServeArgs:
"--output-dir",
type=str,
default="results",
help="The directory to which results are written.",
help="The main directory to which results are written.",
)
parser.add_argument(
"-e",
"--experiment-name",
type=str,
default=None,
help="The name of this experiment (defaults to current timestamp). "
"Results will be stored under `output_dir/experiment_name`.",
)
parser.add_argument(
"--num-runs",
@@ -429,21 +460,10 @@ class SweepServeArgs:
)
parser.add_argument(
"--resume",
type=str,
default=None,
help="Set this to the name of a directory under `output_dir` (which is a "
"timestamp) to resume a previous execution of this script, i.e., only run "
"parameter combinations for which there are still no output files.",
)
parser.add_argument(
"--link-vars",
type=str,
default="",
help=(
"Comma-separated list of linked variables between serve and bench, "
"e.g. max_num_seqs=max_concurrency,max_model_len=random_input_len"
),
action="store_true",
help="Resume a previous execution of this script, i.e., only run "
"parameter combinations for which there are still no output files "
"under `output_dir/experiment_name`.",
)
return parser
@@ -458,33 +478,52 @@ class SweepServeArgs:
pairs.append((a.strip(), b.strip()))
return pairs
def resolve_experiment_dir(self) -> Path:
experiment_dir = self.output_dir / self.experiment_name
if self.resume:
if not experiment_dir.exists():
raise ValueError(f"Cannot resume from non-existent {experiment_dir=}")
else:
if experiment_dir.exists():
raise ValueError(f"Cannot overwrite existing {experiment_dir=}")
return experiment_dir
@contextmanager
def run_ctx(self, experiment_dir: Path):
if self.dry_run:
yield
print(f"Experiment will be saved at: {experiment_dir}")
return
try:
yield
print(f"Experiment has been saved at: {experiment_dir}")
except BaseException as exc:
raise RuntimeError(
"The script was terminated early. Use `--resume` "
"to continue the script from its last checkpoint."
) from exc
def run_main(args: SweepServeArgs):
timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = args.output_dir / timestamp
experiment_dir = args.resolve_experiment_dir()
if args.resume and not output_dir.exists():
raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")
try:
with args.run_ctx(experiment_dir):
return run_combs(
serve_cmd=args.serve_cmd,
bench_cmd=args.bench_cmd,
link_vars=args.link_vars,
after_bench_cmd=args.after_bench_cmd,
show_stdout=args.show_stdout,
server_ready_timeout=args.server_ready_timeout,
serve_params=args.serve_params,
bench_params=args.bench_params,
output_dir=output_dir,
experiment_dir=experiment_dir,
num_runs=args.num_runs,
dry_run=args.dry_run,
link_vars=args.link_vars,
)
except BaseException as exc:
raise RuntimeError(
f"The script was terminated early. Use `--resume {timestamp}` "
f"to continue the script from its last checkpoint."
) from exc
def main(args: argparse.Namespace):

View File

@@ -1,305 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import math
from dataclasses import asdict, dataclass
from datetime import datetime
from pathlib import Path
from typing import ClassVar, Literal, get_args
import numpy as np
from typing_extensions import assert_never
from vllm.utils.import_utils import PlaceholderModule
from .param_sweep import ParameterSweep, ParameterSweepItem
from .serve import (
SweepServeArgs,
_get_comb_base_path,
run_comb,
server_ctx,
)
from .server import ServerProcess
try:
import pandas as pd
except ImportError:
pd = PlaceholderModule("pandas")
SLAVariable = Literal["request_rate", "max_concurrency"]
def _estimate_sla_value(run_data: dict[str, object], sla_variable: SLAVariable):
request_throughput = float(run_data["request_throughput"]) # type: ignore
if sla_variable == "request_rate":
return request_throughput
if sla_variable == "max_concurrency":
mean_latency_ms = float(run_data["mean_e2el_ms"]) # type: ignore
return request_throughput * mean_latency_ms / 1000
assert_never(sla_variable)
def _estimate_sla_avg(runs: list[dict[str, object]], sla_variable: SLAVariable):
return sum(_estimate_sla_value(run, sla_variable) for run in runs) / len(runs)
def run_comb_sla(
server: ServerProcess | None,
bench_cmd: list[str],
*,
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
output_dir: Path,
num_runs: int,
dry_run: bool,
link_vars: list[tuple[str, str]],
sla_variable: SLAVariable,
sla_value: int,
) -> list[dict[str, object]] | None:
bench_comb_sla = bench_comb | {sla_variable: sla_value}
return run_comb(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb_sla,
base_path=_get_comb_base_path(output_dir, serve_comb, bench_comb_sla),
num_runs=num_runs,
dry_run=dry_run,
link_vars=link_vars,
)
def explore_sla(
server: ServerProcess | None,
bench_cmd: list[str],
*,
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
sla_variable: SLAVariable,
sla_iters: int,
output_dir: Path,
num_runs: int,
dry_run: bool,
link_vars: list[tuple[str, str]],
):
print("[SLA START]")
print(f"Serve parameters: {serve_comb.as_text() or '(None)'}")
print(f"Bench parameters: {bench_comb.as_text() or '(None)'}")
print(f"Number of SLA iterations: {sla_iters}")
if sla_iters < 2:
raise ValueError("`sla_iters` should be at least 2")
serial_comb_data = run_comb_sla(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb,
output_dir=output_dir,
num_runs=num_runs,
dry_run=dry_run,
link_vars=link_vars,
sla_variable=sla_variable,
sla_value=1,
)
batch_comb_data = run_comb_sla(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb,
output_dir=output_dir,
num_runs=num_runs,
dry_run=dry_run,
link_vars=link_vars,
sla_variable=sla_variable,
sla_value=int(bench_comb.get("num_prompts", 1000)), # type: ignore
)
if serial_comb_data is None or batch_comb_data is None:
if dry_run:
print("Omitting intermediate SLA iterations.")
print("[SLA END]")
return
serial_sla_value = math.ceil(_estimate_sla_avg(serial_comb_data, sla_variable))
print(f"Serial inference: {sla_variable}={serial_sla_value}")
batch_sla_value = math.floor(_estimate_sla_avg(batch_comb_data, sla_variable))
print(f"Batch inference: {sla_variable}={batch_sla_value}")
# Avoid duplicated runs for intermediate values if the range between
# `serial_sla_value` and `batch_sla_value` is small
inter_sla_values = np.linspace(serial_sla_value, batch_sla_value, sla_iters)[1:-1]
inter_sla_values = sorted(set(map(round, inter_sla_values)))
inter_combs_data: list[dict[str, object]] = []
for inter_sla_value in inter_sla_values:
print(f"Exploring: {sla_variable}={inter_sla_value}")
inter_comb_data = run_comb_sla(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb,
output_dir=output_dir,
num_runs=num_runs,
dry_run=dry_run,
link_vars=link_vars,
sla_variable=sla_variable,
sla_value=inter_sla_value,
)
if inter_comb_data is not None:
inter_combs_data.extend(inter_comb_data)
print("[SLA END]")
return serial_comb_data + inter_combs_data + batch_comb_data
def run_slas(
serve_cmd: list[str],
bench_cmd: list[str],
after_bench_cmd: list[str],
*,
show_stdout: bool,
server_ready_timeout: int,
serve_params: ParameterSweep,
bench_params: ParameterSweep,
sla_variable: SLAVariable,
sla_iters: int,
output_dir: Path,
num_runs: int,
dry_run: bool,
link_vars: list[tuple[str, str]],
):
if any(bench_comb.has_param(sla_variable) for bench_comb in bench_params):
raise ValueError(
f"You should not override `{sla_variable}` in `bench_params` in SLA mode, "
"since it is supposed to be determined automatically."
)
all_data = list[dict[str, object]]()
for serve_comb in serve_params:
with server_ctx(
serve_cmd,
after_bench_cmd,
show_stdout=show_stdout,
server_ready_timeout=server_ready_timeout,
serve_comb=serve_comb,
bench_params=bench_params,
output_dir=output_dir,
dry_run=dry_run,
) as server:
for bench_comb in bench_params:
comb_data = explore_sla(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb,
sla_variable=sla_variable,
sla_iters=sla_iters,
output_dir=output_dir,
num_runs=num_runs,
dry_run=dry_run,
link_vars=link_vars,
)
if comb_data is not None:
all_data.extend(comb_data)
if dry_run:
return None
combined_df = pd.DataFrame.from_records(all_data)
combined_df.to_csv(output_dir / "summary.csv")
return combined_df
@dataclass
class SweepServeSLAArgs(SweepServeArgs):
sla_variable: SLAVariable
sla_iters: int
parser_name: ClassVar[str] = "serve_sla"
parser_help: ClassVar[str] = (
"Explore the latency-throughput space for determining SLAs."
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
# NOTE: Don't use super() as `from_cli_args` calls `cls()`
base_args = SweepServeArgs.from_cli_args(args)
return cls(
**asdict(base_args),
sla_variable=args.sla_variable,
sla_iters=args.sla_iters,
)
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser = super().add_cli_args(parser)
sla_group = parser.add_argument_group("sla options")
sla_group.add_argument(
"--sla-variable",
type=str,
choices=get_args(SLAVariable),
default="request_rate",
help="The variable to adjust in each iteration.",
)
sla_group.add_argument(
"--sla-iters",
type=int,
default=10,
help="Number of iterations used to explore the latency-throughput space. "
"This includes the first two iterations used to interpolate the value of "
"`sla_variable` for remaining iterations.",
)
return parser
def run_main(args: SweepServeSLAArgs):
timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = args.output_dir / timestamp
if args.resume and not output_dir.exists():
raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")
try:
return run_slas(
serve_cmd=args.serve_cmd,
bench_cmd=args.bench_cmd,
after_bench_cmd=args.after_bench_cmd,
show_stdout=args.show_stdout,
server_ready_timeout=args.server_ready_timeout,
serve_params=args.serve_params,
bench_params=args.bench_params,
sla_variable=args.sla_variable,
sla_iters=args.sla_iters,
output_dir=output_dir,
num_runs=args.num_runs,
dry_run=args.dry_run,
link_vars=args.link_vars,
)
except BaseException as exc:
raise RuntimeError(
f"The script was terminated early. Use `--resume {timestamp}` "
f"to continue the script from its last checkpoint."
) from exc
def main(args: argparse.Namespace):
run_main(SweepServeSLAArgs.from_cli_args(args))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=SweepServeSLAArgs.parser_help)
SweepServeSLAArgs.add_cli_args(parser)
main(parser.parse_args())

View File

@@ -0,0 +1,328 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import math
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import ClassVar, Literal, get_args
import numpy as np
from typing_extensions import assert_never
from vllm.benchmarks.datasets import DEFAULT_NUM_PROMPTS
from vllm.utils.import_utils import PlaceholderModule
from .param_sweep import ParameterSweep, ParameterSweepItem
from .serve import (
SweepServeArgs,
_get_comb_base_path,
run_comb,
server_ctx,
)
from .server import ServerProcess
try:
import pandas as pd
except ImportError:
pd = PlaceholderModule("pandas")
WorkloadVariable = Literal["request_rate", "max_concurrency"]
def _estimate_workload_value(
run_data: dict[str, object],
workload_var: WorkloadVariable,
):
request_throughput = float(run_data["request_throughput"]) # type: ignore
if workload_var == "request_rate":
return request_throughput
if workload_var == "max_concurrency":
mean_latency_ms = float(run_data["mean_e2el_ms"]) # type: ignore
return request_throughput * mean_latency_ms / 1000
assert_never(workload_var)
def _estimate_workload_avg(
runs: list[dict[str, object]],
workload_var: WorkloadVariable,
):
total = sum(_estimate_workload_value(run, workload_var) for run in runs)
return total / len(runs)
def run_comb_workload(
server: ServerProcess | None,
bench_cmd: list[str],
*,
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
link_vars: list[tuple[str, str]],
experiment_dir: Path,
num_runs: int,
dry_run: bool,
workload_var: WorkloadVariable,
workload_value: int,
) -> list[dict[str, object]] | None:
bench_comb_workload = bench_comb | {workload_var: workload_value}
return run_comb(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb_workload,
link_vars=link_vars,
base_path=_get_comb_base_path(
experiment_dir,
serve_comb,
bench_comb,
extra_parts=("WL-", f"{workload_var}={workload_value}"),
),
num_runs=num_runs,
dry_run=dry_run,
)
def explore_comb_workloads(
server: ServerProcess | None,
bench_cmd: list[str],
*,
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
link_vars: list[tuple[str, str]],
workload_var: WorkloadVariable,
workload_iters: int,
experiment_dir: Path,
num_runs: int,
dry_run: bool,
):
print("[WL START]")
print(f"Serve parameters: {serve_comb.as_text() or '(None)'}")
print(f"Bench parameters: {bench_comb.as_text() or '(None)'}")
print(f"Number of workload iterations: {workload_iters}")
if workload_iters < 2:
raise ValueError("`workload_iters` should be at least 2")
dataset_size = DEFAULT_NUM_PROMPTS
if "num_prompts" in bench_comb:
dataset_size = int(bench_comb["num_prompts"]) # type: ignore
else:
for i, arg in enumerate(bench_cmd):
if arg == "--num-prompts" and i + 1 < len(bench_cmd):
dataset_size = int(bench_cmd[i + 1])
break
elif arg.startswith("--num-prompts="):
dataset_size = int(arg.split("=", 1)[1])
break
print(f"Dataset size: {dataset_size}")
serial_workload_data = run_comb_workload(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb | {"max_concurrency": 1},
link_vars=link_vars,
experiment_dir=experiment_dir,
num_runs=num_runs,
dry_run=dry_run,
workload_var=workload_var,
workload_value=1,
)
batch_workload_data = run_comb_workload(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb | {"max_concurrency": dataset_size},
link_vars=link_vars,
experiment_dir=experiment_dir,
num_runs=num_runs,
dry_run=dry_run,
workload_var=workload_var,
workload_value=dataset_size,
)
if serial_workload_data is None or batch_workload_data is None:
if dry_run:
print("Omitting intermediate Workload iterations.")
print("[WL END]")
return
serial_workload_value = math.ceil(
_estimate_workload_avg(serial_workload_data, workload_var)
)
print(f"Serial inference: {workload_var}={serial_workload_value}")
batch_workload_value = math.floor(
_estimate_workload_avg(batch_workload_data, workload_var)
)
print(f"Batch inference: {workload_var}={batch_workload_value}")
# Avoid duplicated runs for intermediate values if the range between
# `serial_workload_value` and `batch_workload_value` is small
inter_workload_values = np.linspace(
serial_workload_value, batch_workload_value, workload_iters
)[1:-1]
inter_workload_values = sorted(set(map(round, inter_workload_values)))
inter_workloads_data: list[dict[str, object]] = []
for inter_workload_value in inter_workload_values:
print(f"Exploring: {workload_var}={inter_workload_value}")
inter_workload_data = run_comb_workload(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb,
link_vars=link_vars,
experiment_dir=experiment_dir,
num_runs=num_runs,
dry_run=dry_run,
workload_var=workload_var,
workload_value=inter_workload_value,
)
if inter_workload_data is not None:
inter_workloads_data.extend(inter_workload_data)
print("[WL END]")
return serial_workload_data + inter_workloads_data + batch_workload_data
def explore_combs_workloads(
serve_cmd: list[str],
bench_cmd: list[str],
after_bench_cmd: list[str],
*,
show_stdout: bool,
server_ready_timeout: int,
serve_params: ParameterSweep,
bench_params: ParameterSweep,
link_vars: list[tuple[str, str]],
workload_var: WorkloadVariable,
workload_iters: int,
experiment_dir: Path,
num_runs: int,
dry_run: bool,
):
if any(bench_comb.has_param(workload_var) for bench_comb in bench_params):
raise ValueError(
f"You should not override `{workload_var}` in `bench_params` "
"since it is supposed to be explored automatically."
)
all_data = list[dict[str, object]]()
for serve_comb in serve_params:
with server_ctx(
serve_cmd,
after_bench_cmd,
show_stdout=show_stdout,
server_ready_timeout=server_ready_timeout,
serve_comb=serve_comb,
bench_params=bench_params,
experiment_dir=experiment_dir,
dry_run=dry_run,
) as server:
for bench_comb in bench_params:
comb_data = explore_comb_workloads(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb,
link_vars=link_vars,
workload_var=workload_var,
workload_iters=workload_iters,
experiment_dir=experiment_dir,
num_runs=num_runs,
dry_run=dry_run,
)
if comb_data is not None:
all_data.extend(comb_data)
if dry_run:
return None
combined_df = pd.DataFrame.from_records(all_data)
combined_df.to_csv(experiment_dir / "summary.csv")
return combined_df
@dataclass
class SweepServeWorkloadArgs(SweepServeArgs):
workload_var: WorkloadVariable
workload_iters: int
parser_name: ClassVar[str] = "serve_workload"
parser_help: ClassVar[str] = (
"Explore the latency-throughput tradeoff for different workload levels."
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
# NOTE: Don't use super() as `from_cli_args` calls `cls()`
base_args = SweepServeArgs.from_cli_args(args)
return cls(
**asdict(base_args),
workload_var=args.workload_var,
workload_iters=args.workload_iters,
)
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser = super().add_cli_args(parser)
workload_group = parser.add_argument_group("workload options")
workload_group.add_argument(
"--workload-var",
type=str,
choices=get_args(WorkloadVariable),
default="request_rate",
help="The variable to adjust in each iteration.",
)
workload_group.add_argument(
"--workload-iters",
type=int,
default=10,
help="Number of workload levels to explore. "
"This includes the first two iterations used to interpolate the value of "
"`workload_var` for remaining iterations.",
)
return parser
def run_main(args: SweepServeWorkloadArgs):
experiment_dir = args.resolve_experiment_dir()
with args.run_ctx(experiment_dir):
return explore_combs_workloads(
serve_cmd=args.serve_cmd,
bench_cmd=args.bench_cmd,
after_bench_cmd=args.after_bench_cmd,
show_stdout=args.show_stdout,
server_ready_timeout=args.server_ready_timeout,
serve_params=args.serve_params,
bench_params=args.bench_params,
link_vars=args.link_vars,
workload_var=args.workload_var,
workload_iters=args.workload_iters,
experiment_dir=experiment_dir,
num_runs=args.num_runs,
dry_run=args.dry_run,
)
def main(args: argparse.Namespace):
run_main(SweepServeWorkloadArgs.from_cli_args(args))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=SweepServeWorkloadArgs.parser_help)
SweepServeWorkloadArgs.add_cli_args(parser)
main(parser.parse_args())

View File

@@ -4,6 +4,7 @@ import argparse
import json
import shlex
import subprocess
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime
from functools import lru_cache
@@ -111,7 +112,7 @@ def _apply_output_json(cmd: list[str], output_path: Path) -> list[str]:
def _get_comb_base_path(
output_dir: Path,
experiment_dir: Path,
serve_comb: ParameterSweepItem,
startup_comb: ParameterSweepItem,
) -> Path:
@@ -120,7 +121,8 @@ def _get_comb_base_path(
parts.extend(("SERVE-", serve_comb.name))
if startup_comb:
parts.extend(("STARTUP-", startup_comb.name))
return output_dir / sanitize_filename("-".join(parts))
return experiment_dir / sanitize_filename("-".join(parts))
def _get_comb_run_path(base_path: Path, run_number: int | None) -> Path:
@@ -225,7 +227,7 @@ def run_combs(
*,
serve_params: ParameterSweep,
startup_params: ParameterSweep,
output_dir: Path,
experiment_dir: Path,
num_runs: int,
show_stdout: bool,
dry_run: bool,
@@ -233,7 +235,7 @@ def run_combs(
all_data = list[dict[str, object]]()
for serve_comb in serve_params:
for startup_comb in startup_params:
base_path = _get_comb_base_path(output_dir, serve_comb, startup_comb)
base_path = _get_comb_base_path(experiment_dir, serve_comb, startup_comb)
comb_data = run_comb(
startup_cmd,
serve_comb=serve_comb,
@@ -250,7 +252,7 @@ def run_combs(
return None
combined_df = pd.DataFrame.from_records(all_data)
combined_df.to_csv(output_dir / "summary.csv")
combined_df.to_csv(experiment_dir / "summary.csv")
return combined_df
@@ -260,11 +262,11 @@ class SweepStartupArgs:
serve_params: ParameterSweep
startup_params: ParameterSweep
output_dir: Path
experiment_name: str
num_runs: int
show_stdout: bool
dry_run: bool
resume: str | None
strict_params: bool
resume: bool
parser_name: ClassVar[str] = "startup"
parser_help: ClassVar[str] = (
@@ -286,13 +288,19 @@ class SweepStartupArgs:
startup_params = ParameterSweep.from_records([{}])
supported = _get_supported_startup_keys()
strict_params = args.strict_params
serve_params = _filter_params(
serve_params, supported=supported, strict=args.strict_params
serve_params, supported=supported, strict=strict_params
)
startup_params = _filter_params(
startup_params, supported=supported, strict=args.strict_params
startup_params, supported=supported, strict=strict_params
)
if args.experiment_name:
experiment_name = args.experiment_name
else:
experiment_name = datetime.now().strftime("%Y%m%d_%H%M%S")
if args.num_runs < 1:
raise ValueError("`num_runs` should be at least 1.")
@@ -301,11 +309,11 @@ class SweepStartupArgs:
serve_params=serve_params,
startup_params=startup_params,
output_dir=Path(args.output_dir),
experiment_name=experiment_name,
num_runs=args.num_runs,
show_stdout=args.show_stdout,
dry_run=args.dry_run,
resume=args.resume,
strict_params=args.strict_params,
)
@classmethod
@@ -316,6 +324,7 @@ class SweepStartupArgs:
default="vllm bench startup",
help="The command used to run the startup benchmark.",
)
parser.add_argument(
"--serve-params",
type=str,
@@ -331,12 +340,27 @@ class SweepStartupArgs:
help="Path to JSON file containing parameter combinations "
"for the `vllm bench startup` command.",
)
parser.add_argument(
"--strict-params",
action="store_true",
help="If set, unknown parameters in sweep files raise an error "
"instead of being ignored.",
)
parser.add_argument(
"-o",
"--output-dir",
type=str,
default="results",
help="The directory to which results are written.",
help="The main directory to which results are written.",
)
parser.add_argument(
"-e",
"--experiment-name",
type=str,
default=None,
help="The name of this experiment (defaults to current timestamp). "
"Results will be stored under `output_dir/experiment_name`.",
)
parser.add_argument(
"--num-runs",
@@ -357,43 +381,56 @@ class SweepStartupArgs:
)
parser.add_argument(
"--resume",
type=str,
default=None,
help="Set this to the name of a directory under `output_dir` (which is a "
"timestamp) to resume a previous execution of this script, i.e., only run "
"parameter combinations for which there are still no output files.",
)
parser.add_argument(
"--strict-params",
action="store_true",
help="If set, unknown parameters in sweep files raise an error "
"instead of being ignored.",
help="Resume a previous execution of this script, i.e., only run "
"parameter combinations for which there are still no output files "
"under `output_dir/experiment_name`.",
)
return parser
def resolve_experiment_dir(self) -> Path:
experiment_dir = self.output_dir / self.experiment_name
if self.resume:
if not experiment_dir.exists():
raise ValueError(f"Cannot resume from non-existent {experiment_dir=}")
else:
if experiment_dir.exists():
raise ValueError(f"Cannot overwrite existing {experiment_dir=}")
return experiment_dir
@contextmanager
def run_ctx(self, experiment_dir: Path):
if self.dry_run:
yield
print(f"Experiment will be saved at: {experiment_dir}")
return
try:
yield
print(f"Experiment has been saved at: {experiment_dir}")
except BaseException as exc:
raise RuntimeError(
"The script was terminated early. Use `--resume` "
"to continue the script from its last checkpoint."
) from exc
def run_main(args: SweepStartupArgs):
timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = args.output_dir / timestamp
experiment_dir = args.resolve_experiment_dir()
if args.resume and not output_dir.exists():
raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")
try:
with args.run_ctx(experiment_dir):
return run_combs(
startup_cmd=args.startup_cmd,
serve_params=args.serve_params,
startup_params=args.startup_params,
output_dir=output_dir,
experiment_dir=experiment_dir,
num_runs=args.num_runs,
show_stdout=args.show_stdout,
dry_run=args.dry_run,
)
except BaseException as exc:
raise RuntimeError(
f"The script was terminated early. Use `--resume {timestamp}` "
f"to continue the script from its last checkpoint."
) from exc
def main(args: argparse.Namespace):