diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index 63787addf..99fba8be9 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -9,7 +9,8 @@ It accepts server arguments (the same as launch_server.py) and benchmark argumen python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy ## sweep through multiple data points and store (append) the results in a jsonl file: python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --run-name test_run - +## run with profiling: +python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --profile # Usage (correctness test): python -m sglang.bench_one_batch --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct @@ -77,6 +78,8 @@ class BenchArgs: correctness_test: bool = False # This is only used for correctness test cut_len: int = 4 + profile: bool = False + profile_filename_prefix: str = "profile" @staticmethod def add_cli_args(parser: argparse.ArgumentParser): @@ -95,6 +98,19 @@ class BenchArgs: ) parser.add_argument("--correctness-test", action="store_true") parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len) + parser.add_argument( + "--profile", + action="store_true", + help="Use Torch Profiler. The endpoint must be launched with " + "SGLANG_TORCH_PROFILER_DIR to enable profiler.", + ) + parser.add_argument( + "--profile-filename-prefix", + type=str, + default=BenchArgs.profile_filename_prefix, + help="Prefix of the profiling file names. The full profiling result file(s) be " + '"[profile_filename_prefix]_batch[batch_size]_input[input_len]_output[output_len].trace.json.gz"', + ) @classmethod def from_cli_args(cls, args: argparse.Namespace): @@ -286,7 +302,16 @@ def synchronize(device): def latency_test_run_once( - run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len, device + run_name, + model_runner, + rank_print, + reqs, + batch_size, + input_len, + output_len, + device, + profile, + profile_filename_prefix, ): max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len) if batch_size > max_batch_size: @@ -308,6 +333,17 @@ def latency_test_run_once( tot_latency = 0 + profiler = None + if profile: + profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + with_stack=True, + ) + profiler.start() + # Prefill synchronize(device) tic = time.time() @@ -338,6 +374,13 @@ def latency_test_run_once( f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s" ) + if profile: + profiler.stop() + profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}.trace.json.gz" + parent_dir = os.path.dirname(os.path.abspath(profile_filename)) + os.makedirs(parent_dir, exist_ok=True) + profiler.export_chrome_trace(profile_filename) + # Record decode timing from 2nd output if output_len > 1: med_decode_latency = np.median(decode_latencies) @@ -386,6 +429,8 @@ def latency_test( bench_args.input_len[0], 8, # shorter decoding to speed up the warmup server_args.device, + profile=False, + profile_filename_prefix="", # not used ) rank_print("Benchmark ...") @@ -405,6 +450,8 @@ def latency_test( il, ol, server_args.device, + bench_args.profile, + bench_args.profile_filename_prefix, ) if ret is not None: result_list.append(ret)