Improve bench_one_batch_server script (#9608)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -113,6 +113,7 @@ test = [
|
|||||||
"peft",
|
"peft",
|
||||||
"sentence_transformers",
|
"sentence_transformers",
|
||||||
"pytest",
|
"pytest",
|
||||||
|
"tabulate",
|
||||||
]
|
]
|
||||||
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[torch_memory_saver]", "sglang[decord]"]
|
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[torch_memory_saver]", "sglang[decord]"]
|
||||||
all_hip = ["sglang[srt_hip]", "sglang[openai]", "sglang[anthropic]", "sglang[decord]"]
|
all_hip = ["sglang[srt_hip]", "sglang[openai]", "sglang[anthropic]", "sglang[decord]"]
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ import json
|
|||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
@@ -45,6 +45,7 @@ class BenchArgs:
|
|||||||
skip_warmup: bool = False
|
skip_warmup: bool = False
|
||||||
show_report: bool = False
|
show_report: bool = False
|
||||||
profile: bool = False
|
profile: bool = False
|
||||||
|
profile_steps: int = 3
|
||||||
profile_by_stage: bool = False
|
profile_by_stage: bool = False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -78,6 +79,9 @@ class BenchArgs:
|
|||||||
parser.add_argument("--skip-warmup", action="store_true")
|
parser.add_argument("--skip-warmup", action="store_true")
|
||||||
parser.add_argument("--show-report", action="store_true")
|
parser.add_argument("--show-report", action="store_true")
|
||||||
parser.add_argument("--profile", action="store_true")
|
parser.add_argument("--profile", action="store_true")
|
||||||
|
parser.add_argument(
|
||||||
|
"--profile-steps", type=int, default=BenchArgs.profile_steps
|
||||||
|
)
|
||||||
parser.add_argument("--profile-by-stage", action="store_true")
|
parser.add_argument("--profile-by-stage", action="store_true")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -132,6 +136,7 @@ def run_one_case(
|
|||||||
result_filename: str,
|
result_filename: str,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
profile: bool = False,
|
profile: bool = False,
|
||||||
|
profile_steps: int = 3,
|
||||||
profile_by_stage: bool = False,
|
profile_by_stage: bool = False,
|
||||||
):
|
):
|
||||||
requests.post(url + "/flush_cache")
|
requests.post(url + "/flush_cache")
|
||||||
@@ -162,7 +167,7 @@ def run_one_case(
|
|||||||
profile_link = None
|
profile_link = None
|
||||||
if profile:
|
if profile:
|
||||||
profile_link: str = run_profile(
|
profile_link: str = run_profile(
|
||||||
url, 3, ["CPU", "GPU"], None, None, profile_by_stage
|
url, profile_steps, ["CPU", "GPU"], None, None, profile_by_stage
|
||||||
)
|
)
|
||||||
|
|
||||||
tic = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
@@ -247,6 +252,71 @@ def run_one_case(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_report_summary(
|
||||||
|
result: List[Tuple], server_args: ServerArgs, bench_args: BenchArgs
|
||||||
|
):
|
||||||
|
import tabulate
|
||||||
|
|
||||||
|
summary = (
|
||||||
|
f"\nInput lens: {bench_args.input_len}. Output lens: {bench_args.output_len}.\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = [
|
||||||
|
"batch size",
|
||||||
|
"latency (s)",
|
||||||
|
"input throughput (tok/s)",
|
||||||
|
"output throughput (tok/s)",
|
||||||
|
"acc length",
|
||||||
|
"ITL (ms)",
|
||||||
|
"input cost ($/1M)",
|
||||||
|
"output cost ($/1M)",
|
||||||
|
]
|
||||||
|
if bench_args.profile:
|
||||||
|
headers.append("profile")
|
||||||
|
rows = []
|
||||||
|
|
||||||
|
for (
|
||||||
|
batch_size,
|
||||||
|
latency,
|
||||||
|
ttft,
|
||||||
|
input_throughput,
|
||||||
|
output_throughput,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
acc_length,
|
||||||
|
trace_link,
|
||||||
|
) in result:
|
||||||
|
if is_blackwell():
|
||||||
|
hourly_cost_per_gpu = 4 # $4/hour for one B200
|
||||||
|
else:
|
||||||
|
hourly_cost_per_gpu = 2 # $2/hour for one H100
|
||||||
|
|
||||||
|
hourly_cost = hourly_cost_per_gpu * server_args.tp_size
|
||||||
|
input_util = 0.7
|
||||||
|
accept_length = round(acc_length, 2) if acc_length is not None else "n/a"
|
||||||
|
itl = 1 / (output_throughput / batch_size) * 1000
|
||||||
|
input_cost = 1e6 / (input_throughput * input_util) / 3600 * hourly_cost
|
||||||
|
output_cost = 1e6 / output_throughput / 3600 * hourly_cost
|
||||||
|
row = [
|
||||||
|
batch_size,
|
||||||
|
latency,
|
||||||
|
input_throughput,
|
||||||
|
output_throughput,
|
||||||
|
accept_length,
|
||||||
|
itl,
|
||||||
|
input_cost,
|
||||||
|
output_cost,
|
||||||
|
]
|
||||||
|
if trace_link:
|
||||||
|
row.append(f"[Profile]({trace_link})")
|
||||||
|
rows.append(row)
|
||||||
|
|
||||||
|
summary += tabulate.tabulate(
|
||||||
|
rows, headers=headers, tablefmt="github", floatfmt=".2f"
|
||||||
|
)
|
||||||
|
return summary
|
||||||
|
|
||||||
|
|
||||||
def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
||||||
if bench_args.base_url:
|
if bench_args.base_url:
|
||||||
proc, base_url = None, bench_args.base_url
|
proc, base_url = None, bench_args.base_url
|
||||||
@@ -321,6 +391,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
|||||||
result_filename=bench_args.result_filename,
|
result_filename=bench_args.result_filename,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
profile=bench_args.profile,
|
profile=bench_args.profile,
|
||||||
|
profile_steps=bench_args.profile_steps,
|
||||||
profile_by_stage=bench_args.profile_by_stage,
|
profile_by_stage=bench_args.profile_by_stage,
|
||||||
)[-1],
|
)[-1],
|
||||||
)
|
)
|
||||||
@@ -337,63 +408,14 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
|||||||
if not bench_args.show_report:
|
if not bench_args.show_report:
|
||||||
return
|
return
|
||||||
|
|
||||||
summary = (
|
summary = get_report_summary(result, server_args, bench_args)
|
||||||
f"\nInput lens: {bench_args.input_len}. Output lens: {bench_args.output_len}.\n"
|
|
||||||
)
|
|
||||||
summary += "| batch size | latency (s) | input throughput (tok/s) | output throughput (tok/s) | acc length | ITL (ms) | input cost ($/1M) | output cost ($/1M) |"
|
|
||||||
|
|
||||||
if bench_args.profile:
|
|
||||||
summary += " profile |"
|
|
||||||
|
|
||||||
summary += "\n"
|
|
||||||
summary += "| ---------- | ----------- | ------------------------- | ------------------------- | ---------- | -------- | ----------------- | ------------------ |"
|
|
||||||
|
|
||||||
if bench_args.profile:
|
|
||||||
summary += "-------------|"
|
|
||||||
summary += "\n"
|
|
||||||
|
|
||||||
for (
|
|
||||||
batch_size,
|
|
||||||
latency,
|
|
||||||
ttft,
|
|
||||||
input_throughput,
|
|
||||||
output_throughput,
|
|
||||||
overall_throughput,
|
|
||||||
last_gen_throughput,
|
|
||||||
acc_length,
|
|
||||||
trace_link,
|
|
||||||
) in result:
|
|
||||||
if is_blackwell():
|
|
||||||
hourly_cost_per_gpu = 4 # $4/hour for one B200
|
|
||||||
else:
|
|
||||||
hourly_cost_per_gpu = 2 # $2/hour for one H100
|
|
||||||
|
|
||||||
hourly_cost = hourly_cost_per_gpu * server_args.tp_size
|
|
||||||
input_util = 0.7
|
|
||||||
accept_length = round(acc_length, 2) if acc_length is not None else "n/a"
|
|
||||||
line = (
|
|
||||||
f"| {batch_size} | "
|
|
||||||
f"{latency:.2f} | "
|
|
||||||
f"{input_throughput:.2f} | "
|
|
||||||
f"{output_throughput:.2f} | "
|
|
||||||
f"{accept_length} | "
|
|
||||||
f"{1 / (output_throughput/batch_size) * 1000:.2f} | "
|
|
||||||
f"{1e6 / (input_throughput * input_util) / 3600 * hourly_cost:.2f} | "
|
|
||||||
f"{1e6 / output_throughput / 3600 * hourly_cost:.2f} |"
|
|
||||||
)
|
|
||||||
if trace_link:
|
|
||||||
line += f" [Profile]({trace_link}) |"
|
|
||||||
line += "\n"
|
|
||||||
summary += line
|
|
||||||
|
|
||||||
# print metrics table
|
|
||||||
print(summary)
|
print(summary)
|
||||||
|
|
||||||
if is_in_ci():
|
if is_in_ci():
|
||||||
write_github_step_summary(summary)
|
write_github_step_summary(summary)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
ServerArgs.add_cli_args(parser)
|
ServerArgs.add_cli_args(parser)
|
||||||
BenchArgs.add_cli_args(parser)
|
BenchArgs.add_cli_args(parser)
|
||||||
@@ -402,3 +424,7 @@ if __name__ == "__main__":
|
|||||||
bench_args = BenchArgs.from_cli_args(args)
|
bench_args = BenchArgs.from_cli_args(args)
|
||||||
|
|
||||||
run_benchmark(server_args, bench_args)
|
run_benchmark(server_args, bench_args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import argparse
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import urllib.parse
|
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|||||||
Reference in New Issue
Block a user