diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index a3bf9158c..ee35960ba 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -220,6 +220,68 @@ def correctness_test( rank_print(tokenizer.decode(output_ids[i])) +@torch.inference_mode() +def latency_test_run_once( + model_runner, rank_print, reqs, batch_size, input_len, output_len +): + + # Clear the pools. + model_runner.req_to_token_pool.clear() + model_runner.token_to_kv_pool.clear() + + measurement_results = { + "run_name": "before", + "batch_size": batch_size, + "input_len": input_len, + "output_len": output_len, + } + + tot_latency = 0 + + # Prefill + torch.cuda.synchronize() + tic = time.time() + next_token_ids, _, batch = extend(reqs, model_runner) + torch.cuda.synchronize() + prefill_latency = time.time() - tic + tot_latency += prefill_latency + throughput = input_len * batch_size / prefill_latency + rank_print( + f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s" + ) + measurement_results["prefill_latency"] = prefill_latency + measurement_results["prefill_throughput"] = throughput + + # Decode + for i in range(output_len): + torch.cuda.synchronize() + tic = time.time() + next_token_ids, _ = decode(next_token_ids, batch, model_runner) + torch.cuda.synchronize() + latency = time.time() - tic + tot_latency += latency + throughput = batch_size / latency + if i < 5: + rank_print( + f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s" + ) + avg_decode_latency = (tot_latency - prefill_latency) / output_len + avg_decode_throughput = batch_size / avg_decode_latency + rank_print( + f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s" + ) + measurement_results["avg_decode_latency"] = avg_decode_latency + measurement_results["avg_decode_throughput"] = avg_decode_throughput + + throughput = (input_len + output_len) * batch_size / tot_latency + rank_print( + f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s" + ) + measurement_results["total_latency"] = tot_latency + measurement_results["total_throughput"] = throughput + return measurement_results + + def latency_test( server_args, bench_args, @@ -241,72 +303,23 @@ def latency_test( bench_args.batch_size, bench_args.input_len ) - def clear(): - model_runner.req_to_token_pool.clear() - model_runner.token_to_kv_pool.clear() - - @torch.inference_mode() - def run_once(output_len): - measurement_results = { - "batch_size": bench_args.batch_size, - "output_len": output_len, - } - - # Prefill - torch.cuda.synchronize() - tot_latency = 0 - tic = time.time() - next_token_ids, _, batch = extend(reqs, model_runner) - torch.cuda.synchronize() - prefill_latency = time.time() - tic - tot_latency += prefill_latency - throughput = bench_args.input_len * bench_args.batch_size / prefill_latency - rank_print( - f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s" - ) - measurement_results["prefill_latency"] = prefill_latency - measurement_results["prefill_throughput"] = throughput - - # Decode - for i in range(output_len): - torch.cuda.synchronize() - tic = time.time() - next_token_ids, _ = decode(next_token_ids, batch, model_runner) - torch.cuda.synchronize() - latency = time.time() - tic - tot_latency += latency - throughput = bench_args.batch_size / latency - if i < 5: - rank_print( - f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s" - ) - avg_decode_latency = (tot_latency - prefill_latency) / output_len - avg_decode_throughput = bench_args.batch_size / avg_decode_latency - rank_print( - f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s" - ) - measurement_results["avg_decode_latency"] = avg_decode_latency - measurement_results["avg_decode_throughput"] = avg_decode_throughput - - throughput = ( - (bench_args.input_len + bench_args.output_len) - * bench_args.batch_size - / tot_latency - ) - rank_print( - f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s" - ) - measurement_results["total_latency"] = tot_latency - measurement_results["total_throughput"] = throughput - return measurement_results - # Warm up - run_once(4) - clear() + latency_test_run_once( + model_runner, rank_print, reqs, bench_args.batch_size, bench_args.input_len, 4 + ) # Run again result_list = [] - result_list.append(run_once(bench_args.output_len)) + result_list.append( + latency_test_run_once( + model_runner, + rank_print, + reqs, + bench_args.batch_size, + bench_args.input_len, + bench_args.output_len, + ) + ) # Write results in jsonlines format. if bench_args.result_filename: