Improve profiler and integrate profiler in bench_one_batch_server (#6787)
This commit is contained in:
@@ -8,6 +8,7 @@ Usage:
|
||||
python3 -m sglang.bench_one_batch_server --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8
|
||||
|
||||
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
|
||||
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --show-report --profile --profile-by-stage
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@@ -19,10 +20,10 @@ import os
|
||||
import time
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
from sglang.bench_serving import get_tokenizer, sample_random_requests
|
||||
from sglang.profiler import run_profile
|
||||
from sglang.srt.entrypoints.http_server import launch_server
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
@@ -42,6 +43,8 @@ class BenchArgs:
|
||||
base_url: str = ""
|
||||
skip_warmup: bool = False
|
||||
show_report: bool = False
|
||||
profile: bool = False
|
||||
profile_by_stage: bool = False
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
@@ -68,6 +71,8 @@ class BenchArgs:
|
||||
parser.add_argument("--base-url", type=str, default=BenchArgs.base_url)
|
||||
parser.add_argument("--skip-warmup", action="store_true")
|
||||
parser.add_argument("--show-report", action="store_true")
|
||||
parser.add_argument("--profile", action="store_true")
|
||||
parser.add_argument("--profile-by-stage", action="store_true")
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
@@ -93,8 +98,8 @@ def launch_server_process(server_args: ServerArgs):
|
||||
base_url = f"http://{server_args.host}:{server_args.port}"
|
||||
timeout = 600
|
||||
|
||||
start_time = time.perf_counter()
|
||||
while time.perf_counter() - start_time < timeout:
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
headers = {
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
@@ -119,6 +124,8 @@ def run_one_case(
|
||||
run_name: str,
|
||||
result_filename: str,
|
||||
tokenizer,
|
||||
profile: bool = False,
|
||||
profile_by_stage: bool = False,
|
||||
):
|
||||
requests.post(url + "/flush_cache")
|
||||
input_requests = sample_random_requests(
|
||||
@@ -145,6 +152,12 @@ def run_one_case(
|
||||
else:
|
||||
json_schema = None
|
||||
|
||||
profile_link = None
|
||||
if profile:
|
||||
profile_link: str = run_profile(
|
||||
url, 3, ["CPU", "GPU"], None, None, profile_by_stage
|
||||
)
|
||||
|
||||
tic = time.perf_counter()
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
@@ -194,8 +207,8 @@ def run_one_case(
|
||||
print(f"output_len: {output_len}")
|
||||
print(f"latency: {latency:.2f} s")
|
||||
print(f"ttft: {ttft:.2f} s")
|
||||
print(f"Last generation throughput: {last_gen_throughput:.2f} tok/s")
|
||||
print(f"Input throughput: {input_throughput:.2f} tok/s")
|
||||
print(f"last generation throughput: {last_gen_throughput:.2f} tok/s")
|
||||
print(f"input throughput: {input_throughput:.2f} tok/s")
|
||||
if output_len != 1:
|
||||
print(f"output throughput: {output_throughput:.2f} tok/s")
|
||||
|
||||
@@ -222,6 +235,7 @@ def run_one_case(
|
||||
overall_throughput,
|
||||
last_gen_throughput,
|
||||
acc_length,
|
||||
profile_link if profile else None,
|
||||
)
|
||||
|
||||
|
||||
@@ -253,6 +267,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
||||
|
||||
# benchmark
|
||||
result = []
|
||||
bench_result = []
|
||||
try:
|
||||
for bs, il, ol in itertools.product(
|
||||
bench_args.batch_size, bench_args.input_len, bench_args.output_len
|
||||
@@ -271,6 +286,33 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
)
|
||||
|
||||
if bench_args.profile:
|
||||
try:
|
||||
for bs, il, ol in itertools.product(
|
||||
bench_args.batch_size, bench_args.input_len, bench_args.output_len
|
||||
):
|
||||
bench_result.append(
|
||||
(
|
||||
run_one_case(
|
||||
base_url,
|
||||
bs,
|
||||
il,
|
||||
ol,
|
||||
temperature=bench_args.temperature,
|
||||
return_logprob=bench_args.return_logprob,
|
||||
input_len_step_percentage=bench_args.input_len_step_percentage,
|
||||
run_name=bench_args.run_name,
|
||||
result_filename=bench_args.result_filename,
|
||||
tokenizer=tokenizer,
|
||||
profile=bench_args.profile,
|
||||
profile_by_stage=bench_args.profile_by_stage,
|
||||
)[-1],
|
||||
)
|
||||
)
|
||||
result = [t1[:-1] + t2 for t1, t2 in zip(result, bench_result)]
|
||||
except Exception as e:
|
||||
print(f"Error profiling, there will be no profile trace dump: {e}")
|
||||
finally:
|
||||
if proc:
|
||||
kill_process_tree(proc.pid)
|
||||
@@ -280,8 +322,20 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
||||
if not bench_args.show_report:
|
||||
return
|
||||
|
||||
summary = " | batch size | latency (s) | input throughput (tok/s) | output throughput (tok/s) | acc length | ITL (ms) | input price ($/1M) | output price ($/1M) |\n"
|
||||
summary += "| ---------- | ----------- | ------------------------- | ------------------------- | ---------- | -------- | ------------------ | ------------------- |\n"
|
||||
summary = (
|
||||
f"\nInput lens: {bench_args.input_len}. Output lens: {bench_args.output_len}.\n"
|
||||
)
|
||||
summary += "| batch size | latency (s) | input throughput (tok/s) | output throughput (tok/s) | acc length | ITL (ms) | input cost ($/1M) | output cost ($/1M) |"
|
||||
|
||||
if bench_args.profile:
|
||||
summary += " profile |"
|
||||
|
||||
summary += "\n"
|
||||
summary += "| ---------- | ----------- | ------------------------- | ------------------------- | ---------- | -------- | ----------------- | ------------------ |"
|
||||
|
||||
if bench_args.profile:
|
||||
summary += "-------------|"
|
||||
summary += "\n"
|
||||
|
||||
for (
|
||||
batch_size,
|
||||
@@ -292,6 +346,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
||||
overall_throughput,
|
||||
last_gen_throughput,
|
||||
acc_length,
|
||||
trace_link,
|
||||
) in result:
|
||||
hourly_cost = 2 * server_args.tp_size # $2/hour for one H100
|
||||
input_util = 0.7
|
||||
@@ -304,17 +359,18 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
||||
f"{accept_length} | "
|
||||
f"{1 / (output_throughput/batch_size) * 1000:.2f} | "
|
||||
f"{1e6 / (input_throughput * input_util) / 3600 * hourly_cost:.2f} | "
|
||||
f"{1e6 / output_throughput / 3600 * hourly_cost:.2f} |\n"
|
||||
f"{1e6 / output_throughput / 3600 * hourly_cost:.2f} |"
|
||||
)
|
||||
if trace_link:
|
||||
line += f" [Profile]({trace_link}) |"
|
||||
line += "\n"
|
||||
summary += line
|
||||
|
||||
# print metrics table
|
||||
print(summary)
|
||||
|
||||
if is_in_ci():
|
||||
write_github_step_summary(
|
||||
f"### Test Nightly Benchmark (bench_one_batch) \n{summary}"
|
||||
)
|
||||
write_github_step_summary(summary)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user