Add bench_server_latency.py (#1452)
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
"""
|
||||
Benchmark the latency of a given model. It accepts arguments similar to those of launch_server.py.
|
||||
Benchmark the latency of running a single static batch.
|
||||
This script does not launch a server and uses the low-level APIs.
|
||||
It accepts arguments similar to those of launch_server.py.
|
||||
|
||||
# Usage (latency test)
|
||||
## with dummy weights:
|
||||
|
||||
187
python/sglang/bench_server_latency.py
Normal file
187
python/sglang/bench_server_latency.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""
|
||||
Benchmark the latency of serving a single batch with a real server.
|
||||
This script launches a server and uses the HTTP interface.
|
||||
It accepts arguments similar to those of launch_server.py.
|
||||
|
||||
Usage:
|
||||
|
||||
python3 -m sglang.bench_server_latency --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import itertools
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
from sglang.srt.server import launch_server
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import kill_child_process
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BenchArgs:
|
||||
run_name: str = "default"
|
||||
batch_size: Tuple[int] = (1,)
|
||||
input_len: Tuple[int] = (1024,)
|
||||
output_len: Tuple[int] = (16,)
|
||||
result_filename: str = "result.jsonl"
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--run-name", type=str, default=BenchArgs.run_name)
|
||||
parser.add_argument(
|
||||
"--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-len", type=int, nargs="+", default=BenchArgs.input_len
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-len", type=int, nargs="+", default=BenchArgs.output_len
|
||||
)
|
||||
parser.add_argument(
|
||||
"--result-filename", type=str, default=BenchArgs.result_filename
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
# use the default value's type to case the args into correct types.
|
||||
attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
|
||||
return cls(
|
||||
**{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
|
||||
)
|
||||
|
||||
|
||||
def launch_server_internal(server_args):
|
||||
try:
|
||||
launch_server(server_args)
|
||||
except Exception as e:
|
||||
raise e
|
||||
finally:
|
||||
kill_child_process(os.getpid(), including_parent=False)
|
||||
|
||||
|
||||
def launch_server_process(server_args: ServerArgs):
|
||||
proc = multiprocessing.Process(target=launch_server_internal, args=(server_args,))
|
||||
proc.start()
|
||||
base_url = f"http://{server_args.host}:{server_args.port}"
|
||||
timeout = 600
|
||||
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
headers = {
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
}
|
||||
response = requests.get(f"{base_url}/v1/models", headers=headers)
|
||||
if response.status_code == 200:
|
||||
return proc, base_url
|
||||
except requests.RequestException:
|
||||
pass
|
||||
time.sleep(10)
|
||||
raise TimeoutError("Server failed to start within the timeout period.")
|
||||
|
||||
|
||||
def run_one_case(
|
||||
url: str,
|
||||
batch_size: int,
|
||||
input_len: int,
|
||||
output_len: int,
|
||||
run_name: str,
|
||||
result_filename: str,
|
||||
):
|
||||
input_ids = [
|
||||
[int(x) for x in np.random.randint(0, high=16384, size=(input_len,))]
|
||||
for _ in range(batch_size)
|
||||
]
|
||||
|
||||
tic = time.time()
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"input_ids": input_ids,
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": output_len,
|
||||
"ignore_eos": True,
|
||||
},
|
||||
},
|
||||
)
|
||||
latency = time.time() - tic
|
||||
|
||||
_ = response.json()
|
||||
output_throughput = batch_size * output_len / latency
|
||||
overall_throughput = batch_size * (input_len + output_len) / latency
|
||||
|
||||
print(f"batch size: {batch_size}")
|
||||
print(f"latency: {latency:.2f} s")
|
||||
print(f"output throughput: {output_throughput:.2f} token/s")
|
||||
print(f"(input + output) throughput: {overall_throughput:.2f} token/s")
|
||||
|
||||
if result_filename:
|
||||
with open(result_filename, "a") as fout:
|
||||
res = {
|
||||
"run_name": run_name,
|
||||
"batch_size": batch_size,
|
||||
"input_len": input_len,
|
||||
"output_len": output_len,
|
||||
"latency": round(latency, 4),
|
||||
"output_throughput": round(output_throughput, 2),
|
||||
"overall_throughput": round(overall_throughput, 2),
|
||||
}
|
||||
fout.write(json.dumps(res) + "\n")
|
||||
|
||||
|
||||
def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
||||
proc, base_url = launch_server_process(server_args)
|
||||
|
||||
# warmup
|
||||
run_one_case(
|
||||
base_url,
|
||||
batch_size=16,
|
||||
input_len=1024,
|
||||
output_len=16,
|
||||
run_name="",
|
||||
result_filename="",
|
||||
)
|
||||
|
||||
# benchmark
|
||||
try:
|
||||
for bs, il, ol in itertools.product(
|
||||
bench_args.batch_size, bench_args.input_len, bench_args.output_len
|
||||
):
|
||||
run_one_case(
|
||||
base_url,
|
||||
bs,
|
||||
il,
|
||||
ol,
|
||||
bench_args.run_name,
|
||||
bench_args.result_filename,
|
||||
)
|
||||
finally:
|
||||
kill_child_process(proc.pid)
|
||||
|
||||
print(f"\nResults are saved to {bench_args.result_filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
ServerArgs.add_cli_args(parser)
|
||||
BenchArgs.add_cli_args(parser)
|
||||
# For this script, model-path is not required
|
||||
assert (
|
||||
parser._actions[1].option_strings[0] == "--model-path"
|
||||
), "options changed, this code need to be updated"
|
||||
parser._actions[1].required = False
|
||||
args = parser.parse_args()
|
||||
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
bench_args = BenchArgs.from_cli_args(args)
|
||||
|
||||
run_benchmark(server_args, bench_args)
|
||||
@@ -2,7 +2,7 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/benchmark_serving.py
|
||||
|
||||
"""
|
||||
Benchmark online serving.
|
||||
Benchmark online serving with dynamic requests.
|
||||
|
||||
Usage:
|
||||
python3 -m sglang.bench_serving --backend sglang --num-prompt 10
|
||||
|
||||
@@ -26,17 +26,6 @@ from sglang.srt.utils import is_hip
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoRAPathAction(argparse.Action):
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
setattr(namespace, self.dest, {})
|
||||
for lora_path in values:
|
||||
if "=" in lora_path:
|
||||
name, path = lora_path.split("=", 1)
|
||||
getattr(namespace, self.dest)[name] = path
|
||||
else:
|
||||
getattr(namespace, self.dest)[lora_path] = lora_path
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ServerArgs:
|
||||
# Model and tokenizer
|
||||
@@ -619,3 +608,14 @@ class PortArgs:
|
||||
controller_port: int
|
||||
detokenizer_port: int
|
||||
nccl_ports: List[int]
|
||||
|
||||
|
||||
class LoRAPathAction(argparse.Action):
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
setattr(namespace, self.dest, {})
|
||||
for lora_path in values:
|
||||
if "=" in lora_path:
|
||||
name, path = lora_path.split("=", 1)
|
||||
getattr(namespace, self.dest)[name] = path
|
||||
else:
|
||||
getattr(namespace, self.dest)[lora_path] = lora_path
|
||||
|
||||
@@ -44,7 +44,7 @@ def get_answer_value(answer_str):
|
||||
return INVALID
|
||||
|
||||
|
||||
def main(args):
|
||||
def run_eval(args):
|
||||
# Select backend
|
||||
set_default_backend(RuntimeEndpoint(f"{args.host}:{args.port}"))
|
||||
|
||||
@@ -119,6 +119,12 @@ def main(args):
|
||||
# Dump results
|
||||
dump_state_text("tmp_output_gsm8k.txt", states)
|
||||
|
||||
return {
|
||||
"accuracy": acc,
|
||||
"latency": latency,
|
||||
"output_throughput": output_throughput,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
@@ -129,4 +135,4 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=30000)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
run_eval(args)
|
||||
|
||||
Reference in New Issue
Block a user