diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index cca7d5a49..4a027ae99 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -271,12 +271,13 @@ def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner): batch, dp_size=model_runner.server_args.dp_size, attn_tp_size=1, - tp_cpu_group=model_runner.tp_group.cpu_group, + tp_group=model_runner.tp_group, get_idle_batch=None, disable_cuda_graph=model_runner.server_args.disable_cuda_graph, spec_algorithm=SpeculativeAlgorithm.NONE, speculative_num_draft_tokens=None, require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args), + disable_overlap_schedule=model_runner.server_args.disable_overlap_schedule, ) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 9a1654343..a7f893253 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1945,7 +1945,7 @@ class Scheduler( local_batch, dp_size=self.server_args.dp_size, attn_tp_size=self.attn_tp_size, - tp_cpu_group=self.tp_cpu_group, + tp_group=self.tp_group, get_idle_batch=self.get_idle_batch, disable_cuda_graph=self.server_args.disable_cuda_graph, spec_algorithm=self.spec_algorithm, @@ -1954,6 +1954,7 @@ class Scheduler( enable_deepep_moe=self.server_args.enable_deepep_moe, deepep_mode=DeepEPMode[self.server_args.deepep_mode], require_mlp_tp_gather=require_mlp_tp_gather(self.server_args), + disable_overlap_schedule=self.server_args.disable_overlap_schedule, ) @staticmethod @@ -1961,7 +1962,7 @@ class Scheduler( local_batch: ScheduleBatch, dp_size, attn_tp_size: int, - tp_cpu_group, + tp_group, get_idle_batch, disable_cuda_graph: bool, spec_algorithm, @@ -1970,6 +1971,7 @@ class Scheduler( enable_deepep_moe: bool, deepep_mode: DeepEPMode, require_mlp_tp_gather: bool, + disable_overlap_schedule: bool, ): # Check if other DP workers have running batches if local_batch is None: @@ -2000,6 +2002,12 @@ class Scheduler( ) tbo_preparer = TboDPAttentionPreparer() + if disable_overlap_schedule: + group = tp_group.device_group + device = tp_group.device + else: + group = tp_group.cpu_group + device = "cpu" local_info = torch.tensor( [ @@ -2015,15 +2023,17 @@ class Scheduler( ), ], dtype=torch.int64, + device=device, ) global_info = torch.empty( (dp_size, attn_tp_size, 6), dtype=torch.int64, + device=device, ) torch.distributed.all_gather_into_tensor( global_info.flatten(), local_info, - group=tp_cpu_group, + group=group, ) global_num_tokens = global_info[:, 0, 0].tolist() can_cuda_graph = min(global_info[:, 0, 1].tolist())