feat: add benchmark serving (#657)
This commit is contained in:
627
python/sglang/bench.py
Normal file
627
python/sglang/bench.py
Normal file
@@ -0,0 +1,627 @@
|
|||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/backend_request_func.py
|
||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/benchmark_serving.py
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import resource
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
import warnings
|
||||||
|
from argparse import ArgumentParser as FlexibleArgumentParser
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import AsyncGenerator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import numpy as np
|
||||||
|
import requests
|
||||||
|
from tqdm.asyncio import tqdm
|
||||||
|
from transformers import (
|
||||||
|
AutoTokenizer,
|
||||||
|
PreTrainedTokenizer,
|
||||||
|
PreTrainedTokenizerBase,
|
||||||
|
PreTrainedTokenizerFast,
|
||||||
|
)
|
||||||
|
|
||||||
|
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RequestFuncInput:
|
||||||
|
prompt: str
|
||||||
|
api_url: str
|
||||||
|
prompt_len: int
|
||||||
|
output_len: int
|
||||||
|
model: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RequestFuncOutput:
|
||||||
|
generated_text: str = ""
|
||||||
|
success: bool = False
|
||||||
|
latency: float = 0.0
|
||||||
|
ttft: float = 0.0 # Time to first token
|
||||||
|
itl: List[float] = field(default_factory=list) # List of inter-token latencies
|
||||||
|
prompt_len: int = 0
|
||||||
|
error: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
def remove_prefix(text: str, prefix: str) -> str:
|
||||||
|
return text[len(prefix) :] if text.startswith(prefix) else text
|
||||||
|
|
||||||
|
|
||||||
|
# set ignore_eos True by default
|
||||||
|
async def async_request_openai_completions(
|
||||||
|
request_func_input: RequestFuncInput,
|
||||||
|
pbar: Optional[tqdm] = None,
|
||||||
|
) -> RequestFuncOutput:
|
||||||
|
api_url = request_func_input.api_url
|
||||||
|
assert api_url.endswith(
|
||||||
|
"completions"
|
||||||
|
), "OpenAI Completions API URL must end with 'completions'."
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||||
|
payload = {
|
||||||
|
"model": request_func_input.model,
|
||||||
|
"prompt": request_func_input.prompt,
|
||||||
|
"temperature": 0.0,
|
||||||
|
"best_of": 1,
|
||||||
|
"max_tokens": request_func_input.output_len,
|
||||||
|
"stream": True,
|
||||||
|
"ignore_eos": True,
|
||||||
|
}
|
||||||
|
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
||||||
|
|
||||||
|
output = RequestFuncOutput()
|
||||||
|
output.prompt_len = request_func_input.prompt_len
|
||||||
|
|
||||||
|
generated_text = ""
|
||||||
|
ttft = 0.0
|
||||||
|
st = time.perf_counter()
|
||||||
|
most_recent_timestamp = st
|
||||||
|
try:
|
||||||
|
async with session.post(
|
||||||
|
url=api_url, json=payload, headers=headers
|
||||||
|
) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
async for chunk_bytes in response.content:
|
||||||
|
chunk_bytes = chunk_bytes.strip()
|
||||||
|
if not chunk_bytes:
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
|
||||||
|
if chunk == "[DONE]":
|
||||||
|
latency = time.perf_counter() - st
|
||||||
|
else:
|
||||||
|
data = json.loads(chunk)
|
||||||
|
|
||||||
|
# NOTE: Some completion API might have a last
|
||||||
|
# usage summary response without a token so we
|
||||||
|
# want to check a token was generated
|
||||||
|
if data["choices"][0]["text"]:
|
||||||
|
timestamp = time.perf_counter()
|
||||||
|
# First token
|
||||||
|
if ttft == 0.0:
|
||||||
|
ttft = time.perf_counter() - st
|
||||||
|
output.ttft = ttft
|
||||||
|
|
||||||
|
# Decoding phase
|
||||||
|
output.itl.append(timestamp - most_recent_timestamp)
|
||||||
|
|
||||||
|
most_recent_timestamp = timestamp
|
||||||
|
generated_text += data["choices"][0]["text"]
|
||||||
|
|
||||||
|
output.generated_text = generated_text
|
||||||
|
output.success = True
|
||||||
|
output.latency = latency
|
||||||
|
else:
|
||||||
|
output.error = response.reason or ""
|
||||||
|
output.success = False
|
||||||
|
except Exception:
|
||||||
|
output.success = False
|
||||||
|
exc_info = sys.exc_info()
|
||||||
|
output.error = "".join(traceback.format_exception(*exc_info))
|
||||||
|
|
||||||
|
if pbar:
|
||||||
|
pbar.update(1)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def get_model(pretrained_model_name_or_path: str) -> str:
|
||||||
|
if os.getenv("SGLANG_USE_MODELSCOPE", "False").lower() == "true":
|
||||||
|
import huggingface_hub.constants
|
||||||
|
from modelscope import snapshot_download
|
||||||
|
|
||||||
|
model_path = snapshot_download(
|
||||||
|
model_id=pretrained_model_name_or_path,
|
||||||
|
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||||
|
ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"],
|
||||||
|
)
|
||||||
|
|
||||||
|
return model_path
|
||||||
|
return pretrained_model_name_or_path
|
||||||
|
|
||||||
|
|
||||||
|
def get_tokenizer(
|
||||||
|
pretrained_model_name_or_path: str,
|
||||||
|
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||||
|
if pretrained_model_name_or_path is not None and not os.path.exists(
|
||||||
|
pretrained_model_name_or_path
|
||||||
|
):
|
||||||
|
pretrained_model_name_or_path = get_model(pretrained_model_name_or_path)
|
||||||
|
return AutoTokenizer.from_pretrained(
|
||||||
|
pretrained_model_name_or_path, trust_remote_code=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
ASYNC_REQUEST_FUNCS = {
|
||||||
|
"sglang": async_request_openai_completions,
|
||||||
|
"vllm": async_request_openai_completions,
|
||||||
|
"lmdeploy": async_request_openai_completions,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BenchmarkMetrics:
|
||||||
|
completed: int
|
||||||
|
total_input: int
|
||||||
|
total_output: int
|
||||||
|
request_throughput: float
|
||||||
|
input_throughput: float
|
||||||
|
output_throughput: float
|
||||||
|
mean_ttft_ms: float
|
||||||
|
median_ttft_ms: float
|
||||||
|
std_ttft_ms: float
|
||||||
|
p99_ttft_ms: float
|
||||||
|
mean_tpot_ms: float
|
||||||
|
median_tpot_ms: float
|
||||||
|
std_tpot_ms: float
|
||||||
|
p99_tpot_ms: float
|
||||||
|
mean_itl_ms: float
|
||||||
|
median_itl_ms: float
|
||||||
|
std_itl_ms: float
|
||||||
|
p99_itl_ms: float
|
||||||
|
|
||||||
|
|
||||||
|
def sample_sharegpt_requests(
|
||||||
|
dataset_path: str,
|
||||||
|
num_requests: int,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
fixed_output_len: Optional[int] = None,
|
||||||
|
) -> List[Tuple[str, int, int]]:
|
||||||
|
if fixed_output_len is not None and fixed_output_len < 4:
|
||||||
|
raise ValueError("output_len too small")
|
||||||
|
|
||||||
|
default_dataset_path = "ShareGPT_V3_unfiltered_cleaned_split.json"
|
||||||
|
url = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
|
||||||
|
|
||||||
|
if not os.path.isfile(dataset_path) and not os.path.isfile(default_dataset_path):
|
||||||
|
print(f"Downloading dataset from {url}")
|
||||||
|
try:
|
||||||
|
response = requests.get(url, stream=True)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
total_size = int(response.headers.get("content-length", 0))
|
||||||
|
block_size = 8192
|
||||||
|
|
||||||
|
with open(default_dataset_path, "wb") as f, tqdm(
|
||||||
|
desc="Downloading",
|
||||||
|
total=total_size,
|
||||||
|
unit="iB",
|
||||||
|
unit_scale=True,
|
||||||
|
unit_divisor=1024,
|
||||||
|
) as progress_bar:
|
||||||
|
for data in response.iter_content(block_size):
|
||||||
|
size = f.write(data)
|
||||||
|
progress_bar.update(size)
|
||||||
|
|
||||||
|
print(f"Dataset downloaded and saved to {default_dataset_path}")
|
||||||
|
dataset_path = default_dataset_path
|
||||||
|
except requests.RequestException as e:
|
||||||
|
raise Exception(f"Failed to download dataset: {e}")
|
||||||
|
else:
|
||||||
|
dataset_path = (
|
||||||
|
dataset_path if os.path.isfile(dataset_path) else default_dataset_path
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load the dataset.
|
||||||
|
with open(dataset_path) as f:
|
||||||
|
dataset = json.load(f)
|
||||||
|
# Filter out the conversations with less than 2 turns.
|
||||||
|
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
||||||
|
# Only keep the first two turns of each conversation.
|
||||||
|
dataset = [
|
||||||
|
(data["conversations"][0]["value"], data["conversations"][1]["value"])
|
||||||
|
for data in dataset
|
||||||
|
]
|
||||||
|
|
||||||
|
# Shuffle the dataset.
|
||||||
|
random.shuffle(dataset)
|
||||||
|
|
||||||
|
# Filter out sequences that are too long or too short
|
||||||
|
filtered_dataset: List[Tuple[str, int, int]] = []
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
if len(filtered_dataset) == num_requests:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Tokenize the prompts and completions.
|
||||||
|
prompt = dataset[i][0]
|
||||||
|
prompt_token_ids = tokenizer(prompt).input_ids
|
||||||
|
completion = dataset[i][1]
|
||||||
|
completion_token_ids = tokenizer(completion).input_ids
|
||||||
|
prompt_len = len(prompt_token_ids)
|
||||||
|
output_len = (
|
||||||
|
len(completion_token_ids) if fixed_output_len is None else fixed_output_len
|
||||||
|
)
|
||||||
|
if prompt_len < 4 or output_len < 4:
|
||||||
|
# Prune too short sequences.
|
||||||
|
continue
|
||||||
|
if prompt_len > 1024 or prompt_len + output_len > 2048:
|
||||||
|
# Prune too long sequences.
|
||||||
|
continue
|
||||||
|
filtered_dataset.append((prompt, prompt_len, output_len))
|
||||||
|
|
||||||
|
return filtered_dataset
|
||||||
|
|
||||||
|
|
||||||
|
async def get_request(
|
||||||
|
input_requests: List[Tuple[str, int, int]],
|
||||||
|
request_rate: float,
|
||||||
|
) -> AsyncGenerator[Tuple[str, int, int], None]:
|
||||||
|
input_requests = iter(input_requests)
|
||||||
|
for request in input_requests:
|
||||||
|
yield request
|
||||||
|
|
||||||
|
if request_rate == float("inf"):
|
||||||
|
# If the request rate is infinity, then we don't need to wait.
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Sample the request interval from the exponential distribution.
|
||||||
|
interval = np.random.exponential(1.0 / request_rate)
|
||||||
|
# The next request will be sent after the interval.
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_metrics(
|
||||||
|
input_requests: List[Tuple[str, int, int]],
|
||||||
|
outputs: List[RequestFuncOutput],
|
||||||
|
dur_s: float,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
) -> Tuple[BenchmarkMetrics, List[int]]:
|
||||||
|
actual_output_lens: List[int] = []
|
||||||
|
total_input = 0
|
||||||
|
completed = 0
|
||||||
|
itls: List[float] = []
|
||||||
|
tpots: List[float] = []
|
||||||
|
ttfts: List[float] = []
|
||||||
|
for i in range(len(outputs)):
|
||||||
|
if outputs[i].success:
|
||||||
|
# We use the tokenizer to count the number of output tokens for all
|
||||||
|
# serving backends instead of looking at len(outputs[i].itl) since
|
||||||
|
# multiple output tokens may be bundled together
|
||||||
|
# Note : this may inflate the output token count slightly
|
||||||
|
output_len = len(
|
||||||
|
tokenizer(outputs[i].generated_text, add_special_tokens=False).input_ids
|
||||||
|
)
|
||||||
|
actual_output_lens.append(output_len)
|
||||||
|
total_input += input_requests[i][1]
|
||||||
|
if output_len > 1:
|
||||||
|
tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1))
|
||||||
|
itls += outputs[i].itl
|
||||||
|
ttfts.append(outputs[i].ttft)
|
||||||
|
completed += 1
|
||||||
|
else:
|
||||||
|
actual_output_lens.append(0)
|
||||||
|
|
||||||
|
if completed == 0:
|
||||||
|
warnings.warn(
|
||||||
|
"All requests failed. This is likely due to a misconfiguration "
|
||||||
|
"on the benchmark arguments.",
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
metrics = BenchmarkMetrics(
|
||||||
|
completed=completed,
|
||||||
|
total_input=total_input,
|
||||||
|
total_output=sum(actual_output_lens),
|
||||||
|
request_throughput=completed / dur_s,
|
||||||
|
input_throughput=total_input / dur_s,
|
||||||
|
output_throughput=sum(actual_output_lens) / dur_s,
|
||||||
|
mean_ttft_ms=np.mean(ttfts or 0)
|
||||||
|
* 1000, # ttfts is empty if streaming is not supported by backend
|
||||||
|
median_ttft_ms=np.median(ttfts or 0) * 1000,
|
||||||
|
std_ttft_ms=np.std(ttfts or 0) * 1000,
|
||||||
|
p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
|
||||||
|
mean_tpot_ms=np.mean(tpots or 0) * 1000,
|
||||||
|
median_tpot_ms=np.median(tpots or 0) * 1000,
|
||||||
|
std_tpot_ms=np.std(tpots or 0) * 1000,
|
||||||
|
p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,
|
||||||
|
mean_itl_ms=np.mean(itls or 0) * 1000,
|
||||||
|
median_itl_ms=np.median(itls or 0) * 1000,
|
||||||
|
std_itl_ms=np.std(itls or 0) * 1000,
|
||||||
|
p99_itl_ms=np.percentile(itls or 0, 99) * 1000,
|
||||||
|
)
|
||||||
|
|
||||||
|
return metrics, actual_output_lens
|
||||||
|
|
||||||
|
|
||||||
|
async def benchmark(
|
||||||
|
backend: str,
|
||||||
|
api_url: str,
|
||||||
|
model_id: str,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
input_requests: List[Tuple[str, int, int]],
|
||||||
|
request_rate: float,
|
||||||
|
disable_tqdm: bool,
|
||||||
|
):
|
||||||
|
if backend in ASYNC_REQUEST_FUNCS:
|
||||||
|
request_func = ASYNC_REQUEST_FUNCS[backend]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown backend: {backend}")
|
||||||
|
|
||||||
|
print("Starting initial single prompt test run...")
|
||||||
|
test_prompt, test_prompt_len, test_output_len = input_requests[0]
|
||||||
|
test_input = RequestFuncInput(
|
||||||
|
model=model_id,
|
||||||
|
prompt=test_prompt,
|
||||||
|
api_url=api_url,
|
||||||
|
prompt_len=test_prompt_len,
|
||||||
|
output_len=test_output_len,
|
||||||
|
)
|
||||||
|
test_output = await request_func(request_func_input=test_input)
|
||||||
|
if not test_output.success:
|
||||||
|
raise ValueError(
|
||||||
|
"Initial test run failed - Please make sure benchmark arguments "
|
||||||
|
f"are correctly specified. Error: {test_output.error}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("Initial test run completed. Starting main benchmark run...")
|
||||||
|
|
||||||
|
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
|
||||||
|
|
||||||
|
benchmark_start_time = time.perf_counter()
|
||||||
|
tasks: List[asyncio.Task] = []
|
||||||
|
async for request in get_request(input_requests, request_rate):
|
||||||
|
prompt, prompt_len, output_len = request
|
||||||
|
request_func_input = RequestFuncInput(
|
||||||
|
model=model_id,
|
||||||
|
prompt=prompt,
|
||||||
|
api_url=api_url,
|
||||||
|
prompt_len=prompt_len,
|
||||||
|
output_len=output_len,
|
||||||
|
)
|
||||||
|
tasks.append(
|
||||||
|
asyncio.create_task(
|
||||||
|
request_func(request_func_input=request_func_input, pbar=pbar)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
if pbar is not None:
|
||||||
|
pbar.close()
|
||||||
|
|
||||||
|
benchmark_duration = time.perf_counter() - benchmark_start_time
|
||||||
|
|
||||||
|
metrics, actual_output_lens = calculate_metrics(
|
||||||
|
input_requests=input_requests,
|
||||||
|
outputs=outputs,
|
||||||
|
dur_s=benchmark_duration,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
|
||||||
|
print("{:<40} {:<10}".format("Traffic request rate:", request_rate))
|
||||||
|
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
|
||||||
|
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
|
||||||
|
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
|
||||||
|
print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
|
||||||
|
print(
|
||||||
|
"{:<40} {:<10.2f}".format(
|
||||||
|
"Request throughput (req/s):", metrics.request_throughput
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"{:<40} {:<10.2f}".format(
|
||||||
|
"Input token throughput (tok/s):", metrics.input_throughput
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"{:<40} {:<10.2f}".format(
|
||||||
|
"Output token throughput (tok/s):", metrics.output_throughput
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print("{s:{c}^{n}}".format(s="Time to First Token", n=50, c="-"))
|
||||||
|
print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
|
||||||
|
print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms))
|
||||||
|
print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
|
||||||
|
print(
|
||||||
|
"{s:{c}^{n}}".format(s="Time per Output Token (excl. 1st token)", n=50, c="-")
|
||||||
|
)
|
||||||
|
print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
|
||||||
|
print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms))
|
||||||
|
print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
|
||||||
|
print("{s:{c}^{n}}".format(s="Inter-token Latency", n=50, c="-"))
|
||||||
|
print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
|
||||||
|
print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
|
||||||
|
print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"duration": benchmark_duration,
|
||||||
|
"completed": metrics.completed,
|
||||||
|
"total_input_tokens": metrics.total_input,
|
||||||
|
"total_output_tokens": metrics.total_output,
|
||||||
|
"request_throughput": metrics.request_throughput,
|
||||||
|
"input_throughput": metrics.input_throughput,
|
||||||
|
"output_throughput": metrics.output_throughput,
|
||||||
|
"mean_ttft_ms": metrics.mean_ttft_ms,
|
||||||
|
"median_ttft_ms": metrics.median_ttft_ms,
|
||||||
|
"std_ttft_ms": metrics.std_ttft_ms,
|
||||||
|
"p99_ttft_ms": metrics.p99_ttft_ms,
|
||||||
|
"mean_tpot_ms": metrics.mean_tpot_ms,
|
||||||
|
"median_tpot_ms": metrics.median_tpot_ms,
|
||||||
|
"std_tpot_ms": metrics.std_tpot_ms,
|
||||||
|
"p99_tpot_ms": metrics.p99_tpot_ms,
|
||||||
|
"mean_itl_ms": metrics.mean_itl_ms,
|
||||||
|
"median_itl_ms": metrics.median_itl_ms,
|
||||||
|
"std_itl_ms": metrics.std_itl_ms,
|
||||||
|
"p99_itl_ms": metrics.p99_itl_ms,
|
||||||
|
"input_lens": [output.prompt_len for output in outputs],
|
||||||
|
"output_lens": actual_output_lens,
|
||||||
|
"ttfts": [output.ttft for output in outputs],
|
||||||
|
"itls": [output.itl for output in outputs],
|
||||||
|
"generated_texts": [output.generated_text for output in outputs],
|
||||||
|
"errors": [output.error for output in outputs],
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def fire(args: argparse.Namespace):
|
||||||
|
random.seed(args.seed)
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
|
||||||
|
if args.port is None:
|
||||||
|
args.port = {
|
||||||
|
"sglang": 30000,
|
||||||
|
"lmdeploy": 23333,
|
||||||
|
"vllm": 8000,
|
||||||
|
}.get(args.backend, 30000)
|
||||||
|
|
||||||
|
api_url = (
|
||||||
|
f"{args.base_url}/v1/completions"
|
||||||
|
if args.base_url
|
||||||
|
else f"http://{args.host}:{args.port}/v1/completions"
|
||||||
|
)
|
||||||
|
model_url = (
|
||||||
|
f"{args.base_url}/v1/models"
|
||||||
|
if args.base_url
|
||||||
|
else f"http://{args.host}:{args.port}/v1/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.model is None:
|
||||||
|
try:
|
||||||
|
response = requests.get(model_url)
|
||||||
|
model_list = response.json().get("data", [])
|
||||||
|
args.model = model_list[0]["id"] if model_list else None
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to fetch model from {model_url}. Error: {e}")
|
||||||
|
print(
|
||||||
|
"Please specify the correct host and port using `--host` and `--port`."
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if args.model is None:
|
||||||
|
print("No model specified or found. Please provide a model using `--model`.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
print(f"{args}\n")
|
||||||
|
|
||||||
|
backend = args.backend
|
||||||
|
model_id = args.model
|
||||||
|
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
|
||||||
|
|
||||||
|
tokenizer = get_tokenizer(tokenizer_id)
|
||||||
|
|
||||||
|
assert args.dataset is not None
|
||||||
|
input_requests = sample_sharegpt_requests(
|
||||||
|
dataset_path=args.dataset,
|
||||||
|
num_requests=args.num_prompts,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
fixed_output_len=args.sharegpt_output_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
asyncio.run(
|
||||||
|
benchmark(
|
||||||
|
backend=backend,
|
||||||
|
api_url=api_url,
|
||||||
|
model_id=model_id,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
input_requests=input_requests,
|
||||||
|
request_rate=args.request_rate,
|
||||||
|
disable_tqdm=args.disable_tqdm,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# to avoid relying on SGLang's components
|
||||||
|
def set_ulimit(target_soft_limit=65535):
|
||||||
|
resource_type = resource.RLIMIT_NOFILE
|
||||||
|
current_soft, current_hard = resource.getrlimit(resource_type)
|
||||||
|
|
||||||
|
if current_soft < target_soft_limit:
|
||||||
|
try:
|
||||||
|
resource.setrlimit(resource_type, (target_soft_limit, current_hard))
|
||||||
|
except ValueError as e:
|
||||||
|
print(f"Fail to set RLIMIT_NOFILE: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = FlexibleArgumentParser(
|
||||||
|
description="Benchmark the online serving throughput."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--backend",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
choices=list(ASYNC_REQUEST_FUNCS.keys()),
|
||||||
|
help="Must specify a backend, depending on the LLM Inference Engine.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--base-url",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Server or API base url if not using http host and port.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--port",
|
||||||
|
type=int,
|
||||||
|
help="If not set, the default port is configured according to its default value for different LLM Inference Engines.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset", type=str, default="sharegpt", help="Path to the ShareGPT dataset"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
type=str,
|
||||||
|
help="Name or path of the model. If not set, the default model will request /v1/models for conf.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokenizer",
|
||||||
|
type=str,
|
||||||
|
help="Name or path of the tokenizer. If not set, using the model conf.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-prompts",
|
||||||
|
type=int,
|
||||||
|
default=1000,
|
||||||
|
help="Number of prompts to process. Default is 1000.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--sharegpt-output-len",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Output length for each request. Overrides the output length from the ShareGPT dataset.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--request-rate",
|
||||||
|
type=float,
|
||||||
|
default=128.0,
|
||||||
|
help="Number of requests per second. If this is inf, then all the requests are sent at time 0. "
|
||||||
|
"Otherwise, we use Poisson process to synthesize the request arrival times. Default is 128.0.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--seed", type=int, default=0, help="Default is 0.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--disable-tqdm",
|
||||||
|
action="store_true",
|
||||||
|
help="Specify to disable tqdm progress bar.",
|
||||||
|
)
|
||||||
|
|
||||||
|
set_ulimit()
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
fire(args)
|
||||||
@@ -7,6 +7,23 @@ from pydantic import BaseModel, Field
|
|||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
|
||||||
|
class ModelCard(BaseModel):
|
||||||
|
"""Model cards."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
object: str = "model"
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
owned_by: str = "sglang"
|
||||||
|
root: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ModelList(BaseModel):
|
||||||
|
"""Model list consists of model cards."""
|
||||||
|
|
||||||
|
object: str = "list"
|
||||||
|
data: List[ModelCard] = []
|
||||||
|
|
||||||
|
|
||||||
class ErrorResponse(BaseModel):
|
class ErrorResponse(BaseModel):
|
||||||
object: str = "error"
|
object: str = "error"
|
||||||
message: str
|
message: str
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ from sglang.srt.openai_api_adapter import (
|
|||||||
v1_chat_completions,
|
v1_chat_completions,
|
||||||
v1_completions,
|
v1_completions,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.openai_protocol import ModelCard, ModelList
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
API_KEY_HEADER_NAME,
|
API_KEY_HEADER_NAME,
|
||||||
@@ -73,6 +74,21 @@ async def health() -> Response:
|
|||||||
return Response(status_code=200)
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_list():
|
||||||
|
"""Available models."""
|
||||||
|
model_names = [tokenizer_manager.model_path]
|
||||||
|
return model_names
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/v1/models")
|
||||||
|
def available_models():
|
||||||
|
"""Show available models."""
|
||||||
|
model_cards = []
|
||||||
|
for model_name in get_model_list():
|
||||||
|
model_cards.append(ModelCard(id=model_name, root=model_name))
|
||||||
|
return ModelList(data=model_cards)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/get_model_info")
|
@app.get("/get_model_info")
|
||||||
async def get_model_info():
|
async def get_model_info():
|
||||||
result = {
|
result = {
|
||||||
|
|||||||
Reference in New Issue
Block a user