diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index 0c492626e..fb7343532 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -85,6 +85,7 @@ class BenchArgs: correctness_test: bool = False # This is only used for correctness test cut_len: int = 4 + log_decode_step: int = 0 profile: bool = False profile_filename_prefix: str = "profile" @@ -105,6 +106,12 @@ class BenchArgs: ) parser.add_argument("--correctness-test", action="store_true") parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len) + parser.add_argument( + "--log-decode-step", + type=int, + default=BenchArgs.log_decode_step, + help="Log decode latency by step, default is set to zero to disable.", + ) parser.add_argument( "--profile", action="store_true", help="Use Torch Profiler." ) @@ -335,6 +342,7 @@ def latency_test_run_once( input_len, output_len, device, + log_decode_step, profile, profile_filename_prefix, ): @@ -394,9 +402,9 @@ def latency_test_run_once( tot_latency += latency throughput = batch_size / latency decode_latencies.append(latency) - if i < 5: + if i < 5 or (log_decode_step > 0 and i % log_decode_step == 0): rank_print( - f"Decode. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s" + f"Decode {i}. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s" ) if profile: @@ -457,8 +465,9 @@ def latency_test( reqs, bench_args.batch_size[0], bench_args.input_len[0], - 8, # shorter decoding to speed up the warmup + min(32, bench_args.output_len[0]), # shorter decoding to speed up the warmup server_args.device, + log_decode_step=0, profile=False, profile_filename_prefix="", # not used ) @@ -480,6 +489,7 @@ def latency_test( il, ol, server_args.device, + bench_args.log_decode_step, bench_args.profile if tp_rank == 0 else None, bench_args.profile_filename_prefix, )