diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index b265745e7..ccff4524f 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -123,11 +123,10 @@ class BenchArgs: ) -def load_model(server_args, tp_rank): +def load_model(server_args, port_args, tp_rank): suppress_other_loggers() rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None - port_args = PortArgs.init_new(server_args) model_config = ModelConfig( server_args.model_path, server_args.trust_remote_code, @@ -248,13 +247,14 @@ def decode(input_token_ids, batch, model_runner): @torch.inference_mode() def correctness_test( server_args, + port_args, bench_args, tp_rank, ): rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None # Load the model - model_runner, tokenizer = load_model(server_args, tp_rank) + model_runner, tokenizer = load_model(server_args, port_args, tp_rank) # Prepare inputs input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer) @@ -362,6 +362,7 @@ def latency_test_run_once( def latency_test( server_args, + port_args, bench_args, tp_rank, ): @@ -369,7 +370,7 @@ def latency_test( rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None # Load the model - model_runner, tokenizer = load_model(server_args, tp_rank) + model_runner, tokenizer = load_model(server_args, port_args, tp_rank) # Prepare inputs for warm up reqs = prepare_synthetic_inputs_for_latency_test( @@ -487,8 +488,10 @@ def main(server_args, bench_args): "provide --result-filename for plotting the results" ) + port_args = PortArgs.init_new(server_args) + if server_args.tp_size == 1: - work_func(server_args, bench_args, 0) + work_func(server_args, port_args, bench_args, 0) else: workers = [] for tp_rank in range(server_args.tp_size): @@ -496,6 +499,7 @@ def main(server_args, bench_args): target=work_func, args=( server_args, + port_args, bench_args, tp_rank, ),