[misc] more decode step log for batch_one_batch (#5565)

This commit is contained in:
JieXin Liang
2025-04-27 10:50:28 +08:00
committed by GitHub
parent 408ba02218
commit f55933e1cc

View File

@@ -85,6 +85,7 @@ class BenchArgs:
correctness_test: bool = False
# This is only used for correctness test
cut_len: int = 4
log_decode_step: int = 0
profile: bool = False
profile_filename_prefix: str = "profile"
@@ -105,6 +106,12 @@ class BenchArgs:
)
parser.add_argument("--correctness-test", action="store_true")
parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
parser.add_argument(
"--log-decode-step",
type=int,
default=BenchArgs.log_decode_step,
help="Log decode latency by step, default is set to zero to disable.",
)
parser.add_argument(
"--profile", action="store_true", help="Use Torch Profiler."
)
@@ -335,6 +342,7 @@ def latency_test_run_once(
input_len,
output_len,
device,
log_decode_step,
profile,
profile_filename_prefix,
):
@@ -394,9 +402,9 @@ def latency_test_run_once(
tot_latency += latency
throughput = batch_size / latency
decode_latencies.append(latency)
if i < 5:
if i < 5 or (log_decode_step > 0 and i % log_decode_step == 0):
rank_print(
f"Decode. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
f"Decode {i}. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
)
if profile:
@@ -457,8 +465,9 @@ def latency_test(
reqs,
bench_args.batch_size[0],
bench_args.input_len[0],
8, # shorter decoding to speed up the warmup
min(32, bench_args.output_len[0]), # shorter decoding to speed up the warmup
server_args.device,
log_decode_step=0,
profile=False,
profile_filename_prefix="", # not used
)
@@ -480,6 +489,7 @@ def latency_test(
il,
ol,
server_args.device,
bench_args.log_decode_step,
bench_args.profile if tp_rank == 0 else None,
bench_args.profile_filename_prefix,
)