Files
sglang/python/sglang/bench_serving.py

1027 lines
36 KiB
Python
Raw Normal View History

2024-07-20 02:15:21 +10:00
# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/backend_request_func.py
# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/benchmark_serving.py
2024-07-19 11:12:23 -07:00
"""
2024-09-18 00:56:06 -07:00
Benchmark online serving with dynamic requests.
2024-07-19 11:12:23 -07:00
Usage:
2024-07-29 19:40:28 -07:00
python3 -m sglang.bench_serving --backend sglang --num-prompt 10
2024-07-24 14:44:14 -07:00
2024-07-29 19:40:28 -07:00
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 3000 --random-input 1024 --random-output 1024 --random-range-ratio 0.5
python3 -m sglang.bench_serving --backend sglang --dataset-name random --request-rate-range 1,2,4,8,16,32 --random-input 4096 --random-output 1024 --random-range-ratio 0.125 --multi
2024-07-19 11:12:23 -07:00
"""
2024-07-20 02:15:21 +10:00
import argparse
import asyncio
import json
import os
import random
import resource
import sys
import time
import traceback
import warnings
2024-08-01 21:20:17 -07:00
from argparse import ArgumentParser
2024-07-20 02:15:21 +10:00
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
2024-07-20 02:15:21 +10:00
import aiohttp
import numpy as np
import requests
from tqdm.asyncio import tqdm
from transformers import (
AutoTokenizer,
PreTrainedTokenizer,
PreTrainedTokenizerBase,
PreTrainedTokenizerFast,
)
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
global args
2024-07-20 02:15:21 +10:00
@dataclass
class RequestFuncInput:
prompt: str
api_url: str
prompt_len: int
output_len: int
model: str
extra_request_body: Dict[str, Any]
2024-07-20 02:15:21 +10:00
@dataclass
class RequestFuncOutput:
generated_text: str = ""
success: bool = False
latency: float = 0.0
ttft: float = 0.0 # Time to first token
itl: List[float] = field(default_factory=list) # List of inter-token latencies
prompt_len: int = 0
error: str = ""
2024-07-22 19:34:05 +10:00
output_len: int = 0
2024-07-20 02:15:21 +10:00
def remove_prefix(text: str, prefix: str) -> str:
return text[len(prefix) :] if text.startswith(prefix) else text
# trt llm not support ignore_eos
# https://github.com/triton-inference-server/tensorrtllm_backend/issues/505
async def async_request_trt_llm(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith("generate_stream")
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
payload = {
"accumulate_tokens": True,
"text_input": request_func_input.prompt,
2024-07-22 18:23:33 +10:00
"temperature": 0.000001,
"top_p": 1.0,
"max_tokens": request_func_input.output_len,
"stream": True,
2024-07-22 06:32:41 -07:00
"min_length": request_func_input.output_len,
"end_id": 1048576,
**request_func_input.extra_request_body,
}
if args.disable_ignore_eos:
del payload["min_length"]
del payload["end_id"]
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data:")
data = json.loads(chunk)
output.generated_text += data["text_output"]
timestamp = time.perf_counter()
# First token
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp = timestamp
output.latency = most_recent_timestamp - st
output.success = True
2024-07-22 06:32:41 -07:00
output.output_len = request_func_input.output_len
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar:
pbar.update(1)
return output
2024-07-20 02:15:21 +10:00
# set ignore_eos True by default
async def async_request_openai_completions(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith(
"completions"
), "OpenAI Completions API URL must end with 'completions'."
2024-08-17 17:43:23 -07:00
prompt = request_func_input.prompt
2024-07-20 02:15:21 +10:00
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
payload = {
"model": request_func_input.model,
2024-08-17 17:43:23 -07:00
"prompt": prompt,
2024-07-20 02:15:21 +10:00
"temperature": 0.0,
"best_of": 1,
"max_tokens": request_func_input.output_len,
2024-07-20 18:36:42 -07:00
"stream": not args.disable_stream,
"ignore_eos": not args.disable_ignore_eos,
**request_func_input.extra_request_body,
2024-07-20 02:15:21 +10:00
}
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
generated_text = ""
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(
url=api_url, json=payload, headers=headers
) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
2024-07-20 18:36:42 -07:00
latency = time.perf_counter() - st
2024-07-20 02:15:21 +10:00
if chunk == "[DONE]":
2024-07-20 18:36:42 -07:00
pass
2024-07-20 02:15:21 +10:00
else:
data = json.loads(chunk)
# NOTE: Some completion API might have a last
# usage summary response without a token so we
# want to check a token was generated
if data["choices"][0]["text"]:
timestamp = time.perf_counter()
# First token
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(timestamp - most_recent_timestamp)
2024-07-20 02:15:21 +10:00
most_recent_timestamp = timestamp
generated_text += data["choices"][0]["text"]
output.generated_text = generated_text
output.success = True
output.latency = latency
2024-07-22 19:34:05 +10:00
output.output_len = request_func_input.output_len
2024-07-20 02:15:21 +10:00
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar:
pbar.update(1)
return output
async def async_request_gserver(
2024-08-17 17:43:23 -07:00
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
raise NotImplementedError()
2024-07-20 02:15:21 +10:00
def get_model(pretrained_model_name_or_path: str) -> str:
if os.getenv("SGLANG_USE_MODELSCOPE", "False").lower() == "true":
import huggingface_hub.constants
from modelscope import snapshot_download
model_path = snapshot_download(
model_id=pretrained_model_name_or_path,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"],
)
return model_path
return pretrained_model_name_or_path
def get_tokenizer(
pretrained_model_name_or_path: str,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
2024-08-17 17:43:23 -07:00
if pretrained_model_name_or_path.endswith(
".json"
) or pretrained_model_name_or_path.endswith(".model"):
from sglang.srt.hf_transformers_utils import get_tokenizer
return get_tokenizer(pretrained_model_name_or_path)
2024-07-20 02:15:21 +10:00
if pretrained_model_name_or_path is not None and not os.path.exists(
pretrained_model_name_or_path
):
pretrained_model_name_or_path = get_model(pretrained_model_name_or_path)
return AutoTokenizer.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=True
)
ASYNC_REQUEST_FUNCS = {
"sglang": async_request_openai_completions,
"vllm": async_request_openai_completions,
"lmdeploy": async_request_openai_completions,
"trt": async_request_trt_llm,
"gserver": async_request_gserver,
2024-07-20 02:15:21 +10:00
}
@dataclass
class BenchmarkMetrics:
completed: int
total_input: int
total_output: int
2024-07-22 06:32:41 -07:00
total_output_retokenized: int
2024-07-20 02:15:21 +10:00
request_throughput: float
input_throughput: float
output_throughput: float
2024-07-22 06:32:41 -07:00
output_throughput_retokenized: float
2024-07-20 02:15:21 +10:00
mean_ttft_ms: float
median_ttft_ms: float
std_ttft_ms: float
p99_ttft_ms: float
mean_tpot_ms: float
median_tpot_ms: float
std_tpot_ms: float
p99_tpot_ms: float
mean_itl_ms: float
median_itl_ms: float
std_itl_ms: float
p99_itl_ms: float
2024-07-24 05:51:10 +10:00
mean_e2e_latency_ms: float
median_e2e_latency_ms: float
2024-07-20 02:15:21 +10:00
2024-09-09 04:14:11 -07:00
SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
2024-07-20 01:57:43 -07:00
2024-09-09 04:14:11 -07:00
def download_and_cache_file(url: str, filename: Optional[str] = None):
"""Read and cache a file from a url."""
if filename is None:
filename = os.path.join("/tmp", url.split("/")[-1])
2024-07-20 01:57:43 -07:00
2024-09-09 04:14:11 -07:00
# Check if the cache file already exists
if os.path.exists(filename):
return filename
2024-07-20 01:57:43 -07:00
2024-09-09 04:14:11 -07:00
print(f"Downloading from {url} to {filename}")
2024-07-20 01:57:43 -07:00
2024-09-09 04:14:11 -07:00
# Stream the response to show the progress bar
response = requests.get(url, stream=True)
response.raise_for_status() # Check for request errors
2024-07-20 01:57:43 -07:00
2024-09-09 04:14:11 -07:00
# Total size of the file in bytes
total_size = int(response.headers.get("content-length", 0))
chunk_size = 1024 # Download in chunks of 1KB
# Use tqdm to display the progress bar
with open(filename, "wb") as f, tqdm(
desc=filename,
total=total_size,
unit="B",
unit_scale=True,
unit_divisor=1024,
) as bar:
for chunk in response.iter_content(chunk_size=chunk_size):
f.write(chunk)
bar.update(len(chunk))
return filename
2024-07-20 01:57:43 -07:00
2024-07-20 02:15:21 +10:00
def sample_sharegpt_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int] = None,
) -> List[Tuple[str, int, int]]:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")
2024-07-20 01:57:43 -07:00
# Download sharegpt if necessary
2024-09-09 04:14:11 -07:00
if not os.path.isfile(dataset_path):
dataset_path = download_and_cache_file(SHAREGPT_URL)
2024-07-20 02:15:21 +10:00
# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Only keep the first two turns of each conversation.
dataset = [
(data["conversations"][0]["value"], data["conversations"][1]["value"])
for data in dataset
]
# Shuffle the dataset.
random.shuffle(dataset)
# Filter out sequences that are too long or too short
filtered_dataset: List[Tuple[str, int, int]] = []
for i in range(len(dataset)):
if len(filtered_dataset) == num_requests:
break
# Tokenize the prompts and completions.
prompt = dataset[i][0]
2024-08-17 17:43:23 -07:00
prompt_token_ids = tokenizer.encode(prompt)
2024-07-20 02:15:21 +10:00
completion = dataset[i][1]
2024-08-17 17:43:23 -07:00
completion_token_ids = tokenizer.encode(completion)
2024-07-20 02:15:21 +10:00
prompt_len = len(prompt_token_ids)
output_len = (
len(completion_token_ids) if fixed_output_len is None else fixed_output_len
)
if prompt_len < 4 or output_len < 4:
# Prune too short sequences.
continue
2024-08-17 17:43:23 -07:00
if prompt_len > 1024 or (
prompt_len + output_len > 2048 and fixed_output_len is None
):
2024-07-20 02:15:21 +10:00
# Prune too long sequences.
continue
filtered_dataset.append((prompt, prompt_len, output_len))
return filtered_dataset
def sample_random_requests(
input_len: int,
output_len: int,
num_prompts: int,
range_ratio: float,
tokenizer: PreTrainedTokenizerBase,
2024-07-20 01:57:43 -07:00
dataset_path: str,
) -> List[Tuple[str, int, int]]:
input_lens = np.random.randint(
max(int(input_len * range_ratio), 1),
input_len + 1,
size=num_prompts,
)
output_lens = np.random.randint(
int(output_len * range_ratio),
output_len + 1,
size=num_prompts,
)
2024-07-20 01:57:43 -07:00
if True:
# Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens
# Download sharegpt if necessary
2024-09-09 04:14:11 -07:00
if not os.path.isfile(dataset_path):
dataset_path = download_and_cache_file(SHAREGPT_URL)
2024-07-20 01:57:43 -07:00
# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Only keep the first two turns of each conversation.
dataset = [
(data["conversations"][0]["value"], data["conversations"][1]["value"])
for data in dataset
]
# Shuffle the dataset.
random.shuffle(dataset)
# Filter out sequences that are too long or too short
input_requests: List[Tuple[str, int, int]] = []
for i in range(num_prompts):
# Tokenize the prompts and completions.
prompt = dataset[i][0]
2024-08-17 17:43:23 -07:00
prompt_token_ids = tokenizer.encode(prompt)
2024-07-20 01:57:43 -07:00
prompt_len = len(prompt_token_ids)
if prompt_len > input_lens[i]:
2024-07-20 01:57:43 -07:00
input_ids = prompt_token_ids[: input_lens[i]]
else:
ratio = (input_lens[i] + prompt_len - 1) // prompt_len
input_ids = (prompt_token_ids * ratio)[: input_lens[i]]
prompt = tokenizer.decode(input_ids)
input_requests.append((prompt, int(input_lens[i]), int(output_lens[i])))
else:
# Sample token ids from random integers. This can cause some NaN issues.
offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
input_requests = []
for i in range(num_prompts):
prompt = tokenizer.decode(
[
(offsets[i] + i + j) % tokenizer.vocab_size
for j in range(input_lens[i])
]
)
input_requests.append((prompt, int(input_lens[i]), int(output_lens[i])))
print(f"#Input tokens: {np.sum(input_lens)}")
print(f"#Output tokens: {np.sum(output_lens)}")
return input_requests
2024-07-20 02:15:21 +10:00
async def get_request(
input_requests: List[Tuple[str, int, int]],
request_rate: float,
) -> AsyncGenerator[Tuple[str, int, int], None]:
input_requests = iter(input_requests)
for request in input_requests:
yield request
if request_rate == float("inf"):
# If the request rate is infinity, then we don't need to wait.
continue
# Sample the request interval from the exponential distribution.
interval = np.random.exponential(1.0 / request_rate)
# The next request will be sent after the interval.
await asyncio.sleep(interval)
def calculate_metrics(
input_requests: List[Tuple[str, int, int]],
outputs: List[RequestFuncOutput],
dur_s: float,
tokenizer: PreTrainedTokenizerBase,
2024-07-22 19:34:05 +10:00
backend: str,
2024-07-20 02:15:21 +10:00
) -> Tuple[BenchmarkMetrics, List[int]]:
2024-07-22 06:32:41 -07:00
output_lens: List[int] = []
retokenized_output_lens: List[int] = []
2024-07-20 02:15:21 +10:00
total_input = 0
completed = 0
itls: List[float] = []
tpots: List[float] = []
ttfts: List[float] = []
2024-07-24 05:51:10 +10:00
e2e_latencies: List[float] = []
2024-07-20 02:15:21 +10:00
for i in range(len(outputs)):
if outputs[i].success:
2024-07-22 06:32:41 -07:00
output_len = outputs[i].output_len
output_lens.append(output_len)
retokenized_output_len = len(
2024-08-17 17:43:23 -07:00
tokenizer.encode(outputs[i].generated_text, add_special_tokens=False)
2024-07-22 06:32:41 -07:00
)
retokenized_output_lens.append(retokenized_output_len)
2024-07-20 02:15:21 +10:00
total_input += input_requests[i][1]
if output_len > 1:
tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1))
itls += outputs[i].itl
ttfts.append(outputs[i].ttft)
2024-07-24 05:51:10 +10:00
e2e_latencies.append(outputs[i].latency)
2024-07-20 02:15:21 +10:00
completed += 1
else:
2024-07-22 06:32:41 -07:00
output_lens.append(0)
retokenized_output_lens.append(0)
2024-07-20 02:15:21 +10:00
if completed == 0:
warnings.warn(
"All requests failed. This is likely due to a misconfiguration "
"on the benchmark arguments.",
stacklevel=2,
)
metrics = BenchmarkMetrics(
completed=completed,
total_input=total_input,
2024-07-22 06:32:41 -07:00
total_output=sum(output_lens),
total_output_retokenized=sum(retokenized_output_lens),
2024-07-20 02:15:21 +10:00
request_throughput=completed / dur_s,
input_throughput=total_input / dur_s,
2024-07-22 06:32:41 -07:00
output_throughput=sum(output_lens) / dur_s,
output_throughput_retokenized=sum(retokenized_output_lens) / dur_s,
2024-07-20 02:15:21 +10:00
mean_ttft_ms=np.mean(ttfts or 0)
* 1000, # ttfts is empty if streaming is not supported by backend
median_ttft_ms=np.median(ttfts or 0) * 1000,
std_ttft_ms=np.std(ttfts or 0) * 1000,
p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
mean_tpot_ms=np.mean(tpots or 0) * 1000,
median_tpot_ms=np.median(tpots or 0) * 1000,
std_tpot_ms=np.std(tpots or 0) * 1000,
p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,
mean_itl_ms=np.mean(itls or 0) * 1000,
median_itl_ms=np.median(itls or 0) * 1000,
std_itl_ms=np.std(itls or 0) * 1000,
p99_itl_ms=np.percentile(itls or 0, 99) * 1000,
2024-07-24 05:51:10 +10:00
mean_e2e_latency_ms=np.mean(e2e_latencies) * 1000,
median_e2e_latency_ms=np.median(e2e_latencies) * 1000,
2024-07-20 02:15:21 +10:00
)
2024-07-22 06:32:41 -07:00
return metrics, output_lens
2024-07-20 02:15:21 +10:00
async def benchmark(
backend: str,
api_url: str,
model_id: str,
tokenizer: PreTrainedTokenizerBase,
input_requests: List[Tuple[str, int, int]],
request_rate: float,
disable_tqdm: bool,
extra_request_body: Dict[str, Any],
2024-07-20 02:15:21 +10:00
):
if backend in ASYNC_REQUEST_FUNCS:
request_func = ASYNC_REQUEST_FUNCS[backend]
else:
raise ValueError(f"Unknown backend: {backend}")
print("Starting initial single prompt test run...")
test_prompt, test_prompt_len, test_output_len = input_requests[0]
test_input = RequestFuncInput(
model=model_id,
prompt=test_prompt,
api_url=api_url,
prompt_len=test_prompt_len,
output_len=test_output_len,
extra_request_body=extra_request_body,
2024-07-20 02:15:21 +10:00
)
test_output = await request_func(request_func_input=test_input)
if not test_output.success:
raise ValueError(
"Initial test run failed - Please make sure benchmark arguments "
f"are correctly specified. Error: {test_output.error}"
)
else:
print("Initial test run completed. Starting main benchmark run...")
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
benchmark_start_time = time.perf_counter()
tasks: List[asyncio.Task] = []
async for request in get_request(input_requests, request_rate):
prompt, prompt_len, output_len = request
request_func_input = RequestFuncInput(
model=model_id,
prompt=prompt,
api_url=api_url,
prompt_len=prompt_len,
output_len=output_len,
extra_request_body=extra_request_body,
2024-07-20 02:15:21 +10:00
)
tasks.append(
asyncio.create_task(
request_func(request_func_input=request_func_input, pbar=pbar)
)
)
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
if pbar is not None:
pbar.close()
benchmark_duration = time.perf_counter() - benchmark_start_time
2024-07-22 06:32:41 -07:00
metrics, output_lens = calculate_metrics(
2024-07-20 02:15:21 +10:00
input_requests=input_requests,
outputs=outputs,
dur_s=benchmark_duration,
tokenizer=tokenizer,
2024-07-22 19:34:05 +10:00
backend=backend,
2024-07-20 02:15:21 +10:00
)
print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
2024-07-22 19:34:05 +10:00
print("{:<40} {:<10}".format("Backend:", backend))
2024-07-20 02:15:21 +10:00
print("{:<40} {:<10}".format("Traffic request rate:", request_rate))
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
2024-07-22 06:32:41 -07:00
print(
"{:<40} {:<10}".format(
"Total generated tokens (retokenized):", metrics.total_output_retokenized
)
)
2024-07-20 02:15:21 +10:00
print(
"{:<40} {:<10.2f}".format(
"Request throughput (req/s):", metrics.request_throughput
)
)
print(
"{:<40} {:<10.2f}".format(
"Input token throughput (tok/s):", metrics.input_throughput
)
)
print(
"{:<40} {:<10.2f}".format(
"Output token throughput (tok/s):", metrics.output_throughput
)
)
2024-07-24 05:51:10 +10:00
print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-"))
print(
"{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)
)
print(
"{:<40} {:<10.2f}".format(
"Median E2E Latency (ms):", metrics.median_e2e_latency_ms
)
)
2024-07-20 02:15:21 +10:00
print("{s:{c}^{n}}".format(s="Time to First Token", n=50, c="-"))
print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms))
print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
print(
"{s:{c}^{n}}".format(s="Time per Output Token (excl. 1st token)", n=50, c="-")
)
print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms))
print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
print("{s:{c}^{n}}".format(s="Inter-token Latency", n=50, c="-"))
print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
print("=" * 50)
2024-07-21 18:07:30 +10:00
if (
metrics.median_ttft_ms is not None
and metrics.mean_itl_ms is not None
and metrics.output_throughput is not None
):
result = {
"backend": args.backend,
"dataset_name": args.dataset_name,
"request_rate": request_rate,
2024-08-12 02:21:38 -07:00
"total_input_tokens": metrics.total_input,
"total_output_tokens": metrics.total_output,
"total_output_tokens_retokenized": metrics.total_output_retokenized,
"mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
"median_e2e_latency_ms": metrics.median_e2e_latency_ms,
"median_ttft_ms": metrics.median_ttft_ms,
"median_itl_ms": metrics.median_itl_ms,
"output_throughput": metrics.output_throughput,
2024-07-21 18:07:30 +10:00
"sharegpt_output_len": args.sharegpt_output_len,
"random_input_len": args.random_input_len,
"random_output_len": args.random_output_len,
"random_range_ratio": args.random_range_ratio,
2024-08-12 02:21:38 -07:00
"duration": benchmark_duration,
"completed": metrics.completed,
2024-07-21 18:07:30 +10:00
}
else:
print(f"Error running benchmark for request rate: {request_rate}")
print("-" * 30)
2024-07-21 18:07:30 +10:00
# Determine output file name
if args.output_file:
output_file_name = args.output_file
else:
now = datetime.now().strftime("%m%d")
if args.dataset_name == "random":
output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl"
else:
2024-07-21 18:07:30 +10:00
output_file_name = f"{args.backend}_{now}_{args.num_prompts}_sharegpt.jsonl"
2024-07-21 18:07:30 +10:00
# Append results to a JSONL file
with open(output_file_name, "a") as file:
file.write(json.dumps(result) + "\n")
2024-07-20 02:15:21 +10:00
result = {
"duration": benchmark_duration,
"completed": metrics.completed,
"total_input_tokens": metrics.total_input,
"total_output_tokens": metrics.total_output,
2024-07-22 06:32:41 -07:00
"total_output_tokens_retokenized": metrics.total_output_retokenized,
2024-07-20 02:15:21 +10:00
"request_throughput": metrics.request_throughput,
"input_throughput": metrics.input_throughput,
"output_throughput": metrics.output_throughput,
"mean_ttft_ms": metrics.mean_ttft_ms,
"median_ttft_ms": metrics.median_ttft_ms,
"std_ttft_ms": metrics.std_ttft_ms,
"p99_ttft_ms": metrics.p99_ttft_ms,
"mean_tpot_ms": metrics.mean_tpot_ms,
"median_tpot_ms": metrics.median_tpot_ms,
"std_tpot_ms": metrics.std_tpot_ms,
"p99_tpot_ms": metrics.p99_tpot_ms,
"mean_itl_ms": metrics.mean_itl_ms,
"median_itl_ms": metrics.median_itl_ms,
"std_itl_ms": metrics.std_itl_ms,
"p99_itl_ms": metrics.p99_itl_ms,
"input_lens": [output.prompt_len for output in outputs],
2024-07-22 06:32:41 -07:00
"output_lens": output_lens,
2024-07-20 02:15:21 +10:00
"ttfts": [output.ttft for output in outputs],
"itls": [output.itl for output in outputs],
"generated_texts": [output.generated_text for output in outputs],
"errors": [output.error for output in outputs],
2024-07-24 05:51:10 +10:00
"mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
"median_e2e_latency_ms": metrics.median_e2e_latency_ms,
2024-07-20 02:15:21 +10:00
}
return result
def parse_request_rate_range(request_rate_range):
2024-07-21 16:46:58 +10:00
if len(request_rate_range.split(",")) == 3:
start, stop, step = map(int, request_rate_range.split(","))
return list(range(start, stop, step))
else:
return list(map(int, request_rate_range.split(",")))
def check_chat_template(model_path):
try:
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
return "chat_template" in tokenizer.init_kwargs
except Exception as e:
print(f"Fail to load tokenizer config with error={e}")
return False
def run_benchmark(args_: argparse.Namespace):
global args
args = args_
2024-08-17 17:43:23 -07:00
# Set global environments
set_ulimit()
2024-07-20 02:15:21 +10:00
random.seed(args.seed)
np.random.seed(args.seed)
extra_request_body = {}
if args.extra_request_body:
extra_request_body = json.loads(args.extra_request_body)
2024-08-17 17:43:23 -07:00
# Set url
2024-07-20 02:15:21 +10:00
if args.port is None:
args.port = {
"sglang": 30000,
"lmdeploy": 23333,
"vllm": 8000,
"trt": 8000,
"gserver": 9988,
2024-07-20 02:15:21 +10:00
}.get(args.backend, 30000)
api_url = (
f"{args.base_url}/v1/completions"
if args.base_url
else f"http://{args.host}:{args.port}/v1/completions"
)
model_url = (
f"{args.base_url}/v1/models"
if args.base_url
else f"http://{args.host}:{args.port}/v1/models"
)
if args.backend == "trt":
api_url = (
f"{args.base_url}/v2/models/ensemble/generate_stream"
if args.base_url
else f"http://{args.host}:{args.port}/v2/models/ensemble/generate_stream"
)
if args.model is None:
print("Please provide a model using `--model` when using `trt` backend.")
sys.exit(1)
elif args.backend == "gserver":
2024-08-17 17:43:23 -07:00
api_url = args.base_url if args.base_url else f"{args.host}:{args.port}"
args.model = args.model or "default"
2024-08-17 17:43:23 -07:00
# Get model name
2024-07-20 02:15:21 +10:00
if args.model is None:
try:
response = requests.get(model_url)
model_list = response.json().get("data", [])
args.model = model_list[0]["id"] if model_list else None
except Exception as e:
print(f"Failed to fetch model from {model_url}. Error: {e}")
print(
"Please specify the correct host and port using `--host` and `--port`."
)
sys.exit(1)
if args.model is None:
print("No model specified or found. Please provide a model using `--model`.")
sys.exit(1)
if not check_chat_template(args.model):
print(
"\nWARNING It is recommended to use the `Chat` or `Instruct` model for benchmarking.\n"
"Because when the tokenizer counts the output tokens, if there is gibberish, it might count incorrectly.\n"
)
2024-07-20 02:15:21 +10:00
print(f"{args}\n")
2024-08-17 17:43:23 -07:00
# Read dataset
2024-07-20 02:15:21 +10:00
backend = args.backend
model_id = args.model
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
tokenizer = get_tokenizer(tokenizer_id)
if args.dataset_name == "sharegpt":
assert args.random_input_len is None and args.random_output_len is None
input_requests = sample_sharegpt_requests(
dataset_path=args.dataset_path,
num_requests=args.num_prompts,
tokenizer=tokenizer,
fixed_output_len=args.sharegpt_output_len,
)
elif args.dataset_name == "random":
assert args.random_input_len is not None and args.random_output_len is not None
input_requests = sample_random_requests(
input_len=args.random_input_len,
output_len=args.random_output_len,
num_prompts=args.num_prompts,
range_ratio=args.random_range_ratio,
tokenizer=tokenizer,
2024-07-20 01:57:43 -07:00
dataset_path=args.dataset_path,
)
else:
raise ValueError(f"Unknown dataset: {args.dataset_name}")
2024-07-20 02:15:21 +10:00
2024-08-17 17:43:23 -07:00
if not args.multi:
return asyncio.run(
benchmark(
backend=backend,
api_url=api_url,
model_id=model_id,
tokenizer=tokenizer,
input_requests=input_requests,
request_rate=args.request_rate,
disable_tqdm=args.disable_tqdm,
extra_request_body=extra_request_body,
)
)
else:
# Benchmark multiple rps. TODO: use a fixed duration to compute num_prompts
request_rates = parse_request_rate_range(args.request_rate_range)
for rate in request_rates:
asyncio.run(
benchmark(
backend=backend,
api_url=api_url,
model_id=model_id,
tokenizer=tokenizer,
input_requests=input_requests,
request_rate=rate,
disable_tqdm=args.disable_tqdm,
extra_request_body=extra_request_body,
)
)
2024-07-20 02:15:21 +10:00
def set_ulimit(target_soft_limit=65535):
resource_type = resource.RLIMIT_NOFILE
current_soft, current_hard = resource.getrlimit(resource_type)
if current_soft < target_soft_limit:
try:
resource.setrlimit(resource_type, (target_soft_limit, current_hard))
except ValueError as e:
print(f"Fail to set RLIMIT_NOFILE: {e}")
if __name__ == "__main__":
2024-08-01 21:20:17 -07:00
parser = ArgumentParser(description="Benchmark the online serving throughput.")
2024-07-20 02:15:21 +10:00
parser.add_argument(
"--backend",
type=str,
choices=list(ASYNC_REQUEST_FUNCS.keys()),
2024-08-01 21:20:17 -07:00
default="sglang",
2024-07-20 02:15:21 +10:00
help="Must specify a backend, depending on the LLM Inference Engine.",
)
parser.add_argument(
"--base-url",
type=str,
default=None,
help="Server or API base url if not using http host and port.",
)
parser.add_argument(
"--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0."
)
parser.add_argument(
"--port",
type=int,
help="If not set, the default port is configured according to its default value for different LLM Inference Engines.",
)
parser.add_argument(
"--dataset-name",
type=str,
default="sharegpt",
choices=["sharegpt", "random"],
help="Name of the dataset to benchmark on.",
)
parser.add_argument(
"--dataset-path", type=str, default="", help="Path to the dataset."
2024-07-20 02:15:21 +10:00
)
parser.add_argument(
"--model",
type=str,
help="Name or path of the model. If not set, the default model will request /v1/models for conf.",
)
parser.add_argument(
"--tokenizer",
type=str,
help="Name or path of the tokenizer. If not set, using the model conf.",
)
parser.add_argument(
"--num-prompts",
type=int,
default=1000,
help="Number of prompts to process. Default is 1000.",
)
parser.add_argument(
"--sharegpt-output-len",
type=int,
default=None,
help="Output length for each request. Overrides the output length from the ShareGPT dataset.",
)
parser.add_argument(
"--random-input-len",
type=int,
help="Number of input tokens per request, used only for random dataset.",
)
parser.add_argument(
"--random-output-len",
type=int,
help="Number of output tokens per request, used only for random dataset.",
)
parser.add_argument(
"--random-range-ratio",
type=float,
default=0.0,
help="Range of sampled ratio of input/output length, "
"used only for random dataset.",
)
2024-07-20 02:15:21 +10:00
parser.add_argument(
"--request-rate",
type=float,
default=float("inf"),
2024-07-20 02:15:21 +10:00
help="Number of requests per second. If this is inf, then all the requests are sent at time 0. "
2024-08-18 14:29:09 -07:00
"Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.",
2024-07-20 02:15:21 +10:00
)
2024-08-17 17:43:23 -07:00
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
parser.add_argument(
"--multi",
action="store_true",
help="Use request rate range rather than single value.",
)
parser.add_argument(
"--request-rate-range",
type=str,
default="2,34,2",
2024-07-21 16:46:58 +10:00
help="Range of request rates in the format start,stop,step. Default is 2,34,2. It also supports a list of request rates, requiring the parameters to not equal three.",
)
parser.add_argument("--output-file", type=str, help="Output JSONL file name.")
parser.add_argument(
"--disable-tqdm",
action="store_true",
help="Specify to disable tqdm progress bar.",
)
2024-07-20 18:36:42 -07:00
parser.add_argument(
"--disable-stream",
action="store_true",
help="Disable streaming mode.",
)
parser.add_argument(
"--disable-ignore-eos",
action="store_true",
help="Disable ignoring EOS.",
)
parser.add_argument(
"--extra-request-body",
metavar='{"key1": "value1", "key2": "value2"}',
type=str,
help="Append given JSON object to the request payload. You can use this to specify"
"additional generate params like sampling params.",
)
2024-07-20 02:15:21 +10:00
args = parser.parse_args()
run_benchmark(args)