diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 4e81f80dc..ad336c808 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -650,17 +650,19 @@ class GroupCoordinator: output_size, dtype=input_.dtype, device=input_.device ) - if input_.is_cpu: - if is_shm_available(input_.dtype, self.world_size, self.local_size): - return torch.ops.sgl_kernel.shm_allgather(input_, dim) - else: - torch.distributed.all_gather_into_tensor( - output_tensor, input_, group=self.device_group - ) - return output_tensor - # All-gather. - self.all_gather_into_tensor(output_tensor, input_) + if input_.is_cpu and is_shm_available( + input_.dtype, self.world_size, self.local_size + ): + return torch.ops.sgl_kernel.shm_allgather(input_, dim) + + if input_.is_cpu: + torch.distributed.all_gather_into_tensor( + output_tensor, input_, group=self.device_group + ) + else: + self.all_gather_into_tensor(output_tensor, input_) + # Reshape output_tensor = output_tensor.reshape((world_size,) + input_size) output_tensor = output_tensor.movedim(0, dim)