From 2b302b93938c3de5fc98c5149a7ebcce86648051 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 7 Oct 2024 00:44:38 -0700 Subject: [PATCH] Fix the port_args in bench_latency (#1597) --- python/sglang/bench_latency.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index b265745e7..ccff4524f 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -123,11 +123,10 @@ class BenchArgs: ) -def load_model(server_args, tp_rank): +def load_model(server_args, port_args, tp_rank): suppress_other_loggers() rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None - port_args = PortArgs.init_new(server_args) model_config = ModelConfig( server_args.model_path, server_args.trust_remote_code, @@ -248,13 +247,14 @@ def decode(input_token_ids, batch, model_runner): @torch.inference_mode() def correctness_test( server_args, + port_args, bench_args, tp_rank, ): rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None # Load the model - model_runner, tokenizer = load_model(server_args, tp_rank) + model_runner, tokenizer = load_model(server_args, port_args, tp_rank) # Prepare inputs input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer) @@ -362,6 +362,7 @@ def latency_test_run_once( def latency_test( server_args, + port_args, bench_args, tp_rank, ): @@ -369,7 +370,7 @@ def latency_test( rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None # Load the model - model_runner, tokenizer = load_model(server_args, tp_rank) + model_runner, tokenizer = load_model(server_args, port_args, tp_rank) # Prepare inputs for warm up reqs = prepare_synthetic_inputs_for_latency_test( @@ -487,8 +488,10 @@ def main(server_args, bench_args): "provide --result-filename for plotting the results" ) + port_args = PortArgs.init_new(server_args) + if server_args.tp_size == 1: - work_func(server_args, bench_args, 0) + work_func(server_args, port_args, bench_args, 0) else: workers = [] for tp_rank in range(server_args.tp_size): @@ -496,6 +499,7 @@ def main(server_args, bench_args): target=work_func, args=( server_args, + port_args, bench_args, tp_rank, ),