diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 8253a303b..14ed362cf 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1399,29 +1399,6 @@ class Scheduler( self.metrics_collector.log_stats(self.stats) self._publish_kv_events() - def coordinate_spec_dp_attn_batch(self, new_batch: Optional[ScheduleBatch]): - """Coordinate the DP attention batch.""" - - local_info = torch.tensor( - [ - (new_batch is not None), - ], - dtype=torch.int64, - ) - global_info = torch.empty( - (self.server_args.dp_size, self.attn_tp_size, 1), - dtype=torch.int64, - ) - torch.distributed.all_gather_into_tensor( - global_info.flatten(), - local_info, - group=self.tp_cpu_group, - ) - any_new_batch = any( - global_info[:, 0, 0].tolist() - ) # Any DP worker has forward batch - return any_new_batch - def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: # Merge the prefill batch into the running batch chunked_req_to_exclude = set() @@ -1456,13 +1433,15 @@ class Scheduler( new_batch = self.get_new_batch_prefill() - # TODO(ch-wan): minor refactor is needed here to improve readability - any_new_batch = ( - self.server_args.enable_dp_attention - and not self.spec_algorithm.is_none() - and self.coordinate_spec_dp_attn_batch(new_batch) - ) - if new_batch is not None or any_new_batch: + need_dp_attn_preparation = require_mlp_sync(self.server_args) + + if need_dp_attn_preparation and not self.spec_algorithm.is_none(): + # In speculative decoding, prefill batches and decode batches cannot be processed in the same DP attention group. + # We prepare idle batches in advance to skip preparing decode batches when there are prefill batches in the group. + new_batch, _ = self.prepare_dp_attn_batch(new_batch) + need_dp_attn_preparation = new_batch is None + + if new_batch is not None: # Run prefill first if possible ret = new_batch else: @@ -1473,8 +1452,9 @@ class Scheduler( else: ret = None - if require_mlp_sync(self.server_args): - ret, _ = self.prepare_mlp_sync_batch(ret) + # Handle DP attention + if need_dp_attn_preparation: + ret, _ = self.prepare_dp_attn_batch(ret) return ret