From 8609e637a961dd0bd17bbf7f8f81b34cb2f7863a Mon Sep 17 00:00:00 2001 From: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Date: Fri, 20 Jun 2025 14:57:34 -0700 Subject: [PATCH] Fix All-Gather under world size one (#7219) --- python/sglang/srt/distributed/parallel_state.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index cc2ba95a6..24731d2b8 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -523,17 +523,25 @@ class GroupCoordinator: self, input_: torch.Tensor, dim: int = -1, - tensor_list: List[torch.Tensor] = None, + output_tensor_list: Optional[List[torch.Tensor]] = None, ) -> torch.Tensor: world_size = self.world_size # Bypass the function if we are using only 1 GPU. if world_size == 1: - return input_ + if output_tensor_list is not None: + logger.warning( + "Performing in-place all-gather with a group size of 1. " + "This may be unnecessary; consider bypassing it for better efficiency." + ) + output_tensor_list[0].copy_(input_) + return None + else: + return input_ - if tensor_list is not None: + if output_tensor_list is not None: # TODO(ch-wan): support other backends return torch.distributed.all_gather( - tensor_list, input_, group=self.device_group + output_tensor_list, input_, group=self.device_group ) assert (