Improve profiler and integrate profiler in bench_one_batch_server (#6787)

This commit is contained in:
Lianmin Zheng
2025-05-31 15:53:55 -07:00
committed by GitHub
parent b520d02888
commit 2d72fc47cf
25 changed files with 481 additions and 223 deletions

View File

@@ -8,6 +8,7 @@ Usage:
python3 -m sglang.bench_one_batch_server --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --show-report --profile --profile-by-stage
"""
import argparse
@@ -19,10 +20,10 @@ import os
import time
from typing import Tuple
import numpy as np
import requests
from sglang.bench_serving import get_tokenizer, sample_random_requests
from sglang.profiler import run_profile
from sglang.srt.entrypoints.http_server import launch_server
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import kill_process_tree
@@ -42,6 +43,8 @@ class BenchArgs:
base_url: str = ""
skip_warmup: bool = False
show_report: bool = False
profile: bool = False
profile_by_stage: bool = False
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
@@ -68,6 +71,8 @@ class BenchArgs:
parser.add_argument("--base-url", type=str, default=BenchArgs.base_url)
parser.add_argument("--skip-warmup", action="store_true")
parser.add_argument("--show-report", action="store_true")
parser.add_argument("--profile", action="store_true")
parser.add_argument("--profile-by-stage", action="store_true")
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
@@ -93,8 +98,8 @@ def launch_server_process(server_args: ServerArgs):
base_url = f"http://{server_args.host}:{server_args.port}"
timeout = 600
start_time = time.perf_counter()
while time.perf_counter() - start_time < timeout:
start_time = time.time()
while time.time() - start_time < timeout:
try:
headers = {
"Content-Type": "application/json; charset=utf-8",
@@ -119,6 +124,8 @@ def run_one_case(
run_name: str,
result_filename: str,
tokenizer,
profile: bool = False,
profile_by_stage: bool = False,
):
requests.post(url + "/flush_cache")
input_requests = sample_random_requests(
@@ -145,6 +152,12 @@ def run_one_case(
else:
json_schema = None
profile_link = None
if profile:
profile_link: str = run_profile(
url, 3, ["CPU", "GPU"], None, None, profile_by_stage
)
tic = time.perf_counter()
response = requests.post(
url + "/generate",
@@ -194,8 +207,8 @@ def run_one_case(
print(f"output_len: {output_len}")
print(f"latency: {latency:.2f} s")
print(f"ttft: {ttft:.2f} s")
print(f"Last generation throughput: {last_gen_throughput:.2f} tok/s")
print(f"Input throughput: {input_throughput:.2f} tok/s")
print(f"last generation throughput: {last_gen_throughput:.2f} tok/s")
print(f"input throughput: {input_throughput:.2f} tok/s")
if output_len != 1:
print(f"output throughput: {output_throughput:.2f} tok/s")
@@ -222,6 +235,7 @@ def run_one_case(
overall_throughput,
last_gen_throughput,
acc_length,
profile_link if profile else None,
)
@@ -253,6 +267,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
# benchmark
result = []
bench_result = []
try:
for bs, il, ol in itertools.product(
bench_args.batch_size, bench_args.input_len, bench_args.output_len
@@ -271,6 +286,33 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
tokenizer=tokenizer,
)
)
if bench_args.profile:
try:
for bs, il, ol in itertools.product(
bench_args.batch_size, bench_args.input_len, bench_args.output_len
):
bench_result.append(
(
run_one_case(
base_url,
bs,
il,
ol,
temperature=bench_args.temperature,
return_logprob=bench_args.return_logprob,
input_len_step_percentage=bench_args.input_len_step_percentage,
run_name=bench_args.run_name,
result_filename=bench_args.result_filename,
tokenizer=tokenizer,
profile=bench_args.profile,
profile_by_stage=bench_args.profile_by_stage,
)[-1],
)
)
result = [t1[:-1] + t2 for t1, t2 in zip(result, bench_result)]
except Exception as e:
print(f"Error profiling, there will be no profile trace dump: {e}")
finally:
if proc:
kill_process_tree(proc.pid)
@@ -280,8 +322,20 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
if not bench_args.show_report:
return
summary = " | batch size | latency (s) | input throughput (tok/s) | output throughput (tok/s) | acc length | ITL (ms) | input price ($/1M) | output price ($/1M) |\n"
summary += "| ---------- | ----------- | ------------------------- | ------------------------- | ---------- | -------- | ------------------ | ------------------- |\n"
summary = (
f"\nInput lens: {bench_args.input_len}. Output lens: {bench_args.output_len}.\n"
)
summary += "| batch size | latency (s) | input throughput (tok/s) | output throughput (tok/s) | acc length | ITL (ms) | input cost ($/1M) | output cost ($/1M) |"
if bench_args.profile:
summary += " profile |"
summary += "\n"
summary += "| ---------- | ----------- | ------------------------- | ------------------------- | ---------- | -------- | ----------------- | ------------------ |"
if bench_args.profile:
summary += "-------------|"
summary += "\n"
for (
batch_size,
@@ -292,6 +346,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
overall_throughput,
last_gen_throughput,
acc_length,
trace_link,
) in result:
hourly_cost = 2 * server_args.tp_size # $2/hour for one H100
input_util = 0.7
@@ -304,17 +359,18 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
f"{accept_length} | "
f"{1 / (output_throughput/batch_size) * 1000:.2f} | "
f"{1e6 / (input_throughput * input_util) / 3600 * hourly_cost:.2f} | "
f"{1e6 / output_throughput / 3600 * hourly_cost:.2f} |\n"
f"{1e6 / output_throughput / 3600 * hourly_cost:.2f} |"
)
if trace_link:
line += f" [Profile]({trace_link}) |"
line += "\n"
summary += line
# print metrics table
print(summary)
if is_in_ci():
write_github_step_summary(
f"### Test Nightly Benchmark (bench_one_batch) \n{summary}"
)
write_github_step_summary(summary)
if __name__ == "__main__":