Improve bench_one_batch_server script (#9608)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Liangsheng Yin
2025-08-26 10:38:37 +08:00
committed by GitHub
parent 80dc76e11a
commit 0ff7241995
3 changed files with 80 additions and 54 deletions

View File

@@ -113,6 +113,7 @@ test = [
"peft", "peft",
"sentence_transformers", "sentence_transformers",
"pytest", "pytest",
"tabulate",
] ]
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[torch_memory_saver]", "sglang[decord]"] all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[torch_memory_saver]", "sglang[decord]"]
all_hip = ["sglang[srt_hip]", "sglang[openai]", "sglang[anthropic]", "sglang[decord]"] all_hip = ["sglang[srt_hip]", "sglang[openai]", "sglang[anthropic]", "sglang[decord]"]

View File

@@ -18,7 +18,7 @@ import json
import multiprocessing import multiprocessing
import os import os
import time import time
from typing import Tuple from typing import List, Tuple
import requests import requests
@@ -45,6 +45,7 @@ class BenchArgs:
skip_warmup: bool = False skip_warmup: bool = False
show_report: bool = False show_report: bool = False
profile: bool = False profile: bool = False
profile_steps: int = 3
profile_by_stage: bool = False profile_by_stage: bool = False
@staticmethod @staticmethod
@@ -78,6 +79,9 @@ class BenchArgs:
parser.add_argument("--skip-warmup", action="store_true") parser.add_argument("--skip-warmup", action="store_true")
parser.add_argument("--show-report", action="store_true") parser.add_argument("--show-report", action="store_true")
parser.add_argument("--profile", action="store_true") parser.add_argument("--profile", action="store_true")
parser.add_argument(
"--profile-steps", type=int, default=BenchArgs.profile_steps
)
parser.add_argument("--profile-by-stage", action="store_true") parser.add_argument("--profile-by-stage", action="store_true")
@classmethod @classmethod
@@ -132,6 +136,7 @@ def run_one_case(
result_filename: str, result_filename: str,
tokenizer, tokenizer,
profile: bool = False, profile: bool = False,
profile_steps: int = 3,
profile_by_stage: bool = False, profile_by_stage: bool = False,
): ):
requests.post(url + "/flush_cache") requests.post(url + "/flush_cache")
@@ -162,7 +167,7 @@ def run_one_case(
profile_link = None profile_link = None
if profile: if profile:
profile_link: str = run_profile( profile_link: str = run_profile(
url, 3, ["CPU", "GPU"], None, None, profile_by_stage url, profile_steps, ["CPU", "GPU"], None, None, profile_by_stage
) )
tic = time.perf_counter() tic = time.perf_counter()
@@ -247,6 +252,71 @@ def run_one_case(
) )
def get_report_summary(
result: List[Tuple], server_args: ServerArgs, bench_args: BenchArgs
):
import tabulate
summary = (
f"\nInput lens: {bench_args.input_len}. Output lens: {bench_args.output_len}.\n"
)
headers = [
"batch size",
"latency (s)",
"input throughput (tok/s)",
"output throughput (tok/s)",
"acc length",
"ITL (ms)",
"input cost ($/1M)",
"output cost ($/1M)",
]
if bench_args.profile:
headers.append("profile")
rows = []
for (
batch_size,
latency,
ttft,
input_throughput,
output_throughput,
_,
_,
acc_length,
trace_link,
) in result:
if is_blackwell():
hourly_cost_per_gpu = 4 # $4/hour for one B200
else:
hourly_cost_per_gpu = 2 # $2/hour for one H100
hourly_cost = hourly_cost_per_gpu * server_args.tp_size
input_util = 0.7
accept_length = round(acc_length, 2) if acc_length is not None else "n/a"
itl = 1 / (output_throughput / batch_size) * 1000
input_cost = 1e6 / (input_throughput * input_util) / 3600 * hourly_cost
output_cost = 1e6 / output_throughput / 3600 * hourly_cost
row = [
batch_size,
latency,
input_throughput,
output_throughput,
accept_length,
itl,
input_cost,
output_cost,
]
if trace_link:
row.append(f"[Profile]({trace_link})")
rows.append(row)
summary += tabulate.tabulate(
rows, headers=headers, tablefmt="github", floatfmt=".2f"
)
return summary
def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
if bench_args.base_url: if bench_args.base_url:
proc, base_url = None, bench_args.base_url proc, base_url = None, bench_args.base_url
@@ -321,6 +391,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
result_filename=bench_args.result_filename, result_filename=bench_args.result_filename,
tokenizer=tokenizer, tokenizer=tokenizer,
profile=bench_args.profile, profile=bench_args.profile,
profile_steps=bench_args.profile_steps,
profile_by_stage=bench_args.profile_by_stage, profile_by_stage=bench_args.profile_by_stage,
)[-1], )[-1],
) )
@@ -337,63 +408,14 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
if not bench_args.show_report: if not bench_args.show_report:
return return
summary = ( summary = get_report_summary(result, server_args, bench_args)
f"\nInput lens: {bench_args.input_len}. Output lens: {bench_args.output_len}.\n"
)
summary += "| batch size | latency (s) | input throughput (tok/s) | output throughput (tok/s) | acc length | ITL (ms) | input cost ($/1M) | output cost ($/1M) |"
if bench_args.profile:
summary += " profile |"
summary += "\n"
summary += "| ---------- | ----------- | ------------------------- | ------------------------- | ---------- | -------- | ----------------- | ------------------ |"
if bench_args.profile:
summary += "-------------|"
summary += "\n"
for (
batch_size,
latency,
ttft,
input_throughput,
output_throughput,
overall_throughput,
last_gen_throughput,
acc_length,
trace_link,
) in result:
if is_blackwell():
hourly_cost_per_gpu = 4 # $4/hour for one B200
else:
hourly_cost_per_gpu = 2 # $2/hour for one H100
hourly_cost = hourly_cost_per_gpu * server_args.tp_size
input_util = 0.7
accept_length = round(acc_length, 2) if acc_length is not None else "n/a"
line = (
f"| {batch_size} | "
f"{latency:.2f} | "
f"{input_throughput:.2f} | "
f"{output_throughput:.2f} | "
f"{accept_length} | "
f"{1 / (output_throughput/batch_size) * 1000:.2f} | "
f"{1e6 / (input_throughput * input_util) / 3600 * hourly_cost:.2f} | "
f"{1e6 / output_throughput / 3600 * hourly_cost:.2f} |"
)
if trace_link:
line += f" [Profile]({trace_link}) |"
line += "\n"
summary += line
# print metrics table
print(summary) print(summary)
if is_in_ci(): if is_in_ci():
write_github_step_summary(summary) write_github_step_summary(summary)
if __name__ == "__main__": def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser) ServerArgs.add_cli_args(parser)
BenchArgs.add_cli_args(parser) BenchArgs.add_cli_args(parser)
@@ -402,3 +424,7 @@ if __name__ == "__main__":
bench_args = BenchArgs.from_cli_args(args) bench_args = BenchArgs.from_cli_args(args)
run_benchmark(server_args, bench_args) run_benchmark(server_args, bench_args)
if __name__ == "__main__":
main()

View File

@@ -9,7 +9,6 @@ import argparse
import json import json
import os import os
import time import time
import urllib.parse
from argparse import ArgumentParser from argparse import ArgumentParser
from pathlib import Path from pathlib import Path
from typing import List, Optional from typing import List, Optional