[misc] more decode step log for batch_one_batch (#5565)
This commit is contained in:
@@ -85,6 +85,7 @@ class BenchArgs:
|
|||||||
correctness_test: bool = False
|
correctness_test: bool = False
|
||||||
# This is only used for correctness test
|
# This is only used for correctness test
|
||||||
cut_len: int = 4
|
cut_len: int = 4
|
||||||
|
log_decode_step: int = 0
|
||||||
profile: bool = False
|
profile: bool = False
|
||||||
profile_filename_prefix: str = "profile"
|
profile_filename_prefix: str = "profile"
|
||||||
|
|
||||||
@@ -105,6 +106,12 @@ class BenchArgs:
|
|||||||
)
|
)
|
||||||
parser.add_argument("--correctness-test", action="store_true")
|
parser.add_argument("--correctness-test", action="store_true")
|
||||||
parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
|
parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
|
||||||
|
parser.add_argument(
|
||||||
|
"--log-decode-step",
|
||||||
|
type=int,
|
||||||
|
default=BenchArgs.log_decode_step,
|
||||||
|
help="Log decode latency by step, default is set to zero to disable.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--profile", action="store_true", help="Use Torch Profiler."
|
"--profile", action="store_true", help="Use Torch Profiler."
|
||||||
)
|
)
|
||||||
@@ -335,6 +342,7 @@ def latency_test_run_once(
|
|||||||
input_len,
|
input_len,
|
||||||
output_len,
|
output_len,
|
||||||
device,
|
device,
|
||||||
|
log_decode_step,
|
||||||
profile,
|
profile,
|
||||||
profile_filename_prefix,
|
profile_filename_prefix,
|
||||||
):
|
):
|
||||||
@@ -394,9 +402,9 @@ def latency_test_run_once(
|
|||||||
tot_latency += latency
|
tot_latency += latency
|
||||||
throughput = batch_size / latency
|
throughput = batch_size / latency
|
||||||
decode_latencies.append(latency)
|
decode_latencies.append(latency)
|
||||||
if i < 5:
|
if i < 5 or (log_decode_step > 0 and i % log_decode_step == 0):
|
||||||
rank_print(
|
rank_print(
|
||||||
f"Decode. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
f"Decode {i}. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
||||||
)
|
)
|
||||||
|
|
||||||
if profile:
|
if profile:
|
||||||
@@ -457,8 +465,9 @@ def latency_test(
|
|||||||
reqs,
|
reqs,
|
||||||
bench_args.batch_size[0],
|
bench_args.batch_size[0],
|
||||||
bench_args.input_len[0],
|
bench_args.input_len[0],
|
||||||
8, # shorter decoding to speed up the warmup
|
min(32, bench_args.output_len[0]), # shorter decoding to speed up the warmup
|
||||||
server_args.device,
|
server_args.device,
|
||||||
|
log_decode_step=0,
|
||||||
profile=False,
|
profile=False,
|
||||||
profile_filename_prefix="", # not used
|
profile_filename_prefix="", # not used
|
||||||
)
|
)
|
||||||
@@ -480,6 +489,7 @@ def latency_test(
|
|||||||
il,
|
il,
|
||||||
ol,
|
ol,
|
||||||
server_args.device,
|
server_args.device,
|
||||||
|
bench_args.log_decode_step,
|
||||||
bench_args.profile if tp_rank == 0 else None,
|
bench_args.profile if tp_rank == 0 else None,
|
||||||
bench_args.profile_filename_prefix,
|
bench_args.profile_filename_prefix,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user