Improve benchmark scripts (#1672)
This commit is contained in:
@@ -6,6 +6,8 @@ It accepts arguments similar to those of launch_server.py.
|
||||
Usage:
|
||||
|
||||
python3 -m sglang.bench_server_latency --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8
|
||||
|
||||
python3 -m sglang.bench_server_latency --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@@ -15,7 +17,7 @@ import json
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
from typing import Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
@@ -32,6 +34,8 @@ class BenchArgs:
|
||||
input_len: Tuple[int] = (1024,)
|
||||
output_len: Tuple[int] = (16,)
|
||||
result_filename: str = "result.jsonl"
|
||||
base_url: str = ""
|
||||
skip_warmup: bool = False
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
@@ -48,6 +52,8 @@ class BenchArgs:
|
||||
parser.add_argument(
|
||||
"--result-filename", type=str, default=BenchArgs.result_filename
|
||||
)
|
||||
parser.add_argument("--base-url", type=str, default=BenchArgs.base_url)
|
||||
parser.add_argument("--skip-warmup", action="store_true")
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
@@ -139,17 +145,21 @@ def run_one_case(
|
||||
|
||||
|
||||
def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
||||
proc, base_url = launch_server_process(server_args)
|
||||
if bench_args.base_url:
|
||||
proc, base_url = None, bench_args.base_url
|
||||
else:
|
||||
proc, base_url = launch_server_process(server_args)
|
||||
|
||||
# warmup
|
||||
run_one_case(
|
||||
base_url,
|
||||
batch_size=16,
|
||||
input_len=1024,
|
||||
output_len=16,
|
||||
run_name="",
|
||||
result_filename="",
|
||||
)
|
||||
if not bench_args.skip_warmup:
|
||||
run_one_case(
|
||||
base_url,
|
||||
batch_size=16,
|
||||
input_len=1024,
|
||||
output_len=16,
|
||||
run_name="",
|
||||
result_filename="",
|
||||
)
|
||||
|
||||
# benchmark
|
||||
try:
|
||||
@@ -165,7 +175,8 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
||||
bench_args.result_filename,
|
||||
)
|
||||
finally:
|
||||
kill_child_process(proc.pid)
|
||||
if proc:
|
||||
kill_child_process(proc.pid)
|
||||
|
||||
print(f"\nResults are saved to {bench_args.result_filename}")
|
||||
|
||||
|
||||
@@ -222,6 +222,85 @@ async def async_request_openai_completions(
|
||||
return output
|
||||
|
||||
|
||||
async def async_request_sglang_generate(
|
||||
request_func_input: RequestFuncInput,
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
prompt = request_func_input.prompt
|
||||
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
payload = {
|
||||
"text": prompt,
|
||||
"sampling_params": {
|
||||
"temperature": 0.0,
|
||||
"max_new_tokens": request_func_input.output_len,
|
||||
"ignore_eos": not args.disable_ignore_eos,
|
||||
},
|
||||
"stream": not args.disable_stream,
|
||||
**request_func_input.extra_request_body,
|
||||
}
|
||||
headers = {}
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
generated_text = ""
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(
|
||||
url=api_url, json=payload, headers=headers
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
async for chunk_bytes in response.content:
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
# print(chunk_bytes)
|
||||
|
||||
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
|
||||
latency = time.perf_counter() - st
|
||||
if chunk == "[DONE]":
|
||||
pass
|
||||
else:
|
||||
data = json.loads(chunk)
|
||||
|
||||
# NOTE: Some completion API might have a last
|
||||
# usage summary response without a token so we
|
||||
# want to check a token was generated
|
||||
if data["text"]:
|
||||
timestamp = time.perf_counter()
|
||||
# First token
|
||||
if ttft == 0.0:
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp - most_recent_timestamp)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
generated_text = data["text"]
|
||||
|
||||
output.generated_text = generated_text
|
||||
output.success = True
|
||||
output.latency = latency
|
||||
output.output_len = request_func_input.output_len
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
except Exception:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
async def async_request_gserver(
|
||||
request_func_input: RequestFuncInput,
|
||||
pbar: Optional[tqdm] = None,
|
||||
@@ -264,7 +343,9 @@ def get_tokenizer(
|
||||
|
||||
|
||||
ASYNC_REQUEST_FUNCS = {
|
||||
"sglang": async_request_openai_completions,
|
||||
"sglang": async_request_sglang_generate,
|
||||
"sglang-native": async_request_sglang_generate,
|
||||
"sglang-oai": async_request_openai_completions,
|
||||
"vllm": async_request_openai_completions,
|
||||
"lmdeploy": async_request_openai_completions,
|
||||
"trt": async_request_trt_llm,
|
||||
@@ -387,6 +468,8 @@ def sample_sharegpt_requests(
|
||||
continue
|
||||
filtered_dataset.append((prompt, prompt_len, output_len))
|
||||
|
||||
print(f"#Input tokens: {np.sum([x[1] for x in filtered_dataset])}")
|
||||
print(f"#Output tokens: {np.sum([x[2] for x in filtered_dataset])}")
|
||||
return filtered_dataset
|
||||
|
||||
|
||||
@@ -784,24 +867,33 @@ def run_benchmark(args_: argparse.Namespace):
|
||||
if args.port is None:
|
||||
args.port = {
|
||||
"sglang": 30000,
|
||||
"sglang-native": 30000,
|
||||
"sglang-oai": 30000,
|
||||
"lmdeploy": 23333,
|
||||
"vllm": 8000,
|
||||
"trt": 8000,
|
||||
"gserver": 9988,
|
||||
}.get(args.backend, 30000)
|
||||
|
||||
api_url = (
|
||||
f"{args.base_url}/v1/completions"
|
||||
if args.base_url
|
||||
else f"http://{args.host}:{args.port}/v1/completions"
|
||||
)
|
||||
model_url = (
|
||||
f"{args.base_url}/v1/models"
|
||||
if args.base_url
|
||||
else f"http://{args.host}:{args.port}/v1/models"
|
||||
)
|
||||
|
||||
if args.backend == "trt":
|
||||
if args.backend in ["sglang", "sglang-native"]:
|
||||
api_url = (
|
||||
f"{args.base_url}/generate"
|
||||
if args.base_url
|
||||
else f"http://{args.host}:{args.port}/generate"
|
||||
)
|
||||
elif args.backend in ["sglang-oai", "vllm", "lmdeploy"]:
|
||||
api_url = (
|
||||
f"{args.base_url}/v1/completions"
|
||||
if args.base_url
|
||||
else f"http://{args.host}:{args.port}/v1/completions"
|
||||
)
|
||||
elif args.backend == "trt":
|
||||
api_url = (
|
||||
f"{args.base_url}/v2/models/ensemble/generate_stream"
|
||||
if args.base_url
|
||||
|
||||
Reference in New Issue
Block a user