[Bug] Fix decode stats error on output_len 1 (#1585)
This commit is contained in:
@@ -340,6 +340,9 @@ def latency_test_run_once(
|
|||||||
rank_print(
|
rank_print(
|
||||||
f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# record decode timing from 2nd output
|
||||||
|
if output_len > 1:
|
||||||
med_decode_latency = np.median(decode_latencies)
|
med_decode_latency = np.median(decode_latencies)
|
||||||
med_decode_throughput = batch_size / med_decode_latency
|
med_decode_throughput = batch_size / med_decode_latency
|
||||||
rank_print(
|
rank_print(
|
||||||
@@ -382,7 +385,7 @@ def latency_test(
|
|||||||
reqs,
|
reqs,
|
||||||
bench_args.batch_size[0],
|
bench_args.batch_size[0],
|
||||||
bench_args.input_len[0],
|
bench_args.input_len[0],
|
||||||
4, # shorter decoding to speed up the warmup
|
8, # shorter decoding to speed up the warmup
|
||||||
)
|
)
|
||||||
rank_print("Benchmark ...")
|
rank_print("Benchmark ...")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user