Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -31,6 +31,7 @@ from tempfile import NamedTemporaryFile
|
||||
from typing import Any, cast
|
||||
|
||||
import numpy as np
|
||||
from huggingface_hub import snapshot_download
|
||||
from PIL import Image
|
||||
from typing_extensions import deprecated
|
||||
|
||||
@@ -60,6 +61,8 @@ except ImportError:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_NUM_PROMPTS = 1000
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Data Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -303,9 +306,11 @@ def process_image(image: Any) -> Mapping[str, Any]:
|
||||
a JPEG in memory. - Encodes the JPEG data as a base64 string. - Returns
|
||||
a dictionary with the image as a base64 data URL.
|
||||
|
||||
3. String input: - Treats the string as a URL or local file path. -
|
||||
Prepends "file://" if the string doesn't start with "http://" or
|
||||
"file://". - Returns a dictionary with the image URL.
|
||||
3. String input: - Treats the string as a URL, local file path, or base64
|
||||
encoded data. - If string starts with "data:image/", treats as base64.
|
||||
- If string starts with "http://", "https://", or "file://", treats as URL.
|
||||
- Otherwise treats as local file path and prepends "file://".
|
||||
- Returns a dictionary with the image URL or base64 data.
|
||||
|
||||
Raises:
|
||||
ValueError: If the input is not a supported type.
|
||||
@@ -325,14 +330,14 @@ def process_image(image: Any) -> Mapping[str, Any]:
|
||||
if isinstance(image, str):
|
||||
image_url = (
|
||||
image
|
||||
if image.startswith(("http://", "https://", "file://"))
|
||||
if image.startswith(("http://", "https://", "file://", "data:image/"))
|
||||
else f"file://{image}"
|
||||
)
|
||||
return {"type": "image_url", "image_url": {"url": image_url}}
|
||||
|
||||
raise ValueError(
|
||||
f"Invalid image input {image}. Must be a PIL.Image.Image"
|
||||
" or str or dictionary with raw image bytes."
|
||||
f"Invalid image input {image}. Must be a PIL.Image.Image, "
|
||||
"str (URL, file path, or base64 data URL), or dictionary with raw image bytes."
|
||||
)
|
||||
|
||||
|
||||
@@ -1338,7 +1343,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
|
||||
parser.add_argument(
|
||||
"--num-prompts",
|
||||
type=int,
|
||||
default=1000,
|
||||
default=DEFAULT_NUM_PROMPTS,
|
||||
help="Number of prompts to process.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -2676,6 +2681,14 @@ class MMVUDataset(HuggingFaceDataset):
|
||||
+ (" ".join(f"{k}.{v}" for k, v in x["choices"].items())),
|
||||
}
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._remote_path_root = (
|
||||
f"https://huggingface.co/datasets/{self.hf_name}/resolve/main"
|
||||
)
|
||||
self._local_path_root = snapshot_download(self.hf_name, repo_type="dataset")
|
||||
|
||||
def sample(
|
||||
self,
|
||||
tokenizer: TokenizerLike,
|
||||
@@ -2698,7 +2711,9 @@ class MMVUDataset(HuggingFaceDataset):
|
||||
break
|
||||
|
||||
prompt = parser_fn(item)
|
||||
mm_content = process_video(item["video"])
|
||||
mm_content = process_video(
|
||||
item["video"].replace(self._remote_path_root, self._local_path_root)
|
||||
)
|
||||
prompt_len = len(tokenizer.encode(prompt))
|
||||
if enable_multimodal_chat:
|
||||
# Note: when chat is enabled the request prompt_len is no longer
|
||||
|
||||
3
vllm/benchmarks/lib/__init__.py
Normal file
3
vllm/benchmarks/lib/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Benchmark library utilities."""
|
||||
802
vllm/benchmarks/lib/endpoint_request_func.py
Normal file
802
vllm/benchmarks/lib/endpoint_request_func.py
Normal file
@@ -0,0 +1,802 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""The request function for API endpoints."""
|
||||
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from collections.abc import Awaitable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal, Protocol
|
||||
|
||||
import aiohttp
|
||||
import regex as re
|
||||
from tqdm.asyncio import tqdm
|
||||
|
||||
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
||||
|
||||
|
||||
class StreamedResponseHandler:
|
||||
"""Handles streaming HTTP responses by accumulating chunks until complete
|
||||
messages are available."""
|
||||
|
||||
def __init__(self):
|
||||
self.buffer = ""
|
||||
|
||||
def add_chunk(self, chunk_bytes: bytes) -> list[str]:
|
||||
"""Add a chunk of bytes to the buffer and return any complete
|
||||
messages."""
|
||||
chunk_str = chunk_bytes.decode("utf-8")
|
||||
self.buffer += chunk_str
|
||||
|
||||
messages = []
|
||||
|
||||
# Split by double newlines (SSE message separator)
|
||||
while "\n\n" in self.buffer:
|
||||
message, self.buffer = self.buffer.split("\n\n", 1)
|
||||
message = message.strip()
|
||||
if message:
|
||||
messages.append(message)
|
||||
|
||||
# if self.buffer is not empty, check if it is a complete message
|
||||
# by removing data: prefix and check if it is a valid JSON
|
||||
if self.buffer.startswith("data: "):
|
||||
message_content = self.buffer.removeprefix("data: ").strip()
|
||||
if message_content == "[DONE]":
|
||||
messages.append(self.buffer.strip())
|
||||
self.buffer = ""
|
||||
elif message_content:
|
||||
try:
|
||||
json.loads(message_content)
|
||||
messages.append(self.buffer.strip())
|
||||
self.buffer = ""
|
||||
except json.JSONDecodeError:
|
||||
# Incomplete JSON, wait for more chunks.
|
||||
pass
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestFuncInput:
|
||||
"""The input for the request function."""
|
||||
|
||||
prompt: str | list[str]
|
||||
api_url: str
|
||||
prompt_len: int
|
||||
output_len: int
|
||||
model: str
|
||||
model_name: str | None = None
|
||||
logprobs: int | None = None
|
||||
extra_headers: dict | None = None
|
||||
extra_body: dict | None = None
|
||||
multi_modal_content: dict | list[dict] | None = None
|
||||
ignore_eos: bool = False
|
||||
language: str | None = None
|
||||
request_id: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestFuncOutput:
|
||||
"""The output of the request function including metrics."""
|
||||
|
||||
generated_text: str = ""
|
||||
success: bool = False
|
||||
latency: float = 0.0
|
||||
output_tokens: int = 0
|
||||
ttft: float = 0.0 # Time to first token
|
||||
itl: list[float] = field(default_factory=list) # list of inter-token latencies
|
||||
tpot: float = 0.0 # avg next-token latencies
|
||||
prompt_len: int = 0
|
||||
error: str = ""
|
||||
start_time: float = 0.0
|
||||
input_audio_duration: float = 0.0 # in seconds
|
||||
|
||||
|
||||
class RequestFunc(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: tqdm | None = None,
|
||||
) -> Awaitable[RequestFuncOutput]: ...
|
||||
|
||||
|
||||
def _validate_api_url(
|
||||
api_url: str,
|
||||
api_name: str,
|
||||
expected_suffixes: str | set[str],
|
||||
) -> None:
|
||||
if isinstance(expected_suffixes, str):
|
||||
expected_suffixes = {expected_suffixes}
|
||||
|
||||
expected_suffixes = {*expected_suffixes, "profile"}
|
||||
|
||||
if not api_url.endswith(tuple(expected_suffixes)):
|
||||
raise ValueError(f"{api_name} URL must end with one of: {expected_suffixes}.")
|
||||
|
||||
|
||||
def _update_payload_common(
|
||||
payload: dict[str, Any],
|
||||
request_func_input: RequestFuncInput,
|
||||
) -> None:
|
||||
if request_func_input.ignore_eos:
|
||||
payload["ignore_eos"] = request_func_input.ignore_eos
|
||||
if request_func_input.extra_body:
|
||||
payload.update(request_func_input.extra_body)
|
||||
|
||||
|
||||
def _update_headers_common(
|
||||
headers: dict[str, Any],
|
||||
request_func_input: RequestFuncInput,
|
||||
) -> None:
|
||||
if request_func_input.extra_headers:
|
||||
headers |= request_func_input.extra_headers
|
||||
if request_func_input.request_id:
|
||||
headers["x-request-id"] = request_func_input.request_id
|
||||
|
||||
|
||||
def _get_headers(content_type: str | None = None) -> dict[str, str]:
|
||||
headers = {}
|
||||
if content_type:
|
||||
headers["Content-Type"] = content_type
|
||||
api_key = os.environ.get("OPENAI_API_KEY")
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
return headers
|
||||
|
||||
|
||||
async def async_request_openai_completions(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: tqdm | None = None,
|
||||
) -> RequestFuncOutput:
|
||||
"""The async request function for the OpenAI Completions API.
|
||||
|
||||
Args:
|
||||
request_func_input: The input for the request function.
|
||||
pbar: The progress bar to display the progress.
|
||||
|
||||
Returns:
|
||||
The output of the request function.
|
||||
"""
|
||||
api_url = request_func_input.api_url
|
||||
_validate_api_url(api_url, "OpenAI Completions API", "completions")
|
||||
|
||||
payload = {
|
||||
"model": request_func_input.model_name
|
||||
if request_func_input.model_name
|
||||
else request_func_input.model,
|
||||
"prompt": request_func_input.prompt,
|
||||
"repetition_penalty": 1.0,
|
||||
"max_tokens": request_func_input.output_len,
|
||||
"logprobs": request_func_input.logprobs,
|
||||
"stream": True,
|
||||
"stream_options": {
|
||||
"include_usage": True,
|
||||
},
|
||||
}
|
||||
_update_payload_common(payload, request_func_input)
|
||||
|
||||
headers = _get_headers()
|
||||
_update_headers_common(headers, request_func_input)
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
generated_text = ""
|
||||
st = time.perf_counter()
|
||||
output.start_time = st
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
first_chunk_received = False
|
||||
handler = StreamedResponseHandler()
|
||||
|
||||
async for chunk_bytes in response.content.iter_any():
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
messages = handler.add_chunk(chunk_bytes)
|
||||
for message in messages:
|
||||
# NOTE: SSE comments (often used as pings) start with
|
||||
# a colon. These are not JSON data payload and should
|
||||
# be skipped.
|
||||
if message.startswith(":"):
|
||||
continue
|
||||
|
||||
chunk = message.removeprefix("data: ")
|
||||
|
||||
if chunk != "[DONE]":
|
||||
data = json.loads(chunk)
|
||||
|
||||
# NOTE: Some completion API might have a last
|
||||
# usage summary response without a token so we
|
||||
# want to check a token was generated
|
||||
if choices := data.get("choices"):
|
||||
# Note that text could be empty here
|
||||
# e.g. for special tokens
|
||||
text = choices[0].get("text")
|
||||
timestamp = time.perf_counter()
|
||||
# First token
|
||||
if not first_chunk_received:
|
||||
first_chunk_received = True
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp - most_recent_timestamp)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
generated_text += text or ""
|
||||
elif usage := data.get("usage"):
|
||||
output.output_tokens = usage.get("completion_tokens")
|
||||
if first_chunk_received:
|
||||
output.success = True
|
||||
else:
|
||||
output.success = False
|
||||
output.error = (
|
||||
"Never received a valid chunk to calculate TTFT."
|
||||
"This response will be marked as failed!"
|
||||
)
|
||||
output.generated_text = generated_text
|
||||
output.latency = most_recent_timestamp - st
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
except Exception:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
def _get_chat_content(
|
||||
request_func_input: RequestFuncInput,
|
||||
mm_position: Literal["first", "last"] = "last",
|
||||
) -> list[dict[str, Any]]:
|
||||
text_contents = [{"type": "text", "text": request_func_input.prompt}]
|
||||
|
||||
mm_contents = []
|
||||
if request_func_input.multi_modal_content:
|
||||
mm_content = request_func_input.multi_modal_content
|
||||
if isinstance(mm_content, list):
|
||||
mm_contents.extend(request_func_input.multi_modal_content)
|
||||
elif isinstance(mm_content, dict):
|
||||
mm_contents.append(request_func_input.multi_modal_content)
|
||||
else:
|
||||
raise TypeError(
|
||||
"multi_modal_content must be a dict or list[dict] for openai-chat"
|
||||
)
|
||||
|
||||
if mm_position == "first":
|
||||
return mm_contents + text_contents
|
||||
|
||||
return text_contents + mm_contents
|
||||
|
||||
|
||||
async def async_request_openai_chat_completions(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: tqdm | None = None,
|
||||
mm_position: Literal["first", "last"] = "last",
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
_validate_api_url(api_url, "OpenAI Chat Completions API", "chat/completions")
|
||||
|
||||
content = _get_chat_content(request_func_input, mm_position=mm_position)
|
||||
|
||||
payload = {
|
||||
"model": request_func_input.model_name
|
||||
if request_func_input.model_name
|
||||
else request_func_input.model,
|
||||
"messages": [
|
||||
{"role": "user", "content": content},
|
||||
],
|
||||
"max_completion_tokens": request_func_input.output_len,
|
||||
"stream": True,
|
||||
"stream_options": {
|
||||
"include_usage": True,
|
||||
},
|
||||
}
|
||||
_update_payload_common(payload, request_func_input)
|
||||
|
||||
headers = _get_headers("application/json")
|
||||
_update_headers_common(headers, request_func_input)
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
generated_text = ""
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
output.start_time = st
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
handler = StreamedResponseHandler()
|
||||
async for chunk_bytes in response.content.iter_any():
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
messages = handler.add_chunk(chunk_bytes)
|
||||
for message in messages:
|
||||
# NOTE: SSE comments (often used as pings) start with
|
||||
# a colon. These are not JSON data payload and should
|
||||
# be skipped.
|
||||
if message.startswith(":"):
|
||||
continue
|
||||
|
||||
chunk = message.removeprefix("data: ")
|
||||
|
||||
if chunk != "[DONE]":
|
||||
timestamp = time.perf_counter()
|
||||
data = json.loads(chunk)
|
||||
|
||||
if choices := data.get("choices"):
|
||||
content = choices[0]["delta"].get("content")
|
||||
# First token
|
||||
if ttft == 0.0:
|
||||
ttft = timestamp - st
|
||||
output.ttft = ttft
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp - most_recent_timestamp)
|
||||
|
||||
generated_text += content or ""
|
||||
elif usage := data.get("usage"):
|
||||
output.output_tokens = usage.get("completion_tokens")
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
|
||||
output.generated_text = generated_text
|
||||
output.success = True
|
||||
output.latency = most_recent_timestamp - st
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
except Exception:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
async def async_request_openai_audio(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: tqdm | None = None,
|
||||
) -> RequestFuncOutput:
|
||||
# Lazy import without PlaceholderModule to avoid vllm dep.
|
||||
import soundfile
|
||||
|
||||
api_url = request_func_input.api_url
|
||||
_validate_api_url(api_url, "OpenAI Audio API", {"transcriptions", "translations"})
|
||||
|
||||
content = [{"type": "text", "text": request_func_input.prompt}]
|
||||
payload = {
|
||||
"model": request_func_input.model_name
|
||||
if request_func_input.model_name
|
||||
else request_func_input.model,
|
||||
"max_completion_tokens": request_func_input.output_len,
|
||||
"stream": True,
|
||||
"language": "en",
|
||||
# Flattened due to multipart/form-data
|
||||
"stream_include_usage": True,
|
||||
"stream_continuous_usage_stats": True,
|
||||
}
|
||||
_update_payload_common(payload, request_func_input)
|
||||
|
||||
headers = _get_headers()
|
||||
_update_headers_common(headers, request_func_input)
|
||||
|
||||
# Send audio file
|
||||
def to_bytes(y, sr):
|
||||
buffer = io.BytesIO()
|
||||
soundfile.write(buffer, y, sr, format="WAV")
|
||||
buffer.seek(0)
|
||||
return buffer
|
||||
|
||||
mm_audio = request_func_input.multi_modal_content
|
||||
if not isinstance(mm_audio, dict) or "audio" not in mm_audio:
|
||||
raise TypeError("multi_modal_content must be a dict containing 'audio'")
|
||||
with to_bytes(*mm_audio["audio"]) as f:
|
||||
form = aiohttp.FormData()
|
||||
form.add_field("file", f, content_type="audio/wav")
|
||||
for key, value in payload.items():
|
||||
form.add_field(key, str(value))
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
output.input_audio_duration = soundfile.info(f).duration
|
||||
f.seek(0)
|
||||
|
||||
generated_text = ""
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
output.start_time = st
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(
|
||||
url=api_url, data=form, headers=headers
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
handler = StreamedResponseHandler()
|
||||
|
||||
async for chunk_bytes in response.content.iter_any():
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
messages = handler.add_chunk(chunk_bytes)
|
||||
for message in messages:
|
||||
if type(message) is bytes:
|
||||
message = message.decode("utf-8")
|
||||
chunk = message.removeprefix("data: ")
|
||||
if chunk != "[DONE]":
|
||||
timestamp = time.perf_counter()
|
||||
data = json.loads(chunk)
|
||||
|
||||
if choices := data.get("choices"):
|
||||
content = choices[0]["delta"].get("content")
|
||||
# First token
|
||||
if ttft == 0.0:
|
||||
ttft = timestamp - st
|
||||
output.ttft = ttft
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(
|
||||
timestamp - most_recent_timestamp
|
||||
)
|
||||
|
||||
generated_text += content or ""
|
||||
elif usage := data.get("usage"):
|
||||
output.output_tokens = usage.get(
|
||||
"completion_tokens"
|
||||
)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
|
||||
output.generated_text = generated_text
|
||||
output.success = True
|
||||
output.latency = most_recent_timestamp - st
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
except Exception:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
async def _run_pooling_request(
|
||||
session: aiohttp.ClientSession,
|
||||
api_url: str,
|
||||
payload: dict[str, Any],
|
||||
headers: dict[str, Any],
|
||||
pbar: tqdm | None = None,
|
||||
) -> RequestFuncOutput:
|
||||
output = RequestFuncOutput()
|
||||
st = time.perf_counter()
|
||||
output.start_time = st
|
||||
try:
|
||||
async with session.post(url=api_url, headers=headers, json=payload) as response:
|
||||
if response.status == 200:
|
||||
output.ttft = output.latency = time.perf_counter() - st
|
||||
|
||||
if payload.get("encoding_format", "float") == "bytes":
|
||||
metadata = json.loads(response.headers["metadata"])
|
||||
usage = metadata.get("usage", {})
|
||||
else:
|
||||
data = await response.json()
|
||||
usage = data.get("usage", {})
|
||||
|
||||
output.success = True
|
||||
output.generated_text = ""
|
||||
output.prompt_len = usage.get("prompt_tokens", 0)
|
||||
else:
|
||||
output.success = False
|
||||
output.error = response.reason or ""
|
||||
except Exception as e:
|
||||
output.success = False
|
||||
output.error = str(e)
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
async def async_request_openai_embeddings(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: tqdm | None = None,
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
_validate_api_url(api_url, "OpenAI Embeddings API", "embeddings")
|
||||
|
||||
payload = {
|
||||
"model": request_func_input.model_name
|
||||
if request_func_input.model_name
|
||||
else request_func_input.model,
|
||||
"input": request_func_input.prompt,
|
||||
# Many embedding models have short context length,
|
||||
# this is to avoid dropping some of the requests.
|
||||
"truncate_prompt_tokens": -1,
|
||||
}
|
||||
_update_payload_common(payload, request_func_input)
|
||||
|
||||
headers = _get_headers("application/json")
|
||||
_update_headers_common(headers, request_func_input)
|
||||
|
||||
return await _run_pooling_request(
|
||||
session,
|
||||
api_url,
|
||||
payload=payload,
|
||||
headers=headers,
|
||||
pbar=pbar,
|
||||
)
|
||||
|
||||
|
||||
async def async_request_vllm_rerank(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: tqdm | None = None,
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
_validate_api_url(api_url, "vLLM score API", "rerank")
|
||||
|
||||
assert (
|
||||
isinstance(request_func_input.prompt, list)
|
||||
and len(request_func_input.prompt) > 1
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": request_func_input.model_name
|
||||
if request_func_input.model_name
|
||||
else request_func_input.model,
|
||||
"query": request_func_input.prompt[0],
|
||||
"documents": request_func_input.prompt[1:],
|
||||
# Many reranker models have short context length,
|
||||
# this is to avoid dropping some of the requests.
|
||||
"truncate_prompt_tokens": -1,
|
||||
}
|
||||
|
||||
headers = _get_headers("application/json")
|
||||
_update_headers_common(headers, request_func_input)
|
||||
|
||||
return await _run_pooling_request(
|
||||
session,
|
||||
api_url,
|
||||
payload=payload,
|
||||
headers=headers,
|
||||
pbar=pbar,
|
||||
)
|
||||
|
||||
|
||||
async def async_request_openai_embeddings_chat(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: tqdm | None = None,
|
||||
mm_position: Literal["first", "last"] = "last",
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
_validate_api_url(api_url, "OpenAI Embeddings API", "embeddings")
|
||||
|
||||
content = _get_chat_content(request_func_input, mm_position=mm_position)
|
||||
|
||||
payload = {
|
||||
"model": request_func_input.model_name
|
||||
if request_func_input.model_name
|
||||
else request_func_input.model,
|
||||
"messages": [
|
||||
{"role": "user", "content": content},
|
||||
],
|
||||
# Many embedding models have short context length,
|
||||
# this is to avoid dropping some of the requests.
|
||||
"truncate_prompt_tokens": -1,
|
||||
}
|
||||
_update_payload_common(payload, request_func_input)
|
||||
|
||||
headers = _get_headers("application/json")
|
||||
_update_headers_common(headers, request_func_input)
|
||||
|
||||
return await _run_pooling_request(
|
||||
session,
|
||||
api_url,
|
||||
payload=payload,
|
||||
headers=headers,
|
||||
pbar=pbar,
|
||||
)
|
||||
|
||||
|
||||
def _try_extract_request_idx(request_func_input: RequestFuncInput):
|
||||
if request_func_input.request_id:
|
||||
match = re.search(r"(\d+)$", request_func_input.request_id)
|
||||
if match:
|
||||
try:
|
||||
return int(match.group(1))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _preprocess_clip(request_func_input: RequestFuncInput):
|
||||
if request_func_input.multi_modal_content:
|
||||
# Image input
|
||||
request_func_input.prompt = ""
|
||||
|
||||
|
||||
def _preprocess_vlm2vec(request_func_input: RequestFuncInput):
|
||||
if request_func_input.multi_modal_content:
|
||||
request_idx = _try_extract_request_idx(request_func_input)
|
||||
|
||||
# Adjust the ratio manually if needed.
|
||||
use_image_only_prompt = request_idx is None or request_idx % 2 == 0
|
||||
|
||||
if use_image_only_prompt:
|
||||
# Image input
|
||||
request_func_input.prompt = "Represent the given image."
|
||||
else:
|
||||
# Text+Image input
|
||||
request_func_input.prompt = (
|
||||
f"Represent the given image with the following question: "
|
||||
f"{request_func_input.prompt}"
|
||||
)
|
||||
|
||||
|
||||
async def async_request_openai_embeddings_clip(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: tqdm | None = None,
|
||||
) -> RequestFuncOutput:
|
||||
_preprocess_clip(request_func_input)
|
||||
|
||||
return await async_request_openai_embeddings_chat(
|
||||
request_func_input,
|
||||
session,
|
||||
pbar=pbar,
|
||||
)
|
||||
|
||||
|
||||
async def async_request_openai_embeddings_vlm2vec(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: tqdm | None = None,
|
||||
) -> RequestFuncOutput:
|
||||
_preprocess_vlm2vec(request_func_input)
|
||||
|
||||
return await async_request_openai_embeddings_chat(
|
||||
request_func_input,
|
||||
session,
|
||||
pbar=pbar,
|
||||
mm_position="first",
|
||||
)
|
||||
|
||||
|
||||
async def async_request_infinity_embeddings(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: tqdm | None = None,
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
_validate_api_url(api_url, "Infinity Embeddings API", "embeddings")
|
||||
|
||||
payload = {
|
||||
"model": request_func_input.model_name
|
||||
if request_func_input.model_name
|
||||
else request_func_input.model,
|
||||
}
|
||||
|
||||
if request_func_input.prompt:
|
||||
payload["input"] = request_func_input.prompt
|
||||
else:
|
||||
mm_content = request_func_input.multi_modal_content
|
||||
assert isinstance(mm_content, dict)
|
||||
|
||||
mm_type = mm_content["type"]
|
||||
payload["input"] = mm_content[mm_type]["url"]
|
||||
payload["modality"] = mm_type.split("_", 1)[0]
|
||||
|
||||
_update_payload_common(payload, request_func_input)
|
||||
|
||||
headers = _get_headers("application/json")
|
||||
_update_headers_common(headers, request_func_input)
|
||||
|
||||
return await _run_pooling_request(
|
||||
session,
|
||||
api_url,
|
||||
payload=payload,
|
||||
headers=headers,
|
||||
pbar=pbar,
|
||||
)
|
||||
|
||||
|
||||
async def async_request_infinity_embeddings_clip(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: tqdm | None = None,
|
||||
) -> RequestFuncOutput:
|
||||
_preprocess_clip(request_func_input)
|
||||
|
||||
return await async_request_infinity_embeddings(
|
||||
request_func_input,
|
||||
session,
|
||||
pbar=pbar,
|
||||
)
|
||||
|
||||
|
||||
async def async_request_vllm_pooling(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: tqdm | None = None,
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
_validate_api_url(api_url, "vLLM Pooling API", "pooling")
|
||||
|
||||
payload = {
|
||||
"model": request_func_input.model_name
|
||||
if request_func_input.model_name
|
||||
else request_func_input.model,
|
||||
"truncate_prompt_tokens": -1,
|
||||
}
|
||||
|
||||
payload = payload | request_func_input.prompt
|
||||
|
||||
_update_payload_common(payload, request_func_input)
|
||||
|
||||
headers = _get_headers("application/json")
|
||||
_update_headers_common(headers, request_func_input)
|
||||
|
||||
return await _run_pooling_request(
|
||||
session,
|
||||
api_url,
|
||||
payload=payload,
|
||||
headers=headers,
|
||||
pbar=pbar,
|
||||
)
|
||||
|
||||
|
||||
# TODO: Add more request functions for different API protocols.
|
||||
ASYNC_REQUEST_FUNCS: dict[str, RequestFunc] = {
|
||||
"vllm": async_request_openai_completions,
|
||||
"openai": async_request_openai_completions,
|
||||
"openai-chat": async_request_openai_chat_completions,
|
||||
"openai-audio": async_request_openai_audio,
|
||||
"openai-embeddings": async_request_openai_embeddings,
|
||||
"openai-embeddings-chat": async_request_openai_embeddings_chat,
|
||||
"openai-embeddings-clip": async_request_openai_embeddings_clip,
|
||||
"openai-embeddings-vlm2vec": async_request_openai_embeddings_vlm2vec,
|
||||
# Infinity embedding server: https://github.com/michaelfeil/infinity
|
||||
"infinity-embeddings": async_request_infinity_embeddings,
|
||||
"infinity-embeddings-clip": async_request_infinity_embeddings_clip,
|
||||
# (Infinity embedding server does not support vlm2vec)
|
||||
"vllm-pooling": async_request_vllm_pooling,
|
||||
"vllm-rerank": async_request_vllm_rerank,
|
||||
}
|
||||
|
||||
OPENAI_COMPATIBLE_BACKENDS = [
|
||||
k
|
||||
for k, v in ASYNC_REQUEST_FUNCS.items()
|
||||
if v in (async_request_openai_completions, async_request_openai_chat_completions)
|
||||
]
|
||||
79
vllm/benchmarks/lib/ready_checker.py
Normal file
79
vllm/benchmarks/lib/ready_checker.py
Normal file
@@ -0,0 +1,79 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Utilities for checking endpoint readiness."""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import aiohttp
|
||||
from tqdm.asyncio import tqdm
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .endpoint_request_func import RequestFunc, RequestFuncInput, RequestFuncOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
async def wait_for_endpoint(
|
||||
request_func: RequestFunc,
|
||||
test_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
timeout_seconds: int = 600,
|
||||
retry_interval: int = 5,
|
||||
) -> RequestFuncOutput:
|
||||
"""
|
||||
Wait for an endpoint to become available before starting benchmarks.
|
||||
|
||||
Args:
|
||||
request_func: The async request function to call
|
||||
test_input: The RequestFuncInput to test with
|
||||
timeout_seconds: Maximum time to wait in seconds (default: 10 minutes)
|
||||
retry_interval: Time between retries in seconds (default: 5 seconds)
|
||||
|
||||
Returns:
|
||||
RequestFuncOutput: The successful response
|
||||
|
||||
Raises:
|
||||
ValueError: If the endpoint doesn't become available within the timeout
|
||||
"""
|
||||
deadline = time.perf_counter() + timeout_seconds
|
||||
output = RequestFuncOutput(success=False)
|
||||
print(f"Waiting for endpoint to become up in {timeout_seconds} seconds")
|
||||
|
||||
with tqdm(
|
||||
total=timeout_seconds,
|
||||
bar_format="{desc} |{bar}| {elapsed} elapsed, {remaining} remaining",
|
||||
unit="s",
|
||||
) as pbar:
|
||||
while True:
|
||||
# update progress bar
|
||||
remaining = deadline - time.perf_counter()
|
||||
elapsed = timeout_seconds - remaining
|
||||
update_amount = min(elapsed - pbar.n, timeout_seconds - pbar.n)
|
||||
pbar.update(update_amount)
|
||||
pbar.refresh()
|
||||
if remaining <= 0:
|
||||
pbar.close()
|
||||
break
|
||||
|
||||
# ping the endpoint using request_func
|
||||
try:
|
||||
output = await request_func(
|
||||
request_func_input=test_input, session=session
|
||||
)
|
||||
if output.success:
|
||||
pbar.close()
|
||||
return output
|
||||
else:
|
||||
err_last_line = str(output.error).rstrip().rsplit("\n", 1)[-1]
|
||||
logger.warning("Endpoint is not ready. Error='%s'", err_last_line)
|
||||
except aiohttp.ClientConnectorError:
|
||||
pass
|
||||
|
||||
# retry after a delay
|
||||
sleep_duration = min(retry_interval, remaining)
|
||||
if sleep_duration > 0:
|
||||
await asyncio.sleep(sleep_duration)
|
||||
|
||||
return output
|
||||
131
vllm/benchmarks/lib/utils.py
Normal file
131
vllm/benchmarks/lib/utils.py
Normal file
@@ -0,0 +1,131 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
|
||||
def extract_field(
|
||||
args: argparse.Namespace, extra_info: dict[str, Any], field_name: str
|
||||
) -> str:
|
||||
if field_name in extra_info:
|
||||
return extra_info[field_name]
|
||||
|
||||
v = args
|
||||
# For example, args.compilation_config.mode
|
||||
for nested_field in field_name.split("."):
|
||||
if not hasattr(v, nested_field):
|
||||
return ""
|
||||
v = getattr(v, nested_field)
|
||||
return v
|
||||
|
||||
|
||||
def use_compile(args: argparse.Namespace, extra_info: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Check if the benchmark is run with torch.compile
|
||||
"""
|
||||
return not (
|
||||
extract_field(args, extra_info, "compilation_config.mode") == "0"
|
||||
or "eager" in getattr(args, "output_json", "")
|
||||
or "eager" in getattr(args, "result_filename", "")
|
||||
)
|
||||
|
||||
|
||||
def convert_to_pytorch_benchmark_format(
|
||||
args: argparse.Namespace, metrics: dict[str, list], extra_info: dict[str, Any]
|
||||
) -> list:
|
||||
"""
|
||||
Save the benchmark results in the format used by PyTorch OSS benchmark with
|
||||
on metric per record
|
||||
https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database
|
||||
"""
|
||||
records = []
|
||||
if not os.environ.get("SAVE_TO_PYTORCH_BENCHMARK_FORMAT", False):
|
||||
return records
|
||||
|
||||
for name, benchmark_values in metrics.items():
|
||||
if not isinstance(benchmark_values, list):
|
||||
raise TypeError(
|
||||
f"benchmark_values for metric '{name}' must be a list, "
|
||||
f"but got {type(benchmark_values).__name__}"
|
||||
)
|
||||
|
||||
record = {
|
||||
"benchmark": {
|
||||
"name": "vLLM benchmark",
|
||||
"extra_info": {
|
||||
"args": vars(args),
|
||||
"compilation_config.mode": extract_field(
|
||||
args, extra_info, "compilation_config.mode"
|
||||
),
|
||||
"optimization_level": extract_field(
|
||||
args, extra_info, "optimization_level"
|
||||
),
|
||||
# A boolean field used by vLLM benchmark HUD dashboard
|
||||
"use_compile": use_compile(args, extra_info),
|
||||
},
|
||||
},
|
||||
"model": {
|
||||
"name": args.model,
|
||||
},
|
||||
"metric": {
|
||||
"name": name,
|
||||
"benchmark_values": benchmark_values,
|
||||
"extra_info": extra_info,
|
||||
},
|
||||
}
|
||||
|
||||
tp = record["benchmark"]["extra_info"]["args"].get("tensor_parallel_size")
|
||||
# Save tensor_parallel_size parameter if it's part of the metadata
|
||||
if not tp and "tensor_parallel_size" in extra_info:
|
||||
record["benchmark"]["extra_info"]["args"]["tensor_parallel_size"] = (
|
||||
extra_info["tensor_parallel_size"]
|
||||
)
|
||||
|
||||
records.append(record)
|
||||
|
||||
return records
|
||||
|
||||
|
||||
class InfEncoder(json.JSONEncoder):
|
||||
def clear_inf(self, o: Any):
|
||||
if isinstance(o, dict):
|
||||
return {
|
||||
str(k)
|
||||
if not isinstance(k, (str, int, float, bool, type(None)))
|
||||
else k: self.clear_inf(v)
|
||||
for k, v in o.items()
|
||||
}
|
||||
elif isinstance(o, list):
|
||||
return [self.clear_inf(v) for v in o]
|
||||
elif isinstance(o, float) and math.isinf(o):
|
||||
return "inf"
|
||||
return o
|
||||
|
||||
def iterencode(self, o: Any, *args, **kwargs) -> Any:
|
||||
return super().iterencode(self.clear_inf(o), *args, **kwargs)
|
||||
|
||||
|
||||
def write_to_json(filename: str, records: list) -> None:
|
||||
with open(filename, "w") as f:
|
||||
json.dump(
|
||||
records,
|
||||
f,
|
||||
cls=InfEncoder,
|
||||
default=lambda o: f"<{type(o).__name__} is not JSON serializable>",
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def default_vllm_config():
|
||||
"""Set a default VllmConfig for cases that directly test CustomOps or pathways
|
||||
that use get_current_vllm_config() outside of a full engine context.
|
||||
"""
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
|
||||
with set_current_vllm_config(VllmConfig()):
|
||||
yield
|
||||
316
vllm/benchmarks/plot.py
Normal file
316
vllm/benchmarks/plot.py
Normal file
@@ -0,0 +1,316 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Generate plots for benchmark results."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
try:
|
||||
import plotly.express as px
|
||||
import plotly.io as pio
|
||||
except ImportError:
|
||||
_plotly = PlaceholderModule("plotly")
|
||||
px = _plotly.placeholder_attr("express")
|
||||
pio = _plotly.placeholder_attr("io")
|
||||
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
except ImportError:
|
||||
_matplotlib = PlaceholderModule("matplotlib")
|
||||
plt = _matplotlib.placeholder_attr("pyplot")
|
||||
|
||||
|
||||
def generate_timeline_plot(
|
||||
results: list[dict[str, Any]],
|
||||
output_path: Path,
|
||||
colors: list[str] | None = None,
|
||||
itl_thresholds: list[float] | None = None,
|
||||
labels: list[str] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Generate an HTML timeline plot from benchmark results.
|
||||
|
||||
Args:
|
||||
results: List of per-request result dictionaries containing:
|
||||
- start_time: Request start time (seconds)
|
||||
- ttft: Time to first token (seconds)
|
||||
- itl: List of inter-token latencies (seconds)
|
||||
- latency: Total request latency (seconds)
|
||||
- prompt_len: Number of prompt tokens
|
||||
- output_tokens: Number of output tokens
|
||||
output_path: Path where the HTML file will be saved
|
||||
colors: List of colors for ITL categories (default: green, orange, red, black)
|
||||
itl_thresholds: ITL thresholds in seconds (default: [1.0, 4.0, 6.0])
|
||||
labels: Labels for ITL categories (default based on thresholds)
|
||||
"""
|
||||
|
||||
# Set defaults
|
||||
if colors is None:
|
||||
colors = ["#109618", "#FF7F0E", "#D62728"]
|
||||
if itl_thresholds is None:
|
||||
itl_thresholds = [0.025, 0.050]
|
||||
if labels is None:
|
||||
labels = [
|
||||
f"ITL < {itl_thresholds[0] * 1000:.0f}ms",
|
||||
f"{itl_thresholds[0] * 1000:.0f}ms ≤ ITL < {itl_thresholds[1] * 1000:.0f}ms", # noqa
|
||||
f"ITL ≥ {itl_thresholds[1] * 1000:.0f}ms",
|
||||
]
|
||||
|
||||
labels_colors = {"TTFT": "#636EFA", **dict(zip(labels, colors))}
|
||||
labels_order = ["TTFT"] + labels
|
||||
|
||||
timeline_data = construct_timeline_data(results, itl_thresholds, labels)
|
||||
|
||||
if not timeline_data:
|
||||
print("No timeline data to plot")
|
||||
return
|
||||
|
||||
# Create the plot
|
||||
fig = px.timeline(
|
||||
timeline_data,
|
||||
x_start="start",
|
||||
x_end="end",
|
||||
y="request_id",
|
||||
color="type",
|
||||
color_discrete_map=labels_colors,
|
||||
category_orders={"type": labels_order},
|
||||
hover_data=[
|
||||
"prompt_tokens",
|
||||
"output_tokens",
|
||||
"req_start_time",
|
||||
"req_finish_time",
|
||||
"segment_start",
|
||||
"segment_end",
|
||||
"duration",
|
||||
],
|
||||
)
|
||||
|
||||
# Customize hover template to show only time without date
|
||||
fig.update_traces(
|
||||
hovertemplate="<b>%{y}</b><br>"
|
||||
"Type: %{fullData.name}<br>"
|
||||
"Start: %{customdata[4]}<br>"
|
||||
"End: %{customdata[5]}<br>"
|
||||
"Duration: %{customdata[6]}<br>"
|
||||
"Prompt Tokens: %{customdata[0]}<br>"
|
||||
"Output Tokens: %{customdata[1]}<br>"
|
||||
"Request Start Time: %{customdata[2]}<br>"
|
||||
"Request End Time: %{customdata[3]}<br>"
|
||||
"<extra></extra>"
|
||||
)
|
||||
|
||||
fig.update_yaxes(autorange="reversed")
|
||||
fig.update_layout(
|
||||
xaxis_title="Time",
|
||||
yaxis_title="Request ID",
|
||||
showlegend=True,
|
||||
)
|
||||
|
||||
# Save to HTML
|
||||
pio.write_html(fig, str(output_path))
|
||||
print(f"Timeline plot saved to: {output_path}")
|
||||
|
||||
|
||||
def construct_timeline_data(
|
||||
requests_data: list[dict[str, Any]],
|
||||
itl_thresholds: list[float],
|
||||
labels: list[str],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Construct timeline data from request results.
|
||||
|
||||
Args:
|
||||
requests_data: List of per-request result dictionaries
|
||||
itl_thresholds: ITL thresholds in seconds
|
||||
labels: Labels for ITL categories
|
||||
|
||||
Returns:
|
||||
List of timeline segments for plotting
|
||||
"""
|
||||
|
||||
def tostr(sec_time: float) -> str:
|
||||
"""Convert seconds to HH:MM:SS.mmm format."""
|
||||
h = int(sec_time // 3600)
|
||||
assert h < 100, "time seems to last more than 100 hours"
|
||||
m = int((sec_time % 3600) // 60)
|
||||
s = sec_time % 60
|
||||
return f"{h:02d}:{m:02d}:{s:06.3f}"
|
||||
|
||||
def itl_type(itl: float) -> str:
|
||||
"""Categorize ITL based on thresholds."""
|
||||
if itl < itl_thresholds[0]:
|
||||
return labels[0]
|
||||
elif itl < itl_thresholds[1]:
|
||||
return labels[1]
|
||||
else:
|
||||
return labels[2]
|
||||
|
||||
# Find the earliest start time to use as t0
|
||||
t0 = None
|
||||
for request in requests_data:
|
||||
start_time = request.get("start_time")
|
||||
if start_time is not None and (t0 is None or start_time < t0):
|
||||
t0 = start_time
|
||||
|
||||
if t0 is None:
|
||||
return []
|
||||
|
||||
timeline_data = []
|
||||
|
||||
for i, request in enumerate(requests_data):
|
||||
start_time = request.get("start_time")
|
||||
ttft = request.get("ttft")
|
||||
itl = request.get("itl", [])
|
||||
latency = request.get("latency")
|
||||
prompt_len = request.get("prompt_len", 0)
|
||||
output_tokens = request.get("output_tokens", 0)
|
||||
|
||||
# Skip requests without required data
|
||||
if start_time is None or ttft is None or latency is None:
|
||||
continue
|
||||
|
||||
# Normalize start time
|
||||
start_time = start_time - t0
|
||||
start_time_str = tostr(start_time)
|
||||
|
||||
# TTFT segment
|
||||
ttft_end = start_time + ttft
|
||||
ttft_end_str = tostr(ttft_end)
|
||||
|
||||
timeline_data.append(
|
||||
{
|
||||
"request_id": f"Req {i}",
|
||||
"start": start_time_str,
|
||||
"end": ttft_end_str,
|
||||
"type": "TTFT",
|
||||
"prompt_tokens": prompt_len,
|
||||
"output_tokens": output_tokens,
|
||||
"req_start_time": tostr(start_time),
|
||||
"req_finish_time": tostr(start_time + latency),
|
||||
"segment_start": start_time_str,
|
||||
"segment_end": ttft_end_str,
|
||||
"duration": f"{ttft:.3f}s",
|
||||
}
|
||||
)
|
||||
|
||||
# ITL segments
|
||||
prev_time = ttft_end
|
||||
prev_time_str = ttft_end_str
|
||||
|
||||
for itl_value in itl:
|
||||
itl_end = prev_time + itl_value
|
||||
itl_end_str = tostr(itl_end)
|
||||
|
||||
timeline_data.append(
|
||||
{
|
||||
"request_id": f"Req {i}",
|
||||
"start": prev_time_str,
|
||||
"end": itl_end_str,
|
||||
"type": itl_type(itl_value),
|
||||
"prompt_tokens": prompt_len,
|
||||
"output_tokens": output_tokens,
|
||||
"req_start_time": tostr(start_time),
|
||||
"req_finish_time": tostr(start_time + latency),
|
||||
"segment_start": prev_time_str,
|
||||
"segment_end": itl_end_str,
|
||||
"duration": f"{itl_value:.3f}s",
|
||||
}
|
||||
)
|
||||
|
||||
prev_time = itl_end
|
||||
prev_time_str = itl_end_str
|
||||
|
||||
return timeline_data
|
||||
|
||||
|
||||
def generate_dataset_stats_plot(
|
||||
results: list[dict[str, Any]],
|
||||
output_path: Path,
|
||||
) -> None:
|
||||
"""
|
||||
Generate a matplotlib figure with dataset statistics.
|
||||
|
||||
Creates a figure with 4 subplots:
|
||||
- Top-left: Prompt tokens distribution (histogram)
|
||||
- Top-right: Output tokens distribution (histogram)
|
||||
- Bottom-left: Prompt+output tokens distribution (histogram)
|
||||
- Bottom-right: Stacked bar chart (request_id vs tokens)
|
||||
|
||||
Args:
|
||||
results: List of per-request result dictionaries containing:
|
||||
- prompt_len: Number of prompt tokens
|
||||
- output_tokens: Number of output tokens
|
||||
output_path: Path where the figure will be saved
|
||||
"""
|
||||
# Extract data
|
||||
prompt_tokens = []
|
||||
output_tokens = []
|
||||
total_tokens = []
|
||||
|
||||
for request in results:
|
||||
prompt_len = request.get("prompt_len", 0)
|
||||
output_len = request.get("output_tokens", 0)
|
||||
|
||||
prompt_tokens.append(prompt_len)
|
||||
output_tokens.append(output_len)
|
||||
total_tokens.append(prompt_len + output_len)
|
||||
|
||||
if not prompt_tokens:
|
||||
print("No data available for dataset statistics plot")
|
||||
return
|
||||
|
||||
# Create figure with 4 subplots
|
||||
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(14, 10))
|
||||
|
||||
# Top-left: Prompt tokens distribution
|
||||
ax1.hist(prompt_tokens, bins=30, color="steelblue", edgecolor="black", alpha=0.7)
|
||||
ax1.set_xlabel("Prompt Tokens")
|
||||
ax1.set_ylabel("Frequency")
|
||||
ax1.set_title("Prompt Tokens Distribution")
|
||||
ax1.grid(True, alpha=0.3)
|
||||
|
||||
# Top-right: Output tokens distribution
|
||||
ax2.hist(output_tokens, bins=30, color="coral", edgecolor="black", alpha=0.7)
|
||||
ax2.set_xlabel("Output Tokens")
|
||||
ax2.set_ylabel("Frequency")
|
||||
ax2.set_title("Output Tokens Distribution")
|
||||
ax2.grid(True, alpha=0.3)
|
||||
|
||||
# Bottom-left: Prompt+output tokens distribution
|
||||
ax3.hist(
|
||||
total_tokens, bins=30, color="mediumseagreen", edgecolor="black", alpha=0.7
|
||||
)
|
||||
ax3.set_xlabel("Total Tokens (Prompt + Output)")
|
||||
ax3.set_ylabel("Frequency")
|
||||
ax3.set_title("Total Tokens Distribution")
|
||||
ax3.grid(True, alpha=0.3)
|
||||
|
||||
# Bottom-right: Stacked bar chart
|
||||
request_ids = list(range(len(prompt_tokens)))
|
||||
ax4.bar(
|
||||
request_ids, prompt_tokens, label="Prompt Tokens", color="steelblue", alpha=0.7
|
||||
)
|
||||
ax4.bar(
|
||||
request_ids,
|
||||
output_tokens,
|
||||
bottom=prompt_tokens,
|
||||
label="Output Tokens",
|
||||
color="coral",
|
||||
alpha=0.7,
|
||||
)
|
||||
ax4.set_xlabel("Request ID")
|
||||
ax4.set_ylabel("Tokens")
|
||||
ax4.set_title("Tokens per Request (Stacked)")
|
||||
ax4.legend()
|
||||
ax4.grid(True, alpha=0.3, axis="y")
|
||||
|
||||
# Adjust layout to prevent overlap
|
||||
plt.tight_layout()
|
||||
|
||||
# Save figure
|
||||
plt.savefig(str(output_path), dpi=150, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
|
||||
print(f"Dataset statistics plot saved to: {output_path}")
|
||||
@@ -34,6 +34,7 @@ from collections.abc import AsyncGenerator, Iterable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
import aiohttp
|
||||
@@ -1183,6 +1184,49 @@ def save_to_pytorch_benchmark_format(
|
||||
write_to_json(pt_file, pt_records)
|
||||
|
||||
|
||||
def compute_result_filename(
|
||||
args: argparse.Namespace,
|
||||
model_id: str,
|
||||
label: str,
|
||||
current_dt: str,
|
||||
) -> str | None:
|
||||
"""Compute the result filename based on benchmark configuration.
|
||||
|
||||
Args:
|
||||
args: Command line arguments containing result configuration
|
||||
model_id: The model identifier
|
||||
label: The benchmark label
|
||||
current_dt: Current datetime string
|
||||
|
||||
Returns:
|
||||
The computed filename path or None if no result saving is requested
|
||||
"""
|
||||
if not (args.plot_timeline or args.save_result or args.append_result):
|
||||
return None
|
||||
|
||||
base_model_id = model_id.split("/")[-1]
|
||||
max_concurrency_str = (
|
||||
f"-concurrency{args.max_concurrency}"
|
||||
if args.max_concurrency is not None
|
||||
else ""
|
||||
)
|
||||
label = label or args.backend
|
||||
|
||||
if args.ramp_up_strategy is not None:
|
||||
file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
|
||||
else:
|
||||
file_name = f"{label}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
|
||||
|
||||
if args.result_filename:
|
||||
file_name = args.result_filename
|
||||
|
||||
if args.result_dir:
|
||||
os.makedirs(args.result_dir, exist_ok=True)
|
||||
file_name = os.path.join(args.result_dir, file_name)
|
||||
|
||||
return file_name
|
||||
|
||||
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
add_dataset_parser(parser)
|
||||
parser.add_argument(
|
||||
@@ -1277,6 +1321,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
||||
- "slow" will always use the slow tokenizer.\n
|
||||
- "mistral" will always use the tokenizer from `mistral_common`.\n
|
||||
- "deepseek_v32" will always use the tokenizer from `deepseek_v32`.\n
|
||||
- "qwen_vl" will always use the tokenizer from `qwen_vl`.\n
|
||||
- Other custom values can be supported via plugins.""",
|
||||
)
|
||||
parser.add_argument("--use-beam-search", action="store_true")
|
||||
@@ -1535,6 +1580,30 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
||||
"connecting to servers with self-signed certificates.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--plot-timeline",
|
||||
action="store_true",
|
||||
help="Generate an HTML timeline plot showing request execution. "
|
||||
"The plot will be saved alongside the results JSON file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timeline-itl-thresholds",
|
||||
type=float,
|
||||
nargs=2,
|
||||
default=[25.0, 50.0],
|
||||
metavar=("THRESHOLD1", "THRESHOLD2"),
|
||||
help="ITL thresholds in milliseconds for timeline plot coloring. "
|
||||
"Specify two values to categorize inter-token latencies into three groups: "
|
||||
"below first threshold (green), between thresholds (orange), "
|
||||
"and above second threshold (red). Default: 25 50 (milliseconds).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--plot-dataset-stats",
|
||||
action="store_true",
|
||||
help="Generate a matplotlib figure with dataset statistics showing "
|
||||
"prompt tokens, output tokens, and combined token distributions.",
|
||||
)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace) -> dict[str, Any]:
|
||||
return asyncio.run(main_async(args))
|
||||
@@ -1770,6 +1839,86 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
|
||||
# Merge with benchmark result
|
||||
result_json = {**result_json, **benchmark_result}
|
||||
|
||||
# Compute file_name once before using it for plots or saving results
|
||||
file_name = compute_result_filename(args, model_id, label, current_dt)
|
||||
|
||||
# Generate timeline plot if requested
|
||||
if args.plot_timeline:
|
||||
try:
|
||||
from vllm.benchmarks.plot import generate_timeline_plot
|
||||
|
||||
# Prepare per-request data for timeline
|
||||
per_request_data = []
|
||||
start_times = benchmark_result.get("start_times", [])
|
||||
ttfts = benchmark_result.get("ttfts", [])
|
||||
itls = benchmark_result.get("itls", [])
|
||||
input_lens = benchmark_result.get("input_lens", [])
|
||||
output_lens = benchmark_result.get("output_lens", [])
|
||||
|
||||
if start_times and ttfts and itls:
|
||||
for i in range(len(start_times)):
|
||||
# Calculate latency as ttft + sum of all itls
|
||||
latency = ttfts[i] + sum(itls[i]) if itls[i] else ttfts[i]
|
||||
|
||||
per_request_data.append(
|
||||
{
|
||||
"start_time": start_times[i],
|
||||
"ttft": ttfts[i],
|
||||
"itl": itls[i],
|
||||
"latency": latency,
|
||||
"prompt_len": input_lens[i],
|
||||
"output_tokens": output_lens[i],
|
||||
}
|
||||
)
|
||||
|
||||
timeline_path = Path(file_name).with_suffix(".timeline.html")
|
||||
# Convert thresholds from milliseconds to seconds
|
||||
itl_thresholds_sec = [t / 1000.0 for t in args.timeline_itl_thresholds]
|
||||
generate_timeline_plot(
|
||||
per_request_data, timeline_path, itl_thresholds=itl_thresholds_sec
|
||||
)
|
||||
else:
|
||||
warnings.warn(
|
||||
"Timeline plot requires detailed metrics. "
|
||||
"Ensure the benchmark completed successfully.",
|
||||
stacklevel=2,
|
||||
)
|
||||
except Exception as e:
|
||||
warnings.warn(f"Failed to generate timeline plot: {e}", stacklevel=2)
|
||||
|
||||
# Generate dataset statistics plot if requested
|
||||
if args.plot_dataset_stats:
|
||||
try:
|
||||
from vllm.benchmarks.plot import generate_dataset_stats_plot
|
||||
|
||||
# Prepare per-request data for dataset stats
|
||||
per_request_data = []
|
||||
input_lens = benchmark_result.get("input_lens", [])
|
||||
output_lens = benchmark_result.get("output_lens", [])
|
||||
|
||||
if input_lens and output_lens:
|
||||
for req_input_len, req_output_len in zip(input_lens, output_lens):
|
||||
per_request_data.append(
|
||||
{
|
||||
"prompt_len": req_input_len,
|
||||
"output_tokens": req_output_len,
|
||||
}
|
||||
)
|
||||
|
||||
stats_path = Path(file_name).with_suffix(".dataset_stats.png")
|
||||
generate_dataset_stats_plot(per_request_data, stats_path)
|
||||
else:
|
||||
warnings.warn(
|
||||
"Dataset statistics plot requires input and "
|
||||
"output length data. Ensure the benchmark completed "
|
||||
"successfully.",
|
||||
stacklevel=2,
|
||||
)
|
||||
except Exception as e:
|
||||
warnings.warn(
|
||||
f"Failed to generate dataset statistics plot: {e}", stacklevel=2
|
||||
)
|
||||
|
||||
if not args.save_detailed:
|
||||
# Remove fields with too many data points
|
||||
for field in [
|
||||
@@ -1786,24 +1935,8 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
|
||||
if field in benchmark_result:
|
||||
del benchmark_result[field]
|
||||
|
||||
# Save to file
|
||||
# Save to file
|
||||
if args.save_result or args.append_result:
|
||||
base_model_id = model_id.split("/")[-1]
|
||||
max_concurrency_str = (
|
||||
f"-concurrency{args.max_concurrency}"
|
||||
if args.max_concurrency is not None
|
||||
else ""
|
||||
)
|
||||
label = label or args.backend
|
||||
if args.ramp_up_strategy is not None:
|
||||
file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
|
||||
else:
|
||||
file_name = f"{label}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
|
||||
if args.result_filename:
|
||||
file_name = args.result_filename
|
||||
if args.result_dir:
|
||||
os.makedirs(args.result_dir, exist_ok=True)
|
||||
file_name = os.path.join(args.result_dir, file_name)
|
||||
with open(
|
||||
file_name, mode="a+" if args.append_result else "w", encoding="utf-8"
|
||||
) as outfile:
|
||||
|
||||
@@ -10,14 +10,14 @@ from .plot_pareto import SweepPlotParetoArgs
|
||||
from .plot_pareto import main as plot_pareto_main
|
||||
from .serve import SweepServeArgs
|
||||
from .serve import main as serve_main
|
||||
from .serve_sla import SweepServeSLAArgs
|
||||
from .serve_sla import main as serve_sla_main
|
||||
from .serve_workload import SweepServeWorkloadArgs
|
||||
from .serve_workload import main as serve_workload_main
|
||||
from .startup import SweepStartupArgs
|
||||
from .startup import main as startup_main
|
||||
|
||||
SUBCOMMANDS = (
|
||||
(SweepServeArgs, serve_main),
|
||||
(SweepServeSLAArgs, serve_sla_main),
|
||||
(SweepServeWorkloadArgs, serve_workload_main),
|
||||
(SweepStartupArgs, startup_main),
|
||||
(SweepPlotArgs, plot_main),
|
||||
(SweepPlotParetoArgs, plot_pareto_main),
|
||||
|
||||
@@ -324,6 +324,11 @@ def _plot_fig(
|
||||
df = filter_by.apply(df)
|
||||
df = bin_by.apply(df)
|
||||
|
||||
if len(df) == 0:
|
||||
print(f"No data to plot. Filters: {filter_by}")
|
||||
print("[END FIGURE]")
|
||||
return
|
||||
|
||||
# Sort by curve_by columns alphabetically for consistent legend ordering
|
||||
if curve_by:
|
||||
df = df.sort_values(by=curve_by)
|
||||
@@ -494,7 +499,7 @@ class SweepPlotArgs:
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
output_dir = Path(args.OUTPUT_DIR)
|
||||
output_dir = Path(args.EXPERIMENT_DIR)
|
||||
if not output_dir.exists():
|
||||
raise ValueError(f"No parameter sweep results under {output_dir}")
|
||||
|
||||
@@ -526,11 +531,9 @@ class SweepPlotArgs:
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
"OUTPUT_DIR",
|
||||
"EXPERIMENT_DIR",
|
||||
type=str,
|
||||
default="results",
|
||||
help="The directory containing the results to plot, "
|
||||
"i.e., the `--output-dir` argument to the parameter sweep script.",
|
||||
help="The directory containing the sweep results to plot.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fig-dir",
|
||||
@@ -570,13 +573,13 @@ class SweepPlotArgs:
|
||||
parser.add_argument(
|
||||
"--var-x",
|
||||
type=str,
|
||||
default="request_throughput",
|
||||
default="total_token_throughput",
|
||||
help="The variable for the x-axis.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--var-y",
|
||||
type=str,
|
||||
default="p99_ttft_ms",
|
||||
default="median_ttft_ms",
|
||||
help="The variable for the y-axis",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
@@ -325,7 +325,7 @@ class SweepPlotParetoArgs:
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
output_dir = Path(args.OUTPUT_DIR)
|
||||
output_dir = Path(args.EXPERIMENT_DIR)
|
||||
if not output_dir.exists():
|
||||
raise ValueError(f"No parameter sweep results under {output_dir}")
|
||||
|
||||
@@ -342,9 +342,8 @@ class SweepPlotParetoArgs:
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"OUTPUT_DIR",
|
||||
"EXPERIMENT_DIR",
|
||||
type=str,
|
||||
default="results",
|
||||
help="The directory containing the sweep results to plot.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
@@ -4,6 +4,7 @@ import argparse
|
||||
import contextlib
|
||||
import json
|
||||
import shlex
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
@@ -135,17 +136,21 @@ def run_benchmark(
|
||||
|
||||
|
||||
def _get_comb_base_path(
|
||||
output_dir: Path,
|
||||
experiment_dir: Path,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
*,
|
||||
extra_parts: tuple[str, ...] = (),
|
||||
):
|
||||
parts = list[str]()
|
||||
if serve_comb:
|
||||
parts.extend(("SERVE-", serve_comb.name))
|
||||
if bench_comb:
|
||||
parts.extend(("BENCH-", bench_comb.name))
|
||||
if extra_parts:
|
||||
parts.extend(extra_parts)
|
||||
|
||||
return output_dir / sanitize_filename("-".join(parts))
|
||||
return experiment_dir / sanitize_filename("-".join(parts))
|
||||
|
||||
|
||||
def _get_comb_run_path(base_path: Path, run_number: int | None):
|
||||
@@ -158,10 +163,10 @@ def _get_comb_run_path(base_path: Path, run_number: int | None):
|
||||
def _comb_needs_server(
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_combs: ParameterSweep,
|
||||
output_dir: Path,
|
||||
experiment_dir: Path,
|
||||
):
|
||||
for bench_comb in bench_combs:
|
||||
base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb)
|
||||
base_path = _get_comb_base_path(experiment_dir, serve_comb, bench_comb)
|
||||
if not _get_comb_run_path(base_path, run_number=None).exists():
|
||||
return True
|
||||
|
||||
@@ -175,11 +180,11 @@ def server_ctx(
|
||||
show_stdout: bool,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_params: ParameterSweep,
|
||||
output_dir: Path,
|
||||
experiment_dir: Path,
|
||||
dry_run: bool,
|
||||
server_ready_timeout: int = 300,
|
||||
):
|
||||
if not _comb_needs_server(serve_comb, bench_params, output_dir):
|
||||
if not _comb_needs_server(serve_comb, bench_params, experiment_dir):
|
||||
return contextlib.nullcontext()
|
||||
|
||||
return run_server(
|
||||
@@ -211,10 +216,10 @@ def run_comb(
|
||||
*,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
link_vars: list[tuple[str, str]],
|
||||
base_path: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
link_vars: list[tuple[str, str]],
|
||||
):
|
||||
if not _comb_is_valid(serve_comb, bench_comb, link_vars):
|
||||
return None
|
||||
@@ -253,10 +258,10 @@ def run_combs(
|
||||
server_ready_timeout: int,
|
||||
serve_params: ParameterSweep,
|
||||
bench_params: ParameterSweep,
|
||||
output_dir: Path,
|
||||
link_vars: list[tuple[str, str]],
|
||||
experiment_dir: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
link_vars: list[tuple[str, str]],
|
||||
):
|
||||
all_data = list[dict[str, object]]()
|
||||
for serve_comb in serve_params:
|
||||
@@ -266,22 +271,22 @@ def run_combs(
|
||||
show_stdout=show_stdout,
|
||||
serve_comb=serve_comb,
|
||||
bench_params=bench_params,
|
||||
output_dir=output_dir,
|
||||
experiment_dir=experiment_dir,
|
||||
dry_run=dry_run,
|
||||
server_ready_timeout=server_ready_timeout,
|
||||
) as server:
|
||||
for bench_comb in bench_params:
|
||||
base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb)
|
||||
base_path = _get_comb_base_path(experiment_dir, serve_comb, bench_comb)
|
||||
|
||||
comb_data = run_comb(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
link_vars=link_vars,
|
||||
base_path=base_path,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
link_vars=link_vars,
|
||||
)
|
||||
|
||||
if comb_data is not None:
|
||||
@@ -291,7 +296,7 @@ def run_combs(
|
||||
return None
|
||||
|
||||
combined_df = pd.DataFrame.from_records(all_data)
|
||||
combined_df.to_csv(output_dir / "summary.csv")
|
||||
combined_df.to_csv(experiment_dir / "summary.csv")
|
||||
|
||||
return combined_df
|
||||
|
||||
@@ -305,11 +310,12 @@ class SweepServeArgs:
|
||||
server_ready_timeout: int
|
||||
serve_params: ParameterSweep
|
||||
bench_params: ParameterSweep
|
||||
link_vars: list[tuple[str, str]]
|
||||
output_dir: Path
|
||||
experiment_name: str
|
||||
num_runs: int
|
||||
dry_run: bool
|
||||
resume: str | None
|
||||
link_vars: list[tuple[str, str]]
|
||||
resume: bool
|
||||
|
||||
parser_name: ClassVar[str] = "serve"
|
||||
parser_help: ClassVar[str] = "Run vLLM server benchmark under multiple settings."
|
||||
@@ -336,6 +342,11 @@ class SweepServeArgs:
|
||||
|
||||
link_vars = cls.parse_link_vars(args.link_vars)
|
||||
|
||||
if args.experiment_name:
|
||||
experiment_name = args.experiment_name
|
||||
else:
|
||||
experiment_name = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
num_runs = args.num_runs
|
||||
if num_runs < 1:
|
||||
raise ValueError("`num_runs` should be at least 1.")
|
||||
@@ -347,11 +358,12 @@ class SweepServeArgs:
|
||||
show_stdout=args.show_stdout,
|
||||
serve_params=serve_params,
|
||||
bench_params=bench_params,
|
||||
link_vars=link_vars,
|
||||
output_dir=Path(args.output_dir),
|
||||
experiment_name=experiment_name,
|
||||
num_runs=num_runs,
|
||||
dry_run=args.dry_run,
|
||||
resume=args.resume,
|
||||
link_vars=link_vars,
|
||||
server_ready_timeout=args.server_ready_timeout,
|
||||
)
|
||||
|
||||
@@ -388,6 +400,7 @@ class SweepServeArgs:
|
||||
default=300,
|
||||
help="Timeout in seconds to wait for the server to become ready.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--serve-params",
|
||||
type=str,
|
||||
@@ -398,6 +411,16 @@ class SweepServeArgs:
|
||||
"If both `serve_params` and `bench_params` are given, "
|
||||
"this script will iterate over their Cartesian product.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--link-vars",
|
||||
type=str,
|
||||
default="",
|
||||
help=(
|
||||
"Comma-separated list of linked variables between serve and bench, "
|
||||
"e.g. max_num_seqs=max_concurrency,max_model_len=random_input_len"
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bench-params",
|
||||
type=str,
|
||||
@@ -413,7 +436,15 @@ class SweepServeArgs:
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="results",
|
||||
help="The directory to which results are written.",
|
||||
help="The main directory to which results are written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--experiment-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The name of this experiment (defaults to current timestamp). "
|
||||
"Results will be stored under `output_dir/experiment_name`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-runs",
|
||||
@@ -429,21 +460,10 @@ class SweepServeArgs:
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Set this to the name of a directory under `output_dir` (which is a "
|
||||
"timestamp) to resume a previous execution of this script, i.e., only run "
|
||||
"parameter combinations for which there are still no output files.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--link-vars",
|
||||
type=str,
|
||||
default="",
|
||||
help=(
|
||||
"Comma-separated list of linked variables between serve and bench, "
|
||||
"e.g. max_num_seqs=max_concurrency,max_model_len=random_input_len"
|
||||
),
|
||||
action="store_true",
|
||||
help="Resume a previous execution of this script, i.e., only run "
|
||||
"parameter combinations for which there are still no output files "
|
||||
"under `output_dir/experiment_name`.",
|
||||
)
|
||||
|
||||
return parser
|
||||
@@ -458,33 +478,52 @@ class SweepServeArgs:
|
||||
pairs.append((a.strip(), b.strip()))
|
||||
return pairs
|
||||
|
||||
def resolve_experiment_dir(self) -> Path:
|
||||
experiment_dir = self.output_dir / self.experiment_name
|
||||
|
||||
if self.resume:
|
||||
if not experiment_dir.exists():
|
||||
raise ValueError(f"Cannot resume from non-existent {experiment_dir=}")
|
||||
else:
|
||||
if experiment_dir.exists():
|
||||
raise ValueError(f"Cannot overwrite existing {experiment_dir=}")
|
||||
|
||||
return experiment_dir
|
||||
|
||||
@contextmanager
|
||||
def run_ctx(self, experiment_dir: Path):
|
||||
if self.dry_run:
|
||||
yield
|
||||
print(f"Experiment will be saved at: {experiment_dir}")
|
||||
return
|
||||
|
||||
try:
|
||||
yield
|
||||
print(f"Experiment has been saved at: {experiment_dir}")
|
||||
except BaseException as exc:
|
||||
raise RuntimeError(
|
||||
"The script was terminated early. Use `--resume` "
|
||||
"to continue the script from its last checkpoint."
|
||||
) from exc
|
||||
|
||||
|
||||
def run_main(args: SweepServeArgs):
|
||||
timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_dir = args.output_dir / timestamp
|
||||
experiment_dir = args.resolve_experiment_dir()
|
||||
|
||||
if args.resume and not output_dir.exists():
|
||||
raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")
|
||||
|
||||
try:
|
||||
with args.run_ctx(experiment_dir):
|
||||
return run_combs(
|
||||
serve_cmd=args.serve_cmd,
|
||||
bench_cmd=args.bench_cmd,
|
||||
link_vars=args.link_vars,
|
||||
after_bench_cmd=args.after_bench_cmd,
|
||||
show_stdout=args.show_stdout,
|
||||
server_ready_timeout=args.server_ready_timeout,
|
||||
serve_params=args.serve_params,
|
||||
bench_params=args.bench_params,
|
||||
output_dir=output_dir,
|
||||
experiment_dir=experiment_dir,
|
||||
num_runs=args.num_runs,
|
||||
dry_run=args.dry_run,
|
||||
link_vars=args.link_vars,
|
||||
)
|
||||
except BaseException as exc:
|
||||
raise RuntimeError(
|
||||
f"The script was terminated early. Use `--resume {timestamp}` "
|
||||
f"to continue the script from its last checkpoint."
|
||||
) from exc
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
|
||||
@@ -1,305 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import math
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Literal, get_args
|
||||
|
||||
import numpy as np
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
from .param_sweep import ParameterSweep, ParameterSweepItem
|
||||
from .serve import (
|
||||
SweepServeArgs,
|
||||
_get_comb_base_path,
|
||||
run_comb,
|
||||
server_ctx,
|
||||
)
|
||||
from .server import ServerProcess
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pd = PlaceholderModule("pandas")
|
||||
|
||||
|
||||
SLAVariable = Literal["request_rate", "max_concurrency"]
|
||||
|
||||
|
||||
def _estimate_sla_value(run_data: dict[str, object], sla_variable: SLAVariable):
|
||||
request_throughput = float(run_data["request_throughput"]) # type: ignore
|
||||
if sla_variable == "request_rate":
|
||||
return request_throughput
|
||||
if sla_variable == "max_concurrency":
|
||||
mean_latency_ms = float(run_data["mean_e2el_ms"]) # type: ignore
|
||||
return request_throughput * mean_latency_ms / 1000
|
||||
|
||||
assert_never(sla_variable)
|
||||
|
||||
|
||||
def _estimate_sla_avg(runs: list[dict[str, object]], sla_variable: SLAVariable):
|
||||
return sum(_estimate_sla_value(run, sla_variable) for run in runs) / len(runs)
|
||||
|
||||
|
||||
def run_comb_sla(
|
||||
server: ServerProcess | None,
|
||||
bench_cmd: list[str],
|
||||
*,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
output_dir: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
link_vars: list[tuple[str, str]],
|
||||
sla_variable: SLAVariable,
|
||||
sla_value: int,
|
||||
) -> list[dict[str, object]] | None:
|
||||
bench_comb_sla = bench_comb | {sla_variable: sla_value}
|
||||
|
||||
return run_comb(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb_sla,
|
||||
base_path=_get_comb_base_path(output_dir, serve_comb, bench_comb_sla),
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
link_vars=link_vars,
|
||||
)
|
||||
|
||||
|
||||
def explore_sla(
|
||||
server: ServerProcess | None,
|
||||
bench_cmd: list[str],
|
||||
*,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
sla_variable: SLAVariable,
|
||||
sla_iters: int,
|
||||
output_dir: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
link_vars: list[tuple[str, str]],
|
||||
):
|
||||
print("[SLA START]")
|
||||
print(f"Serve parameters: {serve_comb.as_text() or '(None)'}")
|
||||
print(f"Bench parameters: {bench_comb.as_text() or '(None)'}")
|
||||
print(f"Number of SLA iterations: {sla_iters}")
|
||||
|
||||
if sla_iters < 2:
|
||||
raise ValueError("`sla_iters` should be at least 2")
|
||||
|
||||
serial_comb_data = run_comb_sla(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
output_dir=output_dir,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
link_vars=link_vars,
|
||||
sla_variable=sla_variable,
|
||||
sla_value=1,
|
||||
)
|
||||
batch_comb_data = run_comb_sla(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
output_dir=output_dir,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
link_vars=link_vars,
|
||||
sla_variable=sla_variable,
|
||||
sla_value=int(bench_comb.get("num_prompts", 1000)), # type: ignore
|
||||
)
|
||||
|
||||
if serial_comb_data is None or batch_comb_data is None:
|
||||
if dry_run:
|
||||
print("Omitting intermediate SLA iterations.")
|
||||
print("[SLA END]")
|
||||
|
||||
return
|
||||
|
||||
serial_sla_value = math.ceil(_estimate_sla_avg(serial_comb_data, sla_variable))
|
||||
print(f"Serial inference: {sla_variable}={serial_sla_value}")
|
||||
|
||||
batch_sla_value = math.floor(_estimate_sla_avg(batch_comb_data, sla_variable))
|
||||
print(f"Batch inference: {sla_variable}={batch_sla_value}")
|
||||
|
||||
# Avoid duplicated runs for intermediate values if the range between
|
||||
# `serial_sla_value` and `batch_sla_value` is small
|
||||
inter_sla_values = np.linspace(serial_sla_value, batch_sla_value, sla_iters)[1:-1]
|
||||
inter_sla_values = sorted(set(map(round, inter_sla_values)))
|
||||
|
||||
inter_combs_data: list[dict[str, object]] = []
|
||||
for inter_sla_value in inter_sla_values:
|
||||
print(f"Exploring: {sla_variable}={inter_sla_value}")
|
||||
inter_comb_data = run_comb_sla(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
output_dir=output_dir,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
link_vars=link_vars,
|
||||
sla_variable=sla_variable,
|
||||
sla_value=inter_sla_value,
|
||||
)
|
||||
if inter_comb_data is not None:
|
||||
inter_combs_data.extend(inter_comb_data)
|
||||
|
||||
print("[SLA END]")
|
||||
|
||||
return serial_comb_data + inter_combs_data + batch_comb_data
|
||||
|
||||
|
||||
def run_slas(
|
||||
serve_cmd: list[str],
|
||||
bench_cmd: list[str],
|
||||
after_bench_cmd: list[str],
|
||||
*,
|
||||
show_stdout: bool,
|
||||
server_ready_timeout: int,
|
||||
serve_params: ParameterSweep,
|
||||
bench_params: ParameterSweep,
|
||||
sla_variable: SLAVariable,
|
||||
sla_iters: int,
|
||||
output_dir: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
link_vars: list[tuple[str, str]],
|
||||
):
|
||||
if any(bench_comb.has_param(sla_variable) for bench_comb in bench_params):
|
||||
raise ValueError(
|
||||
f"You should not override `{sla_variable}` in `bench_params` in SLA mode, "
|
||||
"since it is supposed to be determined automatically."
|
||||
)
|
||||
|
||||
all_data = list[dict[str, object]]()
|
||||
for serve_comb in serve_params:
|
||||
with server_ctx(
|
||||
serve_cmd,
|
||||
after_bench_cmd,
|
||||
show_stdout=show_stdout,
|
||||
server_ready_timeout=server_ready_timeout,
|
||||
serve_comb=serve_comb,
|
||||
bench_params=bench_params,
|
||||
output_dir=output_dir,
|
||||
dry_run=dry_run,
|
||||
) as server:
|
||||
for bench_comb in bench_params:
|
||||
comb_data = explore_sla(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
sla_variable=sla_variable,
|
||||
sla_iters=sla_iters,
|
||||
output_dir=output_dir,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
link_vars=link_vars,
|
||||
)
|
||||
|
||||
if comb_data is not None:
|
||||
all_data.extend(comb_data)
|
||||
|
||||
if dry_run:
|
||||
return None
|
||||
|
||||
combined_df = pd.DataFrame.from_records(all_data)
|
||||
combined_df.to_csv(output_dir / "summary.csv")
|
||||
|
||||
return combined_df
|
||||
|
||||
|
||||
@dataclass
|
||||
class SweepServeSLAArgs(SweepServeArgs):
|
||||
sla_variable: SLAVariable
|
||||
sla_iters: int
|
||||
|
||||
parser_name: ClassVar[str] = "serve_sla"
|
||||
parser_help: ClassVar[str] = (
|
||||
"Explore the latency-throughput space for determining SLAs."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
# NOTE: Don't use super() as `from_cli_args` calls `cls()`
|
||||
base_args = SweepServeArgs.from_cli_args(args)
|
||||
|
||||
return cls(
|
||||
**asdict(base_args),
|
||||
sla_variable=args.sla_variable,
|
||||
sla_iters=args.sla_iters,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
parser = super().add_cli_args(parser)
|
||||
|
||||
sla_group = parser.add_argument_group("sla options")
|
||||
sla_group.add_argument(
|
||||
"--sla-variable",
|
||||
type=str,
|
||||
choices=get_args(SLAVariable),
|
||||
default="request_rate",
|
||||
help="The variable to adjust in each iteration.",
|
||||
)
|
||||
sla_group.add_argument(
|
||||
"--sla-iters",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of iterations used to explore the latency-throughput space. "
|
||||
"This includes the first two iterations used to interpolate the value of "
|
||||
"`sla_variable` for remaining iterations.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def run_main(args: SweepServeSLAArgs):
|
||||
timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_dir = args.output_dir / timestamp
|
||||
|
||||
if args.resume and not output_dir.exists():
|
||||
raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")
|
||||
|
||||
try:
|
||||
return run_slas(
|
||||
serve_cmd=args.serve_cmd,
|
||||
bench_cmd=args.bench_cmd,
|
||||
after_bench_cmd=args.after_bench_cmd,
|
||||
show_stdout=args.show_stdout,
|
||||
server_ready_timeout=args.server_ready_timeout,
|
||||
serve_params=args.serve_params,
|
||||
bench_params=args.bench_params,
|
||||
sla_variable=args.sla_variable,
|
||||
sla_iters=args.sla_iters,
|
||||
output_dir=output_dir,
|
||||
num_runs=args.num_runs,
|
||||
dry_run=args.dry_run,
|
||||
link_vars=args.link_vars,
|
||||
)
|
||||
except BaseException as exc:
|
||||
raise RuntimeError(
|
||||
f"The script was terminated early. Use `--resume {timestamp}` "
|
||||
f"to continue the script from its last checkpoint."
|
||||
) from exc
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
run_main(SweepServeSLAArgs.from_cli_args(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=SweepServeSLAArgs.parser_help)
|
||||
SweepServeSLAArgs.add_cli_args(parser)
|
||||
|
||||
main(parser.parse_args())
|
||||
328
vllm/benchmarks/sweep/serve_workload.py
Normal file
328
vllm/benchmarks/sweep/serve_workload.py
Normal file
@@ -0,0 +1,328 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import math
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Literal, get_args
|
||||
|
||||
import numpy as np
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.benchmarks.datasets import DEFAULT_NUM_PROMPTS
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
from .param_sweep import ParameterSweep, ParameterSweepItem
|
||||
from .serve import (
|
||||
SweepServeArgs,
|
||||
_get_comb_base_path,
|
||||
run_comb,
|
||||
server_ctx,
|
||||
)
|
||||
from .server import ServerProcess
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pd = PlaceholderModule("pandas")
|
||||
|
||||
|
||||
WorkloadVariable = Literal["request_rate", "max_concurrency"]
|
||||
|
||||
|
||||
def _estimate_workload_value(
|
||||
run_data: dict[str, object],
|
||||
workload_var: WorkloadVariable,
|
||||
):
|
||||
request_throughput = float(run_data["request_throughput"]) # type: ignore
|
||||
if workload_var == "request_rate":
|
||||
return request_throughput
|
||||
if workload_var == "max_concurrency":
|
||||
mean_latency_ms = float(run_data["mean_e2el_ms"]) # type: ignore
|
||||
return request_throughput * mean_latency_ms / 1000
|
||||
|
||||
assert_never(workload_var)
|
||||
|
||||
|
||||
def _estimate_workload_avg(
|
||||
runs: list[dict[str, object]],
|
||||
workload_var: WorkloadVariable,
|
||||
):
|
||||
total = sum(_estimate_workload_value(run, workload_var) for run in runs)
|
||||
return total / len(runs)
|
||||
|
||||
|
||||
def run_comb_workload(
|
||||
server: ServerProcess | None,
|
||||
bench_cmd: list[str],
|
||||
*,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
link_vars: list[tuple[str, str]],
|
||||
experiment_dir: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
workload_var: WorkloadVariable,
|
||||
workload_value: int,
|
||||
) -> list[dict[str, object]] | None:
|
||||
bench_comb_workload = bench_comb | {workload_var: workload_value}
|
||||
|
||||
return run_comb(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb_workload,
|
||||
link_vars=link_vars,
|
||||
base_path=_get_comb_base_path(
|
||||
experiment_dir,
|
||||
serve_comb,
|
||||
bench_comb,
|
||||
extra_parts=("WL-", f"{workload_var}={workload_value}"),
|
||||
),
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
|
||||
def explore_comb_workloads(
|
||||
server: ServerProcess | None,
|
||||
bench_cmd: list[str],
|
||||
*,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
link_vars: list[tuple[str, str]],
|
||||
workload_var: WorkloadVariable,
|
||||
workload_iters: int,
|
||||
experiment_dir: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
):
|
||||
print("[WL START]")
|
||||
print(f"Serve parameters: {serve_comb.as_text() or '(None)'}")
|
||||
print(f"Bench parameters: {bench_comb.as_text() or '(None)'}")
|
||||
print(f"Number of workload iterations: {workload_iters}")
|
||||
|
||||
if workload_iters < 2:
|
||||
raise ValueError("`workload_iters` should be at least 2")
|
||||
|
||||
dataset_size = DEFAULT_NUM_PROMPTS
|
||||
if "num_prompts" in bench_comb:
|
||||
dataset_size = int(bench_comb["num_prompts"]) # type: ignore
|
||||
else:
|
||||
for i, arg in enumerate(bench_cmd):
|
||||
if arg == "--num-prompts" and i + 1 < len(bench_cmd):
|
||||
dataset_size = int(bench_cmd[i + 1])
|
||||
break
|
||||
elif arg.startswith("--num-prompts="):
|
||||
dataset_size = int(arg.split("=", 1)[1])
|
||||
break
|
||||
|
||||
print(f"Dataset size: {dataset_size}")
|
||||
|
||||
serial_workload_data = run_comb_workload(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb | {"max_concurrency": 1},
|
||||
link_vars=link_vars,
|
||||
experiment_dir=experiment_dir,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
workload_var=workload_var,
|
||||
workload_value=1,
|
||||
)
|
||||
batch_workload_data = run_comb_workload(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb | {"max_concurrency": dataset_size},
|
||||
link_vars=link_vars,
|
||||
experiment_dir=experiment_dir,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
workload_var=workload_var,
|
||||
workload_value=dataset_size,
|
||||
)
|
||||
|
||||
if serial_workload_data is None or batch_workload_data is None:
|
||||
if dry_run:
|
||||
print("Omitting intermediate Workload iterations.")
|
||||
print("[WL END]")
|
||||
|
||||
return
|
||||
|
||||
serial_workload_value = math.ceil(
|
||||
_estimate_workload_avg(serial_workload_data, workload_var)
|
||||
)
|
||||
print(f"Serial inference: {workload_var}={serial_workload_value}")
|
||||
|
||||
batch_workload_value = math.floor(
|
||||
_estimate_workload_avg(batch_workload_data, workload_var)
|
||||
)
|
||||
print(f"Batch inference: {workload_var}={batch_workload_value}")
|
||||
|
||||
# Avoid duplicated runs for intermediate values if the range between
|
||||
# `serial_workload_value` and `batch_workload_value` is small
|
||||
inter_workload_values = np.linspace(
|
||||
serial_workload_value, batch_workload_value, workload_iters
|
||||
)[1:-1]
|
||||
inter_workload_values = sorted(set(map(round, inter_workload_values)))
|
||||
|
||||
inter_workloads_data: list[dict[str, object]] = []
|
||||
for inter_workload_value in inter_workload_values:
|
||||
print(f"Exploring: {workload_var}={inter_workload_value}")
|
||||
inter_workload_data = run_comb_workload(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
link_vars=link_vars,
|
||||
experiment_dir=experiment_dir,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
workload_var=workload_var,
|
||||
workload_value=inter_workload_value,
|
||||
)
|
||||
if inter_workload_data is not None:
|
||||
inter_workloads_data.extend(inter_workload_data)
|
||||
|
||||
print("[WL END]")
|
||||
|
||||
return serial_workload_data + inter_workloads_data + batch_workload_data
|
||||
|
||||
|
||||
def explore_combs_workloads(
|
||||
serve_cmd: list[str],
|
||||
bench_cmd: list[str],
|
||||
after_bench_cmd: list[str],
|
||||
*,
|
||||
show_stdout: bool,
|
||||
server_ready_timeout: int,
|
||||
serve_params: ParameterSweep,
|
||||
bench_params: ParameterSweep,
|
||||
link_vars: list[tuple[str, str]],
|
||||
workload_var: WorkloadVariable,
|
||||
workload_iters: int,
|
||||
experiment_dir: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
):
|
||||
if any(bench_comb.has_param(workload_var) for bench_comb in bench_params):
|
||||
raise ValueError(
|
||||
f"You should not override `{workload_var}` in `bench_params` "
|
||||
"since it is supposed to be explored automatically."
|
||||
)
|
||||
|
||||
all_data = list[dict[str, object]]()
|
||||
for serve_comb in serve_params:
|
||||
with server_ctx(
|
||||
serve_cmd,
|
||||
after_bench_cmd,
|
||||
show_stdout=show_stdout,
|
||||
server_ready_timeout=server_ready_timeout,
|
||||
serve_comb=serve_comb,
|
||||
bench_params=bench_params,
|
||||
experiment_dir=experiment_dir,
|
||||
dry_run=dry_run,
|
||||
) as server:
|
||||
for bench_comb in bench_params:
|
||||
comb_data = explore_comb_workloads(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
link_vars=link_vars,
|
||||
workload_var=workload_var,
|
||||
workload_iters=workload_iters,
|
||||
experiment_dir=experiment_dir,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
if comb_data is not None:
|
||||
all_data.extend(comb_data)
|
||||
|
||||
if dry_run:
|
||||
return None
|
||||
|
||||
combined_df = pd.DataFrame.from_records(all_data)
|
||||
combined_df.to_csv(experiment_dir / "summary.csv")
|
||||
|
||||
return combined_df
|
||||
|
||||
|
||||
@dataclass
|
||||
class SweepServeWorkloadArgs(SweepServeArgs):
|
||||
workload_var: WorkloadVariable
|
||||
workload_iters: int
|
||||
|
||||
parser_name: ClassVar[str] = "serve_workload"
|
||||
parser_help: ClassVar[str] = (
|
||||
"Explore the latency-throughput tradeoff for different workload levels."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
# NOTE: Don't use super() as `from_cli_args` calls `cls()`
|
||||
base_args = SweepServeArgs.from_cli_args(args)
|
||||
|
||||
return cls(
|
||||
**asdict(base_args),
|
||||
workload_var=args.workload_var,
|
||||
workload_iters=args.workload_iters,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
parser = super().add_cli_args(parser)
|
||||
|
||||
workload_group = parser.add_argument_group("workload options")
|
||||
workload_group.add_argument(
|
||||
"--workload-var",
|
||||
type=str,
|
||||
choices=get_args(WorkloadVariable),
|
||||
default="request_rate",
|
||||
help="The variable to adjust in each iteration.",
|
||||
)
|
||||
workload_group.add_argument(
|
||||
"--workload-iters",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of workload levels to explore. "
|
||||
"This includes the first two iterations used to interpolate the value of "
|
||||
"`workload_var` for remaining iterations.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def run_main(args: SweepServeWorkloadArgs):
|
||||
experiment_dir = args.resolve_experiment_dir()
|
||||
|
||||
with args.run_ctx(experiment_dir):
|
||||
return explore_combs_workloads(
|
||||
serve_cmd=args.serve_cmd,
|
||||
bench_cmd=args.bench_cmd,
|
||||
after_bench_cmd=args.after_bench_cmd,
|
||||
show_stdout=args.show_stdout,
|
||||
server_ready_timeout=args.server_ready_timeout,
|
||||
serve_params=args.serve_params,
|
||||
bench_params=args.bench_params,
|
||||
link_vars=args.link_vars,
|
||||
workload_var=args.workload_var,
|
||||
workload_iters=args.workload_iters,
|
||||
experiment_dir=experiment_dir,
|
||||
num_runs=args.num_runs,
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
run_main(SweepServeWorkloadArgs.from_cli_args(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=SweepServeWorkloadArgs.parser_help)
|
||||
SweepServeWorkloadArgs.add_cli_args(parser)
|
||||
|
||||
main(parser.parse_args())
|
||||
@@ -4,6 +4,7 @@ import argparse
|
||||
import json
|
||||
import shlex
|
||||
import subprocess
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
@@ -111,7 +112,7 @@ def _apply_output_json(cmd: list[str], output_path: Path) -> list[str]:
|
||||
|
||||
|
||||
def _get_comb_base_path(
|
||||
output_dir: Path,
|
||||
experiment_dir: Path,
|
||||
serve_comb: ParameterSweepItem,
|
||||
startup_comb: ParameterSweepItem,
|
||||
) -> Path:
|
||||
@@ -120,7 +121,8 @@ def _get_comb_base_path(
|
||||
parts.extend(("SERVE-", serve_comb.name))
|
||||
if startup_comb:
|
||||
parts.extend(("STARTUP-", startup_comb.name))
|
||||
return output_dir / sanitize_filename("-".join(parts))
|
||||
|
||||
return experiment_dir / sanitize_filename("-".join(parts))
|
||||
|
||||
|
||||
def _get_comb_run_path(base_path: Path, run_number: int | None) -> Path:
|
||||
@@ -225,7 +227,7 @@ def run_combs(
|
||||
*,
|
||||
serve_params: ParameterSweep,
|
||||
startup_params: ParameterSweep,
|
||||
output_dir: Path,
|
||||
experiment_dir: Path,
|
||||
num_runs: int,
|
||||
show_stdout: bool,
|
||||
dry_run: bool,
|
||||
@@ -233,7 +235,7 @@ def run_combs(
|
||||
all_data = list[dict[str, object]]()
|
||||
for serve_comb in serve_params:
|
||||
for startup_comb in startup_params:
|
||||
base_path = _get_comb_base_path(output_dir, serve_comb, startup_comb)
|
||||
base_path = _get_comb_base_path(experiment_dir, serve_comb, startup_comb)
|
||||
comb_data = run_comb(
|
||||
startup_cmd,
|
||||
serve_comb=serve_comb,
|
||||
@@ -250,7 +252,7 @@ def run_combs(
|
||||
return None
|
||||
|
||||
combined_df = pd.DataFrame.from_records(all_data)
|
||||
combined_df.to_csv(output_dir / "summary.csv")
|
||||
combined_df.to_csv(experiment_dir / "summary.csv")
|
||||
return combined_df
|
||||
|
||||
|
||||
@@ -260,11 +262,11 @@ class SweepStartupArgs:
|
||||
serve_params: ParameterSweep
|
||||
startup_params: ParameterSweep
|
||||
output_dir: Path
|
||||
experiment_name: str
|
||||
num_runs: int
|
||||
show_stdout: bool
|
||||
dry_run: bool
|
||||
resume: str | None
|
||||
strict_params: bool
|
||||
resume: bool
|
||||
|
||||
parser_name: ClassVar[str] = "startup"
|
||||
parser_help: ClassVar[str] = (
|
||||
@@ -286,13 +288,19 @@ class SweepStartupArgs:
|
||||
startup_params = ParameterSweep.from_records([{}])
|
||||
|
||||
supported = _get_supported_startup_keys()
|
||||
strict_params = args.strict_params
|
||||
serve_params = _filter_params(
|
||||
serve_params, supported=supported, strict=args.strict_params
|
||||
serve_params, supported=supported, strict=strict_params
|
||||
)
|
||||
startup_params = _filter_params(
|
||||
startup_params, supported=supported, strict=args.strict_params
|
||||
startup_params, supported=supported, strict=strict_params
|
||||
)
|
||||
|
||||
if args.experiment_name:
|
||||
experiment_name = args.experiment_name
|
||||
else:
|
||||
experiment_name = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
if args.num_runs < 1:
|
||||
raise ValueError("`num_runs` should be at least 1.")
|
||||
|
||||
@@ -301,11 +309,11 @@ class SweepStartupArgs:
|
||||
serve_params=serve_params,
|
||||
startup_params=startup_params,
|
||||
output_dir=Path(args.output_dir),
|
||||
experiment_name=experiment_name,
|
||||
num_runs=args.num_runs,
|
||||
show_stdout=args.show_stdout,
|
||||
dry_run=args.dry_run,
|
||||
resume=args.resume,
|
||||
strict_params=args.strict_params,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -316,6 +324,7 @@ class SweepStartupArgs:
|
||||
default="vllm bench startup",
|
||||
help="The command used to run the startup benchmark.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--serve-params",
|
||||
type=str,
|
||||
@@ -331,12 +340,27 @@ class SweepStartupArgs:
|
||||
help="Path to JSON file containing parameter combinations "
|
||||
"for the `vllm bench startup` command.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--strict-params",
|
||||
action="store_true",
|
||||
help="If set, unknown parameters in sweep files raise an error "
|
||||
"instead of being ignored.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="results",
|
||||
help="The directory to which results are written.",
|
||||
help="The main directory to which results are written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--experiment-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The name of this experiment (defaults to current timestamp). "
|
||||
"Results will be stored under `output_dir/experiment_name`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-runs",
|
||||
@@ -357,43 +381,56 @@ class SweepStartupArgs:
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Set this to the name of a directory under `output_dir` (which is a "
|
||||
"timestamp) to resume a previous execution of this script, i.e., only run "
|
||||
"parameter combinations for which there are still no output files.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--strict-params",
|
||||
action="store_true",
|
||||
help="If set, unknown parameters in sweep files raise an error "
|
||||
"instead of being ignored.",
|
||||
help="Resume a previous execution of this script, i.e., only run "
|
||||
"parameter combinations for which there are still no output files "
|
||||
"under `output_dir/experiment_name`.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
def resolve_experiment_dir(self) -> Path:
|
||||
experiment_dir = self.output_dir / self.experiment_name
|
||||
|
||||
if self.resume:
|
||||
if not experiment_dir.exists():
|
||||
raise ValueError(f"Cannot resume from non-existent {experiment_dir=}")
|
||||
else:
|
||||
if experiment_dir.exists():
|
||||
raise ValueError(f"Cannot overwrite existing {experiment_dir=}")
|
||||
|
||||
return experiment_dir
|
||||
|
||||
@contextmanager
|
||||
def run_ctx(self, experiment_dir: Path):
|
||||
if self.dry_run:
|
||||
yield
|
||||
print(f"Experiment will be saved at: {experiment_dir}")
|
||||
return
|
||||
|
||||
try:
|
||||
yield
|
||||
print(f"Experiment has been saved at: {experiment_dir}")
|
||||
except BaseException as exc:
|
||||
raise RuntimeError(
|
||||
"The script was terminated early. Use `--resume` "
|
||||
"to continue the script from its last checkpoint."
|
||||
) from exc
|
||||
|
||||
|
||||
def run_main(args: SweepStartupArgs):
|
||||
timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_dir = args.output_dir / timestamp
|
||||
experiment_dir = args.resolve_experiment_dir()
|
||||
|
||||
if args.resume and not output_dir.exists():
|
||||
raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")
|
||||
|
||||
try:
|
||||
with args.run_ctx(experiment_dir):
|
||||
return run_combs(
|
||||
startup_cmd=args.startup_cmd,
|
||||
serve_params=args.serve_params,
|
||||
startup_params=args.startup_params,
|
||||
output_dir=output_dir,
|
||||
experiment_dir=experiment_dir,
|
||||
num_runs=args.num_runs,
|
||||
show_stdout=args.show_stdout,
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
except BaseException as exc:
|
||||
raise RuntimeError(
|
||||
f"The script was terminated early. Use `--resume {timestamp}` "
|
||||
f"to continue the script from its last checkpoint."
|
||||
) from exc
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
|
||||
Reference in New Issue
Block a user