Sync from v0.13
This commit is contained in:
0
vllm/benchmarks/__init__.py
Normal file
0
vllm/benchmarks/__init__.py
Normal file
3228
vllm/benchmarks/datasets.py
Normal file
3228
vllm/benchmarks/datasets.py
Normal file
File diff suppressed because it is too large
Load Diff
170
vllm/benchmarks/latency.py
Normal file
170
vllm/benchmarks/latency.py
Normal file
@@ -0,0 +1,170 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Benchmark the latency of processing a single batch of requests."""
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
|
||||
|
||||
def save_to_pytorch_benchmark_format(
|
||||
args: argparse.Namespace, results: dict[str, Any]
|
||||
) -> None:
|
||||
pt_records = convert_to_pytorch_benchmark_format(
|
||||
args=args,
|
||||
metrics={"latency": results["latencies"]},
|
||||
extra_info={k: results[k] for k in ["avg_latency", "percentiles"]},
|
||||
)
|
||||
if pt_records:
|
||||
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
|
||||
write_to_json(pt_file, pt_records)
|
||||
|
||||
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--input-len", type=int, default=32)
|
||||
parser.add_argument("--output-len", type=int, default=128)
|
||||
parser.add_argument("--batch-size", type=int, default=8)
|
||||
parser.add_argument(
|
||||
"--n",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of generated sequences per prompt.",
|
||||
)
|
||||
parser.add_argument("--use-beam-search", action="store_true")
|
||||
parser.add_argument(
|
||||
"--num-iters-warmup",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of iterations to run for warmup.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-iters", type=int, default=30, help="Number of iterations to run."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
action="store_true",
|
||||
help="profile the generation process of a single batch",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-json",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to save the latency results in JSON format.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-detokenize",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Do not detokenize responses (i.e. do not include "
|
||||
"detokenization time in the latency measurement)"
|
||||
),
|
||||
)
|
||||
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# V1 enables prefix caching by default which skews the latency
|
||||
# numbers. We need to disable prefix caching by default.
|
||||
parser.set_defaults(enable_prefix_caching=False)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
if args.profile and not engine_args.profiler_config.profiler == "torch":
|
||||
raise ValueError(
|
||||
"The torch profiler is not enabled. Please provide profiler_config."
|
||||
)
|
||||
|
||||
# Lazy import to avoid importing LLM when the bench command is not selected.
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# NOTE(woosuk): If the request cannot be processed in a single batch,
|
||||
# the engine will automatically process the request in multiple batches.
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
assert llm.llm_engine.model_config.max_model_len >= (
|
||||
args.input_len + args.output_len
|
||||
), (
|
||||
"Please ensure that max_model_len is greater than"
|
||||
" the sum of input_len and output_len."
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
n=args.n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=args.output_len,
|
||||
detokenize=not args.disable_detokenize,
|
||||
)
|
||||
dummy_prompt_token_ids = np.random.randint(
|
||||
10000, size=(args.batch_size, args.input_len)
|
||||
)
|
||||
dummy_prompts: list[PromptType] = [
|
||||
{"prompt_token_ids": batch} for batch in dummy_prompt_token_ids.tolist()
|
||||
]
|
||||
|
||||
def llm_generate():
|
||||
if not args.use_beam_search:
|
||||
llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False)
|
||||
else:
|
||||
llm.beam_search(
|
||||
dummy_prompts,
|
||||
BeamSearchParams(
|
||||
beam_width=args.n,
|
||||
max_tokens=args.output_len,
|
||||
ignore_eos=True,
|
||||
),
|
||||
)
|
||||
|
||||
def run_to_completion(profile_dir: str | None = None):
|
||||
if profile_dir:
|
||||
llm.start_profile()
|
||||
llm_generate()
|
||||
llm.stop_profile()
|
||||
else:
|
||||
start_time = time.perf_counter()
|
||||
llm_generate()
|
||||
end_time = time.perf_counter()
|
||||
latency = end_time - start_time
|
||||
return latency
|
||||
|
||||
print("Warming up...")
|
||||
for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
|
||||
run_to_completion(profile_dir=None)
|
||||
|
||||
if args.profile:
|
||||
profile_dir = engine_args.profiler_config.torch_profiler_dir
|
||||
print(f"Profiling (results will be saved to '{profile_dir}')...")
|
||||
run_to_completion(profile_dir=profile_dir)
|
||||
return
|
||||
|
||||
# Benchmark.
|
||||
latencies = []
|
||||
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
|
||||
latencies.append(run_to_completion(profile_dir=None))
|
||||
latencies = np.array(latencies)
|
||||
percentages = [10, 25, 50, 75, 90, 99]
|
||||
percentiles = np.percentile(latencies, percentages)
|
||||
print(f"Avg latency: {np.mean(latencies)} seconds")
|
||||
for percentage, percentile in zip(percentages, percentiles):
|
||||
print(f"{percentage}% percentile latency: {percentile} seconds")
|
||||
|
||||
# Output JSON results if specified
|
||||
if args.output_json:
|
||||
results = {
|
||||
"avg_latency": np.mean(latencies),
|
||||
"latencies": latencies.tolist(),
|
||||
"percentiles": dict(zip(percentages, percentiles.tolist())),
|
||||
}
|
||||
with open(args.output_json, "w") as f:
|
||||
json.dump(results, f, indent=4)
|
||||
save_to_pytorch_benchmark_format(args, results)
|
||||
3
vllm/benchmarks/lib/__init__.py
Normal file
3
vllm/benchmarks/lib/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Benchmark library utilities."""
|
||||
777
vllm/benchmarks/lib/endpoint_request_func.py
Normal file
777
vllm/benchmarks/lib/endpoint_request_func.py
Normal file
@@ -0,0 +1,777 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""The request function for API endpoints."""
|
||||
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from collections.abc import Awaitable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal, Protocol
|
||||
|
||||
import aiohttp
|
||||
import regex as re
|
||||
from tqdm.asyncio import tqdm
|
||||
|
||||
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
||||
|
||||
|
||||
class StreamedResponseHandler:
|
||||
"""Handles streaming HTTP responses by accumulating chunks until complete
|
||||
messages are available."""
|
||||
|
||||
def __init__(self):
|
||||
self.buffer = ""
|
||||
|
||||
def add_chunk(self, chunk_bytes: bytes) -> list[str]:
|
||||
"""Add a chunk of bytes to the buffer and return any complete
|
||||
messages."""
|
||||
chunk_str = chunk_bytes.decode("utf-8")
|
||||
self.buffer += chunk_str
|
||||
|
||||
messages = []
|
||||
|
||||
# Split by double newlines (SSE message separator)
|
||||
while "\n\n" in self.buffer:
|
||||
message, self.buffer = self.buffer.split("\n\n", 1)
|
||||
message = message.strip()
|
||||
if message:
|
||||
messages.append(message)
|
||||
|
||||
# if self.buffer is not empty, check if it is a complete message
|
||||
# by removing data: prefix and check if it is a valid JSON
|
||||
if self.buffer.startswith("data: "):
|
||||
message_content = self.buffer.removeprefix("data: ").strip()
|
||||
if message_content == "[DONE]":
|
||||
messages.append(self.buffer.strip())
|
||||
self.buffer = ""
|
||||
elif message_content:
|
||||
try:
|
||||
json.loads(message_content)
|
||||
messages.append(self.buffer.strip())
|
||||
self.buffer = ""
|
||||
except json.JSONDecodeError:
|
||||
# Incomplete JSON, wait for more chunks.
|
||||
pass
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestFuncInput:
|
||||
"""The input for the request function."""
|
||||
|
||||
prompt: str | list[str]
|
||||
api_url: str
|
||||
prompt_len: int
|
||||
output_len: int
|
||||
model: str
|
||||
model_name: str | None = None
|
||||
logprobs: int | None = None
|
||||
extra_headers: dict | None = None
|
||||
extra_body: dict | None = None
|
||||
multi_modal_content: dict | list[dict] | None = None
|
||||
ignore_eos: bool = False
|
||||
language: str | None = None
|
||||
request_id: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestFuncOutput:
|
||||
"""The output of the request function including metrics."""
|
||||
|
||||
generated_text: str = ""
|
||||
success: bool = False
|
||||
latency: float = 0.0
|
||||
output_tokens: int = 0
|
||||
ttft: float = 0.0 # Time to first token
|
||||
itl: list[float] = field(default_factory=list) # list of inter-token latencies
|
||||
tpot: float = 0.0 # avg next-token latencies
|
||||
prompt_len: int = 0
|
||||
error: str = ""
|
||||
start_time: float = 0.0
|
||||
|
||||
|
||||
class RequestFunc(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: tqdm | None = None,
|
||||
) -> Awaitable[RequestFuncOutput]: ...
|
||||
|
||||
|
||||
def _validate_api_url(
|
||||
api_url: str,
|
||||
api_name: str,
|
||||
expected_suffixes: str | set[str],
|
||||
) -> None:
|
||||
if isinstance(expected_suffixes, str):
|
||||
expected_suffixes = {expected_suffixes}
|
||||
|
||||
expected_suffixes = {*expected_suffixes, "profile"}
|
||||
|
||||
if not api_url.endswith(tuple(expected_suffixes)):
|
||||
raise ValueError(f"{api_name} URL must end with one of: {expected_suffixes}.")
|
||||
|
||||
|
||||
def _update_payload_common(
|
||||
payload: dict[str, Any],
|
||||
request_func_input: RequestFuncInput,
|
||||
) -> None:
|
||||
if request_func_input.ignore_eos:
|
||||
payload["ignore_eos"] = request_func_input.ignore_eos
|
||||
if request_func_input.extra_body:
|
||||
payload.update(request_func_input.extra_body)
|
||||
|
||||
|
||||
def _update_headers_common(
|
||||
headers: dict[str, Any],
|
||||
request_func_input: RequestFuncInput,
|
||||
) -> None:
|
||||
if request_func_input.extra_headers:
|
||||
headers |= request_func_input.extra_headers
|
||||
if request_func_input.request_id:
|
||||
headers["x-request-id"] = request_func_input.request_id
|
||||
|
||||
|
||||
async def async_request_openai_completions(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: tqdm | None = None,
|
||||
) -> RequestFuncOutput:
|
||||
"""The async request function for the OpenAI Completions API.
|
||||
|
||||
Args:
|
||||
request_func_input: The input for the request function.
|
||||
pbar: The progress bar to display the progress.
|
||||
|
||||
Returns:
|
||||
The output of the request function.
|
||||
"""
|
||||
api_url = request_func_input.api_url
|
||||
_validate_api_url(api_url, "OpenAI Completions API", "completions")
|
||||
|
||||
payload = {
|
||||
"model": request_func_input.model_name
|
||||
if request_func_input.model_name
|
||||
else request_func_input.model,
|
||||
"prompt": request_func_input.prompt,
|
||||
"temperature": 0.0,
|
||||
"repetition_penalty": 1.0,
|
||||
"max_tokens": request_func_input.output_len,
|
||||
"logprobs": request_func_input.logprobs,
|
||||
"stream": True,
|
||||
"stream_options": {
|
||||
"include_usage": True,
|
||||
},
|
||||
}
|
||||
_update_payload_common(payload, request_func_input)
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
}
|
||||
_update_headers_common(headers, request_func_input)
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
generated_text = ""
|
||||
st = time.perf_counter()
|
||||
output.start_time = st
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
first_chunk_received = False
|
||||
handler = StreamedResponseHandler()
|
||||
|
||||
async for chunk_bytes in response.content.iter_any():
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
messages = handler.add_chunk(chunk_bytes)
|
||||
for message in messages:
|
||||
# NOTE: SSE comments (often used as pings) start with
|
||||
# a colon. These are not JSON data payload and should
|
||||
# be skipped.
|
||||
if message.startswith(":"):
|
||||
continue
|
||||
|
||||
chunk = message.removeprefix("data: ")
|
||||
|
||||
if chunk != "[DONE]":
|
||||
data = json.loads(chunk)
|
||||
|
||||
# NOTE: Some completion API might have a last
|
||||
# usage summary response without a token so we
|
||||
# want to check a token was generated
|
||||
if choices := data.get("choices"):
|
||||
# Note that text could be empty here
|
||||
# e.g. for special tokens
|
||||
text = choices[0].get("text")
|
||||
timestamp = time.perf_counter()
|
||||
# First token
|
||||
if not first_chunk_received:
|
||||
first_chunk_received = True
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp - most_recent_timestamp)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
generated_text += text or ""
|
||||
elif usage := data.get("usage"):
|
||||
output.output_tokens = usage.get("completion_tokens")
|
||||
if first_chunk_received:
|
||||
output.success = True
|
||||
else:
|
||||
output.success = False
|
||||
output.error = (
|
||||
"Never received a valid chunk to calculate TTFT."
|
||||
"This response will be marked as failed!"
|
||||
)
|
||||
output.generated_text = generated_text
|
||||
output.latency = most_recent_timestamp - st
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
except Exception:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
def _get_chat_content(
|
||||
request_func_input: RequestFuncInput,
|
||||
mm_position: Literal["first", "last"] = "last",
|
||||
) -> list[dict[str, Any]]:
|
||||
text_contents = [{"type": "text", "text": request_func_input.prompt}]
|
||||
|
||||
mm_contents = []
|
||||
if request_func_input.multi_modal_content:
|
||||
mm_content = request_func_input.multi_modal_content
|
||||
if isinstance(mm_content, list):
|
||||
mm_contents.extend(request_func_input.multi_modal_content)
|
||||
elif isinstance(mm_content, dict):
|
||||
mm_contents.append(request_func_input.multi_modal_content)
|
||||
else:
|
||||
raise TypeError(
|
||||
"multi_modal_content must be a dict or list[dict] for openai-chat"
|
||||
)
|
||||
|
||||
if mm_position == "first":
|
||||
return mm_contents + text_contents
|
||||
|
||||
return text_contents + mm_contents
|
||||
|
||||
|
||||
async def async_request_openai_chat_completions(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: tqdm | None = None,
|
||||
mm_position: Literal["first", "last"] = "last",
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
_validate_api_url(api_url, "OpenAI Chat Completions API", "chat/completions")
|
||||
|
||||
content = _get_chat_content(request_func_input, mm_position=mm_position)
|
||||
|
||||
payload = {
|
||||
"model": request_func_input.model_name
|
||||
if request_func_input.model_name
|
||||
else request_func_input.model,
|
||||
"messages": [
|
||||
{"role": "user", "content": content},
|
||||
],
|
||||
"temperature": 0.0,
|
||||
"max_completion_tokens": request_func_input.output_len,
|
||||
"stream": True,
|
||||
"stream_options": {
|
||||
"include_usage": True,
|
||||
},
|
||||
}
|
||||
_update_payload_common(payload, request_func_input)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
}
|
||||
_update_headers_common(headers, request_func_input)
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
generated_text = ""
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
output.start_time = st
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
handler = StreamedResponseHandler()
|
||||
async for chunk_bytes in response.content.iter_any():
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
messages = handler.add_chunk(chunk_bytes)
|
||||
for message in messages:
|
||||
# NOTE: SSE comments (often used as pings) start with
|
||||
# a colon. These are not JSON data payload and should
|
||||
# be skipped.
|
||||
if message.startswith(":"):
|
||||
continue
|
||||
|
||||
chunk = message.removeprefix("data: ")
|
||||
|
||||
if chunk != "[DONE]":
|
||||
timestamp = time.perf_counter()
|
||||
data = json.loads(chunk)
|
||||
|
||||
if choices := data.get("choices"):
|
||||
content = choices[0]["delta"].get("content")
|
||||
# First token
|
||||
if ttft == 0.0:
|
||||
ttft = timestamp - st
|
||||
output.ttft = ttft
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp - most_recent_timestamp)
|
||||
|
||||
generated_text += content or ""
|
||||
elif usage := data.get("usage"):
|
||||
output.output_tokens = usage.get("completion_tokens")
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
|
||||
output.generated_text = generated_text
|
||||
output.success = True
|
||||
output.latency = most_recent_timestamp - st
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
except Exception:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
async def async_request_openai_audio(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: tqdm | None = None,
|
||||
) -> RequestFuncOutput:
|
||||
# Lazy import without PlaceholderModule to avoid vllm dep.
|
||||
import soundfile
|
||||
|
||||
api_url = request_func_input.api_url
|
||||
_validate_api_url(api_url, "OpenAI Audio API", {"transcriptions", "translations"})
|
||||
|
||||
content = [{"type": "text", "text": request_func_input.prompt}]
|
||||
payload = {
|
||||
"model": request_func_input.model_name
|
||||
if request_func_input.model_name
|
||||
else request_func_input.model,
|
||||
"temperature": 0.0,
|
||||
"max_completion_tokens": request_func_input.output_len,
|
||||
"stream": True,
|
||||
"language": "en",
|
||||
# Flattened due to multipart/form-data
|
||||
"stream_include_usage": True,
|
||||
"stream_continuous_usage_stats": True,
|
||||
}
|
||||
_update_payload_common(payload, request_func_input)
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
}
|
||||
_update_headers_common(headers, request_func_input)
|
||||
|
||||
# Send audio file
|
||||
def to_bytes(y, sr):
|
||||
buffer = io.BytesIO()
|
||||
soundfile.write(buffer, y, sr, format="WAV")
|
||||
buffer.seek(0)
|
||||
return buffer
|
||||
|
||||
mm_audio = request_func_input.multi_modal_content
|
||||
if not isinstance(mm_audio, dict) or "audio" not in mm_audio:
|
||||
raise TypeError("multi_modal_content must be a dict containing 'audio'")
|
||||
with to_bytes(*mm_audio["audio"]) as f:
|
||||
form = aiohttp.FormData()
|
||||
form.add_field("file", f, content_type="audio/wav")
|
||||
for key, value in payload.items():
|
||||
form.add_field(key, str(value))
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
generated_text = ""
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
output.start_time = st
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(
|
||||
url=api_url, data=form, headers=headers
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
handler = StreamedResponseHandler()
|
||||
|
||||
async for chunk_bytes in response.content.iter_any():
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
messages = handler.add_chunk(chunk_bytes)
|
||||
for message in messages:
|
||||
chunk = message.decode("utf-8").removeprefix("data: ")
|
||||
if chunk != "[DONE]":
|
||||
timestamp = time.perf_counter()
|
||||
data = json.loads(chunk)
|
||||
|
||||
if choices := data.get("choices"):
|
||||
content = choices[0]["delta"].get("content")
|
||||
# First token
|
||||
if ttft == 0.0:
|
||||
ttft = timestamp - st
|
||||
output.ttft = ttft
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(
|
||||
timestamp - most_recent_timestamp
|
||||
)
|
||||
|
||||
generated_text += content or ""
|
||||
elif usage := data.get("usage"):
|
||||
output.output_tokens = usage.get(
|
||||
"completion_tokens"
|
||||
)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
|
||||
output.generated_text = generated_text
|
||||
output.success = True
|
||||
output.latency = most_recent_timestamp - st
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
except Exception:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
async def _run_pooling_request(
|
||||
session: aiohttp.ClientSession,
|
||||
api_url: str,
|
||||
payload: dict[str, Any],
|
||||
headers: dict[str, Any],
|
||||
pbar: tqdm | None = None,
|
||||
) -> RequestFuncOutput:
|
||||
output = RequestFuncOutput()
|
||||
st = time.perf_counter()
|
||||
output.start_time = st
|
||||
try:
|
||||
async with session.post(url=api_url, headers=headers, json=payload) as response:
|
||||
if response.status == 200:
|
||||
output.ttft = output.latency = time.perf_counter() - st
|
||||
|
||||
if payload.get("encoding_format", "float") == "bytes":
|
||||
metadata = json.loads(response.headers["metadata"])
|
||||
usage = metadata.get("usage", {})
|
||||
else:
|
||||
data = await response.json()
|
||||
usage = data.get("usage", {})
|
||||
|
||||
output.success = True
|
||||
output.generated_text = ""
|
||||
output.prompt_len = usage.get("prompt_tokens", 0)
|
||||
else:
|
||||
output.success = False
|
||||
output.error = response.reason or ""
|
||||
except Exception as e:
|
||||
output.success = False
|
||||
output.error = str(e)
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
async def async_request_openai_embeddings(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: tqdm | None = None,
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
_validate_api_url(api_url, "OpenAI Embeddings API", "embeddings")
|
||||
|
||||
payload = {
|
||||
"model": request_func_input.model_name
|
||||
if request_func_input.model_name
|
||||
else request_func_input.model,
|
||||
"input": request_func_input.prompt,
|
||||
# Many embedding models have short context length,
|
||||
# this is to avoid dropping some of the requests.
|
||||
"truncate_prompt_tokens": -1,
|
||||
}
|
||||
_update_payload_common(payload, request_func_input)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
}
|
||||
_update_headers_common(headers, request_func_input)
|
||||
|
||||
return await _run_pooling_request(
|
||||
session,
|
||||
api_url,
|
||||
payload=payload,
|
||||
headers=headers,
|
||||
pbar=pbar,
|
||||
)
|
||||
|
||||
|
||||
async def async_request_vllm_rerank(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: tqdm | None = None,
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
_validate_api_url(api_url, "vLLM score API", "rerank")
|
||||
|
||||
assert (
|
||||
isinstance(request_func_input.prompt, list)
|
||||
and len(request_func_input.prompt) > 1
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": request_func_input.model_name
|
||||
if request_func_input.model_name
|
||||
else request_func_input.model,
|
||||
"query": request_func_input.prompt[0],
|
||||
"documents": request_func_input.prompt[1:],
|
||||
# Many reranker models have short context length,
|
||||
# this is to avoid dropping some of the requests.
|
||||
"truncate_prompt_tokens": -1,
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
}
|
||||
_update_headers_common(headers, request_func_input)
|
||||
|
||||
return await _run_pooling_request(
|
||||
session,
|
||||
api_url,
|
||||
payload=payload,
|
||||
headers=headers,
|
||||
pbar=pbar,
|
||||
)
|
||||
|
||||
|
||||
async def async_request_openai_embeddings_chat(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: tqdm | None = None,
|
||||
mm_position: Literal["first", "last"] = "last",
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
_validate_api_url(api_url, "OpenAI Embeddings API", "embeddings")
|
||||
|
||||
content = _get_chat_content(request_func_input, mm_position=mm_position)
|
||||
|
||||
payload = {
|
||||
"model": request_func_input.model_name
|
||||
if request_func_input.model_name
|
||||
else request_func_input.model,
|
||||
"messages": [
|
||||
{"role": "user", "content": content},
|
||||
],
|
||||
# Many embedding models have short context length,
|
||||
# this is to avoid dropping some of the requests.
|
||||
"truncate_prompt_tokens": -1,
|
||||
}
|
||||
_update_payload_common(payload, request_func_input)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
}
|
||||
_update_headers_common(headers, request_func_input)
|
||||
|
||||
return await _run_pooling_request(
|
||||
session,
|
||||
api_url,
|
||||
payload=payload,
|
||||
headers=headers,
|
||||
pbar=pbar,
|
||||
)
|
||||
|
||||
|
||||
def _try_extract_request_idx(request_func_input: RequestFuncInput):
|
||||
if request_func_input.request_id:
|
||||
match = re.search(r"(\d+)$", request_func_input.request_id)
|
||||
if match:
|
||||
try:
|
||||
return int(match.group(1))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _preprocess_clip(request_func_input: RequestFuncInput):
|
||||
if request_func_input.multi_modal_content:
|
||||
# Image input
|
||||
request_func_input.prompt = ""
|
||||
|
||||
|
||||
def _preprocess_vlm2vec(request_func_input: RequestFuncInput):
|
||||
if request_func_input.multi_modal_content:
|
||||
request_idx = _try_extract_request_idx(request_func_input)
|
||||
|
||||
# Adjust the ratio manually if needed.
|
||||
use_image_only_prompt = request_idx is None or request_idx % 2 == 0
|
||||
|
||||
if use_image_only_prompt:
|
||||
# Image input
|
||||
request_func_input.prompt = "Represent the given image."
|
||||
else:
|
||||
# Text+Image input
|
||||
request_func_input.prompt = (
|
||||
f"Represent the given image with the following question: "
|
||||
f"{request_func_input.prompt}"
|
||||
)
|
||||
|
||||
|
||||
async def async_request_openai_embeddings_clip(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: tqdm | None = None,
|
||||
) -> RequestFuncOutput:
|
||||
_preprocess_clip(request_func_input)
|
||||
|
||||
return await async_request_openai_embeddings_chat(
|
||||
request_func_input,
|
||||
session,
|
||||
pbar=pbar,
|
||||
)
|
||||
|
||||
|
||||
async def async_request_openai_embeddings_vlm2vec(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: tqdm | None = None,
|
||||
) -> RequestFuncOutput:
|
||||
_preprocess_vlm2vec(request_func_input)
|
||||
|
||||
return await async_request_openai_embeddings_chat(
|
||||
request_func_input,
|
||||
session,
|
||||
pbar=pbar,
|
||||
mm_position="first",
|
||||
)
|
||||
|
||||
|
||||
async def async_request_infinity_embeddings(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: tqdm | None = None,
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
_validate_api_url(api_url, "Infinity Embeddings API", "embeddings")
|
||||
|
||||
payload = {
|
||||
"model": request_func_input.model_name
|
||||
if request_func_input.model_name
|
||||
else request_func_input.model,
|
||||
}
|
||||
|
||||
if request_func_input.prompt:
|
||||
payload["input"] = request_func_input.prompt
|
||||
else:
|
||||
mm_content = request_func_input.multi_modal_content
|
||||
assert isinstance(mm_content, dict)
|
||||
|
||||
mm_type = mm_content["type"]
|
||||
payload["input"] = mm_content[mm_type]["url"]
|
||||
payload["modality"] = mm_type.split("_", 1)[0]
|
||||
|
||||
_update_payload_common(payload, request_func_input)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
}
|
||||
_update_headers_common(headers, request_func_input)
|
||||
|
||||
return await _run_pooling_request(
|
||||
session,
|
||||
api_url,
|
||||
payload=payload,
|
||||
headers=headers,
|
||||
pbar=pbar,
|
||||
)
|
||||
|
||||
|
||||
async def async_request_infinity_embeddings_clip(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: tqdm | None = None,
|
||||
) -> RequestFuncOutput:
|
||||
_preprocess_clip(request_func_input)
|
||||
|
||||
return await async_request_infinity_embeddings(
|
||||
request_func_input,
|
||||
session,
|
||||
pbar=pbar,
|
||||
)
|
||||
|
||||
|
||||
# TODO: Add more request functions for different API protocols.
|
||||
ASYNC_REQUEST_FUNCS: dict[str, RequestFunc] = {
|
||||
"vllm": async_request_openai_completions,
|
||||
"openai": async_request_openai_completions,
|
||||
"openai-chat": async_request_openai_chat_completions,
|
||||
"openai-audio": async_request_openai_audio,
|
||||
"openai-embeddings": async_request_openai_embeddings,
|
||||
"openai-embeddings-chat": async_request_openai_embeddings_chat,
|
||||
"openai-embeddings-clip": async_request_openai_embeddings_clip,
|
||||
"openai-embeddings-vlm2vec": async_request_openai_embeddings_vlm2vec,
|
||||
# Infinity embedding server: https://github.com/michaelfeil/infinity
|
||||
"infinity-embeddings": async_request_infinity_embeddings,
|
||||
"infinity-embeddings-clip": async_request_infinity_embeddings_clip,
|
||||
# (Infinity embedding server does not support vlm2vec)
|
||||
"vllm-rerank": async_request_vllm_rerank,
|
||||
}
|
||||
|
||||
OPENAI_COMPATIBLE_BACKENDS = [
|
||||
k
|
||||
for k, v in ASYNC_REQUEST_FUNCS.items()
|
||||
if v in (async_request_openai_completions, async_request_openai_chat_completions)
|
||||
]
|
||||
72
vllm/benchmarks/lib/ready_checker.py
Normal file
72
vllm/benchmarks/lib/ready_checker.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Utilities for checking endpoint readiness."""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import aiohttp
|
||||
from tqdm.asyncio import tqdm
|
||||
|
||||
from .endpoint_request_func import RequestFunc, RequestFuncInput, RequestFuncOutput
|
||||
|
||||
|
||||
async def wait_for_endpoint(
|
||||
request_func: RequestFunc,
|
||||
test_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
timeout_seconds: int = 600,
|
||||
retry_interval: int = 5,
|
||||
) -> RequestFuncOutput:
|
||||
"""
|
||||
Wait for an endpoint to become available before starting benchmarks.
|
||||
|
||||
Args:
|
||||
request_func: The async request function to call
|
||||
test_input: The RequestFuncInput to test with
|
||||
timeout_seconds: Maximum time to wait in seconds (default: 10 minutes)
|
||||
retry_interval: Time between retries in seconds (default: 5 seconds)
|
||||
|
||||
Returns:
|
||||
RequestFuncOutput: The successful response
|
||||
|
||||
Raises:
|
||||
ValueError: If the endpoint doesn't become available within the timeout
|
||||
"""
|
||||
deadline = time.perf_counter() + timeout_seconds
|
||||
output = RequestFuncOutput(success=False)
|
||||
print(f"Waiting for endpoint to become up in {timeout_seconds} seconds")
|
||||
|
||||
with tqdm(
|
||||
total=timeout_seconds,
|
||||
bar_format="{desc} |{bar}| {elapsed} elapsed, {remaining} remaining",
|
||||
unit="s",
|
||||
) as pbar:
|
||||
while True:
|
||||
# update progress bar
|
||||
remaining = deadline - time.perf_counter()
|
||||
elapsed = timeout_seconds - remaining
|
||||
update_amount = min(elapsed - pbar.n, timeout_seconds - pbar.n)
|
||||
pbar.update(update_amount)
|
||||
pbar.refresh()
|
||||
if remaining <= 0:
|
||||
pbar.close()
|
||||
break
|
||||
|
||||
# ping the endpoint using request_func
|
||||
try:
|
||||
output = await request_func(
|
||||
request_func_input=test_input, session=session
|
||||
)
|
||||
if output.success:
|
||||
pbar.close()
|
||||
return output
|
||||
except aiohttp.ClientConnectorError:
|
||||
pass
|
||||
|
||||
# retry after a delay
|
||||
sleep_duration = min(retry_interval, remaining)
|
||||
if sleep_duration > 0:
|
||||
await asyncio.sleep(sleep_duration)
|
||||
|
||||
return output
|
||||
79
vllm/benchmarks/lib/utils.py
Normal file
79
vllm/benchmarks/lib/utils.py
Normal file
@@ -0,0 +1,79 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
|
||||
def convert_to_pytorch_benchmark_format(
|
||||
args: argparse.Namespace, metrics: dict[str, list], extra_info: dict[str, Any]
|
||||
) -> list:
|
||||
"""
|
||||
Save the benchmark results in the format used by PyTorch OSS benchmark with
|
||||
on metric per record
|
||||
https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database
|
||||
"""
|
||||
records = []
|
||||
if not os.environ.get("SAVE_TO_PYTORCH_BENCHMARK_FORMAT", False):
|
||||
return records
|
||||
|
||||
for name, benchmark_values in metrics.items():
|
||||
record = {
|
||||
"benchmark": {
|
||||
"name": "vLLM benchmark",
|
||||
"extra_info": {
|
||||
"args": vars(args),
|
||||
},
|
||||
},
|
||||
"model": {
|
||||
"name": args.model,
|
||||
},
|
||||
"metric": {
|
||||
"name": name,
|
||||
"benchmark_values": benchmark_values,
|
||||
"extra_info": extra_info,
|
||||
},
|
||||
}
|
||||
|
||||
tp = record["benchmark"]["extra_info"]["args"].get("tensor_parallel_size")
|
||||
# Save tensor_parallel_size parameter if it's part of the metadata
|
||||
if not tp and "tensor_parallel_size" in extra_info:
|
||||
record["benchmark"]["extra_info"]["args"]["tensor_parallel_size"] = (
|
||||
extra_info["tensor_parallel_size"]
|
||||
)
|
||||
|
||||
records.append(record)
|
||||
|
||||
return records
|
||||
|
||||
|
||||
class InfEncoder(json.JSONEncoder):
|
||||
def clear_inf(self, o: Any):
|
||||
if isinstance(o, dict):
|
||||
return {
|
||||
str(k)
|
||||
if not isinstance(k, (str, int, float, bool, type(None)))
|
||||
else k: self.clear_inf(v)
|
||||
for k, v in o.items()
|
||||
}
|
||||
elif isinstance(o, list):
|
||||
return [self.clear_inf(v) for v in o]
|
||||
elif isinstance(o, float) and math.isinf(o):
|
||||
return "inf"
|
||||
return o
|
||||
|
||||
def iterencode(self, o: Any, *args, **kwargs) -> Any:
|
||||
return super().iterencode(self.clear_inf(o), *args, **kwargs)
|
||||
|
||||
|
||||
def write_to_json(filename: str, records: list) -> None:
|
||||
with open(filename, "w") as f:
|
||||
json.dump(
|
||||
records,
|
||||
f,
|
||||
cls=InfEncoder,
|
||||
default=lambda o: f"<{type(o).__name__} is not JSON serializable>",
|
||||
)
|
||||
1538
vllm/benchmarks/serve.py
Normal file
1538
vllm/benchmarks/serve.py
Normal file
File diff suppressed because it is too large
Load Diff
326
vllm/benchmarks/startup.py
Normal file
326
vllm/benchmarks/startup.py
Normal file
@@ -0,0 +1,326 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Benchmark the cold and warm startup time of vLLM models.
|
||||
|
||||
This script measures total startup time (including model loading, compilation,
|
||||
and cache operations) for both cold and warm scenarios:
|
||||
- Cold startup: Fresh start with no caches (temporary cache directories)
|
||||
- Warm startup: Using cached compilation and model info
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.benchmarks.lib.utils import (
|
||||
convert_to_pytorch_benchmark_format,
|
||||
write_to_json,
|
||||
)
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
|
||||
|
||||
@contextmanager
|
||||
def cold_startup():
|
||||
"""
|
||||
Context manager to measure cold startup time:
|
||||
1. Uses a temporary directory for vLLM cache to avoid any pollution
|
||||
between cold startup iterations.
|
||||
2. Uses inductor's fresh_cache to clear torch.compile caches.
|
||||
"""
|
||||
from torch._inductor.utils import fresh_cache
|
||||
|
||||
# Use temporary directory for caching to avoid any pollution between cold startups
|
||||
original_cache_root = os.environ.get("VLLM_CACHE_ROOT")
|
||||
temp_cache_dir = tempfile.mkdtemp(prefix="vllm_startup_bench_cold_")
|
||||
try:
|
||||
os.environ["VLLM_CACHE_ROOT"] = temp_cache_dir
|
||||
with fresh_cache():
|
||||
yield
|
||||
finally:
|
||||
# Clean up temporary cache directory
|
||||
shutil.rmtree(temp_cache_dir, ignore_errors=True)
|
||||
if original_cache_root:
|
||||
os.environ["VLLM_CACHE_ROOT"] = original_cache_root
|
||||
else:
|
||||
os.environ.pop("VLLM_CACHE_ROOT", None)
|
||||
|
||||
|
||||
def run_startup_in_subprocess(engine_args_dict, result_queue):
|
||||
"""
|
||||
Run LLM startup in a subprocess and return timing metrics via a queue.
|
||||
This ensures complete isolation between iterations.
|
||||
"""
|
||||
try:
|
||||
# Import inside the subprocess to avoid issues with forking
|
||||
from vllm import LLM
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
|
||||
engine_args = EngineArgs(**engine_args_dict)
|
||||
|
||||
# Measure total startup time
|
||||
start_time = time.perf_counter()
|
||||
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
|
||||
total_startup_time = time.perf_counter() - start_time
|
||||
|
||||
# Extract compilation time if available
|
||||
compilation_time = 0.0
|
||||
if hasattr(llm.llm_engine, "vllm_config"):
|
||||
vllm_config = llm.llm_engine.vllm_config
|
||||
if (
|
||||
hasattr(vllm_config, "compilation_config")
|
||||
and vllm_config.compilation_config is not None
|
||||
):
|
||||
compilation_time = vllm_config.compilation_config.compilation_time
|
||||
|
||||
result_queue.put(
|
||||
{
|
||||
"total_startup_time": total_startup_time,
|
||||
"compilation_time": compilation_time,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
result_queue.put(None)
|
||||
result_queue.put(str(e))
|
||||
|
||||
|
||||
def save_to_pytorch_benchmark_format(
|
||||
args: argparse.Namespace, results: dict[str, Any]
|
||||
) -> None:
|
||||
base_name = os.path.splitext(args.output_json)[0]
|
||||
|
||||
cold_startup_records = convert_to_pytorch_benchmark_format(
|
||||
args=args,
|
||||
metrics={
|
||||
"avg_cold_startup_time": results["avg_cold_startup_time"],
|
||||
},
|
||||
extra_info={
|
||||
"cold_startup_times": results["cold_startup_times"],
|
||||
"cold_startup_percentiles": results["cold_startup_percentiles"],
|
||||
},
|
||||
)
|
||||
if cold_startup_records:
|
||||
write_to_json(f"{base_name}.cold_startup.pytorch.json", cold_startup_records)
|
||||
|
||||
cold_compilation_records = convert_to_pytorch_benchmark_format(
|
||||
args=args,
|
||||
metrics={
|
||||
"avg_cold_compilation_time": results["avg_cold_compilation_time"],
|
||||
},
|
||||
extra_info={
|
||||
"cold_compilation_times": results["cold_compilation_times"],
|
||||
"cold_compilation_percentiles": results["cold_compilation_percentiles"],
|
||||
},
|
||||
)
|
||||
if cold_compilation_records:
|
||||
write_to_json(
|
||||
f"{base_name}.cold_compilation.pytorch.json", cold_compilation_records
|
||||
)
|
||||
|
||||
warm_startup_records = convert_to_pytorch_benchmark_format(
|
||||
args=args,
|
||||
metrics={
|
||||
"avg_warm_startup_time": results["avg_warm_startup_time"],
|
||||
},
|
||||
extra_info={
|
||||
"warm_startup_times": results["warm_startup_times"],
|
||||
"warm_startup_percentiles": results["warm_startup_percentiles"],
|
||||
},
|
||||
)
|
||||
if warm_startup_records:
|
||||
write_to_json(f"{base_name}.warm_startup.pytorch.json", warm_startup_records)
|
||||
|
||||
warm_compilation_records = convert_to_pytorch_benchmark_format(
|
||||
args=args,
|
||||
metrics={
|
||||
"avg_warm_compilation_time": results["avg_warm_compilation_time"],
|
||||
},
|
||||
extra_info={
|
||||
"warm_compilation_times": results["warm_compilation_times"],
|
||||
"warm_compilation_percentiles": results["warm_compilation_percentiles"],
|
||||
},
|
||||
)
|
||||
if warm_compilation_records:
|
||||
write_to_json(
|
||||
f"{base_name}.warm_compilation.pytorch.json", warm_compilation_records
|
||||
)
|
||||
|
||||
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--num-iters-cold",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of cold startup iterations.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-iters-warmup",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Number of warmup iterations before benchmarking warm startups.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-iters-warm",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of warm startup iterations.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-json",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to save the startup time results in JSON format.",
|
||||
)
|
||||
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
return parser
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
# Set multiprocessing start method to 'spawn' for clean process isolation
|
||||
# This ensures each subprocess starts fresh without inheriting state
|
||||
multiprocessing.set_start_method("spawn", force=True)
|
||||
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
|
||||
def create_llm_and_measure_startup():
|
||||
"""
|
||||
Create LLM instance in a subprocess and measure startup time.
|
||||
Returns timing metrics, using subprocess for complete isolation.
|
||||
"""
|
||||
# Convert engine_args to dictionary for pickling
|
||||
engine_args_dict = dataclasses.asdict(engine_args)
|
||||
|
||||
# Create a queue for inter-process communication
|
||||
result_queue = multiprocessing.Queue()
|
||||
process = multiprocessing.Process(
|
||||
target=run_startup_in_subprocess,
|
||||
args=(
|
||||
engine_args_dict,
|
||||
result_queue,
|
||||
),
|
||||
)
|
||||
process.start()
|
||||
process.join()
|
||||
|
||||
if not result_queue.empty():
|
||||
result = result_queue.get()
|
||||
if result is None:
|
||||
if not result_queue.empty():
|
||||
error_msg = result_queue.get()
|
||||
raise RuntimeError(f"Subprocess failed: {error_msg}")
|
||||
else:
|
||||
raise RuntimeError("Subprocess failed with unknown error")
|
||||
return result
|
||||
else:
|
||||
raise RuntimeError("Subprocess did not return a result")
|
||||
|
||||
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
|
||||
print("Setting VLLM_ENABLE_V1_MULTIPROCESSING=0 to collect startup metrics.\n")
|
||||
|
||||
print("Measuring cold startup time...\n")
|
||||
cold_startup_times = []
|
||||
cold_compilation_times = []
|
||||
for i in tqdm(range(args.num_iters_cold), desc="Cold startup iterations"):
|
||||
with cold_startup():
|
||||
metrics = create_llm_and_measure_startup()
|
||||
cold_startup_times.append(metrics["total_startup_time"])
|
||||
cold_compilation_times.append(metrics["compilation_time"])
|
||||
|
||||
# Warmup for warm startup
|
||||
print("\nWarming up for warm startup measurement...\n")
|
||||
for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
|
||||
create_llm_and_measure_startup()
|
||||
|
||||
print("\nMeasuring warm startup time...\n")
|
||||
warm_startup_times = []
|
||||
warm_compilation_times = []
|
||||
for i in tqdm(range(args.num_iters_warm), desc="Warm startup iterations"):
|
||||
metrics = create_llm_and_measure_startup()
|
||||
warm_startup_times.append(metrics["total_startup_time"])
|
||||
warm_compilation_times.append(metrics["compilation_time"])
|
||||
|
||||
# Calculate statistics
|
||||
cold_startup_array = np.array(cold_startup_times)
|
||||
cold_compilation_array = np.array(cold_compilation_times)
|
||||
warm_startup_array = np.array(warm_startup_times)
|
||||
warm_compilation_array = np.array(warm_compilation_times)
|
||||
|
||||
avg_cold_startup = np.mean(cold_startup_array)
|
||||
avg_cold_compilation = np.mean(cold_compilation_array)
|
||||
avg_warm_startup = np.mean(warm_startup_array)
|
||||
avg_warm_compilation = np.mean(warm_compilation_array)
|
||||
|
||||
percentages = [10, 25, 50, 75, 90, 99]
|
||||
cold_startup_percentiles = np.percentile(cold_startup_array, percentages)
|
||||
cold_compilation_percentiles = np.percentile(cold_compilation_array, percentages)
|
||||
warm_startup_percentiles = np.percentile(warm_startup_array, percentages)
|
||||
warm_compilation_percentiles = np.percentile(warm_compilation_array, percentages)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("STARTUP TIME BENCHMARK RESULTS")
|
||||
print("=" * 60)
|
||||
|
||||
# Cold startup statistics
|
||||
print("\nCOLD STARTUP:")
|
||||
print(f"Avg total startup time: {avg_cold_startup:.2f} seconds")
|
||||
print(f"Avg compilation time: {avg_cold_compilation:.2f} seconds")
|
||||
print("Startup time percentiles:")
|
||||
for percentage, percentile in zip(percentages, cold_startup_percentiles):
|
||||
print(f" {percentage}%: {percentile:.2f} seconds")
|
||||
print("Compilation time percentiles:")
|
||||
for percentage, percentile in zip(percentages, cold_compilation_percentiles):
|
||||
print(f" {percentage}%: {percentile:.2f} seconds")
|
||||
|
||||
# Warm startup statistics
|
||||
print("\nWARM STARTUP:")
|
||||
print(f"Avg total startup time: {avg_warm_startup:.2f} seconds")
|
||||
print(f"Avg compilation time: {avg_warm_compilation:.2f} seconds")
|
||||
print("Startup time percentiles:")
|
||||
for percentage, percentile in zip(percentages, warm_startup_percentiles):
|
||||
print(f" {percentage}%: {percentile:.2f} seconds")
|
||||
print("Compilation time percentiles:")
|
||||
for percentage, percentile in zip(percentages, warm_compilation_percentiles):
|
||||
print(f" {percentage}%: {percentile:.2f} seconds")
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
# Output JSON results if specified
|
||||
if args.output_json:
|
||||
results = {
|
||||
"avg_cold_startup_time": float(avg_cold_startup),
|
||||
"avg_cold_compilation_time": float(avg_cold_compilation),
|
||||
"cold_startup_times": cold_startup_times,
|
||||
"cold_compilation_times": cold_compilation_times,
|
||||
"cold_startup_percentiles": dict(
|
||||
zip(percentages, cold_startup_percentiles.tolist())
|
||||
),
|
||||
"cold_compilation_percentiles": dict(
|
||||
zip(percentages, cold_compilation_percentiles.tolist())
|
||||
),
|
||||
"avg_warm_startup_time": float(avg_warm_startup),
|
||||
"avg_warm_compilation_time": float(avg_warm_compilation),
|
||||
"warm_startup_times": warm_startup_times,
|
||||
"warm_compilation_times": warm_compilation_times,
|
||||
"warm_startup_percentiles": dict(
|
||||
zip(percentages, warm_startup_percentiles.tolist())
|
||||
),
|
||||
"warm_compilation_percentiles": dict(
|
||||
zip(percentages, warm_compilation_percentiles.tolist())
|
||||
),
|
||||
}
|
||||
with open(args.output_json, "w") as f:
|
||||
json.dump(results, f, indent=4)
|
||||
save_to_pytorch_benchmark_format(args, results)
|
||||
0
vllm/benchmarks/sweep/__init__.py
Normal file
0
vllm/benchmarks/sweep/__init__.py
Normal file
41
vllm/benchmarks/sweep/cli.py
Normal file
41
vllm/benchmarks/sweep/cli.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
|
||||
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG
|
||||
|
||||
from .plot import SweepPlotArgs
|
||||
from .plot import main as plot_main
|
||||
from .plot_pareto import SweepPlotParetoArgs
|
||||
from .plot_pareto import main as plot_pareto_main
|
||||
from .serve import SweepServeArgs
|
||||
from .serve import main as serve_main
|
||||
from .serve_sla import SweepServeSLAArgs
|
||||
from .serve_sla import main as serve_sla_main
|
||||
|
||||
SUBCOMMANDS = (
|
||||
(SweepServeArgs, serve_main),
|
||||
(SweepServeSLAArgs, serve_sla_main),
|
||||
(SweepPlotArgs, plot_main),
|
||||
(SweepPlotParetoArgs, plot_pareto_main),
|
||||
)
|
||||
|
||||
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
subparsers = parser.add_subparsers(required=True, dest="sweep_type")
|
||||
|
||||
for cmd, entrypoint in SUBCOMMANDS:
|
||||
cmd_subparser = subparsers.add_parser(
|
||||
cmd.parser_name,
|
||||
description=cmd.parser_help,
|
||||
usage=f"vllm bench sweep {cmd.parser_name} [options]",
|
||||
)
|
||||
cmd_subparser.set_defaults(dispatch_function=entrypoint)
|
||||
cmd.add_cli_args(cmd_subparser)
|
||||
cmd_subparser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format(
|
||||
subcmd=f"sweep {cmd.parser_name}"
|
||||
)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
args.dispatch_function(args)
|
||||
158
vllm/benchmarks/sweep/param_sweep.py
Normal file
158
vllm/benchmarks/sweep/param_sweep.py
Normal file
@@ -0,0 +1,158 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ParameterSweep(list["ParameterSweepItem"]):
|
||||
@classmethod
|
||||
def read_json(cls, filepath: os.PathLike):
|
||||
with open(filepath, "rb") as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Support both list and dict formats
|
||||
if isinstance(data, dict):
|
||||
return cls.read_from_dict(data)
|
||||
|
||||
return cls.from_records(data)
|
||||
|
||||
@classmethod
|
||||
def read_from_dict(cls, data: dict[str, dict[str, object]]):
|
||||
"""
|
||||
Read parameter sweep from a dict format where keys are names.
|
||||
|
||||
Example:
|
||||
{
|
||||
"experiment1": {"max_tokens": 100, "temperature": 0.7},
|
||||
"experiment2": {"max_tokens": 200, "temperature": 0.9}
|
||||
}
|
||||
"""
|
||||
records = [{"_benchmark_name": name, **params} for name, params in data.items()]
|
||||
return cls.from_records(records)
|
||||
|
||||
@classmethod
|
||||
def from_records(cls, records: list[dict[str, object]]):
|
||||
if not isinstance(records, list):
|
||||
raise TypeError(
|
||||
f"The parameter sweep should be a list of dictionaries, "
|
||||
f"but found type: {type(records)}"
|
||||
)
|
||||
|
||||
# Validate that all _benchmark_name values are unique if provided
|
||||
names = [r["_benchmark_name"] for r in records if "_benchmark_name" in r]
|
||||
if names and len(names) != len(set(names)):
|
||||
duplicates = [name for name in names if names.count(name) > 1]
|
||||
raise ValueError(
|
||||
f"Duplicate _benchmark_name values found: {set(duplicates)}. "
|
||||
f"All _benchmark_name values must be unique."
|
||||
)
|
||||
|
||||
return cls(ParameterSweepItem.from_record(record) for record in records)
|
||||
|
||||
|
||||
class ParameterSweepItem(dict[str, object]):
|
||||
@classmethod
|
||||
def from_record(cls, record: dict[str, object]):
|
||||
if not isinstance(record, dict):
|
||||
raise TypeError(
|
||||
f"Each item in the parameter sweep should be a dictionary, "
|
||||
f"but found type: {type(record)}"
|
||||
)
|
||||
|
||||
return cls(record)
|
||||
|
||||
def __or__(self, other: dict[str, Any]):
|
||||
return type(self)(super().__or__(other))
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""
|
||||
Get the name for this parameter sweep item.
|
||||
|
||||
Returns the '_benchmark_name' field if present, otherwise returns a text
|
||||
representation of all parameters.
|
||||
"""
|
||||
if "_benchmark_name" in self:
|
||||
return self["_benchmark_name"]
|
||||
return self.as_text(sep="-")
|
||||
|
||||
# In JSON, we prefer "_"
|
||||
def _iter_param_key_candidates(self, param_key: str):
|
||||
# Inner config arguments are not converted by the CLI
|
||||
if "." in param_key:
|
||||
prefix, rest = param_key.split(".", 1)
|
||||
for prefix_candidate in self._iter_param_key_candidates(prefix):
|
||||
yield prefix_candidate + "." + rest
|
||||
|
||||
return
|
||||
|
||||
yield param_key
|
||||
yield param_key.replace("-", "_")
|
||||
yield param_key.replace("_", "-")
|
||||
|
||||
# In CLI, we prefer "-"
|
||||
def _iter_cmd_key_candidates(self, param_key: str):
|
||||
for k in reversed(tuple(self._iter_param_key_candidates(param_key))):
|
||||
yield "--" + k
|
||||
|
||||
def _normalize_cmd_key(self, param_key: str):
|
||||
return next(self._iter_cmd_key_candidates(param_key))
|
||||
|
||||
def has_param(self, param_key: str) -> bool:
|
||||
return any(k in self for k in self._iter_param_key_candidates(param_key))
|
||||
|
||||
def _normalize_cmd_kv_pair(self, k: str, v: object) -> list[str]:
|
||||
"""
|
||||
Normalize a key-value pair into command-line arguments.
|
||||
|
||||
Returns a list containing either:
|
||||
- A single element for boolean flags (e.g., ['--flag'] or ['--flag=true'])
|
||||
- Two elements for key-value pairs (e.g., ['--key', 'value'])
|
||||
"""
|
||||
if isinstance(v, bool):
|
||||
# For nested params (containing "."), use =true/false syntax
|
||||
if "." in k:
|
||||
return [f"{self._normalize_cmd_key(k)}={'true' if v else 'false'}"]
|
||||
else:
|
||||
return [self._normalize_cmd_key(k if v else "no-" + k)]
|
||||
else:
|
||||
return [self._normalize_cmd_key(k), str(v)]
|
||||
|
||||
def apply_to_cmd(self, cmd: list[str]) -> list[str]:
|
||||
cmd = list(cmd)
|
||||
|
||||
for k, v in self.items():
|
||||
# Skip the '_benchmark_name' field, not a parameter
|
||||
if k == "_benchmark_name":
|
||||
continue
|
||||
|
||||
# Serialize dict values as JSON
|
||||
if isinstance(v, dict):
|
||||
v = json.dumps(v)
|
||||
|
||||
for k_candidate in self._iter_cmd_key_candidates(k):
|
||||
try:
|
||||
k_idx = cmd.index(k_candidate)
|
||||
|
||||
# Replace existing parameter
|
||||
normalized = self._normalize_cmd_kv_pair(k, v)
|
||||
if len(normalized) == 1:
|
||||
# Boolean flag
|
||||
cmd[k_idx] = normalized[0]
|
||||
else:
|
||||
# Key-value pair
|
||||
cmd[k_idx] = normalized[0]
|
||||
cmd[k_idx + 1] = normalized[1]
|
||||
|
||||
break
|
||||
except ValueError:
|
||||
continue
|
||||
else:
|
||||
# Add new parameter
|
||||
cmd.extend(self._normalize_cmd_kv_pair(k, v))
|
||||
|
||||
return cmd
|
||||
|
||||
def as_text(self, sep: str = ", ") -> str:
|
||||
return sep.join(f"{k}={v}" for k, v in self.items() if k != "_benchmark_name")
|
||||
675
vllm/benchmarks/sweep/plot.py
Normal file
675
vllm/benchmarks/sweep/plot.py
Normal file
@@ -0,0 +1,675 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from types import TracebackType
|
||||
from typing import ClassVar
|
||||
|
||||
from typing_extensions import Self, override
|
||||
|
||||
from vllm.utils.collection_utils import full_groupby
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
from .utils import sanitize_filename
|
||||
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
except ImportError:
|
||||
plt = PlaceholderModule("matplotlib").placeholder_attr("pyplot")
|
||||
pd = PlaceholderModule("pandas")
|
||||
seaborn = PlaceholderModule("seaborn")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotFilterBase(ABC):
|
||||
var: str
|
||||
target: str
|
||||
|
||||
@classmethod
|
||||
def parse_str(cls, s: str):
|
||||
for op_key in PLOT_FILTERS:
|
||||
if op_key in s:
|
||||
key, value = s.split(op_key)
|
||||
return PLOT_FILTERS[op_key](
|
||||
key,
|
||||
value.removeprefix(op_key).strip("'").strip('"'),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid operator for plot filter '{s}'. "
|
||||
f"Valid operators are: {sorted(PLOT_FILTERS)}",
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
"""Applies this filter to a DataFrame."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotEqualTo(PlotFilterBase):
|
||||
@override
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
try:
|
||||
target = float(self.target)
|
||||
except ValueError:
|
||||
target = self.target
|
||||
|
||||
return df[df[self.var] == target]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotNotEqualTo(PlotFilterBase):
|
||||
@override
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
try:
|
||||
target = float(self.target)
|
||||
except ValueError:
|
||||
target = self.target
|
||||
|
||||
return df[df[self.var] != target]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotLessThan(PlotFilterBase):
|
||||
@override
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
return df[df[self.var] < float(self.target)]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotLessThanOrEqualTo(PlotFilterBase):
|
||||
@override
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
return df[df[self.var] <= float(self.target)]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotGreaterThan(PlotFilterBase):
|
||||
@override
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
return df[df[self.var] > float(self.target)]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotGreaterThanOrEqualTo(PlotFilterBase):
|
||||
@override
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
return df[df[self.var] >= float(self.target)]
|
||||
|
||||
|
||||
# NOTE: The ordering is important! Match longer op_keys first
|
||||
PLOT_FILTERS: dict[str, type[PlotFilterBase]] = {
|
||||
"==": PlotEqualTo,
|
||||
"!=": PlotNotEqualTo,
|
||||
"<=": PlotLessThanOrEqualTo,
|
||||
">=": PlotGreaterThanOrEqualTo,
|
||||
"<": PlotLessThan,
|
||||
">": PlotGreaterThan,
|
||||
}
|
||||
|
||||
|
||||
class PlotFilters(list[PlotFilterBase]):
|
||||
@classmethod
|
||||
def parse_str(cls, s: str):
|
||||
if not s:
|
||||
return cls()
|
||||
|
||||
return cls(PlotFilterBase.parse_str(e) for e in s.split(","))
|
||||
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
for item in self:
|
||||
df = item.apply(df)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotBinner:
|
||||
var: str
|
||||
bin_size: float
|
||||
|
||||
@classmethod
|
||||
def parse_str(cls, s: str):
|
||||
for op_key in PLOT_BINNERS:
|
||||
if op_key in s:
|
||||
key, value = s.split(op_key)
|
||||
return PLOT_BINNERS[op_key](key, float(value.removeprefix(op_key)))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid operator for plot binner '{s}'. "
|
||||
f"Valid operators are: {sorted(PLOT_BINNERS)}",
|
||||
)
|
||||
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
"""Applies this binner to a DataFrame."""
|
||||
df = df.copy()
|
||||
df[self.var] = df[self.var] // self.bin_size * self.bin_size
|
||||
return df
|
||||
|
||||
|
||||
PLOT_BINNERS: dict[str, type[PlotBinner]] = {
|
||||
"%": PlotBinner,
|
||||
}
|
||||
|
||||
|
||||
class PlotBinners(list[PlotBinner]):
|
||||
@classmethod
|
||||
def parse_str(cls, s: str):
|
||||
if not s:
|
||||
return cls()
|
||||
|
||||
return cls(PlotBinner.parse_str(e) for e in s.split(","))
|
||||
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
for item in self:
|
||||
df = item.apply(df)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def _json_load_bytes(path: Path) -> list[dict[str, object]]:
|
||||
with path.open("rb") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def _convert_inf_nan_strings(data: list[dict[str, object]]) -> list[dict[str, object]]:
|
||||
"""
|
||||
Convert string values "inf", "-inf", and "nan" to their float equivalents.
|
||||
|
||||
This handles the case where JSON serialization represents inf/nan as strings.
|
||||
"""
|
||||
converted_data = []
|
||||
for record in data:
|
||||
converted_record = {}
|
||||
for key, value in record.items():
|
||||
if isinstance(value, str):
|
||||
if value in ["inf", "-inf", "nan"]:
|
||||
converted_record[key] = float(value)
|
||||
else:
|
||||
converted_record[key] = value
|
||||
else:
|
||||
converted_record[key] = value
|
||||
converted_data.append(converted_record)
|
||||
return converted_data
|
||||
|
||||
|
||||
def _get_metric(run_data: dict[str, object], metric_key: str):
|
||||
try:
|
||||
return run_data[metric_key]
|
||||
except KeyError as exc:
|
||||
raise ValueError(f"Cannot find metric {metric_key!r} in {run_data=}") from exc
|
||||
|
||||
|
||||
def _get_group(run_data: dict[str, object], group_keys: list[str]):
|
||||
return tuple((k, str(_get_metric(run_data, k))) for k in group_keys)
|
||||
|
||||
|
||||
def _get_fig_path(fig_dir: Path, group: tuple[tuple[str, str], ...], fig_name: str):
|
||||
parts = list[str]()
|
||||
|
||||
# Start with figure name (always provided, defaults to "FIGURE")
|
||||
parts.append(fig_name)
|
||||
|
||||
# Always append group data if present
|
||||
if group:
|
||||
parts.extend(f"{k}={v}" for k, v in group)
|
||||
|
||||
return fig_dir / sanitize_filename("-".join(parts) + ".png")
|
||||
|
||||
|
||||
class DummyExecutor:
|
||||
map = map
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
exc_traceback: TracebackType | None,
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def _plot_fig(
|
||||
fig_dir: Path,
|
||||
fig_group_data: tuple[tuple[tuple[str, str], ...], list[dict[str, object]]],
|
||||
row_by: list[str],
|
||||
col_by: list[str],
|
||||
curve_by: list[str],
|
||||
*,
|
||||
var_x: str,
|
||||
var_y: str,
|
||||
filter_by: PlotFilters,
|
||||
bin_by: PlotBinners,
|
||||
scale_x: str | None,
|
||||
scale_y: str | None,
|
||||
dry_run: bool,
|
||||
fig_name: str,
|
||||
error_bars: bool,
|
||||
fig_height: float,
|
||||
fig_dpi: int,
|
||||
):
|
||||
fig_group, fig_data = fig_group_data
|
||||
|
||||
row_groups = full_groupby(
|
||||
fig_data,
|
||||
key=lambda item: _get_group(item, row_by),
|
||||
)
|
||||
num_rows = len(row_groups)
|
||||
num_cols = max(
|
||||
len(full_groupby(row_data, key=lambda item: _get_group(item, col_by)))
|
||||
for _, row_data in row_groups
|
||||
)
|
||||
|
||||
fig_path = _get_fig_path(fig_dir, fig_group, fig_name)
|
||||
|
||||
print("[BEGIN FIGURE]")
|
||||
print(f"Group: {dict(fig_group)}")
|
||||
print(f"Grid: {num_rows} rows x {num_cols} cols")
|
||||
print(f"Output file: {fig_path}")
|
||||
|
||||
if dry_run:
|
||||
print("[END FIGURE]")
|
||||
return
|
||||
|
||||
# Convert string "inf", "-inf", and "nan" to their float equivalents
|
||||
fig_data = _convert_inf_nan_strings(fig_data)
|
||||
df = pd.DataFrame.from_records(fig_data)
|
||||
|
||||
if var_x not in df.columns:
|
||||
raise ValueError(
|
||||
f"Cannot find {var_x=!r} in parameter sweep results. "
|
||||
f"Available variables: {df.columns.tolist()}"
|
||||
)
|
||||
if var_y not in df.columns:
|
||||
raise ValueError(
|
||||
f"Cannot find {var_y=!r} in parameter sweep results. "
|
||||
f"Available variables: {df.columns.tolist()}"
|
||||
)
|
||||
for k in row_by:
|
||||
if k not in df.columns:
|
||||
raise ValueError(
|
||||
f"Cannot find row_by={k!r} in parameter sweep results. "
|
||||
f"Available variables: {df.columns.tolist()}"
|
||||
)
|
||||
for k in col_by:
|
||||
if k not in df.columns:
|
||||
raise ValueError(
|
||||
f"Cannot find col_by={k!r} in parameter sweep results. "
|
||||
f"Available variables: {df.columns.tolist()}"
|
||||
)
|
||||
for k in curve_by:
|
||||
if k not in df.columns:
|
||||
raise ValueError(
|
||||
f"Cannot find curve_by={k!r} in parameter sweep results. "
|
||||
f"Available variables: {df.columns.tolist()}"
|
||||
)
|
||||
|
||||
df = filter_by.apply(df)
|
||||
df = bin_by.apply(df)
|
||||
|
||||
# Sort by curve_by columns alphabetically for consistent legend ordering
|
||||
if curve_by:
|
||||
df = df.sort_values(by=curve_by)
|
||||
|
||||
df["row_group"] = (
|
||||
pd.concat(
|
||||
[k + "=" + df[k].astype(str) for k in row_by],
|
||||
axis=1,
|
||||
).agg("\n".join, axis=1)
|
||||
if row_by
|
||||
else "(All)"
|
||||
)
|
||||
|
||||
df["col_group"] = (
|
||||
pd.concat(
|
||||
[k + "=" + df[k].astype(str) for k in col_by],
|
||||
axis=1,
|
||||
).agg("\n".join, axis=1)
|
||||
if col_by
|
||||
else "(All)"
|
||||
)
|
||||
|
||||
g = sns.FacetGrid(df, row="row_group", col="col_group", height=fig_height)
|
||||
|
||||
if row_by and col_by:
|
||||
g.set_titles("{row_name}\n{col_name}")
|
||||
elif row_by:
|
||||
g.set_titles("{row_name}")
|
||||
elif col_by:
|
||||
g.set_titles("{col_name}")
|
||||
else:
|
||||
g.set_titles("")
|
||||
|
||||
if scale_x:
|
||||
g.set(xscale=scale_x)
|
||||
if scale_y:
|
||||
g.set(yscale=scale_y)
|
||||
|
||||
if len(curve_by) <= 3:
|
||||
hue, style, size, *_ = (*curve_by, None, None, None)
|
||||
|
||||
g.map_dataframe(
|
||||
sns.lineplot,
|
||||
x=var_x,
|
||||
y=var_y,
|
||||
hue=hue,
|
||||
style=style,
|
||||
size=size,
|
||||
markers=True,
|
||||
errorbar="sd" if error_bars else None,
|
||||
)
|
||||
|
||||
g.add_legend(title=hue)
|
||||
else:
|
||||
df["curve_group"] = (
|
||||
pd.concat(
|
||||
[k + "=" + df[k].astype(str) for k in curve_by],
|
||||
axis=1,
|
||||
).agg("\n".join, axis=1)
|
||||
if curve_by
|
||||
else "(All)"
|
||||
)
|
||||
|
||||
g.map_dataframe(
|
||||
sns.lineplot,
|
||||
x=var_x,
|
||||
y=var_y,
|
||||
hue="curve_group",
|
||||
markers=True,
|
||||
errorbar="sd" if error_bars else None,
|
||||
)
|
||||
|
||||
g.add_legend()
|
||||
|
||||
g.savefig(fig_path, dpi=fig_dpi)
|
||||
plt.close(g.figure)
|
||||
|
||||
print("[END FIGURE]")
|
||||
|
||||
|
||||
def plot(
|
||||
output_dir: Path,
|
||||
fig_dir: Path,
|
||||
fig_by: list[str],
|
||||
row_by: list[str],
|
||||
col_by: list[str],
|
||||
curve_by: list[str],
|
||||
*,
|
||||
var_x: str,
|
||||
var_y: str,
|
||||
filter_by: PlotFilters,
|
||||
bin_by: PlotBinners,
|
||||
scale_x: str | None,
|
||||
scale_y: str | None,
|
||||
dry_run: bool,
|
||||
fig_name: str = "FIGURE",
|
||||
error_bars: bool = True,
|
||||
fig_height: float = 6.4,
|
||||
fig_dpi: int = 300,
|
||||
):
|
||||
all_data = [
|
||||
run_data
|
||||
for path in output_dir.rglob("**/summary.json")
|
||||
for run_data in _json_load_bytes(path)
|
||||
]
|
||||
|
||||
if not all_data:
|
||||
raise ValueError(f"Did not find any parameter sweep results under {output_dir}")
|
||||
|
||||
fig_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
fig_groups = full_groupby(
|
||||
all_data,
|
||||
key=lambda item: _get_group(item, fig_by),
|
||||
)
|
||||
|
||||
with DummyExecutor() if len(fig_groups) <= 1 else ProcessPoolExecutor() as executor:
|
||||
# Resolve the iterable to ensure that the workers are run
|
||||
all(
|
||||
executor.map(
|
||||
partial(
|
||||
_plot_fig,
|
||||
fig_dir,
|
||||
row_by=row_by,
|
||||
col_by=col_by,
|
||||
curve_by=curve_by,
|
||||
var_x=var_x,
|
||||
var_y=var_y,
|
||||
filter_by=filter_by,
|
||||
bin_by=bin_by,
|
||||
scale_x=scale_x,
|
||||
scale_y=scale_y,
|
||||
dry_run=dry_run,
|
||||
fig_name=fig_name,
|
||||
error_bars=error_bars,
|
||||
fig_height=fig_height,
|
||||
fig_dpi=fig_dpi,
|
||||
),
|
||||
fig_groups,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SweepPlotArgs:
|
||||
output_dir: Path
|
||||
fig_dir: Path
|
||||
fig_by: list[str]
|
||||
row_by: list[str]
|
||||
col_by: list[str]
|
||||
curve_by: list[str]
|
||||
var_x: str
|
||||
var_y: str
|
||||
filter_by: PlotFilters
|
||||
bin_by: PlotBinners
|
||||
scale_x: str | None
|
||||
scale_y: str | None
|
||||
dry_run: bool
|
||||
fig_name: str = "FIGURE"
|
||||
error_bars: bool = True
|
||||
fig_height: float = 6.4
|
||||
fig_dpi: int = 300
|
||||
|
||||
parser_name: ClassVar[str] = "plot"
|
||||
parser_help: ClassVar[str] = "Plot performance curves from parameter sweep results."
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
output_dir = Path(args.OUTPUT_DIR)
|
||||
if not output_dir.exists():
|
||||
raise ValueError(f"No parameter sweep results under {output_dir}")
|
||||
|
||||
curve_by = [] if not args.curve_by else args.curve_by.split(",")
|
||||
row_by = [] if not args.row_by else args.row_by.split(",")
|
||||
col_by = [] if not args.col_by else args.col_by.split(",")
|
||||
fig_by = [] if not args.fig_by else args.fig_by.split(",")
|
||||
|
||||
return cls(
|
||||
output_dir=output_dir,
|
||||
fig_dir=output_dir / args.fig_dir,
|
||||
fig_by=fig_by,
|
||||
row_by=row_by,
|
||||
col_by=col_by,
|
||||
curve_by=curve_by,
|
||||
var_x=args.var_x,
|
||||
var_y=args.var_y,
|
||||
filter_by=PlotFilters.parse_str(args.filter_by),
|
||||
bin_by=PlotBinners.parse_str(args.bin_by),
|
||||
scale_x=args.scale_x,
|
||||
scale_y=args.scale_y,
|
||||
dry_run=args.dry_run,
|
||||
fig_name=args.fig_name,
|
||||
error_bars=not args.no_error_bars,
|
||||
fig_height=args.fig_height,
|
||||
fig_dpi=args.fig_dpi,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
"OUTPUT_DIR",
|
||||
type=str,
|
||||
default="results",
|
||||
help="The directory containing the results to plot, "
|
||||
"i.e., the `--output-dir` argument to the parameter sweep script.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fig-dir",
|
||||
type=str,
|
||||
default="",
|
||||
help="The directory to save the figures, relative to `OUTPUT_DIR`. "
|
||||
"By default, the same directory is used.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fig-by",
|
||||
type=str,
|
||||
default="",
|
||||
help="A comma-separated list of variables, such that a separate figure "
|
||||
"is created for each combination of these variables.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--row-by",
|
||||
type=str,
|
||||
default="",
|
||||
help="A comma-separated list of variables, such that a separate row "
|
||||
"is created for each combination of these variables.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--col-by",
|
||||
type=str,
|
||||
default="",
|
||||
help="A comma-separated list of variables, such that a separate column "
|
||||
"is created for each combination of these variables.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--curve-by",
|
||||
type=str,
|
||||
default=None,
|
||||
help="A comma-separated list of variables, such that a separate curve "
|
||||
"is created for each combination of these variables.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--var-x",
|
||||
type=str,
|
||||
default="request_throughput",
|
||||
help="The variable for the x-axis.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--var-y",
|
||||
type=str,
|
||||
default="p99_e2el_ms",
|
||||
help="The variable for the y-axis",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--filter-by",
|
||||
type=str,
|
||||
default="",
|
||||
help="A comma-separated list of statements indicating values to filter by. "
|
||||
"This is useful to remove outliers. "
|
||||
"Example: `max_concurrency<1000,max_num_batched_tokens<=4096` means "
|
||||
"plot only the points where `max_concurrency` is less than 1000 and "
|
||||
"`max_num_batched_tokens` is no greater than 4096.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bin-by",
|
||||
type=str,
|
||||
default="",
|
||||
help="A comma-separated list of statements indicating values to bin by. "
|
||||
"This is useful to avoid plotting points that are too close together. "
|
||||
"Example: `request_throughput%%1` means "
|
||||
"use a bin size of 1 for the `request_throughput` variable.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale-x",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The scale to use for the x-axis. "
|
||||
"Currently only accepts string values such as 'log' and 'sqrt'. "
|
||||
"See also: https://seaborn.pydata.org/generated/seaborn.objects.Plot.scale.html",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale-y",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The scale to use for the y-axis. "
|
||||
"Currently only accepts string values such as 'log' and 'sqrt'. "
|
||||
"See also: https://seaborn.pydata.org/generated/seaborn.objects.Plot.scale.html",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fig-name",
|
||||
type=str,
|
||||
default="FIGURE",
|
||||
help="Name prefix for the output figure file. "
|
||||
"Group data is always appended when present. "
|
||||
"Default: 'FIGURE'. Example: --fig-name my_performance_plot",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-error-bars",
|
||||
action="store_true",
|
||||
help="If set, disables error bars on the plot. "
|
||||
"By default, error bars are shown.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fig-height",
|
||||
type=float,
|
||||
default=6.4,
|
||||
help="Height of each subplot in inches. Default: 6.4",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fig-dpi",
|
||||
type=int,
|
||||
default=300,
|
||||
help="Resolution of the output figure in dots per inch. Default: 300",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="If set, prints the information about each figure to plot, "
|
||||
"then exits without drawing them.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def run_main(args: SweepPlotArgs):
|
||||
return plot(
|
||||
output_dir=args.output_dir,
|
||||
fig_dir=args.fig_dir,
|
||||
fig_by=args.fig_by,
|
||||
row_by=args.row_by,
|
||||
col_by=args.col_by,
|
||||
curve_by=args.curve_by,
|
||||
var_x=args.var_x,
|
||||
var_y=args.var_y,
|
||||
filter_by=args.filter_by,
|
||||
bin_by=args.bin_by,
|
||||
scale_x=args.scale_x,
|
||||
scale_y=args.scale_y,
|
||||
dry_run=args.dry_run,
|
||||
fig_name=args.fig_name,
|
||||
error_bars=args.error_bars,
|
||||
fig_height=args.fig_height,
|
||||
fig_dpi=args.fig_dpi,
|
||||
)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
run_main(SweepPlotArgs.from_cli_args(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=SweepPlotArgs.parser_help)
|
||||
SweepPlotArgs.add_cli_args(parser)
|
||||
|
||||
main(parser.parse_args())
|
||||
393
vllm/benchmarks/sweep/plot_pareto.py
Normal file
393
vllm/benchmarks/sweep/plot_pareto.py
Normal file
@@ -0,0 +1,393 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import math
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import ClassVar
|
||||
|
||||
from vllm.utils.collection_utils import full_groupby
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
from .plot import DummyExecutor, _json_load_bytes
|
||||
from .utils import sanitize_filename
|
||||
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
except ImportError:
|
||||
plt = PlaceholderModule("matplotlib").placeholder_attr("pyplot")
|
||||
pd = PlaceholderModule("pandas")
|
||||
sns = PlaceholderModule("seaborn")
|
||||
|
||||
|
||||
def _first_present(run_data: dict[str, object], keys: list[str]):
|
||||
for key in keys:
|
||||
for candidate in {key, key.replace("_", "-"), key.replace("-", "_")}:
|
||||
if candidate in run_data:
|
||||
return run_data[candidate]
|
||||
return None
|
||||
|
||||
|
||||
def _get_numeric(
|
||||
run_data: dict[str, object],
|
||||
keys: list[str],
|
||||
*,
|
||||
allow_zero: bool = True,
|
||||
) -> float | None:
|
||||
value = _first_present(run_data, keys)
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
numeric = float(value)
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise ValueError(
|
||||
f"Expected numeric value for one of {keys}, "
|
||||
f"but found {value!r} in {run_data=}"
|
||||
) from exc
|
||||
|
||||
if not allow_zero and numeric == 0:
|
||||
return None
|
||||
|
||||
return numeric
|
||||
|
||||
|
||||
def _infer_user_count(
|
||||
run_data: dict[str, object],
|
||||
user_count_var: str | None,
|
||||
) -> float | None:
|
||||
candidates = [user_count_var] if user_count_var else []
|
||||
candidates.extend(["request_rate"])
|
||||
user_count = _get_numeric(run_data, candidates, allow_zero=False)
|
||||
if user_count is not None:
|
||||
return user_count
|
||||
|
||||
# Fallback to the observed peak if configured value is missing.
|
||||
return _get_numeric(run_data, ["max_concurrent_requests"], allow_zero=False)
|
||||
|
||||
|
||||
def _infer_gpu_count(
|
||||
run_data: dict[str, object],
|
||||
gpu_count_var: str | None,
|
||||
) -> float:
|
||||
direct_candidates = [gpu_count_var] if gpu_count_var else []
|
||||
direct_gpu_count = _get_numeric(run_data, direct_candidates, allow_zero=False)
|
||||
if direct_gpu_count:
|
||||
return direct_gpu_count
|
||||
|
||||
tp_size = _get_numeric(run_data, ["tensor_parallel_size", "tp"])
|
||||
pp_size = _get_numeric(run_data, ["pipeline_parallel_size", "pp"])
|
||||
dp_size = _get_numeric(run_data, ["data_parallel_size", "dp"])
|
||||
world_size = 1.0
|
||||
if tp_size:
|
||||
world_size *= tp_size
|
||||
if pp_size:
|
||||
world_size *= pp_size
|
||||
if dp_size:
|
||||
world_size *= dp_size
|
||||
|
||||
return world_size
|
||||
|
||||
|
||||
def _get_throughput(
|
||||
run_data: dict[str, object],
|
||||
throughput_var: str,
|
||||
) -> float:
|
||||
throughput = _get_numeric(run_data, [throughput_var])
|
||||
if throughput is None:
|
||||
raise ValueError(
|
||||
f"Cannot find throughput metric {throughput_var!r} in run data. "
|
||||
f"Available keys: {sorted(run_data)}"
|
||||
)
|
||||
|
||||
return throughput
|
||||
|
||||
|
||||
def _prepare_records(
|
||||
all_data: list[dict[str, object]],
|
||||
*,
|
||||
user_count_var: str | None,
|
||||
gpu_count_var: str | None,
|
||||
) -> tuple[list[dict[str, object]], int]:
|
||||
prepared = []
|
||||
skipped_missing_users = 0
|
||||
|
||||
for record in all_data:
|
||||
throughput = _get_throughput(record, "output_throughput")
|
||||
user_count = _infer_user_count(record, user_count_var)
|
||||
if user_count is None:
|
||||
skipped_missing_users += 1
|
||||
continue
|
||||
|
||||
gpu_count = _infer_gpu_count(record, gpu_count_var)
|
||||
tokens_per_user = throughput / user_count
|
||||
tokens_per_gpu = throughput / gpu_count
|
||||
|
||||
prepared.append(
|
||||
{
|
||||
**record,
|
||||
"tokens_per_user": tokens_per_user,
|
||||
"tokens_per_gpu": tokens_per_gpu,
|
||||
"user_count_estimate": user_count,
|
||||
"gpu_count": gpu_count,
|
||||
}
|
||||
)
|
||||
|
||||
return prepared, skipped_missing_users
|
||||
|
||||
|
||||
def _pareto_frontier(
|
||||
df: "pd.DataFrame",
|
||||
x_col: str,
|
||||
y_col: str,
|
||||
*,
|
||||
epsilon: float = 1e-9,
|
||||
) -> "pd.DataFrame":
|
||||
sorted_df = df.sort_values([x_col, y_col], ascending=[False, False])
|
||||
frontier_indices = []
|
||||
best_y = -math.inf
|
||||
|
||||
for idx, row in sorted_df.iterrows():
|
||||
y_val = row[y_col]
|
||||
if y_val >= best_y - epsilon:
|
||||
frontier_indices.append(idx)
|
||||
best_y = max(best_y, y_val)
|
||||
|
||||
return df.loc[frontier_indices]
|
||||
|
||||
|
||||
def _get_fig_path(
|
||||
fig_dir: Path,
|
||||
fig_group: tuple[tuple[str, str], ...],
|
||||
) -> Path:
|
||||
parts = ["PARETO"]
|
||||
if fig_group:
|
||||
parts.extend(f"{k}={v}" for k, v in fig_group)
|
||||
filename = sanitize_filename("-".join(parts) + ".png")
|
||||
return fig_dir / filename
|
||||
|
||||
|
||||
def _plot_fig(
|
||||
fig_dir: Path,
|
||||
fig_group_data: tuple[tuple[tuple[str, str], ...], list[dict[str, object]]],
|
||||
label_by: list[str],
|
||||
*,
|
||||
dry_run: bool,
|
||||
):
|
||||
fig_group, fig_data = fig_group_data
|
||||
fig_path = _get_fig_path(fig_dir, fig_group)
|
||||
|
||||
print("[BEGIN FIGURE]")
|
||||
print(f"Group: {dict(fig_group)}")
|
||||
print(f"Output file: {fig_path}")
|
||||
|
||||
if dry_run:
|
||||
print("[END FIGURE]")
|
||||
return
|
||||
|
||||
df = pd.DataFrame.from_records(fig_data)
|
||||
df = df.dropna(subset=["tokens_per_user", "tokens_per_gpu"])
|
||||
|
||||
if df.empty:
|
||||
print("No data points available after filtering; skipping.")
|
||||
print("[END FIGURE]")
|
||||
return
|
||||
|
||||
frontier = _pareto_frontier(df, "tokens_per_user", "tokens_per_gpu")
|
||||
frontier = frontier.sort_values("tokens_per_user")
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
sns.scatterplot(
|
||||
data=df,
|
||||
x="tokens_per_user",
|
||||
y="tokens_per_gpu",
|
||||
color="0.5",
|
||||
alpha=0.6,
|
||||
ax=ax,
|
||||
label="All runs",
|
||||
)
|
||||
sns.lineplot(
|
||||
data=frontier,
|
||||
x="tokens_per_user",
|
||||
y="tokens_per_gpu",
|
||||
marker="o",
|
||||
ax=ax,
|
||||
label="Pareto frontier",
|
||||
)
|
||||
|
||||
if label_by:
|
||||
for _, row in frontier.iterrows():
|
||||
label_parts = []
|
||||
for key in label_by:
|
||||
if key in row:
|
||||
label_parts.append(f"{key}={row[key]}")
|
||||
if label_parts:
|
||||
ax.text(
|
||||
row["tokens_per_user"],
|
||||
row["tokens_per_gpu"],
|
||||
"\n".join(label_parts),
|
||||
fontsize=8,
|
||||
)
|
||||
|
||||
ax.set_xlabel("Tokens/s/user")
|
||||
ax.set_ylabel("Tokens/s/GPU")
|
||||
ax.grid(True, linestyle="--", linewidth=0.5, alpha=0.6)
|
||||
fig.tight_layout()
|
||||
fig.savefig(fig_path)
|
||||
plt.close(fig)
|
||||
|
||||
print(
|
||||
f"Plotted {len(df)} points; Pareto frontier size: {len(frontier)}.",
|
||||
)
|
||||
print("[END FIGURE]")
|
||||
|
||||
|
||||
def plot_pareto(
|
||||
output_dir: Path,
|
||||
user_count_var: str | None,
|
||||
gpu_count_var: str | None,
|
||||
label_by: list[str],
|
||||
*,
|
||||
dry_run: bool,
|
||||
):
|
||||
fig_dir = output_dir / "pareto"
|
||||
raw_data = [
|
||||
run_data
|
||||
for path in output_dir.rglob("**/summary.json")
|
||||
for run_data in _json_load_bytes(path)
|
||||
]
|
||||
|
||||
if not raw_data:
|
||||
raise ValueError(f"Did not find any parameter sweep results under {output_dir}")
|
||||
|
||||
fig_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
prepared_data, skipped_missing_users = _prepare_records(
|
||||
raw_data,
|
||||
user_count_var=user_count_var,
|
||||
gpu_count_var=gpu_count_var,
|
||||
)
|
||||
|
||||
if skipped_missing_users:
|
||||
print(
|
||||
f"Skipped {skipped_missing_users} runs without a user count "
|
||||
"(`max_concurrency` or `max_concurrent_requests`).",
|
||||
)
|
||||
|
||||
if not prepared_data:
|
||||
raise ValueError(
|
||||
"No data points with both throughput and user count available "
|
||||
"to plot Pareto frontier.",
|
||||
)
|
||||
|
||||
fig_groups = full_groupby(
|
||||
prepared_data,
|
||||
key=lambda item: tuple(),
|
||||
)
|
||||
|
||||
with DummyExecutor() if len(fig_groups) <= 1 else ProcessPoolExecutor() as executor:
|
||||
all(
|
||||
executor.map(
|
||||
partial(
|
||||
_plot_fig,
|
||||
fig_dir,
|
||||
label_by=label_by,
|
||||
dry_run=dry_run,
|
||||
),
|
||||
fig_groups,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SweepPlotParetoArgs:
|
||||
output_dir: Path
|
||||
user_count_var: str | None
|
||||
gpu_count_var: str | None
|
||||
label_by: list[str]
|
||||
dry_run: bool
|
||||
|
||||
parser_name: ClassVar[str] = "plot_pareto"
|
||||
parser_help: ClassVar[str] = (
|
||||
"Plot Pareto frontier between tokens/s/user and tokens/s/GPU "
|
||||
"from parameter sweep results."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
output_dir = Path(args.OUTPUT_DIR)
|
||||
if not output_dir.exists():
|
||||
raise ValueError(f"No parameter sweep results under {output_dir}")
|
||||
|
||||
label_by = [] if not args.label_by else args.label_by.split(",")
|
||||
|
||||
return cls(
|
||||
output_dir=output_dir,
|
||||
user_count_var=args.user_count_var,
|
||||
gpu_count_var=args.gpu_count_var,
|
||||
label_by=label_by,
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"OUTPUT_DIR",
|
||||
type=str,
|
||||
default="results",
|
||||
help="The directory containing the sweep results to plot.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--user-count-var",
|
||||
type=str,
|
||||
default="max_concurrency",
|
||||
help="Result key that stores concurrent user count. "
|
||||
"Falls back to max_concurrent_requests if missing.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpu-count-var",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Result key that stores GPU count. "
|
||||
"If not provided, falls back to num_gpus/gpu_count "
|
||||
"or tensor_parallel_size * pipeline_parallel_size.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--label-by",
|
||||
type=str,
|
||||
default="max_concurrency,gpu_count",
|
||||
help="Comma-separated list of fields to annotate on Pareto frontier "
|
||||
"points.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="If set, prints the figures to plot without drawing them.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def run_main(args: SweepPlotParetoArgs):
|
||||
return plot_pareto(
|
||||
output_dir=args.output_dir,
|
||||
user_count_var=args.user_count_var,
|
||||
gpu_count_var=args.gpu_count_var,
|
||||
label_by=args.label_by,
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
run_main(SweepPlotParetoArgs.from_cli_args(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=SweepPlotParetoArgs.parser_help)
|
||||
SweepPlotParetoArgs.add_cli_args(parser)
|
||||
|
||||
main(parser.parse_args())
|
||||
450
vllm/benchmarks/sweep/serve.py
Normal file
450
vllm/benchmarks/sweep/serve.py
Normal file
@@ -0,0 +1,450 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import contextlib
|
||||
import json
|
||||
import shlex
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import ClassVar
|
||||
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
from .param_sweep import ParameterSweep, ParameterSweepItem
|
||||
from .server import ServerProcess
|
||||
from .utils import sanitize_filename
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pd = PlaceholderModule("pandas")
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def run_server(
|
||||
serve_cmd: list[str],
|
||||
after_bench_cmd: list[str],
|
||||
*,
|
||||
show_stdout: bool,
|
||||
serve_overrides: ParameterSweepItem,
|
||||
dry_run: bool,
|
||||
):
|
||||
server_cmd = serve_overrides.apply_to_cmd(serve_cmd)
|
||||
|
||||
print("[BEGIN SERVER]")
|
||||
print(f"Server overrides: {serve_overrides}")
|
||||
print(f"Server command: {server_cmd}")
|
||||
|
||||
if dry_run:
|
||||
yield None
|
||||
print("[END SERVER]")
|
||||
return
|
||||
|
||||
with ServerProcess(server_cmd, after_bench_cmd, show_stdout=show_stdout) as server:
|
||||
yield server
|
||||
|
||||
print("[END SERVER]")
|
||||
|
||||
|
||||
def _update_run_data(
|
||||
run_data: dict[str, object],
|
||||
serve_overrides: ParameterSweepItem,
|
||||
bench_overrides: ParameterSweepItem,
|
||||
run_number: int,
|
||||
):
|
||||
run_data["run_number"] = run_number
|
||||
run_data.update(serve_overrides)
|
||||
run_data.update(bench_overrides)
|
||||
|
||||
return run_data
|
||||
|
||||
|
||||
def run_benchmark(
|
||||
server: ServerProcess | None,
|
||||
bench_cmd: list[str],
|
||||
*,
|
||||
serve_overrides: ParameterSweepItem,
|
||||
bench_overrides: ParameterSweepItem,
|
||||
run_number: int,
|
||||
output_path: Path,
|
||||
dry_run: bool,
|
||||
):
|
||||
benchmark_cmd = [
|
||||
*bench_overrides.apply_to_cmd(bench_cmd),
|
||||
"--percentile-metrics",
|
||||
"ttft,tpot,itl,e2el",
|
||||
"--save-result",
|
||||
"--result-dir",
|
||||
str(output_path.parent),
|
||||
"--result-filename",
|
||||
output_path.name,
|
||||
]
|
||||
|
||||
print("[BEGIN BENCHMARK]")
|
||||
print(f"Benchmark overrides: {bench_overrides}")
|
||||
print(f"Run Number: {run_number}")
|
||||
print(f"Benchmark command: {benchmark_cmd}")
|
||||
print(f"Output file: {output_path}")
|
||||
|
||||
run_data: dict[str, object]
|
||||
|
||||
if output_path.exists():
|
||||
print("Found existing results. Skipping.")
|
||||
|
||||
with output_path.open("rb") as f:
|
||||
run_data = json.load(f)
|
||||
return _update_run_data(
|
||||
run_data,
|
||||
serve_overrides,
|
||||
bench_overrides,
|
||||
run_number,
|
||||
)
|
||||
|
||||
if server is None:
|
||||
if not dry_run:
|
||||
raise ValueError(f"Cannot find results at {output_path}")
|
||||
|
||||
print("[END BENCHMARK]")
|
||||
return None
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
server.run_subcommand(benchmark_cmd)
|
||||
server.after_bench()
|
||||
|
||||
with output_path.open("rb") as f:
|
||||
run_data = json.load(f)
|
||||
|
||||
run_data = _update_run_data(
|
||||
run_data,
|
||||
serve_overrides,
|
||||
bench_overrides,
|
||||
run_number,
|
||||
)
|
||||
|
||||
with output_path.open("w") as f:
|
||||
json.dump(run_data, f, indent=4)
|
||||
|
||||
print("[END BENCHMARK]")
|
||||
|
||||
return run_data
|
||||
|
||||
|
||||
def _get_comb_base_path(
|
||||
output_dir: Path,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
):
|
||||
parts = list[str]()
|
||||
if serve_comb:
|
||||
parts.extend(("SERVE-", serve_comb.name))
|
||||
if bench_comb:
|
||||
parts.extend(("BENCH-", bench_comb.name))
|
||||
|
||||
return output_dir / sanitize_filename("-".join(parts))
|
||||
|
||||
|
||||
def _get_comb_run_path(base_path: Path, run_number: int | None):
|
||||
if run_number is None:
|
||||
return base_path / "summary.json"
|
||||
|
||||
return base_path / f"run={run_number}.json"
|
||||
|
||||
|
||||
def _comb_needs_server(
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_combs: ParameterSweep,
|
||||
output_dir: Path,
|
||||
):
|
||||
for bench_comb in bench_combs:
|
||||
base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb)
|
||||
if not _get_comb_run_path(base_path, run_number=None).exists():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def run_comb(
|
||||
server: ServerProcess | None,
|
||||
bench_cmd: list[str],
|
||||
*,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
base_path: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
):
|
||||
comb_data = list[dict[str, object]]()
|
||||
|
||||
for run_number in range(num_runs):
|
||||
run_data = run_benchmark(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_overrides=serve_comb,
|
||||
bench_overrides=bench_comb,
|
||||
run_number=run_number,
|
||||
output_path=_get_comb_run_path(base_path, run_number),
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
if run_data is not None:
|
||||
comb_data.append(run_data)
|
||||
|
||||
if dry_run:
|
||||
return None
|
||||
|
||||
with _get_comb_run_path(base_path, run_number=None).open("w") as f:
|
||||
json.dump(comb_data, f, indent=4)
|
||||
|
||||
return comb_data
|
||||
|
||||
|
||||
def run_combs(
|
||||
serve_cmd: list[str],
|
||||
bench_cmd: list[str],
|
||||
after_bench_cmd: list[str],
|
||||
*,
|
||||
show_stdout: bool,
|
||||
serve_params: ParameterSweep,
|
||||
bench_params: ParameterSweep,
|
||||
output_dir: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
links: list[tuple[str, str]],
|
||||
):
|
||||
all_data = list[dict[str, object]]()
|
||||
for serve_comb in serve_params:
|
||||
with (
|
||||
run_server(
|
||||
serve_cmd,
|
||||
after_bench_cmd,
|
||||
show_stdout=show_stdout,
|
||||
serve_overrides=serve_comb,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
if _comb_needs_server(serve_comb, bench_params, output_dir)
|
||||
else contextlib.nullcontext()
|
||||
) as server:
|
||||
for bench_comb in bench_params:
|
||||
should_run = all(
|
||||
serve_key in serve_comb
|
||||
and bench_key in bench_comb
|
||||
and serve_comb[serve_key] == bench_comb[bench_key]
|
||||
for serve_key, bench_key in links
|
||||
)
|
||||
if not should_run:
|
||||
continue
|
||||
base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb)
|
||||
|
||||
comb_data = run_comb(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
base_path=base_path,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
if comb_data is not None:
|
||||
all_data.extend(comb_data)
|
||||
|
||||
if dry_run:
|
||||
return None
|
||||
|
||||
combined_df = pd.DataFrame.from_records(all_data)
|
||||
combined_df.to_csv(output_dir / "summary.csv")
|
||||
|
||||
return combined_df
|
||||
|
||||
|
||||
@dataclass
|
||||
class SweepServeArgs:
|
||||
serve_cmd: list[str]
|
||||
bench_cmd: list[str]
|
||||
after_bench_cmd: list[str]
|
||||
show_stdout: bool
|
||||
serve_params: ParameterSweep
|
||||
bench_params: ParameterSweep
|
||||
output_dir: Path
|
||||
num_runs: int
|
||||
dry_run: bool
|
||||
resume: str | None
|
||||
link_vars: list[tuple[str, str]] | None
|
||||
|
||||
parser_name: ClassVar[str] = "serve"
|
||||
parser_help: ClassVar[str] = "Run vLLM server benchmark under multiple settings."
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
serve_cmd = shlex.split(args.serve_cmd)
|
||||
bench_cmd = shlex.split(args.bench_cmd)
|
||||
after_bench_cmd = (
|
||||
[] if args.after_bench_cmd is None else shlex.split(args.after_bench_cmd)
|
||||
)
|
||||
|
||||
if args.serve_params:
|
||||
serve_params = ParameterSweep.read_json(args.serve_params)
|
||||
else:
|
||||
# i.e.: run serve_cmd without any modification
|
||||
serve_params = ParameterSweep.from_records([{}])
|
||||
|
||||
if args.bench_params:
|
||||
bench_params = ParameterSweep.read_json(args.bench_params)
|
||||
else:
|
||||
# i.e.: run bench_cmd without any modification
|
||||
bench_params = ParameterSweep.from_records([{}])
|
||||
link_vars = cls.parse_link_vars(args.link_vars)
|
||||
num_runs = args.num_runs
|
||||
if num_runs < 1:
|
||||
raise ValueError("`num_runs` should be at least 1.")
|
||||
|
||||
return cls(
|
||||
serve_cmd=serve_cmd,
|
||||
bench_cmd=bench_cmd,
|
||||
after_bench_cmd=after_bench_cmd,
|
||||
show_stdout=args.show_stdout,
|
||||
serve_params=serve_params,
|
||||
bench_params=bench_params,
|
||||
output_dir=Path(args.output_dir),
|
||||
num_runs=num_runs,
|
||||
dry_run=args.dry_run,
|
||||
resume=args.resume,
|
||||
link_vars=link_vars,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
"--serve-cmd",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The command used to run the server: `vllm serve ...`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bench-cmd",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The command used to run the benchmark: `vllm bench serve ...`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--after-bench-cmd",
|
||||
type=str,
|
||||
default=None,
|
||||
help="After a benchmark run is complete, invoke this command instead of "
|
||||
"the default `ServerWrapper.clear_cache()`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--show-stdout",
|
||||
action="store_true",
|
||||
help="If set, logs the standard output of subcommands. "
|
||||
"Useful for debugging but can be quite spammy.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--serve-params",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to JSON file containing parameter combinations "
|
||||
"for the `vllm serve` command. Can be either a list of dicts or a dict "
|
||||
"where keys are benchmark names. "
|
||||
"If both `serve_params` and `bench_params` are given, "
|
||||
"this script will iterate over their Cartesian product.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bench-params",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to JSON file containing parameter combinations "
|
||||
"for the `vllm bench serve` command. Can be either a list of dicts or "
|
||||
"a dict where keys are benchmark names. "
|
||||
"If both `serve_params` and `bench_params` are given, "
|
||||
"this script will iterate over their Cartesian product.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="results",
|
||||
help="The directory to which results are written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-runs",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Number of runs per parameter combination.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="If set, prints the commands to run, "
|
||||
"then exits without executing them.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Set this to the name of a directory under `output_dir` (which is a "
|
||||
"timestamp) to resume a previous execution of this script, i.e., only run "
|
||||
"parameter combinations for which there are still no output files.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--link-vars",
|
||||
type=str,
|
||||
default="",
|
||||
help=(
|
||||
"Comma-separated list of linked variables between serve and bench, "
|
||||
"e.g. max_num_seqs=max_concurrency,max_model_len=random_input_len"
|
||||
),
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
@staticmethod
|
||||
def parse_link_vars(s: str) -> list[tuple[str, str]]:
|
||||
if not s:
|
||||
return []
|
||||
pairs = []
|
||||
for item in s.split(","):
|
||||
a, b = item.split("=")
|
||||
pairs.append((a.strip(), b.strip()))
|
||||
return pairs
|
||||
|
||||
|
||||
def run_main(args: SweepServeArgs):
|
||||
timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_dir = args.output_dir / timestamp
|
||||
|
||||
if args.resume and not output_dir.exists():
|
||||
raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")
|
||||
|
||||
try:
|
||||
return run_combs(
|
||||
serve_cmd=args.serve_cmd,
|
||||
bench_cmd=args.bench_cmd,
|
||||
after_bench_cmd=args.after_bench_cmd,
|
||||
show_stdout=args.show_stdout,
|
||||
serve_params=args.serve_params,
|
||||
bench_params=args.bench_params,
|
||||
output_dir=output_dir,
|
||||
num_runs=args.num_runs,
|
||||
dry_run=args.dry_run,
|
||||
links=args.link_vars,
|
||||
)
|
||||
except BaseException as exc:
|
||||
raise RuntimeError(
|
||||
f"The script was terminated early. Use `--resume {timestamp}` "
|
||||
f"to continue the script from its last checkpoint."
|
||||
) from exc
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
run_main(SweepServeArgs.from_cli_args(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=SweepServeArgs.parser_help)
|
||||
SweepServeArgs.add_cli_args(parser)
|
||||
|
||||
main(parser.parse_args())
|
||||
492
vllm/benchmarks/sweep/serve_sla.py
Normal file
492
vllm/benchmarks/sweep/serve_sla.py
Normal file
@@ -0,0 +1,492 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import contextlib
|
||||
import json
|
||||
import math
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Literal, get_args
|
||||
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
from .param_sweep import ParameterSweep, ParameterSweepItem
|
||||
from .serve import SweepServeArgs, run_benchmark, run_server
|
||||
from .server import ServerProcess
|
||||
from .sla_sweep import SLASweep, SLASweepItem
|
||||
from .utils import sanitize_filename
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pd = PlaceholderModule("pandas")
|
||||
|
||||
|
||||
def _get_sla_base_path(
|
||||
output_dir: Path,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
):
|
||||
parts = list[str]()
|
||||
if serve_comb:
|
||||
parts.extend(("SERVE-", serve_comb.as_text(sep="-")))
|
||||
if bench_comb:
|
||||
parts.extend(("BENCH-", bench_comb.as_text(sep="-")))
|
||||
|
||||
return output_dir / sanitize_filename("-".join(parts))
|
||||
|
||||
|
||||
def _get_sla_iter_path(
|
||||
base_path: Path,
|
||||
sla_comb: SLASweepItem,
|
||||
sla_variable: str,
|
||||
sla_value: int | None,
|
||||
):
|
||||
if sla_value is None:
|
||||
prefix = sla_comb.as_text(sep="-")
|
||||
return base_path / f"SLA--{prefix}.json"
|
||||
|
||||
return base_path / f"{sla_variable}={sla_value}"
|
||||
|
||||
|
||||
def _get_sla_run_path(iter_path: Path, run_number: int | None):
|
||||
if run_number is None:
|
||||
return iter_path / "summary.json"
|
||||
|
||||
return iter_path / f"run={run_number}.json"
|
||||
|
||||
|
||||
def _sla_needs_server(
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_combs: ParameterSweep,
|
||||
sla_combs: SLASweep,
|
||||
sla_variable: str,
|
||||
output_dir: Path,
|
||||
):
|
||||
for bench_comb in bench_combs:
|
||||
base_path = _get_sla_base_path(output_dir, serve_comb, bench_comb)
|
||||
for sla_comb in sla_combs:
|
||||
if not _get_sla_iter_path(
|
||||
base_path,
|
||||
sla_comb,
|
||||
sla_variable,
|
||||
sla_value=None,
|
||||
).exists():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def run_sla(
|
||||
server: ServerProcess | None,
|
||||
bench_cmd: list[str],
|
||||
*,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
iter_path: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
):
|
||||
iter_data = list[dict[str, object]]()
|
||||
|
||||
for run_number in range(num_runs):
|
||||
run_data = run_benchmark(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_overrides=serve_comb,
|
||||
bench_overrides=bench_comb,
|
||||
run_number=run_number,
|
||||
output_path=_get_sla_run_path(iter_path, run_number),
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
if run_data is not None:
|
||||
iter_data.append(run_data)
|
||||
|
||||
if dry_run:
|
||||
return None
|
||||
|
||||
with _get_sla_run_path(iter_path, run_number=None).open("w") as f:
|
||||
json.dump(iter_data, f, indent=4)
|
||||
|
||||
return iter_data
|
||||
|
||||
|
||||
SLAVariable = Literal["request_rate", "max_concurrency"]
|
||||
|
||||
|
||||
def _estimate_sla_value(run_data: dict[str, object], sla_variable: SLAVariable):
|
||||
request_throughput = float(run_data["request_throughput"]) # type: ignore
|
||||
if sla_variable == "request_rate":
|
||||
return request_throughput
|
||||
if sla_variable == "max_concurrency":
|
||||
mean_latency_ms = float(run_data["mean_e2el_ms"]) # type: ignore
|
||||
return request_throughput * mean_latency_ms / 1000
|
||||
|
||||
assert_never(sla_variable)
|
||||
|
||||
|
||||
def _estimate_sla_bounds(
|
||||
server: ServerProcess | None,
|
||||
bench_cmd: list[str],
|
||||
*,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
sla_comb: SLASweepItem,
|
||||
base_path: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
sla_variable: SLAVariable,
|
||||
init_value: int,
|
||||
max_value: int,
|
||||
):
|
||||
sla_data = list[dict[str, object]]()
|
||||
|
||||
max_passing: int = 0
|
||||
min_failing: int = 0
|
||||
|
||||
val: int = init_value
|
||||
assert val > 0
|
||||
|
||||
while True:
|
||||
print(f"Testing {sla_variable}: {val} req/s")
|
||||
|
||||
iter_data = run_sla(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb | {sla_variable: val},
|
||||
iter_path=_get_sla_iter_path(base_path, sla_comb, sla_variable, val),
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
assert iter_data is not None
|
||||
sla_data.extend(iter_data)
|
||||
|
||||
iter_data_mean = {
|
||||
k: sum(float(run_data[k]) for run_data in iter_data) / len(iter_data) # type: ignore
|
||||
for k in sla_comb
|
||||
}
|
||||
|
||||
sla_results = [
|
||||
criterion.print_and_validate(iter_data_mean, k)
|
||||
for k, criterion in sla_comb.items()
|
||||
]
|
||||
|
||||
if all(sla_results):
|
||||
print("SLA criteria are met.")
|
||||
max_passing = val
|
||||
val *= 2
|
||||
else:
|
||||
print("SLA criteria are not met.")
|
||||
min_failing = val
|
||||
break
|
||||
|
||||
if val >= max_value:
|
||||
break
|
||||
|
||||
return sla_data, (max_passing, min_failing)
|
||||
|
||||
|
||||
def _find_sla_value(
|
||||
server: ServerProcess | None,
|
||||
bench_cmd: list[str],
|
||||
*,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
sla_comb: SLASweepItem,
|
||||
base_path: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
sla_variable: SLAVariable,
|
||||
min_value: int,
|
||||
max_value: int,
|
||||
):
|
||||
sla_data = list[dict[str, object]]()
|
||||
|
||||
left: int = min_value
|
||||
right: int = max_value
|
||||
|
||||
while True:
|
||||
val = (left + right) // 2
|
||||
print(f"Testing {sla_variable}: {val} req/s")
|
||||
|
||||
iter_data = run_sla(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb | {sla_variable: val},
|
||||
iter_path=_get_sla_iter_path(base_path, sla_comb, sla_variable, val),
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
assert iter_data is not None
|
||||
sla_data.extend(iter_data)
|
||||
|
||||
iter_data_mean = {
|
||||
k: sum(float(run_data[k]) for run_data in iter_data) / len(iter_data) # type: ignore
|
||||
for k in sla_comb
|
||||
}
|
||||
|
||||
sla_results = [
|
||||
criterion.print_and_validate(iter_data_mean, k)
|
||||
for k, criterion in sla_comb.items()
|
||||
]
|
||||
|
||||
if all(sla_results):
|
||||
print("SLA criteria are met.")
|
||||
left = val
|
||||
else:
|
||||
print("SLA criteria are not met.")
|
||||
right = val
|
||||
|
||||
if right - left <= 1:
|
||||
break
|
||||
|
||||
return sla_data, left
|
||||
|
||||
|
||||
def search_sla(
|
||||
server: ServerProcess | None,
|
||||
bench_cmd: list[str],
|
||||
*,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
sla_comb: SLASweepItem,
|
||||
sla_variable: SLAVariable,
|
||||
sla_inf_value: int = 65536, # The value that represents infinite QPS
|
||||
base_path: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
):
|
||||
print("[SLA START]")
|
||||
print(f"SLA criteria: {sla_comb.as_text()}")
|
||||
|
||||
sla_data_0 = run_sla(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb | {sla_variable: sla_inf_value},
|
||||
iter_path=_get_sla_iter_path(base_path, sla_comb, sla_variable, sla_inf_value),
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
if sla_data_0 is None:
|
||||
assert dry_run
|
||||
print("Omitting SLA search.")
|
||||
print("[SLA END]")
|
||||
return None
|
||||
|
||||
sla_init_value = math.ceil(
|
||||
sum(_estimate_sla_value(item, sla_variable) for item in sla_data_0)
|
||||
/ len(sla_data_0)
|
||||
)
|
||||
print(f"Initial {sla_variable} to search: {sla_init_value} req/s.")
|
||||
|
||||
sla_data_1, (sla_min, sla_max) = _estimate_sla_bounds(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
sla_comb=sla_comb,
|
||||
base_path=base_path,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
sla_variable=sla_variable,
|
||||
init_value=sla_init_value,
|
||||
max_value=sla_inf_value,
|
||||
)
|
||||
print(f"Range of {sla_variable} to search: [{sla_min}, {sla_max}] req/s.")
|
||||
|
||||
sla_data_2, sla_value = _find_sla_value(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
sla_comb=sla_comb,
|
||||
base_path=base_path,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
sla_variable=sla_variable,
|
||||
min_value=sla_min,
|
||||
max_value=sla_max,
|
||||
)
|
||||
|
||||
sla_data = sla_data_0 + sla_data_1 + sla_data_2
|
||||
print(f"Maximum {sla_variable} for SLA: {sla_value} req/s.")
|
||||
|
||||
with _get_sla_iter_path(
|
||||
base_path,
|
||||
sla_comb,
|
||||
sla_variable,
|
||||
sla_value=None,
|
||||
).open("w") as f:
|
||||
json.dump(sla_data, f, indent=4)
|
||||
|
||||
print("[SLA END]")
|
||||
|
||||
return sla_data
|
||||
|
||||
|
||||
def run_slas(
|
||||
serve_cmd: list[str],
|
||||
bench_cmd: list[str],
|
||||
after_bench_cmd: list[str],
|
||||
*,
|
||||
show_stdout: bool,
|
||||
serve_params: ParameterSweep,
|
||||
bench_params: ParameterSweep,
|
||||
sla_params: SLASweep,
|
||||
sla_variable: SLAVariable,
|
||||
output_dir: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
):
|
||||
if any(bench_comb.has_param(sla_variable) for bench_comb in bench_params):
|
||||
raise ValueError(
|
||||
f"You should not override `{sla_variable}` in `bench_params` in SLA mode, "
|
||||
"since it is supposed to be determined automatically."
|
||||
)
|
||||
|
||||
all_data = list[dict[str, object]]()
|
||||
for serve_comb in serve_params:
|
||||
with (
|
||||
run_server(
|
||||
serve_cmd,
|
||||
after_bench_cmd,
|
||||
show_stdout=show_stdout,
|
||||
serve_overrides=serve_comb,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
if _sla_needs_server(
|
||||
serve_comb,
|
||||
bench_params,
|
||||
sla_params,
|
||||
sla_variable,
|
||||
output_dir,
|
||||
)
|
||||
else contextlib.nullcontext()
|
||||
) as server:
|
||||
for bench_comb in bench_params:
|
||||
for sla_comb in sla_params:
|
||||
base_path = _get_sla_base_path(output_dir, serve_comb, bench_comb)
|
||||
|
||||
comb_data = search_sla(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
sla_comb=sla_comb,
|
||||
sla_variable=sla_variable,
|
||||
base_path=base_path,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
if comb_data is not None:
|
||||
all_data.extend(comb_data)
|
||||
|
||||
if dry_run:
|
||||
return None
|
||||
|
||||
combined_df = pd.DataFrame.from_records(all_data)
|
||||
combined_df.to_csv(output_dir / "summary.csv")
|
||||
|
||||
return combined_df
|
||||
|
||||
|
||||
@dataclass
|
||||
class SweepServeSLAArgs(SweepServeArgs):
|
||||
sla_params: SLASweep
|
||||
sla_variable: SLAVariable
|
||||
|
||||
parser_name: ClassVar[str] = "serve_sla"
|
||||
parser_help: ClassVar[str] = "Tune a variable to meet SLAs under multiple settings."
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
# NOTE: Don't use super() as `from_cli_args` calls `cls()`
|
||||
base_args = SweepServeArgs.from_cli_args(args)
|
||||
|
||||
if args.sla_params:
|
||||
sla_params = SLASweep.read_json(args.sla_params)
|
||||
else:
|
||||
sla_params = SLASweep.from_records([])
|
||||
|
||||
return cls(
|
||||
**asdict(base_args),
|
||||
sla_params=sla_params,
|
||||
sla_variable=args.sla_variable,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
parser = super().add_cli_args(parser)
|
||||
|
||||
sla_group = parser.add_argument_group("sla options")
|
||||
sla_group.add_argument(
|
||||
"--sla-params",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to JSON file containing a list of SLA constraints to satisfy. "
|
||||
'Each constraint is expressed in `{"<KEY>": "<OP><VALUE>"}` format, '
|
||||
'e.g.: `{"p99_e2el_ms": "<=500"}` means that '
|
||||
"the E2E latency should be less than 500ms 99%% of the time. "
|
||||
"Setting this option runs this script in SLA mode, which searches for "
|
||||
"the maximum `sla_variable` that satisfies the constraints for "
|
||||
"each combination of `serve_params`, `bench_params`, and `sla_params`.",
|
||||
)
|
||||
sla_group.add_argument(
|
||||
"--sla-variable",
|
||||
type=str,
|
||||
choices=get_args(SLAVariable),
|
||||
default="request_rate",
|
||||
help="Whether to tune request rate or maximum concurrency to satisfy "
|
||||
"the SLA constraints.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def run_main(args: SweepServeSLAArgs):
|
||||
timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_dir = args.output_dir / timestamp
|
||||
|
||||
if args.resume and not output_dir.exists():
|
||||
raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")
|
||||
|
||||
try:
|
||||
return run_slas(
|
||||
serve_cmd=args.serve_cmd,
|
||||
bench_cmd=args.bench_cmd,
|
||||
after_bench_cmd=args.after_bench_cmd,
|
||||
show_stdout=args.show_stdout,
|
||||
serve_params=args.serve_params,
|
||||
bench_params=args.bench_params,
|
||||
sla_params=args.sla_params,
|
||||
sla_variable=args.sla_variable,
|
||||
output_dir=output_dir,
|
||||
num_runs=args.num_runs,
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
except BaseException as exc:
|
||||
raise RuntimeError(
|
||||
f"The script was terminated early. Use `--resume {timestamp}` "
|
||||
f"to continue the script from its last checkpoint."
|
||||
) from exc
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
run_main(SweepServeSLAArgs.from_cli_args(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=SweepServeSLAArgs.parser_help)
|
||||
SweepServeSLAArgs.add_cli_args(parser)
|
||||
|
||||
main(parser.parse_args())
|
||||
114
vllm/benchmarks/sweep/server.py
Normal file
114
vllm/benchmarks/sweep/server.py
Normal file
@@ -0,0 +1,114 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
from types import TracebackType
|
||||
|
||||
import requests
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
class ServerProcess:
|
||||
def __init__(
|
||||
self,
|
||||
server_cmd: list[str],
|
||||
after_bench_cmd: list[str],
|
||||
*,
|
||||
show_stdout: bool,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.server_cmd = server_cmd
|
||||
self.after_bench_cmd = after_bench_cmd
|
||||
self.show_stdout = show_stdout
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
exc_traceback: TracebackType | None,
|
||||
) -> None:
|
||||
self.stop()
|
||||
|
||||
def start(self):
|
||||
# Create new process for clean termination
|
||||
self._server_process = subprocess.Popen(
|
||||
self.server_cmd,
|
||||
start_new_session=True,
|
||||
stdout=None if self.show_stdout else subprocess.DEVNULL,
|
||||
# Need `VLLM_SERVER_DEV_MODE=1` for `_reset_caches`
|
||||
env=os.environ | {"VLLM_SERVER_DEV_MODE": "1"},
|
||||
)
|
||||
|
||||
def stop(self):
|
||||
server_process = self._server_process
|
||||
|
||||
if server_process.poll() is None:
|
||||
# In case only some processes have been terminated
|
||||
with contextlib.suppress(ProcessLookupError):
|
||||
# We need to kill both API Server and Engine processes
|
||||
os.killpg(os.getpgid(server_process.pid), signal.SIGKILL)
|
||||
|
||||
def run_subcommand(self, cmd: list[str]):
|
||||
return subprocess.run(
|
||||
cmd,
|
||||
stdout=None if self.show_stdout else subprocess.DEVNULL,
|
||||
check=True,
|
||||
)
|
||||
|
||||
def after_bench(self) -> None:
|
||||
if not self.after_bench_cmd:
|
||||
self.reset_caches()
|
||||
return
|
||||
|
||||
self.run_subcommand(self.after_bench_cmd)
|
||||
|
||||
def _get_vllm_server_address(self) -> str:
|
||||
server_cmd = self.server_cmd
|
||||
|
||||
for host_key in ("--host",):
|
||||
if host_key in server_cmd:
|
||||
host = server_cmd[server_cmd.index(host_key) + 1]
|
||||
break
|
||||
else:
|
||||
host = "localhost"
|
||||
|
||||
for port_key in ("-p", "--port"):
|
||||
if port_key in server_cmd:
|
||||
port = int(server_cmd[server_cmd.index(port_key) + 1])
|
||||
break
|
||||
else:
|
||||
port = 8000 # The default value in vllm serve
|
||||
|
||||
return f"http://{host}:{port}"
|
||||
|
||||
def reset_caches(self) -> None:
|
||||
server_cmd = self.server_cmd
|
||||
|
||||
# Use `.endswith()` to match `/bin/...`
|
||||
if server_cmd[0].endswith("vllm"):
|
||||
server_address = self._get_vllm_server_address()
|
||||
print(f"Resetting caches at {server_address}")
|
||||
|
||||
res = requests.post(f"{server_address}/reset_prefix_cache")
|
||||
res.raise_for_status()
|
||||
|
||||
res = requests.post(f"{server_address}/reset_mm_cache")
|
||||
res.raise_for_status()
|
||||
elif server_cmd[0].endswith("infinity_emb"):
|
||||
if "--vector-disk-cache" in server_cmd:
|
||||
raise NotImplementedError(
|
||||
"Infinity server uses caching but does not expose a method "
|
||||
"to reset the cache"
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"No implementation of `reset_caches` for `{server_cmd[0]}` server. "
|
||||
"Please specify a custom command via `--after-bench-cmd`."
|
||||
)
|
||||
132
vllm/benchmarks/sweep/sla_sweep.py
Normal file
132
vllm/benchmarks/sweep/sla_sweep.py
Normal file
@@ -0,0 +1,132 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import json
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
@dataclass
|
||||
class SLACriterionBase(ABC):
|
||||
target: float
|
||||
|
||||
@abstractmethod
|
||||
def validate(self, actual: float) -> bool:
|
||||
"""Return `True` if this criterion is met; otherwise `False`."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def format_cond(self, lhs: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def print_and_validate(
|
||||
self,
|
||||
metrics: dict[str, float],
|
||||
metrics_key: str,
|
||||
) -> bool:
|
||||
metric = metrics[metrics_key]
|
||||
result = self.validate(metric)
|
||||
|
||||
cond = self.format_cond(f"{metrics_key} = {metric:.2f}")
|
||||
print(f"Validating SLA: {cond} | " + ("PASSED" if result else "FAILED"))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class SLALessThan(SLACriterionBase):
|
||||
@override
|
||||
def validate(self, actual: float) -> bool:
|
||||
return actual < self.target
|
||||
|
||||
@override
|
||||
def format_cond(self, lhs: str) -> str:
|
||||
return f"{lhs}<{self.target:.2f}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SLALessThanOrEqualTo(SLACriterionBase):
|
||||
@override
|
||||
def validate(self, actual: float) -> bool:
|
||||
return actual <= self.target
|
||||
|
||||
@override
|
||||
def format_cond(self, lhs: str) -> str:
|
||||
return f"{lhs}<={self.target:.2f}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SLAGreaterThan(SLACriterionBase):
|
||||
@override
|
||||
def validate(self, actual: float) -> bool:
|
||||
return actual > self.target
|
||||
|
||||
@override
|
||||
def format_cond(self, lhs: str) -> str:
|
||||
return f"{lhs}>{self.target:.2f}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SLAGreaterThanOrEqualTo(SLACriterionBase):
|
||||
@override
|
||||
def validate(self, actual: float) -> bool:
|
||||
return actual >= self.target
|
||||
|
||||
@override
|
||||
def format_cond(self, lhs: str) -> str:
|
||||
return f"{lhs}>={self.target:.2f}"
|
||||
|
||||
|
||||
# NOTE: The ordering is important! Match longer op_keys first
|
||||
SLA_CRITERIA: dict[str, type[SLACriterionBase]] = {
|
||||
"<=": SLALessThanOrEqualTo,
|
||||
">=": SLAGreaterThanOrEqualTo,
|
||||
"<": SLALessThan,
|
||||
">": SLAGreaterThan,
|
||||
}
|
||||
|
||||
|
||||
class SLASweep(list["SLASweepItem"]):
|
||||
@classmethod
|
||||
def read_json(cls, filepath: os.PathLike):
|
||||
with open(filepath, "rb") as f:
|
||||
records = json.load(f)
|
||||
|
||||
return cls.from_records(records)
|
||||
|
||||
@classmethod
|
||||
def from_records(cls, records: list[dict[str, str]]):
|
||||
if not isinstance(records, list):
|
||||
raise TypeError(
|
||||
f"The SLA sweep should be a list of dictionaries, "
|
||||
f"but found type: {type(records)}"
|
||||
)
|
||||
|
||||
return cls(SLASweepItem.from_record(record) for record in records)
|
||||
|
||||
|
||||
class SLASweepItem(dict[str, SLACriterionBase]):
|
||||
@classmethod
|
||||
def from_record(cls, record: dict[str, str]):
|
||||
sla_criteria: dict[str, SLACriterionBase] = {}
|
||||
|
||||
for metric_key, metric_value in record.items():
|
||||
for op_key in SLA_CRITERIA:
|
||||
if metric_value.startswith(op_key):
|
||||
sla_criteria[metric_key] = SLA_CRITERIA[op_key](
|
||||
float(metric_value.removeprefix(op_key))
|
||||
)
|
||||
break
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid operator for "
|
||||
f"SLA constraint '{metric_key}={metric_value}'. "
|
||||
f"Valid operators are: {sorted(SLA_CRITERIA)}",
|
||||
)
|
||||
|
||||
return cls(sla_criteria)
|
||||
|
||||
def as_text(self, sep: str = ", ") -> str:
|
||||
return sep.join(v.format_cond(k) for k, v in self.items())
|
||||
4
vllm/benchmarks/sweep/utils.py
Normal file
4
vllm/benchmarks/sweep/utils.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
def sanitize_filename(filename: str) -> str:
|
||||
return filename.replace("/", "_").replace("..", "__").strip("'").strip('"')
|
||||
808
vllm/benchmarks/throughput.py
Normal file
808
vllm/benchmarks/throughput.py
Normal file
@@ -0,0 +1,808 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Benchmark offline inference throughput."""
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import warnings
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import uvloop
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
|
||||
|
||||
from vllm.benchmarks.datasets import (
|
||||
AIMODataset,
|
||||
BurstGPTDataset,
|
||||
ConversationDataset,
|
||||
InstructCoderDataset,
|
||||
MultiModalConversationDataset,
|
||||
PrefixRepetitionRandomDataset,
|
||||
RandomDataset,
|
||||
SampleRequest,
|
||||
ShareGPTDataset,
|
||||
SonnetDataset,
|
||||
VisionArenaDataset,
|
||||
)
|
||||
from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||
from vllm.inputs import TextPrompt, TokensPrompt
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
from vllm.tokenizers import TokenizerLike, get_tokenizer
|
||||
from vllm.utils.async_utils import merge_async_iterators
|
||||
|
||||
|
||||
def run_vllm(
|
||||
requests: list[SampleRequest],
|
||||
n: int,
|
||||
engine_args: EngineArgs,
|
||||
do_profile: bool,
|
||||
disable_detokenize: bool = False,
|
||||
) -> tuple[float, list[RequestOutput] | None]:
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
assert all(
|
||||
llm.llm_engine.model_config.max_model_len
|
||||
>= (request.prompt_len + request.expected_output_len)
|
||||
for request in requests
|
||||
), (
|
||||
"Please ensure that max_model_len is greater than the sum of"
|
||||
" prompt_len and expected_output_len for all requests."
|
||||
)
|
||||
# Add the requests to the engine.
|
||||
prompts: list[TextPrompt | TokensPrompt] = []
|
||||
sampling_params: list[SamplingParams] = []
|
||||
for request in requests:
|
||||
prompt = (
|
||||
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"])
|
||||
if "prompt_token_ids" in request.prompt
|
||||
else TextPrompt(prompt=request.prompt)
|
||||
)
|
||||
if request.multi_modal_data:
|
||||
assert isinstance(request.multi_modal_data, dict)
|
||||
prompt["multi_modal_data"] = request.multi_modal_data
|
||||
prompts.append(prompt)
|
||||
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
detokenize=not disable_detokenize,
|
||||
)
|
||||
)
|
||||
lora_requests: list[LoRARequest] | None = None
|
||||
if engine_args.enable_lora:
|
||||
lora_requests = [request.lora_request for request in requests]
|
||||
|
||||
use_beam_search = False
|
||||
|
||||
outputs = None
|
||||
if not use_beam_search:
|
||||
start = time.perf_counter()
|
||||
if do_profile:
|
||||
llm.start_profile()
|
||||
outputs = llm.generate(
|
||||
prompts, sampling_params, lora_request=lora_requests, use_tqdm=True
|
||||
)
|
||||
if do_profile:
|
||||
llm.stop_profile()
|
||||
end = time.perf_counter()
|
||||
else:
|
||||
assert lora_requests is None, "BeamSearch API does not support LoRA"
|
||||
prompts = [request.prompt for request in requests]
|
||||
# output_len should be the same for all requests.
|
||||
output_len = requests[0].expected_output_len
|
||||
for request in requests:
|
||||
assert request.expected_output_len == output_len
|
||||
start = time.perf_counter()
|
||||
if do_profile:
|
||||
llm.start_profile()
|
||||
llm.beam_search(
|
||||
prompts,
|
||||
BeamSearchParams(
|
||||
beam_width=n,
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
),
|
||||
)
|
||||
if do_profile:
|
||||
llm.stop_profile()
|
||||
end = time.perf_counter()
|
||||
return end - start, outputs
|
||||
|
||||
|
||||
def run_vllm_chat(
|
||||
requests: list[SampleRequest],
|
||||
n: int,
|
||||
engine_args: EngineArgs,
|
||||
do_profile: bool,
|
||||
disable_detokenize: bool = False,
|
||||
) -> tuple[float, list[RequestOutput]]:
|
||||
"""
|
||||
Run vLLM chat benchmark. This function is recommended ONLY for benchmarking
|
||||
multimodal models as it properly handles multimodal inputs and chat
|
||||
formatting. For non-multimodal models, use run_vllm() instead.
|
||||
"""
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
|
||||
assert all(
|
||||
llm.llm_engine.model_config.max_model_len
|
||||
>= (request.prompt_len + request.expected_output_len)
|
||||
for request in requests
|
||||
), (
|
||||
"Please ensure that max_model_len is greater than the sum of "
|
||||
"prompt_len and expected_output_len for all requests."
|
||||
)
|
||||
|
||||
prompts = []
|
||||
sampling_params: list[SamplingParams] = []
|
||||
for request in requests:
|
||||
prompts.append(request.prompt)
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
detokenize=not disable_detokenize,
|
||||
)
|
||||
)
|
||||
start = time.perf_counter()
|
||||
if do_profile:
|
||||
llm.start_profile()
|
||||
outputs = llm.chat(prompts, sampling_params, use_tqdm=True)
|
||||
if do_profile:
|
||||
llm.stop_profile()
|
||||
end = time.perf_counter()
|
||||
return end - start, outputs
|
||||
|
||||
|
||||
async def run_vllm_async(
|
||||
requests: list[SampleRequest],
|
||||
n: int,
|
||||
engine_args: AsyncEngineArgs,
|
||||
do_profile: bool,
|
||||
disable_frontend_multiprocessing: bool = False,
|
||||
disable_detokenize: bool = False,
|
||||
) -> float:
|
||||
from vllm import SamplingParams
|
||||
from vllm.entrypoints.openai.api_server import (
|
||||
build_async_engine_client_from_engine_args,
|
||||
)
|
||||
|
||||
async with build_async_engine_client_from_engine_args(
|
||||
engine_args,
|
||||
disable_frontend_multiprocessing=disable_frontend_multiprocessing,
|
||||
) as llm:
|
||||
model_config = llm.model_config
|
||||
assert all(
|
||||
model_config.max_model_len
|
||||
>= (request.prompt_len + request.expected_output_len)
|
||||
for request in requests
|
||||
), (
|
||||
"Please ensure that max_model_len is greater than the sum of"
|
||||
" prompt_len and expected_output_len for all requests."
|
||||
)
|
||||
|
||||
# Add the requests to the engine.
|
||||
prompts: list[TextPrompt | TokensPrompt] = []
|
||||
sampling_params: list[SamplingParams] = []
|
||||
lora_requests: list[LoRARequest | None] = []
|
||||
for request in requests:
|
||||
prompt = (
|
||||
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"])
|
||||
if "prompt_token_ids" in request.prompt
|
||||
else TextPrompt(prompt=request.prompt)
|
||||
)
|
||||
|
||||
if request.multi_modal_data:
|
||||
assert isinstance(request.multi_modal_data, dict)
|
||||
prompt["multi_modal_data"] = request.multi_modal_data
|
||||
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
detokenize=not disable_detokenize,
|
||||
)
|
||||
)
|
||||
prompts.append(prompt)
|
||||
lora_requests.append(request.lora_request)
|
||||
|
||||
generators = []
|
||||
start = time.perf_counter()
|
||||
if do_profile:
|
||||
await llm.start_profile()
|
||||
for i, (prompt, sp, lr) in enumerate(
|
||||
zip(prompts, sampling_params, lora_requests)
|
||||
):
|
||||
generator = llm.generate(prompt, sp, lora_request=lr, request_id=f"test{i}")
|
||||
generators.append(generator)
|
||||
all_gens = merge_async_iterators(*generators)
|
||||
async for i, res in all_gens:
|
||||
pass
|
||||
if do_profile:
|
||||
await llm.stop_profile()
|
||||
end = time.perf_counter()
|
||||
return end - start
|
||||
|
||||
|
||||
def run_hf(
|
||||
requests: list[SampleRequest],
|
||||
model: str,
|
||||
tokenizer: TokenizerLike,
|
||||
n: int,
|
||||
max_batch_size: int,
|
||||
trust_remote_code: bool,
|
||||
disable_detokenize: bool = False,
|
||||
) -> float:
|
||||
assert isinstance(tokenizer, PreTrainedTokenizerBase), (
|
||||
"the hf backend only supports HF tokenizers"
|
||||
)
|
||||
llm = AutoModelForCausalLM.from_pretrained(
|
||||
model, dtype=torch.float16, trust_remote_code=trust_remote_code
|
||||
)
|
||||
if llm.config.model_type == "llama":
|
||||
# To enable padding in the HF backend.
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
llm = llm.cuda()
|
||||
|
||||
pbar = tqdm(total=len(requests))
|
||||
start = time.perf_counter()
|
||||
batch: list[str] = []
|
||||
max_prompt_len = 0
|
||||
max_output_len = 0
|
||||
for i in range(len(requests)):
|
||||
prompt = requests[i].prompt
|
||||
prompt_len = requests[i].prompt_len
|
||||
output_len = requests[i].expected_output_len
|
||||
# Add the prompt to the batch.
|
||||
batch.append(prompt)
|
||||
max_prompt_len = max(max_prompt_len, prompt_len)
|
||||
max_output_len = max(max_output_len, output_len)
|
||||
if len(batch) < max_batch_size and i != len(requests) - 1:
|
||||
# Check if we can add more requests to the batch.
|
||||
next_prompt_len = requests[i + 1].prompt_len
|
||||
next_output_len = requests[i + 1].expected_output_len
|
||||
if (
|
||||
max(max_prompt_len, next_prompt_len)
|
||||
+ max(max_output_len, next_output_len)
|
||||
) <= 2048:
|
||||
# We can add more requests to the batch.
|
||||
continue
|
||||
|
||||
# Generate the sequences.
|
||||
input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids
|
||||
llm_outputs = llm.generate(
|
||||
input_ids=input_ids.cuda(),
|
||||
do_sample=True,
|
||||
num_return_sequences=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
use_cache=True,
|
||||
max_new_tokens=max_output_len,
|
||||
)
|
||||
if not disable_detokenize:
|
||||
# Include the decoding time.
|
||||
tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
|
||||
pbar.update(len(batch))
|
||||
|
||||
# Clear the batch.
|
||||
batch = []
|
||||
max_prompt_len = 0
|
||||
max_output_len = 0
|
||||
end = time.perf_counter()
|
||||
return end - start
|
||||
|
||||
|
||||
def save_to_pytorch_benchmark_format(
|
||||
args: argparse.Namespace, results: dict[str, Any]
|
||||
) -> None:
|
||||
pt_records = convert_to_pytorch_benchmark_format(
|
||||
args=args,
|
||||
metrics={
|
||||
"requests_per_second": [results["requests_per_second"]],
|
||||
"tokens_per_second": [results["tokens_per_second"]],
|
||||
},
|
||||
extra_info={
|
||||
k: results[k] for k in ["elapsed_time", "num_requests", "total_num_tokens"]
|
||||
},
|
||||
)
|
||||
if pt_records:
|
||||
# Don't use json suffix here as we don't want CI to pick it up
|
||||
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
|
||||
write_to_json(pt_file, pt_records)
|
||||
|
||||
|
||||
def get_requests(args, tokenizer):
|
||||
# Common parameters for all dataset types.
|
||||
common_kwargs = {
|
||||
"dataset_path": args.dataset_path,
|
||||
"random_seed": args.seed,
|
||||
}
|
||||
sample_kwargs = {
|
||||
"tokenizer": tokenizer,
|
||||
"lora_path": args.lora_path,
|
||||
"max_loras": args.max_loras,
|
||||
"num_requests": args.num_prompts,
|
||||
"input_len": args.input_len,
|
||||
"output_len": args.output_len,
|
||||
}
|
||||
|
||||
if args.dataset_path is None or args.dataset_name == "random":
|
||||
sample_kwargs["range_ratio"] = args.random_range_ratio
|
||||
sample_kwargs["prefix_len"] = args.prefix_len
|
||||
dataset_cls = RandomDataset
|
||||
elif args.dataset_name == "sharegpt":
|
||||
dataset_cls = ShareGPTDataset
|
||||
if args.backend == "vllm-chat":
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
elif args.dataset_name == "sonnet":
|
||||
assert tokenizer.chat_template or tokenizer.default_chat_template, (
|
||||
"Tokenizer/model must have chat template for sonnet dataset."
|
||||
)
|
||||
dataset_cls = SonnetDataset
|
||||
sample_kwargs["prefix_len"] = args.prefix_len
|
||||
sample_kwargs["return_prompt_formatted"] = True
|
||||
elif args.dataset_name == "burstgpt":
|
||||
dataset_cls = BurstGPTDataset
|
||||
elif args.dataset_name == "hf":
|
||||
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = VisionArenaDataset
|
||||
common_kwargs["dataset_subset"] = None
|
||||
common_kwargs["dataset_split"] = "train"
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = InstructCoderDataset
|
||||
common_kwargs["dataset_split"] = "train"
|
||||
elif args.dataset_path in MultiModalConversationDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = MultiModalConversationDataset
|
||||
common_kwargs["dataset_subset"] = args.hf_subset
|
||||
common_kwargs["dataset_split"] = args.hf_split
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = ConversationDataset
|
||||
common_kwargs["dataset_subset"] = args.hf_subset
|
||||
common_kwargs["dataset_split"] = args.hf_split
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = AIMODataset
|
||||
common_kwargs["dataset_subset"] = None
|
||||
common_kwargs["dataset_split"] = "train"
|
||||
elif args.dataset_name == "prefix_repetition":
|
||||
dataset_cls = PrefixRepetitionRandomDataset
|
||||
sample_kwargs["prefix_len"] = args.prefix_repetition_prefix_len
|
||||
sample_kwargs["suffix_len"] = args.prefix_repetition_suffix_len
|
||||
sample_kwargs["num_prefixes"] = args.prefix_repetition_num_prefixes
|
||||
sample_kwargs["output_len"] = args.prefix_repetition_output_len
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset name: {args.dataset_name}")
|
||||
# Remove None values
|
||||
sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None}
|
||||
requests = dataset_cls(**common_kwargs).sample(**sample_kwargs)
|
||||
requests = filter_requests_for_dp(requests, args.data_parallel_size)
|
||||
return requests
|
||||
|
||||
|
||||
def filter_requests_for_dp(requests, data_parallel_size):
|
||||
# Note(zhuohan): The way we get data_parallel_rank is hacky and only
|
||||
# works for external launcher mode. Should be cleaned up and deprecated
|
||||
# in the future with a better vLLM distributed process design.
|
||||
if data_parallel_size == 1:
|
||||
return requests
|
||||
|
||||
global_rank = int(os.environ["RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
data_parallel_rank = global_rank // (world_size // data_parallel_size)
|
||||
return [
|
||||
r
|
||||
for i, r in enumerate(requests)
|
||||
if i % data_parallel_size == data_parallel_rank
|
||||
]
|
||||
|
||||
|
||||
def validate_args(args):
|
||||
"""
|
||||
Validate command-line arguments.
|
||||
"""
|
||||
|
||||
# === Deprecation and Defaulting ===
|
||||
if args.dataset is not None:
|
||||
warnings.warn(
|
||||
"The '--dataset' argument will be deprecated in the next release. "
|
||||
"Please use '--dataset-name' and '--dataset-path' instead.",
|
||||
stacklevel=2,
|
||||
)
|
||||
args.dataset_path = args.dataset
|
||||
|
||||
if not getattr(args, "tokenizer", None):
|
||||
args.tokenizer = args.model
|
||||
|
||||
# === Backend Validation ===
|
||||
valid_backends = {"vllm", "hf", "mii", "vllm-chat"}
|
||||
if args.backend not in valid_backends:
|
||||
raise ValueError(f"Unsupported backend: {args.backend}")
|
||||
|
||||
# === Dataset Configuration ===
|
||||
if (
|
||||
not args.dataset
|
||||
and not args.dataset_path
|
||||
and args.dataset_name not in {"prefix_repetition"}
|
||||
):
|
||||
print("When dataset path is not set, it will default to random dataset")
|
||||
args.dataset_name = "random"
|
||||
if args.input_len is None:
|
||||
raise ValueError("input_len must be provided for a random dataset")
|
||||
|
||||
# === Dataset Name Specific Checks ===
|
||||
# --hf-subset and --hf-split: only used
|
||||
# when dataset_name is 'hf'
|
||||
if args.dataset_name != "hf" and (
|
||||
getattr(args, "hf_subset", None) is not None
|
||||
or getattr(args, "hf_split", None) is not None
|
||||
):
|
||||
warnings.warn(
|
||||
"--hf-subset and --hf-split will be ignored \
|
||||
since --dataset-name is not 'hf'.",
|
||||
stacklevel=2,
|
||||
)
|
||||
elif args.dataset_name == "hf":
|
||||
if args.dataset_path in (
|
||||
VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys()
|
||||
| MultiModalConversationDataset.SUPPORTED_DATASET_PATHS
|
||||
| ConversationDataset.SUPPORTED_DATASET_PATHS
|
||||
):
|
||||
assert args.backend == "vllm-chat", (
|
||||
f"{args.dataset_path} needs to use vllm-chat as the backend."
|
||||
)
|
||||
elif args.dataset_path in (
|
||||
InstructCoderDataset.SUPPORTED_DATASET_PATHS
|
||||
| AIMODataset.SUPPORTED_DATASET_PATHS
|
||||
):
|
||||
assert args.backend == "vllm", (
|
||||
f"{args.dataset_path} needs to use vllm as the backend."
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"{args.dataset_path} is not supported by hf dataset.")
|
||||
|
||||
# --random-range-ratio: only used when dataset_name is 'random'
|
||||
if args.dataset_name != "random" and args.random_range_ratio is not None:
|
||||
warnings.warn(
|
||||
"--random-range-ratio will be ignored since \
|
||||
--dataset-name is not 'random'.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# --prefix-len: only used when dataset_name is 'random', 'sonnet', or not
|
||||
# set.
|
||||
if (
|
||||
args.dataset_name not in {"random", "sonnet", None}
|
||||
and args.prefix_len is not None
|
||||
):
|
||||
warnings.warn(
|
||||
"--prefix-len will be ignored since --dataset-name\
|
||||
is not 'random', 'sonnet', or not set.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# === LoRA Settings ===
|
||||
if getattr(args, "enable_lora", False) and args.backend != "vllm":
|
||||
raise ValueError("LoRA benchmarking is only supported for vLLM backend")
|
||||
if getattr(args, "enable_lora", False) and args.lora_path is None:
|
||||
raise ValueError("LoRA path must be provided when enable_lora is True")
|
||||
|
||||
# === Backend-specific Validations ===
|
||||
if args.backend == "hf" and args.hf_max_batch_size is None:
|
||||
raise ValueError("HF max batch size is required for HF backend")
|
||||
if args.backend != "hf" and args.hf_max_batch_size is not None:
|
||||
raise ValueError("HF max batch size is only for HF backend.")
|
||||
|
||||
if (
|
||||
args.backend in {"hf", "mii"}
|
||||
and getattr(args, "quantization", None) is not None
|
||||
):
|
||||
raise ValueError("Quantization is only for vLLM backend.")
|
||||
|
||||
if args.backend == "mii" and args.dtype != "auto":
|
||||
raise ValueError("dtype must be auto for MII backend.")
|
||||
if args.backend == "mii" and args.n != 1:
|
||||
raise ValueError("n must be 1 for MII backend.")
|
||||
if args.backend == "mii" and args.tokenizer != args.model:
|
||||
raise ValueError("Tokenizer must be the same as the model for MII backend.")
|
||||
|
||||
if args.data_parallel_size > 1 and (
|
||||
args.distributed_executor_backend != "external_launcher" or args.async_engine
|
||||
):
|
||||
# --data-parallel is not supported fully.
|
||||
# Old issue: https://github.com/vllm-project/vllm/issues/16222
|
||||
# Currently we only support data parallel with external launcher
|
||||
# mode (i.e., launch with toruchrun).
|
||||
raise ValueError(
|
||||
"Data parallel is only supported with external launcher mode "
|
||||
"with synchronous engine in offline benchmark, "
|
||||
"please use benchmark serving instead"
|
||||
)
|
||||
|
||||
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
type=str,
|
||||
choices=["vllm", "hf", "mii", "vllm-chat"],
|
||||
default="vllm",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
choices=["sharegpt", "random", "sonnet", "burstgpt", "hf", "prefix_repetition"],
|
||||
help="Name of the dataset to benchmark on.",
|
||||
default="sharegpt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the ShareGPT dataset, will be deprecated in\
|
||||
the next release. The dataset is expected to "
|
||||
"be a json in form of list[dict[..., conversations: "
|
||||
"list[dict[..., value: <prompt_or_response>]]]]",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-path", type=str, default=None, help="Path to the dataset"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Input prompt length for each request",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Output length for each request. Overrides the "
|
||||
"output length from the dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n", type=int, default=1, help="Number of generated sequences per prompt."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-prompts", type=int, default=1000, help="Number of prompts to process."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-max-batch-size",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum batch size for HF backend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-json",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to save the throughput results in JSON format.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--async-engine",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use vLLM async engine rather than LLM class.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-frontend-multiprocessing",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Disable decoupled async engine frontend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-detokenize",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Do not detokenize the response (i.e. do not include "
|
||||
"detokenization time in the measurement)"
|
||||
),
|
||||
)
|
||||
# LoRA
|
||||
parser.add_argument(
|
||||
"--lora-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the lora adapters to use. This can be an absolute path, "
|
||||
"a relative path, or a Hugging Face model identifier.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefix-len",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of fixed prefix tokens before the random "
|
||||
"context in a request (default: 0).",
|
||||
)
|
||||
# random dataset
|
||||
parser.add_argument(
|
||||
"--random-range-ratio",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Range ratio for sampling input/output length, "
|
||||
"used only for RandomDataset. Must be in the range [0, 1) to define "
|
||||
"a symmetric sampling range "
|
||||
"[length * (1 - range_ratio), length * (1 + range_ratio)].",
|
||||
)
|
||||
|
||||
# hf dtaset
|
||||
parser.add_argument(
|
||||
"--hf-subset", type=str, default=None, help="Subset of the HF dataset."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-split", type=str, default=None, help="Split of the HF dataset."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use vLLM Profiling. --profiler-config must be provided on the server.",
|
||||
)
|
||||
|
||||
# prefix repetition dataset
|
||||
prefix_repetition_group = parser.add_argument_group(
|
||||
"prefix repetition dataset options"
|
||||
)
|
||||
prefix_repetition_group.add_argument(
|
||||
"--prefix-repetition-prefix-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of prefix tokens per request, used only for prefix "
|
||||
"repetition dataset.",
|
||||
)
|
||||
prefix_repetition_group.add_argument(
|
||||
"--prefix-repetition-suffix-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of suffix tokens per request, used only for prefix "
|
||||
"repetition dataset. Total input length is prefix_len + suffix_len.",
|
||||
)
|
||||
prefix_repetition_group.add_argument(
|
||||
"--prefix-repetition-num-prefixes",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of prefixes to generate, used only for prefix repetition "
|
||||
"dataset. Prompts per prefix is num_requests // num_prefixes.",
|
||||
)
|
||||
prefix_repetition_group.add_argument(
|
||||
"--prefix-repetition-output-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of output tokens per request, used only for prefix "
|
||||
"repetition dataset.",
|
||||
)
|
||||
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
validate_args(args)
|
||||
if args.seed is None:
|
||||
args.seed = 0
|
||||
random.seed(args.seed)
|
||||
# Sample the requests.
|
||||
if (
|
||||
args.backend == "hf" or args.backend == "mii"
|
||||
) and args.tokenizer_mode == "auto":
|
||||
# mistral_common tokenizer is only supported on vllm and vllm-chat backends;
|
||||
# for hf and mii backends, we use hf tokenizer
|
||||
args.tokenizer_mode = "hf"
|
||||
tokenizer = get_tokenizer(
|
||||
args.tokenizer,
|
||||
tokenizer_mode=args.tokenizer_mode,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
)
|
||||
requests = get_requests(args, tokenizer)
|
||||
is_multi_modal = any(request.multi_modal_data is not None for request in requests)
|
||||
request_outputs: list[RequestOutput] | None = None
|
||||
if args.backend == "vllm":
|
||||
if args.async_engine:
|
||||
elapsed_time = uvloop.run(
|
||||
run_vllm_async(
|
||||
requests,
|
||||
args.n,
|
||||
AsyncEngineArgs.from_cli_args(args),
|
||||
disable_frontend_multiprocessing=args.disable_frontend_multiprocessing,
|
||||
disable_detokenize=args.disable_detokenize,
|
||||
do_profile=args.profile,
|
||||
)
|
||||
)
|
||||
else:
|
||||
elapsed_time, request_outputs = run_vllm(
|
||||
requests,
|
||||
args.n,
|
||||
EngineArgs.from_cli_args(args),
|
||||
disable_detokenize=args.disable_detokenize,
|
||||
do_profile=args.profile,
|
||||
)
|
||||
elif args.backend == "hf":
|
||||
assert args.tensor_parallel_size == 1
|
||||
if args.profile:
|
||||
raise NotImplementedError("Profiling not implemented yet for backend='hf'.")
|
||||
elapsed_time = run_hf(
|
||||
requests,
|
||||
args.model,
|
||||
tokenizer,
|
||||
args.n,
|
||||
args.hf_max_batch_size,
|
||||
args.trust_remote_code,
|
||||
args.disable_detokenize,
|
||||
)
|
||||
elif args.backend == "vllm-chat":
|
||||
elapsed_time, request_outputs = run_vllm_chat(
|
||||
requests,
|
||||
args.n,
|
||||
EngineArgs.from_cli_args(args),
|
||||
disable_detokenize=args.disable_detokenize,
|
||||
do_profile=args.profile,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown backend: {args.backend}")
|
||||
|
||||
if request_outputs:
|
||||
# Note: with the vllm and vllm-chat backends,
|
||||
# we have request_outputs, which we use to count tokens.
|
||||
total_prompt_tokens = 0
|
||||
total_output_tokens = 0
|
||||
for ro in request_outputs:
|
||||
if not isinstance(ro, RequestOutput):
|
||||
continue
|
||||
total_prompt_tokens += (
|
||||
len(ro.prompt_token_ids) if ro.prompt_token_ids else 0
|
||||
)
|
||||
total_output_tokens += sum(len(o.token_ids) for o in ro.outputs if o)
|
||||
total_num_tokens = total_prompt_tokens + total_output_tokens
|
||||
else:
|
||||
total_num_tokens = sum(r.prompt_len + r.expected_output_len for r in requests)
|
||||
total_output_tokens = sum(r.expected_output_len for r in requests)
|
||||
total_prompt_tokens = total_num_tokens - total_output_tokens
|
||||
|
||||
if is_multi_modal and args.backend != "vllm-chat":
|
||||
print(
|
||||
"\033[91mWARNING\033[0m: Multi-modal request with "
|
||||
f"{args.backend} backend detected. The "
|
||||
"following metrics are not accurate because image tokens are not"
|
||||
" counted. See vllm-project/vllm/issues/9778 for details."
|
||||
)
|
||||
# TODO(vllm-project/vllm/issues/9778): Count multi-modal token length.
|
||||
# vllm-chat backend counts the image tokens now
|
||||
|
||||
print(
|
||||
f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
||||
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
|
||||
f"{total_output_tokens / elapsed_time:.2f} output tokens/s"
|
||||
)
|
||||
print(f"Total num prompt tokens: {total_prompt_tokens}")
|
||||
print(f"Total num output tokens: {total_output_tokens}")
|
||||
|
||||
# Output JSON results if specified
|
||||
if args.output_json:
|
||||
results = {
|
||||
"elapsed_time": elapsed_time,
|
||||
"num_requests": len(requests),
|
||||
"total_num_tokens": total_num_tokens,
|
||||
"requests_per_second": len(requests) / elapsed_time,
|
||||
"tokens_per_second": total_num_tokens / elapsed_time,
|
||||
}
|
||||
with open(args.output_json, "w") as f:
|
||||
json.dump(results, f, indent=4)
|
||||
save_to_pytorch_benchmark_format(args, results)
|
||||
Reference in New Issue
Block a user