From 4d086719e5cee5dc84d89d9b47522b11bb776157 Mon Sep 17 00:00:00 2001 From: HAI Date: Sun, 6 Oct 2024 01:09:09 -0700 Subject: [PATCH] [Bug] Fix decode stats error on output_len 1 (#1585) --- python/sglang/bench_latency.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index 2baa8e72c..b265745e7 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -340,13 +340,16 @@ def latency_test_run_once( rank_print( f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s" ) - med_decode_latency = np.median(decode_latencies) - med_decode_throughput = batch_size / med_decode_latency - rank_print( - f"Decode. median latency: {med_decode_latency:6.5f} s, median throughput: {med_decode_throughput:9.2f} token/s" - ) - measurement_results["median_decode_latency"] = med_decode_latency - measurement_results["median_decode_throughput"] = med_decode_throughput + + # record decode timing from 2nd output + if output_len > 1: + med_decode_latency = np.median(decode_latencies) + med_decode_throughput = batch_size / med_decode_latency + rank_print( + f"Decode. median latency: {med_decode_latency:6.5f} s, median throughput: {med_decode_throughput:9.2f} token/s" + ) + measurement_results["median_decode_latency"] = med_decode_latency + measurement_results["median_decode_throughput"] = med_decode_throughput throughput = (input_len + output_len) * batch_size / tot_latency rank_print( @@ -382,7 +385,7 @@ def latency_test( reqs, bench_args.batch_size[0], bench_args.input_len[0], - 4, # shorter decoding to speed up the warmup + 8, # shorter decoding to speed up the warmup ) rank_print("Benchmark ...")