diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index 900272282..a163cbd30 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -81,6 +81,7 @@ def load_model(server_args, tp_rank): nccl_port=28888, server_args=server_args, ) + print(f"max_total_num_tokens={model_runner.max_total_num_tokens}") tokenizer = get_tokenizer( server_args.tokenizer_path, tokenizer_mode=server_args.tokenizer_mode, @@ -209,6 +210,7 @@ def latency_test( # Load the model model_runner, tokenizer = load_model(server_args, tp_rank) + print(f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}") # Prepare inputs reqs = prepare_synthetic_inputs(bench_args, tokenizer) @@ -221,22 +223,31 @@ def latency_test( def run_once(output_len): # Prefill torch.cuda.synchronize() + tot_latency = 0 tic = time.time() next_token_ids, _, batch = extend(reqs, model_runner) torch.cuda.synchronize() - latency = time.time() - tic - throughput = bench_args.input_len * bench_args.batch_size / latency - rank_print(f"Prefill. latency: {latency:6.3f} ms, throughput: {throughput:9.2f} token/s") + prefill_latency = time.time() - tic + tot_latency += prefill_latency + throughput = bench_args.input_len * bench_args.batch_size / prefill_latency + rank_print(f"Prefill. latency: {prefill_latency:6.5f} ms, throughput: {throughput:9.2f} token/s") # Decode - for _ in range(output_len): + for i in range(output_len): torch.cuda.synchronize() tic = time.time() next_token_ids, _ = decode(next_token_ids, batch, model_runner) torch.cuda.synchronize() latency = time.time() - tic + tot_latency += latency throughput = bench_args.batch_size / latency - rank_print(f"Decode. latency: {latency:6.3f} ms, throughput: {throughput:9.2f} token/s") + if i < 5: rank_print(f"Decode. latency: {latency:6.5f} ms, throughput: {throughput:9.2f} token/s") + avg_decode_latency = (tot_latency - prefill_latency) / output_len + avg_decode_throughput = bench_args.batch_size / avg_decode_latency + rank_print(f"Decode. avg latency: {avg_decode_latency:6.5f} ms, avg throughput: {avg_decode_throughput:9.2f} token/s") + + throughput = (bench_args.input_len + bench_args.output_len) * bench_args.batch_size / tot_latency + rank_print(f"Total. latency: {tot_latency:6.3f} ms, throughput: {throughput:9.2f} token/s") # Warm up run_once(4)