From 2f2e07439ce2ab7598a6b2ee92ee51ac14b7dc01 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 16 Nov 2024 00:30:39 -0800 Subject: [PATCH] Fix weight update for data parallelism (#2050) --- python/sglang/bench_offline_throughput.py | 8 +++++--- python/sglang/srt/managers/data_parallel_controller.py | 1 + 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/python/sglang/bench_offline_throughput.py b/python/sglang/bench_offline_throughput.py index 924bad9f3..5320da72b 100644 --- a/python/sglang/bench_offline_throughput.py +++ b/python/sglang/bench_offline_throughput.py @@ -231,15 +231,16 @@ def throughput_test( input_requests = get_dataset(bench_args, tokenizer) warmup_requests = sample_random_requests( - input_len=20, - output_len=4, - num_prompts=2, + input_len=256, + output_len=16, + num_prompts=16, range_ratio=0.8, tokenizer=tokenizer, dataset_path=bench_args.dataset_path, ) # Warm up + logging.info("\nWarmup...") throughput_test_once( backend_name=bench_args.backend, backend=backend, @@ -247,6 +248,7 @@ def throughput_test( ignore_eos=not bench_args.disable_ignore_eos, ) + logging.info("\nBenchmark...") result = throughput_test_once( backend_name=bench_args.backend, backend=backend, diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 132827966..95af3decc 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -83,6 +83,7 @@ class DataParallelController: self.workers = [] for dp_rank in range(server_args.dp_size): tmp_port_args = PortArgs.init_new(server_args) + tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name send_to = self.launch_tensor_parallel_group(