[Bug] Fix decode stats error on output_len 1 (#1585)

This commit is contained in:
HAI
2024-10-06 01:09:09 -07:00
committed by GitHub
parent 9244f27f0a
commit 4d086719e5

View File

@@ -340,13 +340,16 @@ def latency_test_run_once(
rank_print(
f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
)
med_decode_latency = np.median(decode_latencies)
med_decode_throughput = batch_size / med_decode_latency
rank_print(
f"Decode. median latency: {med_decode_latency:6.5f} s, median throughput: {med_decode_throughput:9.2f} token/s"
)
measurement_results["median_decode_latency"] = med_decode_latency
measurement_results["median_decode_throughput"] = med_decode_throughput
# record decode timing from 2nd output
if output_len > 1:
med_decode_latency = np.median(decode_latencies)
med_decode_throughput = batch_size / med_decode_latency
rank_print(
f"Decode. median latency: {med_decode_latency:6.5f} s, median throughput: {med_decode_throughput:9.2f} token/s"
)
measurement_results["median_decode_latency"] = med_decode_latency
measurement_results["median_decode_throughput"] = med_decode_throughput
throughput = (input_len + output_len) * batch_size / tot_latency
rank_print(
@@ -382,7 +385,7 @@ def latency_test(
reqs,
bench_args.batch_size[0],
bench_args.input_len[0],
4, # shorter decoding to speed up the warmup
8, # shorter decoding to speed up the warmup
)
rank_print("Benchmark ...")