Files
sglang/python/sglang/bench_serving.py

2412 lines
86 KiB
Python
Raw Normal View History

2024-07-20 02:15:21 +10:00
# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/backend_request_func.py
# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/benchmark_serving.py
2024-07-19 11:12:23 -07:00
"""
2024-09-18 00:56:06 -07:00
Benchmark online serving with dynamic requests.
2024-07-19 11:12:23 -07:00
Usage:
2024-07-29 19:40:28 -07:00
python3 -m sglang.bench_serving --backend sglang --num-prompt 10
2024-07-24 14:44:14 -07:00
2024-07-29 19:40:28 -07:00
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 3000 --random-input 1024 --random-output 1024 --random-range-ratio 0.5
2024-07-19 11:12:23 -07:00
"""
2024-07-20 02:15:21 +10:00
import argparse
import asyncio
import base64
import io
2024-07-20 02:15:21 +10:00
import json
import os
import pickle
2024-07-20 02:15:21 +10:00
import random
import resource
import sys
import time
import traceback
import warnings
2024-08-01 21:20:17 -07:00
from argparse import ArgumentParser
2024-07-20 02:15:21 +10:00
from dataclasses import dataclass, field
from datetime import datetime
from json import JSONDecodeError
from pathlib import Path
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
2024-07-20 02:15:21 +10:00
import aiohttp
import numpy as np
import requests
from tqdm.asyncio import tqdm
from transformers import (
AutoTokenizer,
PreTrainedTokenizer,
PreTrainedTokenizerBase,
PreTrainedTokenizerFast,
)
2025-03-04 21:23:47 -08:00
ASSISTANT_SUFFIX = "Assistant:"
2024-07-20 02:15:21 +10:00
global args
2024-07-20 02:15:21 +10:00
2025-04-01 15:10:56 -07:00
# don't want to import sglang package here
def _get_bool_env_var(name: str, default: str = "false") -> bool:
value = os.getenv(name, default)
return value.lower() in ("true", "1")
def _create_bench_client_session():
# When the pressure is big, the read buffer could be full before aio thread read
# the content. We increase the read_bufsize from 64K to 10M.
# Define constants for timeout and buffer size for clarity and maintainability
BENCH_AIOHTTP_TIMEOUT_SECONDS = 6 * 60 * 60 # 6 hours
BENCH_AIOHTTP_READ_BUFSIZE_BYTES = 10 * 1024**2 # 10 MB
aiohttp_timeout = aiohttp.ClientTimeout(total=BENCH_AIOHTTP_TIMEOUT_SECONDS)
return aiohttp.ClientSession(
timeout=aiohttp_timeout, read_bufsize=BENCH_AIOHTTP_READ_BUFSIZE_BYTES
)
2024-07-20 02:15:21 +10:00
@dataclass
class RequestFuncInput:
prompt: str
api_url: str
prompt_len: int
output_len: int
model: str
lora_name: str
image_data: Optional[List[str]]
extra_request_body: Dict[str, Any]
timestamp: Optional[float] = None
2024-07-20 02:15:21 +10:00
@dataclass
class RequestFuncOutput:
generated_text: str = ""
success: bool = False
latency: float = 0.0
ttft: float = 0.0 # Time to first token
itl: List[float] = field(default_factory=list) # List of inter-token latencies
prompt_len: int = 0
error: str = ""
2024-07-22 19:34:05 +10:00
output_len: int = 0
2024-07-20 02:15:21 +10:00
@staticmethod
def init_new(request_func_input: RequestFuncInput):
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
return output
2024-07-20 02:15:21 +10:00
def remove_prefix(text: str, prefix: str) -> str:
return text[len(prefix) :] if text.startswith(prefix) else text
def remove_suffix(text: str, suffix: str) -> str:
return text[: -len(suffix)] if text.endswith(suffix) else text
def get_auth_headers() -> Dict[str, str]:
openai_api_key = os.environ.get("OPENAI_API_KEY")
if openai_api_key:
return {"Authorization": f"Bearer {openai_api_key}"}
else:
api_key = os.environ.get("API_KEY")
if api_key:
return {"Authorization": f"{api_key}"}
return {}
# trt llm does not support ignore_eos
# https://github.com/triton-inference-server/tensorrtllm_backend/issues/505
async def async_request_trt_llm(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith("generate_stream")
async with _create_bench_client_session() as session:
payload = {
"accumulate_tokens": True,
"text_input": request_func_input.prompt,
2024-07-22 18:23:33 +10:00
"temperature": 0.000001,
"top_p": 1.0,
"max_tokens": request_func_input.output_len,
"stream": True,
2024-07-22 06:32:41 -07:00
"min_length": request_func_input.output_len,
"end_id": 1048576,
**request_func_input.extra_request_body,
}
if args.disable_ignore_eos:
del payload["min_length"]
del payload["end_id"]
output = RequestFuncOutput.init_new(request_func_input)
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data:")
data = json.loads(chunk)
output.generated_text += data["text_output"]
timestamp = time.perf_counter()
# First token
if ttft == 0.0:
2025-03-16 07:33:58 +08:00
ttft = timestamp - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp = timestamp
output.latency = most_recent_timestamp - st
output.success = True
2024-07-22 06:32:41 -07:00
output.output_len = request_func_input.output_len
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar:
pbar.update(1)
return output
2024-07-20 02:15:21 +10:00
# set ignore_eos True by default
async def async_request_openai_completions(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith(
"completions"
), "OpenAI Completions API URL must end with 'completions'."
2024-08-17 17:43:23 -07:00
prompt = request_func_input.prompt
async with _create_bench_client_session() as session:
2024-07-20 02:15:21 +10:00
payload = {
"model": request_func_input.model,
2024-08-17 17:43:23 -07:00
"prompt": prompt,
2024-07-20 02:15:21 +10:00
"temperature": 0.0,
"best_of": 1,
"max_tokens": request_func_input.output_len,
2024-07-20 18:36:42 -07:00
"stream": not args.disable_stream,
"ignore_eos": not args.disable_ignore_eos,
**request_func_input.extra_request_body,
2024-07-20 02:15:21 +10:00
}
2025-09-27 06:24:30 +08:00
if request_func_input.image_data:
payload.update({"image_data": request_func_input.image_data})
headers = get_auth_headers()
2024-07-20 02:15:21 +10:00
output = RequestFuncOutput.init_new(request_func_input)
2024-07-20 02:15:21 +10:00
generated_text = ""
output_len = request_func_input.output_len
2024-07-20 02:15:21 +10:00
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(
url=api_url, json=payload, headers=headers
) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
2024-07-20 18:36:42 -07:00
latency = time.perf_counter() - st
2024-07-20 02:15:21 +10:00
if chunk == "[DONE]":
2024-07-20 18:36:42 -07:00
pass
2024-07-20 02:15:21 +10:00
else:
data = json.loads(chunk)
# NOTE: Some completion API might have a last
# usage summary response without a token so we
# want to check a token was generated
if data["choices"][0]["text"]:
timestamp = time.perf_counter()
# First token
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(timestamp - most_recent_timestamp)
2024-07-20 02:15:21 +10:00
most_recent_timestamp = timestamp
generated_text += data["choices"][0]["text"]
2025-03-06 21:34:02 +08:00
output_len = (data.get("usage") or {}).get(
"completion_tokens", output_len
)
2024-07-20 02:15:21 +10:00
output.generated_text = generated_text
output.success = True
output.latency = latency
output.output_len = output_len
2024-07-20 02:15:21 +10:00
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar:
pbar.update(1)
return output
async def async_request_openai_chat_completions(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
"""Makes a request to the OpenAI Chat Completions API.
Handles both streaming and non-streaming responses, including support
for image data in messages. Calculates and returns various performance
metrics.
Args:
request_func_input: Input parameters for the request.
pbar: Optional tqdm progress bar to update.
Returns:
RequestFuncOutput: Output of the request, including generated text,
latency, TTFT, ITL, and success status.
"""
api_url = request_func_input.api_url
assert api_url.endswith(
"chat/completions"
), "OpenAI Chat Completions API URL must end with 'chat/completions'."
if request_func_input.image_data:
# Build multi-image content: a list of image_url entries followed by the text
content_items = [
{
"type": "image_url",
"image_url": {"url": img_url},
}
for img_url in request_func_input.image_data
]
content_items.append({"type": "text", "text": request_func_input.prompt})
messages = [
{
"role": "user",
"content": content_items,
},
]
else:
messages = [{"role": "user", "content": request_func_input.prompt}]
async with _create_bench_client_session() as session:
payload = {
"model": request_func_input.model,
"messages": messages,
"temperature": 0.0,
"max_tokens": request_func_input.output_len,
"stream": not args.disable_stream,
**request_func_input.extra_request_body,
}
headers = get_auth_headers()
output = RequestFuncOutput.init_new(request_func_input)
generated_text = ""
output_len = request_func_input.output_len
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(
url=api_url, json=payload, headers=headers
) as response:
if response.status == 200:
if args.disable_stream:
# Non-streaming response
response_json = await response.json()
output.generated_text = response_json["choices"][0]["message"][
"content"
]
output.success = True
output.latency = time.perf_counter() - st
output.ttft = (
output.latency
) # For non-streaming, TTFT = total latency
output.output_len = response_json.get("usage", {}).get(
"completion_tokens", output_len
)
else:
# Streaming response
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
latency = time.perf_counter() - st
if chunk == "[DONE]":
pass
else:
data = json.loads(chunk)
# Check if this chunk contains content
delta = data.get("choices", [{}])[0].get("delta", {})
content = delta.get("content", "")
if content:
timestamp = time.perf_counter()
# First token
if ttft == 0.0:
ttft = timestamp - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(
timestamp - most_recent_timestamp
)
most_recent_timestamp = timestamp
generated_text += content
# Check for usage info in final chunk
output_len = (data.get("usage") or {}).get(
"completion_tokens", output_len
)
output.generated_text = generated_text
output.success = True
output.latency = latency
output.output_len = output_len
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar:
pbar.update(1)
return output
async def async_request_truss(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
prompt = request_func_input.prompt
async with _create_bench_client_session() as session:
payload = {
"model": request_func_input.model,
"prompt": prompt,
"temperature": 0.0,
"best_of": 1,
"max_tokens": request_func_input.output_len,
"stream": not args.disable_stream,
"ignore_eos": not args.disable_ignore_eos,
**request_func_input.extra_request_body,
}
headers = get_auth_headers()
output = RequestFuncOutput.init_new(request_func_input)
generated_text = ""
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(
url=api_url, json=payload, headers=headers
) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
latency = time.perf_counter() - st
if chunk == "[DONE]":
pass
else:
data = json.loads(chunk)
# NOTE: Some completion API might have a last
# usage summary response without a token so we
# want to check a token was generated
if data["choices"][0]["text"]:
timestamp = time.perf_counter()
# First token
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp = timestamp
generated_text += data["choices"][0]["text"]
output.generated_text = generated_text
output.success = True
output.latency = latency
output.output_len = request_func_input.output_len
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar:
pbar.update(1)
return output
2024-10-14 21:53:01 -07:00
async def async_request_sglang_generate(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
prompt = request_func_input.prompt
async with _create_bench_client_session() as session:
2024-10-14 21:53:01 -07:00
payload = {
("text" if isinstance(prompt, str) else "input_ids"): prompt,
2024-10-14 21:53:01 -07:00
"sampling_params": {
"temperature": 0.0,
"max_new_tokens": request_func_input.output_len,
"ignore_eos": not args.disable_ignore_eos,
},
"stream": not args.disable_stream,
"lora_path": request_func_input.lora_name,
"return_logprob": args.return_logprob,
"logprob_start_len": -1,
2024-10-14 21:53:01 -07:00
**request_func_input.extra_request_body,
}
# Add image data if available (list of image urls/base64)
if request_func_input.image_data:
payload["image_data"] = request_func_input.image_data
headers = get_auth_headers()
2024-10-14 21:53:01 -07:00
output = RequestFuncOutput.init_new(request_func_input)
2024-10-14 21:53:01 -07:00
generated_text = ""
output_len = request_func_input.output_len
2024-10-14 21:53:01 -07:00
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
last_output_len = 0
2024-10-14 21:53:01 -07:00
try:
async with session.post(
url=api_url, json=payload, headers=headers
) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
latency = time.perf_counter() - st
if chunk == "[DONE]":
pass
else:
data = json.loads(chunk)
# NOTE: Some completion API might have a last
# usage summary response without a token so we
# want to check a token was generated
2025-06-12 02:47:25 +08:00
if "text" in data and data["text"]:
2024-10-14 21:53:01 -07:00
timestamp = time.perf_counter()
generated_text = data["text"]
output_len = data["meta_info"]["completion_tokens"]
2024-10-14 21:53:01 -07:00
# First token
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft
# Decoding phase
else:
num_new_tokens = output_len - last_output_len
if num_new_tokens == 0:
continue
adjust_itl = (
timestamp - most_recent_timestamp
) / num_new_tokens
output.itl.extend([adjust_itl] * num_new_tokens)
2024-10-14 21:53:01 -07:00
most_recent_timestamp = timestamp
2025-03-03 03:20:23 -08:00
last_output_len = output_len
2024-10-14 21:53:01 -07:00
output.generated_text = generated_text
output.success = True
output.latency = latency
output.output_len = output_len
2024-10-14 21:53:01 -07:00
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
print(f"{output.error=}")
2024-10-14 21:53:01 -07:00
if pbar:
pbar.update(1)
return output
async def async_request_gserver(
2024-08-17 17:43:23 -07:00
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
raise NotImplementedError()
async def async_request_profile(api_url: str) -> RequestFuncOutput:
async with _create_bench_client_session() as session:
output = RequestFuncOutput()
try:
async with session.post(url=api_url) as response:
if response.status == 200:
output.success = True
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
return output
2024-07-20 02:15:21 +10:00
def get_model(pretrained_model_name_or_path: str) -> str:
if os.getenv("SGLANG_USE_MODELSCOPE", "false").lower() == "true":
2024-07-20 02:15:21 +10:00
import huggingface_hub.constants
from modelscope import snapshot_download
model_path = snapshot_download(
model_id=pretrained_model_name_or_path,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"],
)
return model_path
return pretrained_model_name_or_path
def get_tokenizer(
pretrained_model_name_or_path: str,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
assert (
pretrained_model_name_or_path is not None
and pretrained_model_name_or_path != ""
)
2024-08-17 17:43:23 -07:00
if pretrained_model_name_or_path.endswith(
".json"
) or pretrained_model_name_or_path.endswith(".model"):
from sglang.srt.hf_transformers_utils import get_tokenizer
return get_tokenizer(pretrained_model_name_or_path)
2024-07-20 02:15:21 +10:00
if pretrained_model_name_or_path is not None and not os.path.exists(
pretrained_model_name_or_path
):
pretrained_model_name_or_path = get_model(pretrained_model_name_or_path)
return AutoTokenizer.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=True
)
def get_dataset(args, tokenizer):
2025-05-26 14:36:34 +08:00
tokenize_prompt = getattr(args, "tokenize_prompt", False)
if args.dataset_name == "sharegpt":
2025-05-26 14:36:34 +08:00
assert not tokenize_prompt
input_requests = sample_sharegpt_requests(
dataset_path=args.dataset_path,
num_requests=args.num_prompts,
tokenizer=tokenizer,
fixed_output_len=args.sharegpt_output_len,
context_len=args.sharegpt_context_len,
prompt_suffix=args.prompt_suffix,
apply_chat_template=args.apply_chat_template,
)
elif args.dataset_name.startswith("random") and args.dataset_name != "random-image":
input_requests = sample_random_requests(
input_len=args.random_input_len,
output_len=args.random_output_len,
num_prompts=args.num_prompts,
range_ratio=args.random_range_ratio,
tokenizer=tokenizer,
2025-09-27 06:24:30 +08:00
dataset_path=args.dataset_name,
random_sample=args.dataset_name == "random",
2025-05-26 14:36:34 +08:00
return_text=not tokenize_prompt,
)
elif args.dataset_name == "random-image":
assert not tokenize_prompt, "random-image does not support --tokenize-prompt"
input_requests = sample_random_image_requests(
num_requests=args.num_prompts,
num_images=args.random_image_num_images,
input_len=args.random_input_len,
output_len=args.random_output_len,
range_ratio=args.random_range_ratio,
tokenizer=tokenizer,
apply_chat_template=args.apply_chat_template,
image_resolution=args.random_image_resolution,
)
elif args.dataset_name == "generated-shared-prefix":
2025-05-26 14:36:34 +08:00
assert not tokenize_prompt
input_requests = sample_generated_shared_prefix_requests(
num_groups=args.gsp_num_groups,
prompts_per_group=args.gsp_prompts_per_group,
system_prompt_len=args.gsp_system_prompt_len,
question_len=args.gsp_question_len,
output_len=args.gsp_output_len,
tokenizer=tokenizer,
args=args,
)
elif args.dataset_name == "mmmu":
2025-05-26 14:36:34 +08:00
assert not tokenize_prompt
input_requests = sample_mmmu_requests(
num_requests=args.num_prompts,
tokenizer=tokenizer,
fixed_output_len=args.random_output_len,
apply_chat_template=args.apply_chat_template,
random_sample=True,
)
elif args.dataset_name == "mooncake":
# For mooncake, we don't generate the prompts here.
# We just load the raw trace data. The async generator will handle the rest.
if not args.dataset_path:
2025-09-08 15:09:55 -07:00
local_path = os.path.join("/tmp", args.mooncake_workload + "_trace.jsonl")
else:
local_path = args.dataset_path
if not os.path.exists(local_path):
2025-09-08 15:09:55 -07:00
download_and_cache_file(
MOONCAKE_DATASET_URL[args.mooncake_workload], local_path
)
with open(local_path, "r") as f:
all_requests_data = [json.loads(line) for line in f if line.strip()]
# Limit the number of requests based on --num-prompts
input_requests = all_requests_data[: args.num_prompts]
else:
raise ValueError(f"Unknown dataset: {args.dataset_name}")
return input_requests
2024-07-20 02:15:21 +10:00
ASYNC_REQUEST_FUNCS = {
2024-10-14 21:53:01 -07:00
"sglang": async_request_sglang_generate,
"sglang-native": async_request_sglang_generate,
"sglang-oai": async_request_openai_completions,
"sglang-oai-chat": async_request_openai_chat_completions,
2024-07-20 02:15:21 +10:00
"vllm": async_request_openai_completions,
"vllm-chat": async_request_openai_chat_completions,
2024-07-20 02:15:21 +10:00
"lmdeploy": async_request_openai_completions,
"lmdeploy-chat": async_request_openai_chat_completions,
"trt": async_request_trt_llm,
"gserver": async_request_gserver,
"truss": async_request_truss,
2024-07-20 02:15:21 +10:00
}
@dataclass
class BenchmarkMetrics:
completed: int
total_input: int
total_output: int
2024-07-22 06:32:41 -07:00
total_output_retokenized: int
2024-07-20 02:15:21 +10:00
request_throughput: float
input_throughput: float
output_throughput: float
2024-07-22 06:32:41 -07:00
output_throughput_retokenized: float
total_throughput: float
total_throughput_retokenized: float
2024-07-20 02:15:21 +10:00
mean_ttft_ms: float
median_ttft_ms: float
std_ttft_ms: float
p99_ttft_ms: float
mean_tpot_ms: float
median_tpot_ms: float
std_tpot_ms: float
p99_tpot_ms: float
mean_itl_ms: float
median_itl_ms: float
std_itl_ms: float
p95_itl_ms: float
2024-07-20 02:15:21 +10:00
p99_itl_ms: float
max_itl_ms: float
2024-07-24 05:51:10 +10:00
mean_e2e_latency_ms: float
median_e2e_latency_ms: float
std_e2e_latency_ms: float
p99_e2e_latency_ms: float
concurrency: float
2024-07-20 02:15:21 +10:00
2024-09-09 04:14:11 -07:00
SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
MOONCAKE_DATASET_URL = {
"mooncake": "https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/arxiv-trace/mooncake_trace.jsonl",
"conversation": "https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/traces/conversation_trace.jsonl",
"synthetic": "https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/traces/synthetic_trace.jsonl",
"toolagent": "https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/traces/toolagent_trace.jsonl",
}
2024-07-20 01:57:43 -07:00
2024-09-09 04:14:11 -07:00
def download_and_cache_file(url: str, filename: Optional[str] = None):
"""Read and cache a file from a url."""
if filename is None:
filename = os.path.join("/tmp", url.split("/")[-1])
2024-07-20 01:57:43 -07:00
2024-09-09 04:14:11 -07:00
# Check if the cache file already exists
if is_file_valid_json(filename):
2024-09-09 04:14:11 -07:00
return filename
2024-07-20 01:57:43 -07:00
2024-09-09 04:14:11 -07:00
print(f"Downloading from {url} to {filename}")
2024-07-20 01:57:43 -07:00
2024-09-09 04:14:11 -07:00
# Stream the response to show the progress bar
response = requests.get(url, stream=True)
response.raise_for_status() # Check for request errors
2024-07-20 01:57:43 -07:00
2024-09-09 04:14:11 -07:00
# Total size of the file in bytes
total_size = int(response.headers.get("content-length", 0))
chunk_size = 1024 # Download in chunks of 1KB
# Use tqdm to display the progress bar
with open(filename, "wb") as f, tqdm(
desc=filename,
total=total_size,
unit="B",
unit_scale=True,
unit_divisor=1024,
) as bar:
for chunk in response.iter_content(chunk_size=chunk_size):
f.write(chunk)
bar.update(len(chunk))
return filename
2024-07-20 01:57:43 -07:00
def is_file_valid_json(path):
if not os.path.isfile(path):
return False
# TODO can fuse into the real file open later
try:
with open(path) as f:
json.load(f)
return True
except JSONDecodeError as e:
print(
f"{path} exists but json loading fails ({e=}), thus treat as invalid file"
)
return False
@dataclass
class DatasetRow:
prompt: str
prompt_len: int
output_len: int
image_data: Optional[List[str]] = None
timestamp: Optional[float] = None
async def get_mooncake_request_over_time(
input_requests: List[Dict],
tokenizer: PreTrainedTokenizerBase,
slowdown_factor: float,
num_rounds: int,
) -> AsyncGenerator[DatasetRow, None]:
"""
An async generator that yields requests based on the timestamps in the Mooncake trace file,
with support for multi-round sessions.
"""
if not input_requests:
return
input_requests.sort(key=lambda r: r["timestamp"])
start_time = time.perf_counter()
trace_start_time_ms = input_requests[0]["timestamp"]
for record in input_requests:
# Calculate when this entire session should start
relative_arrival_time_s = (record["timestamp"] - trace_start_time_ms) / 1000.0
target_arrival_time_s = relative_arrival_time_s * slowdown_factor
current_elapsed_time_s = time.perf_counter() - start_time
sleep_duration_s = target_arrival_time_s - current_elapsed_time_s
if sleep_duration_s > 0:
await asyncio.sleep(sleep_duration_s)
# Once the session starts, generate all rounds for it as a burst
# This simulates a user engaging in a multi-turn conversation
# Base user query constructed from hash_ids
user_query_base = ""
hash_ids = record.get("hash_ids", [])
for hash_id in hash_ids:
user_query_base += f"{hash_id}" + " ".join(
["hi"] * 128
) # Shorter for multi-round
user_query_base += "Tell me a story based on this context."
output_len_per_round = record.get("output_length", 256)
chat_history = []
for i in range(num_rounds):
# Add user query for the current round
chat_history.append(
{"role": "user", "content": f"Round {i+1}: {user_query_base}"}
)
# Form the full prompt from history
try:
full_prompt_text = tokenizer.apply_chat_template(
chat_history, tokenize=False, add_generation_prompt=True
)
except Exception:
full_prompt_text = "\n".join(
[f"{msg['role']}: {msg['content']}" for msg in chat_history]
)
prompt_len = len(tokenizer.encode(full_prompt_text))
yield DatasetRow(
prompt=full_prompt_text,
prompt_len=prompt_len,
output_len=output_len_per_round,
)
# Add a placeholder assistant response for the next round's context
# We use a placeholder because we don't know the real response
placeholder_response = " ".join(["story"] * output_len_per_round)
chat_history.append({"role": "assistant", "content": placeholder_response})
def sample_mmmu_requests(
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int] = None,
apply_chat_template: bool = True,
random_sample: bool = True,
) -> List[DatasetRow]:
"""
Sample requests from the MMMU dataset using HuggingFace datasets.
Args:
num_requests: Number of requests to sample.
tokenizer: Tokenizer to use for token counting.
fixed_output_len: If provided, use this fixed output length for all requests.
apply_chat_template: Whether to apply the chat template to the prompt.
random_sample: Whether to randomly sample or take the first N.
Returns:
List of tuples (prompt, prompt_token_len, output_token_len).
"""
try:
import io
import pybase64
from datasets import load_dataset
except ImportError:
raise ImportError("Please install datasets: pip install datasets")
print("Loading MMMU dataset from HuggingFace...")
try:
print("Attempting to load MMMU Math dataset...")
mmmu_dataset = load_dataset("MMMU/MMMU", "Math", split="test")
print(
f"Successfully loaded MMMU Math dataset from HuggingFace with {len(mmmu_dataset)} examples"
)
except Exception as e:
print(f"Failed to load MMMU Math dataset: {e}")
raise ValueError(f"Failed to load MMMU dataset: {e}")
# Sample from the dataset
if len(mmmu_dataset) > num_requests:
if random_sample:
# Random sample
indices = random.sample(range(len(mmmu_dataset)), num_requests)
sample_dataset = mmmu_dataset.select(indices)
else:
# Take first N
sample_dataset = mmmu_dataset.select(
range(min(num_requests, len(mmmu_dataset)))
)
else:
print(f"Dataset has less than {num_requests} examples, using all examples")
sample_dataset = mmmu_dataset
print(f"Selected {len(sample_dataset)} examples for benchmarking")
# Create prompts
filtered_dataset = []
for i, example in enumerate(sample_dataset):
try:
# Extract image_1
image = example.get("image_1")
if image is not None:
if hasattr(image, "save"):
# Convert RGBA images to RGB before encoding
if image.mode == "RGBA":
image = image.convert("RGB")
# Encode image to base64 (save as PNG to support palette/alpha modes)
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_str = pybase64.b64encode(buffered.getvalue()).decode("utf-8")
image_data = f"data:image/png;base64,{img_str}"
else:
continue
# Extract the question
question = example.get("question")
# Construct the prompt
prompt = f"Question: {question}\n\nAnswer: "
if apply_chat_template:
try:
is_phi4_multimodal = (
"phi-4-multimodal" in tokenizer.name_or_path.lower()
)
if is_phi4_multimodal:
# <|endoftext10|> is the image token used in the phi-4-multimodal model.
content = prompt.replace("image 1", "<|endoftext10|>")
else:
content = [
{
"type": "image_url",
"image_url": {"url": image_data},
},
{"type": "text", "text": prompt},
]
prompt = tokenizer.apply_chat_template(
[
{
"role": "user",
"content": content,
}
],
add_generation_prompt=True,
tokenize=False,
)
except Exception as e:
# Note (Xinyuan): This is a workaround for an issue where some tokenizers do not support content as a list. (e.g. InternVL)
print(
f"Error applying chat template: {e}, fallback to <image> tag"
)
prompt = f"<image>{prompt}"
# Calculate token lengths for text only (without image data)
prompt_token_ids = tokenizer.encode(prompt)
prompt_len = len(prompt_token_ids)
output_len = fixed_output_len if fixed_output_len is not None else 256
filtered_dataset.append(
DatasetRow(
prompt=prompt,
prompt_len=prompt_len,
output_len=output_len,
image_data=[image_data],
)
)
except Exception as e:
print(f"Error processing example {i}: {e}")
print(f"\nCreated {len(filtered_dataset)} MMMU prompts")
return filtered_dataset
2024-07-20 02:15:21 +10:00
def sample_sharegpt_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int] = None,
context_len: Optional[int] = None,
prompt_suffix: Optional[str] = "",
apply_chat_template=False,
) -> List[DatasetRow]:
2024-07-20 02:15:21 +10:00
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")
2024-07-20 01:57:43 -07:00
# Download sharegpt if necessary
if not is_file_valid_json(dataset_path) and dataset_path == "":
2024-09-09 04:14:11 -07:00
dataset_path = download_and_cache_file(SHAREGPT_URL)
2024-07-20 02:15:21 +10:00
# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
2024-07-20 02:15:21 +10:00
# Filter out the conversations with less than 2 turns.
dataset = [
data
for data in dataset
if len(data.get("conversations", data.get("conversation", []))) >= 2
]
2024-07-20 02:15:21 +10:00
# Only keep the first two turns of each conversation.
dataset = [
(
data.get("conversations", data.get("conversation", []))[0]["value"],
data.get("conversations", data.get("conversation", []))[1]["value"],
)
2024-07-20 02:15:21 +10:00
for data in dataset
]
# Shuffle the dataset.
random.shuffle(dataset)
# Filter out sequences that are too long or too short
filtered_dataset: List[DatasetRow] = []
2024-07-20 02:15:21 +10:00
for i in range(len(dataset)):
if len(filtered_dataset) == num_requests:
break
# Tokenize the prompts and completions.
prompt = dataset[i][0]
if prompt_suffix:
2025-03-04 21:23:47 -08:00
prompt = (
remove_suffix(prompt, ASSISTANT_SUFFIX)
+ prompt_suffix
+ ASSISTANT_SUFFIX
)
if apply_chat_template:
prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt=True,
tokenize=False,
)
if tokenizer.bos_token:
prompt = prompt.replace(tokenizer.bos_token, "")
2024-08-17 17:43:23 -07:00
prompt_token_ids = tokenizer.encode(prompt)
2024-07-20 02:15:21 +10:00
completion = dataset[i][1]
2024-08-17 17:43:23 -07:00
completion_token_ids = tokenizer.encode(completion)
2024-07-20 02:15:21 +10:00
prompt_len = len(prompt_token_ids)
output_len = (
len(completion_token_ids) if fixed_output_len is None else fixed_output_len
)
if prompt_len < 2 or output_len < 2:
2024-07-20 02:15:21 +10:00
# Prune too short sequences.
continue
if context_len and prompt_len + output_len > context_len:
2024-07-20 02:15:21 +10:00
# Prune too long sequences.
continue
filtered_dataset.append(
DatasetRow(prompt=prompt, prompt_len=prompt_len, output_len=output_len)
)
2024-07-20 02:15:21 +10:00
print(f"#Input tokens: {np.sum([x.prompt_len for x in filtered_dataset])}")
print(f"#Output tokens: {np.sum([x.output_len for x in filtered_dataset])}")
2024-07-20 02:15:21 +10:00
return filtered_dataset
def sample_random_requests(
input_len: int,
output_len: int,
num_prompts: int,
range_ratio: float,
tokenizer: PreTrainedTokenizerBase,
2024-07-20 01:57:43 -07:00
dataset_path: str,
random_sample: bool = True,
return_text: bool = True,
) -> List[DatasetRow]:
input_lens = np.random.randint(
max(int(input_len * range_ratio), 1),
input_len + 1,
size=num_prompts,
)
output_lens = np.random.randint(
int(output_len * range_ratio),
output_len + 1,
size=num_prompts,
)
2024-07-20 01:57:43 -07:00
if random_sample:
2024-07-20 01:57:43 -07:00
# Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens
# Download sharegpt if necessary
if not is_file_valid_json(dataset_path):
2024-09-09 04:14:11 -07:00
dataset_path = download_and_cache_file(SHAREGPT_URL)
2024-07-20 01:57:43 -07:00
# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [
data
for data in dataset
if len(data.get("conversations", data.get("conversation", []))) >= 2
]
2024-07-20 01:57:43 -07:00
# Only keep the first two turns of each conversation.
dataset = [
(
data.get("conversations", data.get("conversation", []))[0]["value"],
data.get("conversations", data.get("conversation", []))[1]["value"],
)
2024-07-20 01:57:43 -07:00
for data in dataset
]
# Shuffle the dataset.
random.shuffle(dataset)
# Filter out sequences that are too long or too short
input_requests: List[DatasetRow] = []
for data in dataset:
i = len(input_requests)
if i == num_prompts:
break
2024-07-20 01:57:43 -07:00
# Tokenize the prompts and completions.
prompt = data[0]
2024-08-17 17:43:23 -07:00
prompt_token_ids = tokenizer.encode(prompt)
2024-07-20 01:57:43 -07:00
prompt_len = len(prompt_token_ids)
# Skip empty prompt
if prompt_len == 0:
continue
if prompt_len > input_lens[i]:
2024-07-20 01:57:43 -07:00
input_ids = prompt_token_ids[: input_lens[i]]
else:
ratio = (input_lens[i] + prompt_len - 1) // prompt_len
input_ids = (prompt_token_ids * ratio)[: input_lens[i]]
input_content = input_ids
if return_text:
input_content = tokenizer.decode(input_content)
input_requests.append(
DatasetRow(
prompt=input_content,
prompt_len=int(input_lens[i]),
output_len=int(output_lens[i]),
)
)
2024-07-20 01:57:43 -07:00
else:
# Sample token ids from random integers. This can cause some NaN issues.
offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
input_requests = []
for i in range(num_prompts):
input_content = [
(offsets[i] + i + j) % tokenizer.vocab_size
for j in range(input_lens[i])
]
if return_text:
input_content = tokenizer.decode(input_content)
input_requests.append(
DatasetRow(
prompt=input_content,
prompt_len=int(input_lens[i]),
output_len=int(output_lens[i]),
)
)
print(f"#Input tokens: {np.sum(input_lens)}")
print(f"#Output tokens: {np.sum(output_lens)}")
return input_requests
def parse_random_image_resolution(image_resolution: str) -> Tuple[int, int]:
"""Parse image resolution into (width, height).
Supports presets '1080p', '720p', '360p' and custom 'heightxwidth' format
(e.g., '1080x1920' means height=1080, width=1920).
"""
resolution_to_size = {
"4k": (3840, 2160),
"1080p": (1920, 1080),
"720p": (1280, 720),
"360p": (640, 360),
}
if image_resolution in resolution_to_size:
return resolution_to_size[image_resolution]
res = image_resolution.strip().lower()
if "x" in res:
parts = res.split("x")
if len(parts) == 2 and parts[0].isdigit() and parts[1].isdigit():
height = int(parts[0])
width = int(parts[1])
if height > 0 and width > 0:
return (width, height)
raise ValueError(
f"Unsupported random-image resolution: {image_resolution}. "
"Choose from 4k, 1080p, 720p, 360p, or provide custom 'heightxwidth' (e.g., 1080x1920)."
)
def sample_random_image_requests(
num_requests: int,
num_images: int,
input_len: int,
output_len: int,
range_ratio: float,
tokenizer: PreTrainedTokenizerBase,
apply_chat_template: bool = True,
image_resolution: str = "1080p",
) -> List[DatasetRow]:
"""Generate requests with random images.
- Each request includes ``num_images`` random images.
- Supported resolutions: 4k (3840x2160), 1080p (1920x1080), 720p (1280x720), 360p (640x360),
or custom 'heightxwidth' (e.g., 1080x1920).
- Text lengths follow the 'random' dataset sampling rule. ``prompt_len``
only counts text tokens and excludes image data.
"""
try:
import pybase64
from PIL import Image
except ImportError as e:
raise ImportError(
"Please install Pillow to generate random images: pip install pillow"
) from e
# Parse resolution (supports presets and 'heightxwidth')
width, height = parse_random_image_resolution(image_resolution)
# Check for potentially problematic combinations and warn user
if width * height >= 1920 * 1080 and num_images * num_requests >= 100:
warnings.warn(
f"High resolution ({width}x{height}) with {num_images * num_requests} total images "
f"may take a long time. Consider reducing resolution or image count.",
UserWarning,
stacklevel=2,
)
# Sample text lengths
input_lens = np.random.randint(
max(int(input_len * range_ratio), 1), input_len + 1, size=num_requests
)
output_lens = np.random.randint(
int(output_len * range_ratio), output_len + 1, size=num_requests
)
def _gen_random_image_data_uri(width: int = width, height: int = height) -> str:
arr = (np.random.rand(height, width, 3) * 255).astype(np.uint8)
img = Image.fromarray(arr, mode="RGB")
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=85)
encoded = pybase64.b64encode(buf.getvalue()).decode("utf-8")
return f"data:image/jpeg;base64,{encoded}"
dataset: List[DatasetRow] = []
for i in range(num_requests):
# Generate text prompt
text_prompt = gen_prompt(tokenizer, int(input_lens[i]))
# Generate image list
images = [_gen_random_image_data_uri() for _ in range(num_images)]
prompt_str = text_prompt
if apply_chat_template:
try:
content_items = [
{"type": "image_url", "image_url": {"url": img_url}}
for img_url in images
]
content_items.append({"type": "text", "text": text_prompt})
prompt_str = tokenizer.apply_chat_template(
[{"role": "user", "content": content_items}],
add_generation_prompt=True,
tokenize=False,
)
except Exception:
# Some tokenizers do not support list content; fall back to a placeholder in the text
prompt_str = f"<image>{text_prompt}"
prompt_token_ids = tokenizer.encode(prompt_str)
prompt_token_len = len(prompt_token_ids)
dataset.append(
DatasetRow(
prompt=prompt_str,
prompt_len=prompt_token_len,
output_len=int(output_lens[i]),
image_data=images,
)
)
print(f"#Input tokens: {np.sum([x.prompt_len for x in dataset])}")
print(f"#Output tokens: {np.sum([x.output_len for x in dataset])}")
return dataset
def gen_prompt(tokenizer, token_num):
"""Generate a random prompt of specified token length using tokenizer vocabulary."""
all_available_tokens = list(tokenizer.get_vocab().values())
selected_tokens = random.choices(all_available_tokens, k=token_num)
return tokenizer.decode(selected_tokens)
def get_gen_prefix_cache_path(args, tokenizer):
"""Create cache directory under ~/.cache/sglang/benchmark"""
cache_dir = Path.home() / ".cache" / "sglang" / "benchmark"
# Create a unique cache filename based on the generation parameters
cache_key = (
f"gen_shared_prefix_{args.gsp_num_groups}_{args.gsp_prompts_per_group}_"
f"{args.gsp_system_prompt_len}_{args.gsp_question_len}_{args.gsp_output_len}_"
f"{tokenizer.__class__.__name__}.pkl"
)
return cache_dir / cache_key
def sample_generated_shared_prefix_requests(
num_groups: int,
prompts_per_group: int,
system_prompt_len: int,
question_len: int,
output_len: int,
tokenizer: PreTrainedTokenizerBase,
args: argparse.Namespace,
) -> List[DatasetRow]:
"""Generate benchmark requests with shared system prompts using random tokens and caching."""
cache_path = get_gen_prefix_cache_path(args, tokenizer)
# Try to load from cache first
if cache_path.exists():
print(f"\nLoading cached generated input data from {cache_path}")
with open(cache_path, "rb") as f:
return pickle.load(f)
print("\nGenerating new input data...")
# Generate system prompts for each group
system_prompts = []
for _ in range(num_groups):
system_prompt = gen_prompt(tokenizer, system_prompt_len)
system_prompts.append(system_prompt)
# Generate questions
questions = []
for _ in range(num_groups * prompts_per_group):
question = gen_prompt(tokenizer, question_len)
questions.append(question)
# Combine system prompts with questions
input_requests = []
total_input_tokens = 0
total_output_tokens = 0
for group_idx in tqdm(range(num_groups), desc="Generating system prompt"):
system_prompt = system_prompts[group_idx]
for prompt_idx in tqdm(
range(prompts_per_group), desc="Generating questions", leave=False
):
question = questions[group_idx * prompts_per_group + prompt_idx]
full_prompt = f"{system_prompt}\n\n{question}"
prompt_len = len(tokenizer.encode(full_prompt))
input_requests.append(
DatasetRow(
prompt=full_prompt, prompt_len=prompt_len, output_len=output_len
)
)
total_input_tokens += prompt_len
total_output_tokens += output_len
# Shuffle questions
random.shuffle(input_requests)
# Print statistics
print(f"\nGenerated shared prefix dataset statistics:")
print(f"Number of groups: {num_groups}")
print(f"Prompts per group: {prompts_per_group}")
print(f"Total prompts: {len(input_requests)}")
print(f"Total input tokens: {total_input_tokens}")
print(f"Total output tokens: {total_output_tokens}")
print(
f"Average system prompt length: {sum(len(tokenizer.encode(sp)) for sp in system_prompts) / len(system_prompts):.1f} tokens"
)
print(
f"Average question length: {sum(len(tokenizer.encode(q)) for q in questions) / len(questions):.1f} tokens\n"
)
# Save to cache
cache_path.parent.mkdir(parents=True, exist_ok=True)
print(f"Caching generated input data to {cache_path}")
with open(cache_path, "wb") as f:
pickle.dump(input_requests, f)
return input_requests
2024-07-20 02:15:21 +10:00
async def get_request(
input_requests: List[DatasetRow],
2024-07-20 02:15:21 +10:00
request_rate: float,
use_trace_timestamps: bool = False,
slowdown_factor: float = 1.0,
) -> AsyncGenerator[DatasetRow, None]:
if use_trace_timestamps:
print(
f"Using trace timestamps for request generation with slowdown factor {slowdown_factor}."
)
# Sort requests by timestamp for correct replay
input_requests.sort(key=lambda r: r.timestamp)
2024-07-20 02:15:21 +10:00
start_time = time.perf_counter()
trace_start_time_ms = input_requests[0].timestamp if input_requests else 0
for request in input_requests:
trace_time_s = (request.timestamp - trace_start_time_ms) / 1000.0
target_arrival_time = start_time + (trace_time_s * slowdown_factor)
sleep_duration = target_arrival_time - time.perf_counter()
if sleep_duration > 0:
await asyncio.sleep(sleep_duration)
yield request
else:
input_requests_iter = iter(input_requests)
for request in input_requests_iter:
yield request
if request_rate == float("inf"):
# If the request rate is infinity, then we don't need to wait.
continue
2024-07-20 02:15:21 +10:00
# Sample the request interval from the exponential distribution.
interval = np.random.exponential(1.0 / request_rate)
# The next request will be sent after the interval.
await asyncio.sleep(interval)
2024-07-20 02:15:21 +10:00
def calculate_metrics(
input_requests: List[DatasetRow],
2024-07-20 02:15:21 +10:00
outputs: List[RequestFuncOutput],
dur_s: float,
tokenizer: PreTrainedTokenizerBase,
2024-07-22 19:34:05 +10:00
backend: str,
2024-07-20 02:15:21 +10:00
) -> Tuple[BenchmarkMetrics, List[int]]:
2024-07-22 06:32:41 -07:00
output_lens: List[int] = []
retokenized_output_lens: List[int] = []
2024-07-20 02:15:21 +10:00
total_input = 0
completed = 0
itls: List[float] = []
tpots: List[float] = []
ttfts: List[float] = []
2024-07-24 05:51:10 +10:00
e2e_latencies: List[float] = []
2024-07-20 02:15:21 +10:00
for i in range(len(outputs)):
if outputs[i].success:
2024-07-22 06:32:41 -07:00
output_len = outputs[i].output_len
output_lens.append(output_len)
retokenized_output_len = len(
2024-08-17 17:43:23 -07:00
tokenizer.encode(outputs[i].generated_text, add_special_tokens=False)
2024-07-22 06:32:41 -07:00
)
retokenized_output_lens.append(retokenized_output_len)
total_input += outputs[i].prompt_len
2024-07-20 02:15:21 +10:00
if output_len > 1:
tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1))
itls += outputs[i].itl
ttfts.append(outputs[i].ttft)
2024-07-24 05:51:10 +10:00
e2e_latencies.append(outputs[i].latency)
2024-07-20 02:15:21 +10:00
completed += 1
else:
2024-07-22 06:32:41 -07:00
output_lens.append(0)
retokenized_output_lens.append(0)
2024-07-20 02:15:21 +10:00
if completed == 0:
warnings.warn(
"All requests failed. This is likely due to a misconfiguration "
"on the benchmark arguments.",
stacklevel=2,
)
metrics = BenchmarkMetrics(
completed=completed,
total_input=total_input,
2024-07-22 06:32:41 -07:00
total_output=sum(output_lens),
total_output_retokenized=sum(retokenized_output_lens),
2024-07-20 02:15:21 +10:00
request_throughput=completed / dur_s,
input_throughput=total_input / dur_s,
2024-07-22 06:32:41 -07:00
output_throughput=sum(output_lens) / dur_s,
output_throughput_retokenized=sum(retokenized_output_lens) / dur_s,
total_throughput=(total_input + sum(output_lens)) / dur_s,
total_throughput_retokenized=(total_input + sum(retokenized_output_lens))
/ dur_s,
2024-07-20 02:15:21 +10:00
mean_ttft_ms=np.mean(ttfts or 0)
* 1000, # ttfts is empty if streaming is not supported by backend
median_ttft_ms=np.median(ttfts or 0) * 1000,
std_ttft_ms=np.std(ttfts or 0) * 1000,
p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
mean_tpot_ms=np.mean(tpots or 0) * 1000,
median_tpot_ms=np.median(tpots or 0) * 1000,
std_tpot_ms=np.std(tpots or 0) * 1000,
p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,
mean_itl_ms=np.mean(itls or 0) * 1000,
median_itl_ms=np.median(itls or 0) * 1000,
std_itl_ms=np.std(itls or 0) * 1000,
p95_itl_ms=np.percentile(itls or 0, 95) * 1000,
2024-07-20 02:15:21 +10:00
p99_itl_ms=np.percentile(itls or 0, 99) * 1000,
max_itl_ms=np.max(itls or 0) * 1000,
2024-07-24 05:51:10 +10:00
mean_e2e_latency_ms=np.mean(e2e_latencies) * 1000,
median_e2e_latency_ms=np.median(e2e_latencies) * 1000,
std_e2e_latency_ms=np.std(e2e_latencies) * 1000,
p99_e2e_latency_ms=np.percentile(e2e_latencies, 99) * 1000,
concurrency=np.sum(e2e_latencies) / dur_s,
2024-07-20 02:15:21 +10:00
)
2024-07-22 06:32:41 -07:00
return metrics, output_lens
2024-07-20 02:15:21 +10:00
async def benchmark(
backend: str,
api_url: str,
base_url: str,
2024-07-20 02:15:21 +10:00
model_id: str,
tokenizer: PreTrainedTokenizerBase,
input_requests: List[DatasetRow],
2024-07-20 02:15:21 +10:00
request_rate: float,
max_concurrency: Optional[int],
2024-07-20 02:15:21 +10:00
disable_tqdm: bool,
lora_names: List[str],
extra_request_body: Dict[str, Any],
profile: bool,
pd_separated: bool = False,
2025-04-01 15:10:56 -07:00
flush_cache: bool = False,
2025-04-27 05:17:44 -07:00
warmup_requests: int = 1,
use_trace_timestamps: bool = False,
mooncake_slowdown_factor=1.0,
mooncake_num_rounds=1,
2024-07-20 02:15:21 +10:00
):
if backend in ASYNC_REQUEST_FUNCS:
request_func = ASYNC_REQUEST_FUNCS[backend]
else:
raise ValueError(f"Unknown backend: {backend}")
2024-12-26 07:56:26 -08:00
# Limit concurrency
# From https://github.com/vllm-project/vllm/pull/9390
semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
async def limited_request_func(request_func_input, pbar):
if semaphore is None:
return await request_func(request_func_input=request_func_input, pbar=pbar)
async with semaphore:
return await request_func(request_func_input=request_func_input, pbar=pbar)
2024-12-26 07:56:26 -08:00
# Warmup
2025-04-27 05:17:44 -07:00
print(f"Starting warmup with {warmup_requests} sequences...")
# Handle the data structure difference for the warmup request
if args.dataset_name == "mooncake":
# For mooncake, input_requests is a list of dicts.
# We need to build a temporary DatasetRow for the warmup phase.
warmup_record = input_requests[0]
# Build prompt from hash_ids, just like in the async generator
hash_ids = warmup_record.get("hash_ids", [])
prompt_text = ""
for hash_id in hash_ids:
prompt_text += f"{hash_id}" + " ".join(["hi"] * 512)
prompt_text += "Can you tell me a detailed story in 1000 words?"
output_len = warmup_record.get("output_length", 32)
prompt_len = len(tokenizer.encode(prompt_text))
# Create a temporary DatasetRow object for warmup
test_request = DatasetRow(
prompt=prompt_text,
prompt_len=prompt_len,
output_len=output_len,
image_data=None, # Mooncake doesn't have image data
)
else:
# For all other datasets, input_requests is a list of DatasetRow objects
test_request = input_requests[0]
if lora_names is not None and len(lora_names) != 0:
lora_name = lora_names[0]
else:
lora_name = None
# Create the test input once
2024-07-20 02:15:21 +10:00
test_input = RequestFuncInput(
model=model_id,
prompt=test_request.prompt,
2024-07-20 02:15:21 +10:00
api_url=api_url,
prompt_len=test_request.prompt_len,
output_len=min(test_request.output_len, 32),
lora_name=lora_name,
image_data=test_request.image_data,
extra_request_body=extra_request_body,
2024-07-20 02:15:21 +10:00
)
# Run warmup requests
warmup_tasks = []
2025-04-27 05:17:44 -07:00
for _ in range(warmup_requests):
warmup_tasks.append(
asyncio.create_task(request_func(request_func_input=test_input))
)
warmup_outputs = await asyncio.gather(*warmup_tasks)
# Check if at least one warmup request succeeded
2025-04-27 05:17:44 -07:00
if warmup_requests > 0 and not any(output.success for output in warmup_outputs):
2024-07-20 02:15:21 +10:00
raise ValueError(
"Warmup failed - Please make sure benchmark arguments "
f"are correctly specified. Error: {warmup_outputs[0].error}"
2024-07-20 02:15:21 +10:00
)
else:
print(
f"Warmup completed with {args.warmup_requests} sequences. Starting main benchmark run..."
)
2024-07-20 02:15:21 +10:00
2024-12-26 07:56:26 -08:00
# Flush cache
2025-04-01 15:10:56 -07:00
if ("sglang" in backend and _get_bool_env_var("SGLANG_IS_IN_CI")) or flush_cache:
requests.post(base_url + "/flush_cache", headers=get_auth_headers())
2024-12-26 07:56:26 -08:00
time.sleep(1.0)
2024-12-26 07:56:26 -08:00
# Start profiler
if profile:
print("Starting profiler...")
profile_output = await async_request_profile(
api_url=base_url + "/start_profile"
)
if profile_output.success:
print("Profiler started")
2024-12-26 07:56:26 -08:00
# Run all requests
2024-07-20 02:15:21 +10:00
benchmark_start_time = time.perf_counter()
tasks: List[asyncio.Task] = []
pbar_total = len(input_requests)
if (
backend == "sglang" and args.dataset_name == "mooncake"
): # Assuming mooncake is mainly for sglang or similar backends
print("Using time-based Mooncake request scheduler, ignoring --request-rate.")
request_generator = get_mooncake_request_over_time(
input_requests, tokenizer, mooncake_slowdown_factor, mooncake_num_rounds
)
print(
f"Starting Mooncake trace replay. Sessions: {len(input_requests)}, Rounds per session: {mooncake_num_rounds}. Slowdown factor: {mooncake_slowdown_factor}"
)
pbar_total *= args.mooncake_num_rounds
else:
request_generator = get_request(input_requests, request_rate)
pbar = None if disable_tqdm else tqdm(total=pbar_total)
async for request in request_generator:
if lora_names is not None and len(lora_names) != 0:
idx = random.randint(0, len(lora_names) - 1)
lora_name = lora_names[idx]
else:
lora_name = None
2024-07-20 02:15:21 +10:00
request_func_input = RequestFuncInput(
model=model_id,
prompt=request.prompt,
2024-07-20 02:15:21 +10:00
api_url=api_url,
prompt_len=request.prompt_len,
output_len=request.output_len,
lora_name=lora_name,
image_data=request.image_data,
extra_request_body=extra_request_body,
timestamp=request.timestamp,
2024-07-20 02:15:21 +10:00
)
2024-07-20 02:15:21 +10:00
tasks.append(
asyncio.create_task(
limited_request_func(request_func_input=request_func_input, pbar=pbar)
2024-07-20 02:15:21 +10:00
)
)
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
2024-12-26 07:56:26 -08:00
# Stop profiler
if profile:
print("Stopping profiler...")
profile_output = await async_request_profile(api_url=base_url + "/stop_profile")
if profile_output.success:
print("Profiler stopped")
2024-07-20 02:15:21 +10:00
if pbar is not None:
pbar.close()
if "sglang" in backend:
server_info = requests.get(base_url + "/get_server_info")
2025-05-22 14:27:10 -07:00
if server_info.status_code == 200:
server_info_json = server_info.json()
if "decode" in server_info_json:
server_info_json = server_info_json["decode"][0]
accept_length = server_info_json["internal_states"][0].get(
"avg_spec_accept_length", None
)
else:
2025-05-22 14:27:10 -07:00
accept_length = None
else:
accept_length = None
2024-12-26 07:56:26 -08:00
# Compute metrics and print results
2024-07-20 02:15:21 +10:00
benchmark_duration = time.perf_counter() - benchmark_start_time
2024-07-22 06:32:41 -07:00
metrics, output_lens = calculate_metrics(
2024-07-20 02:15:21 +10:00
input_requests=input_requests,
outputs=outputs,
dur_s=benchmark_duration,
tokenizer=tokenizer,
2024-07-22 19:34:05 +10:00
backend=backend,
2024-07-20 02:15:21 +10:00
)
print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
2024-07-22 19:34:05 +10:00
print("{:<40} {:<10}".format("Backend:", backend))
print(
"{:<40} {:<10}".format(
"Traffic request rate:", "trace" if use_trace_timestamps else request_rate
)
)
print(
"{:<40} {:<10}".format(
"Max request concurrency:",
max_concurrency if max_concurrency else "not set",
)
)
2024-07-20 02:15:21 +10:00
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
2024-07-22 06:32:41 -07:00
print(
"{:<40} {:<10}".format(
"Total generated tokens (retokenized):", metrics.total_output_retokenized
)
)
2024-07-20 02:15:21 +10:00
print(
"{:<40} {:<10.2f}".format(
"Request throughput (req/s):", metrics.request_throughput
)
)
print(
"{:<40} {:<10.2f}".format(
"Input token throughput (tok/s):", metrics.input_throughput
)
)
print(
"{:<40} {:<10.2f}".format(
"Output token throughput (tok/s):", metrics.output_throughput
)
)
print(
"{:<40} {:<10.2f}".format(
"Total token throughput (tok/s):", metrics.total_throughput
)
)
print("{:<40} {:<10.2f}".format("Concurrency:", metrics.concurrency))
if accept_length:
print("{:<40} {:<10.2f}".format("Accept length:", accept_length))
2024-07-24 05:51:10 +10:00
print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-"))
print(
"{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)
)
print(
"{:<40} {:<10.2f}".format(
"Median E2E Latency (ms):", metrics.median_e2e_latency_ms
)
)
2024-07-20 02:15:21 +10:00
print("{s:{c}^{n}}".format(s="Time to First Token", n=50, c="-"))
print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms))
print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
print("{s:{c}^{n}}".format(s="Inter-Token Latency", n=50, c="-"))
2024-07-20 02:15:21 +10:00
print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
print("{:<40} {:<10.2f}".format("P95 ITL (ms):", metrics.p95_itl_ms))
2024-07-20 02:15:21 +10:00
print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
print("{:<40} {:<10.2f}".format("Max ITL (ms):", metrics.max_itl_ms))
2024-07-20 02:15:21 +10:00
print("=" * 50)
2024-07-21 18:07:30 +10:00
if (
metrics.median_ttft_ms is not None
and metrics.mean_itl_ms is not None
and metrics.output_throughput is not None
):
result = {
# Arguments
2024-07-21 18:07:30 +10:00
"backend": args.backend,
"dataset_name": args.dataset_name,
"request_rate": "trace" if use_trace_timestamps else request_rate,
"max_concurrency": max_concurrency,
"sharegpt_output_len": args.sharegpt_output_len,
"random_input_len": args.random_input_len,
"random_output_len": args.random_output_len,
"random_range_ratio": args.random_range_ratio,
# Results
"duration": benchmark_duration,
"completed": metrics.completed,
2024-08-12 02:21:38 -07:00
"total_input_tokens": metrics.total_input,
"total_output_tokens": metrics.total_output,
"total_output_tokens_retokenized": metrics.total_output_retokenized,
"request_throughput": metrics.request_throughput,
"input_throughput": metrics.input_throughput,
"output_throughput": metrics.output_throughput,
2024-08-12 02:21:38 -07:00
"mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
"median_e2e_latency_ms": metrics.median_e2e_latency_ms,
"std_e2e_latency_ms": metrics.std_e2e_latency_ms,
"p99_e2e_latency_ms": metrics.p99_e2e_latency_ms,
"mean_ttft_ms": metrics.mean_ttft_ms,
2024-08-12 02:21:38 -07:00
"median_ttft_ms": metrics.median_ttft_ms,
"std_ttft_ms": metrics.std_ttft_ms,
"p99_ttft_ms": metrics.p99_ttft_ms,
"mean_tpot_ms": metrics.mean_tpot_ms,
"median_tpot_ms": metrics.median_tpot_ms,
"std_tpot_ms": metrics.std_tpot_ms,
"p99_tpot_ms": metrics.p99_tpot_ms,
"mean_itl_ms": metrics.mean_itl_ms,
2024-08-12 02:21:38 -07:00
"median_itl_ms": metrics.median_itl_ms,
"std_itl_ms": metrics.std_itl_ms,
"p95_itl_ms": metrics.p95_itl_ms,
"p99_itl_ms": metrics.p99_itl_ms,
"concurrency": metrics.concurrency,
"accept_length": accept_length,
2024-07-21 18:07:30 +10:00
}
else:
print(f"Error running benchmark for request rate: {request_rate}")
print("-" * 30)
2024-07-21 18:07:30 +10:00
# Determine output file name
if args.output_file:
output_file_name = args.output_file
else:
now = datetime.now().strftime("%m%d")
if args.dataset_name == "random-image":
output_file_name = (
f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_"
f"{args.random_output_len}_{args.random_image_num_images}imgs_"
f"{args.random_image_resolution}.jsonl"
)
elif args.dataset_name.startswith("random"):
2024-07-21 18:07:30 +10:00
output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl"
else:
output_file_name = (
f"{args.backend}_{now}_{args.num_prompts}_{args.dataset_name}.jsonl"
)
result_details = {
"input_lens": [output.prompt_len for output in outputs],
"output_lens": output_lens,
"ttfts": [output.ttft for output in outputs],
"itls": [output.itl for output in outputs],
"generated_texts": [output.generated_text for output in outputs],
"errors": [output.error for output in outputs],
}
2024-07-21 18:07:30 +10:00
# Append results to a JSONL file
with open(output_file_name, "a") as file:
if args.output_details:
result_for_dump = result | result_details
else:
result_for_dump = result
file.write(json.dumps(result_for_dump) + "\n")
return result | result_details
2024-07-20 02:15:21 +10:00
def check_chat_template(model_path):
try:
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
return "chat_template" in tokenizer.init_kwargs
except Exception as e:
print(f"Fail to load tokenizer config with error={e}")
return False
def set_global_args(args_: argparse.Namespace):
"""Set the global args."""
global args
args = args_
def run_benchmark(args_: argparse.Namespace):
global args
args = args_
# Set default value for max_concurrency if not present
if not hasattr(args, "max_concurrency"):
args.max_concurrency = None
# Set default value for warmup_requests if not present
if not hasattr(args, "warmup_requests"):
args.warmup_requests = 1
if not hasattr(args, "output_details"):
args.output_details = False
if not hasattr(args, "tokenize_prompt"):
args.tokenize_prompt = False
if not hasattr(args, "use_trace_timestamps"):
args.use_trace_timestamps = False
if not hasattr(args, "mooncake_slowdown_factor"):
args.mooncake_slowdown_factor = 1.0
if not hasattr(args, "mooncake_slowdown_factor"):
args.mooncake_slowdown_factor = 1.0
if not hasattr(args, "mooncake_num_rounds"):
args.mooncake_num_rounds = 1
print(f"benchmark_args={args}")
2024-08-17 17:43:23 -07:00
# Set global environments
set_ulimit()
2024-07-20 02:15:21 +10:00
random.seed(args.seed)
np.random.seed(args.seed)
extra_request_body = {}
if args.extra_request_body:
extra_request_body = json.loads(args.extra_request_body)
if args.tokenize_prompt:
assert (
args.backend == "sglang"
), "`--tokenize-prompt` only compatible with `--backend sglang` currently"
2024-08-17 17:43:23 -07:00
# Set url
2024-07-20 02:15:21 +10:00
if args.port is None:
args.port = {
"sglang": 30000,
2024-10-14 21:53:01 -07:00
"sglang-native": 30000,
"sglang-oai": 30000,
2024-07-20 02:15:21 +10:00
"lmdeploy": 23333,
"vllm": 8000,
"trt": 8000,
"gserver": 9988,
"truss": 8080,
2024-07-20 02:15:21 +10:00
}.get(args.backend, 30000)
model_url = (
f"{args.base_url}/v1/models"
if args.base_url
else f"http://{args.host}:{args.port}/v1/models"
)
2024-10-14 21:53:01 -07:00
if args.backend in ["sglang", "sglang-native"]:
api_url = (
f"{args.base_url}/generate"
if args.base_url
else f"http://{args.host}:{args.port}/generate"
)
elif args.backend in ["sglang-oai", "vllm", "lmdeploy"]:
api_url = (
f"{args.base_url}/v1/completions"
if args.base_url
else f"http://{args.host}:{args.port}/v1/completions"
)
elif args.backend in ["sglang-oai-chat", "vllm-chat", "lmdeploy-chat"]:
api_url = (
f"{args.base_url}/v1/chat/completions"
if args.base_url
else f"http://{args.host}:{args.port}/v1/chat/completions"
)
2024-10-14 21:53:01 -07:00
elif args.backend == "trt":
api_url = (
f"{args.base_url}/v2/models/ensemble/generate_stream"
if args.base_url
else f"http://{args.host}:{args.port}/v2/models/ensemble/generate_stream"
)
if args.model is None:
print("Please provide a model using `--model` when using `trt` backend.")
sys.exit(1)
elif args.backend == "gserver":
2024-08-17 17:43:23 -07:00
api_url = args.base_url if args.base_url else f"{args.host}:{args.port}"
args.model = args.model or "default"
elif args.backend == "truss":
api_url = (
f"{args.base_url}/v1/models/model:predict"
if args.base_url
else f"http://{args.host}:{args.port}/v1/models/model:predict"
)
base_url = (
f"http://{args.host}:{args.port}" if args.base_url is None else args.base_url
)
2024-08-17 17:43:23 -07:00
# Get model name
2024-07-20 02:15:21 +10:00
if args.model is None:
if args.backend == "truss":
print(
"Please provide a model with `--model` when using truss backend. e.g. --model meta-llama/Llama-3.1-8B-Instruct"
)
sys.exit(1)
2024-07-20 02:15:21 +10:00
try:
response = requests.get(model_url, headers=get_auth_headers())
2024-07-20 02:15:21 +10:00
model_list = response.json().get("data", [])
args.model = model_list[0]["id"] if model_list else None
except Exception as e:
print(f"Failed to fetch model from {model_url}. Error: {e}")
print(
"Please specify the correct host and port using `--host` and `--port`."
)
sys.exit(1)
if args.model is None:
print("No model specified or found. Please provide a model using `--model`.")
sys.exit(1)
if not check_chat_template(args.model):
print(
"\nWARNING It is recommended to use the `Chat` or `Instruct` model for benchmarking.\n"
"Because when the tokenizer counts the output tokens, if there is gibberish, it might count incorrectly.\n"
)
2024-07-20 02:15:21 +10:00
print(f"{args}\n")
2024-08-17 17:43:23 -07:00
# Read dataset
2024-07-20 02:15:21 +10:00
backend = args.backend
model_id = args.model
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
tokenizer = get_tokenizer(tokenizer_id)
input_requests = get_dataset(args, tokenizer)
2024-07-20 02:15:21 +10:00
2025-04-01 15:10:56 -07:00
# compatible with SimpleNamespace
if not hasattr(args, "flush_cache"):
args.flush_cache = False
return asyncio.run(
benchmark(
backend=backend,
api_url=api_url,
base_url=base_url,
model_id=model_id,
tokenizer=tokenizer,
input_requests=input_requests,
request_rate=args.request_rate,
max_concurrency=args.max_concurrency,
disable_tqdm=args.disable_tqdm,
lora_names=args.lora_name,
extra_request_body=extra_request_body,
profile=args.profile,
pd_separated=args.pd_separated,
2025-04-01 15:10:56 -07:00
flush_cache=args.flush_cache,
warmup_requests=args.warmup_requests,
use_trace_timestamps=args.use_trace_timestamps,
mooncake_slowdown_factor=args.mooncake_slowdown_factor,
mooncake_num_rounds=args.mooncake_num_rounds,
2024-08-17 17:43:23 -07:00
)
)
2024-07-20 02:15:21 +10:00
def set_ulimit(target_soft_limit=65535):
resource_type = resource.RLIMIT_NOFILE
current_soft, current_hard = resource.getrlimit(resource_type)
if current_soft < target_soft_limit:
try:
resource.setrlimit(resource_type, (target_soft_limit, current_hard))
except ValueError as e:
print(f"Fail to set RLIMIT_NOFILE: {e}")
class LoRAPathAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, [])
for lora_name in values:
getattr(namespace, self.dest).append(lora_name)
2024-07-20 02:15:21 +10:00
if __name__ == "__main__":
2024-08-01 21:20:17 -07:00
parser = ArgumentParser(description="Benchmark the online serving throughput.")
2024-07-20 02:15:21 +10:00
parser.add_argument(
"--backend",
type=str,
choices=list(ASYNC_REQUEST_FUNCS.keys()),
2024-08-01 21:20:17 -07:00
default="sglang",
2024-07-20 02:15:21 +10:00
help="Must specify a backend, depending on the LLM Inference Engine.",
)
parser.add_argument(
"--base-url",
type=str,
default=None,
help="Server or API base url if not using http host and port.",
)
parser.add_argument(
"--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0."
)
parser.add_argument(
"--port",
type=int,
help="If not set, the default port is configured according to its default value for different LLM Inference Engines.",
)
parser.add_argument(
"--dataset-name",
type=str,
default="sharegpt",
choices=[
"sharegpt",
"random",
"random-ids",
"generated-shared-prefix",
"mmmu",
"random-image",
"mooncake",
],
help="Name of the dataset to benchmark on.",
)
parser.add_argument(
"--dataset-path", type=str, default="", help="Path to the dataset."
2024-07-20 02:15:21 +10:00
)
parser.add_argument(
"--model",
type=str,
help="Name or path of the model. If not set, the default model will request /v1/models for conf.",
)
parser.add_argument(
"--tokenizer",
type=str,
help="Name or path of the tokenizer. If not set, using the model conf.",
)
parser.add_argument(
"--num-prompts",
type=int,
default=1000,
help="Number of prompts to process. Default is 1000.",
)
parser.add_argument(
"--sharegpt-output-len",
type=int,
default=None,
help="Output length for each request. Overrides the output length from the ShareGPT dataset.",
)
parser.add_argument(
"--sharegpt-context-len",
type=int,
default=None,
help="The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped.",
)
parser.add_argument(
"--random-input-len",
type=int,
default=1024,
help="Number of input tokens per request, used only for random dataset.",
)
parser.add_argument(
"--random-output-len",
default=1024,
type=int,
help="Number of output tokens per request, used only for random dataset.",
)
parser.add_argument(
"--random-range-ratio",
type=float,
default=0.0,
help="Range of sampled ratio of input/output length, "
"used only for random dataset.",
)
# random-image dataset args
parser.add_argument(
"--random-image-num-images",
type=int,
default=1,
help="Number of images per request (only available with the random-image dataset)",
)
parser.add_argument(
"--random-image-resolution",
type=str,
default="1080p",
help=(
"Resolution of random images for random-image dataset. "
"Supports presets 4k/1080p/720p/360p or custom 'heightxwidth' (e.g., 1080x1920)."
),
)
2024-07-20 02:15:21 +10:00
parser.add_argument(
"--request-rate",
type=float,
default=float("inf"),
2024-07-20 02:15:21 +10:00
help="Number of requests per second. If this is inf, then all the requests are sent at time 0. "
2024-08-18 14:29:09 -07:00
"Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.",
2024-07-20 02:15:21 +10:00
)
parser.add_argument(
"--use-trace-timestamps",
action="store_true",
help="Use timestamps from the trace file for request scheduling. Only valid for 'mooncake' dataset.",
)
parser.add_argument(
"--max-concurrency",
type=int,
default=None,
help="Maximum number of concurrent requests. This can be used "
"to help simulate an environment where a higher level component "
"is enforcing a maximum number of concurrent requests. While the "
"--request-rate argument controls the rate at which requests are "
"initiated, this argument will control how many are actually allowed "
"to execute at a time. This means that when used in combination, the "
"actual request rate may be lower than specified with --request-rate, "
"if the server is not processing requests fast enough to keep up.",
)
parser.add_argument("--output-file", type=str, help="Output JSONL file name.")
parser.add_argument(
"--output-details", action="store_true", help="Output details of benchmarking."
)
parser.add_argument(
"--disable-tqdm",
action="store_true",
help="Specify to disable tqdm progress bar.",
)
2024-07-20 18:36:42 -07:00
parser.add_argument(
"--disable-stream",
action="store_true",
help="Disable streaming mode.",
)
parser.add_argument(
"--return-logprob",
action="store_true",
help="Return logprob.",
)
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
parser.add_argument(
"--disable-ignore-eos",
action="store_true",
help="Disable ignoring EOS.",
)
parser.add_argument(
"--extra-request-body",
metavar='{"key1": "value1", "key2": "value2"}',
type=str,
help="Append given JSON object to the request payload. You can use this to specify"
"additional generate params like sampling params.",
)
parser.add_argument(
"--apply-chat-template",
action="store_true",
help="Apply chat template",
)
parser.add_argument(
"--profile",
action="store_true",
help="Use Torch Profiler. The endpoint must be launched with "
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
)
parser.add_argument(
"--lora-name",
type=str,
nargs="*",
default=None,
action=LoRAPathAction,
help="The names of LoRA adapters. You can provide a list of names in the format {name} {name} {name}...",
)
parser.add_argument(
"--prompt-suffix",
type=str,
default="",
help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.",
)
parser.add_argument(
2025-05-12 15:45:12 -07:00
"--pd-separated",
action="store_true",
help="Benchmark PD disaggregation server",
)
2025-04-01 15:10:56 -07:00
parser.add_argument(
"--flush-cache",
action="store_true",
help="Flush the cache before running the benchmark",
)
parser.add_argument(
"--warmup-requests",
type=int,
default=1,
help="Number of warmup requests to run before the benchmark",
)
parser.add_argument(
"--tokenize-prompt",
action="store_true",
help="Use integer ids instead of string for inputs. Useful to control prompt lengths accurately",
)
group = parser.add_argument_group("generated-shared-prefix dataset arguments")
group.add_argument(
"--gsp-num-groups",
type=int,
default=64,
help="Number of system prompt groups for generated-shared-prefix dataset",
)
group.add_argument(
"--gsp-prompts-per-group",
type=int,
default=16,
help="Number of prompts per system prompt group for generated-shared-prefix dataset",
)
group.add_argument(
"--gsp-system-prompt-len",
type=int,
default=2048,
help="Target length in tokens for system prompts in generated-shared-prefix dataset",
)
group.add_argument(
"--gsp-question-len",
type=int,
default=128,
help="Target length in tokens for questions in generated-shared-prefix dataset",
)
group.add_argument(
"--gsp-output-len",
type=int,
default=256,
help="Target length in tokens for outputs in generated-shared-prefix dataset",
)
mooncake_group = parser.add_argument_group("mooncake dataset arguments")
mooncake_group.add_argument(
"--mooncake-slowdown-factor",
type=float,
default=1.0,
help="Slowdown factor for replaying the mooncake trace. "
"A value of 2.0 means the replay is twice as slow. "
"NOTE: --request-rate is IGNORED in mooncake mode.",
)
mooncake_group.add_argument(
"--mooncake-num-rounds",
type=int,
default=1,
help="Number of conversation rounds for each session in the mooncake dataset. "
"A value > 1 will enable true multi-turn session benchmarking.",
)
mooncake_group.add_argument(
"--mooncake-workload",
type=str,
default="conversation",
choices=[
"mooncake",
"conversation",
"synthetic",
"toolagent",
],
help="Underlying workload for the mooncake dataset.",
)
2024-07-20 02:15:21 +10:00
args = parser.parse_args()
run_benchmark(args)