diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 231efb965..c572fc8e8 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -139,6 +139,27 @@ if supports_custom_op(): fake_impl=outplace_all_reduce_fake, ) + def reg_all_gather_into_tensor( + output: torch.Tensor, input: torch.Tensor, group_name: str + ) -> None: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + group._all_gather_into_tensor(output, input) + + def reg_all_gather_into_tensor_fake( + output: torch.Tensor, input: torch.Tensor, group_name: str + ) -> None: + pass + + direct_register_custom_op( + op_name="reg_all_gather_into_tensor", + op_func=reg_all_gather_into_tensor, + mutates_args=[], + fake_impl=reg_all_gather_into_tensor_fake, + ) + class GroupCoordinator: """ @@ -414,6 +435,23 @@ class GroupCoordinator: else: torch.distributed.all_reduce(input_, group=self.device_group) + def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor): + pynccl_comm = self.pynccl_comm + if pynccl_comm is not None and not pynccl_comm.disabled: + pynccl_comm.all_gather(output, input) + else: + torch.distributed.all_gather_into_tensor( + output, input, group=self.device_group + ) + + def all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor): + if not supports_custom_op(): + self._all_gather_into_tensor(output, input) + else: + torch.ops.sglang.reg_all_gather_into_tensor( + output, input, group_name=self.unique_name + ) + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: world_size = self.world_size # Bypass the function if we are using only 1 GPU. @@ -441,9 +479,7 @@ class GroupCoordinator: output_size, dtype=input_.dtype, device=input_.device ) # All-gather. - torch.distributed.all_gather_into_tensor( - output_tensor, input_, group=self.device_group - ) + self.all_gather_into_tensor(output_tensor, input_) # Reshape output_tensor = output_tensor.reshape((world_size,) + input_size) output_tensor = output_tensor.movedim(0, dim) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 0de01f3f9..f5182c828 100755 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -824,9 +824,7 @@ def all_gather( input_tensor, (0, 0, 0, max_len - input_tensor.shape[0]) ) - torch.distributed.all_gather_into_tensor( - forward_batch.gathered_buffer, padded_tensor, group=group - ) + group.all_gather_into_tensor(forward_batch.gathered_buffer, padded_tensor) gathered_tensors = torch.concat( [ @@ -862,7 +860,7 @@ class DeepseekV2DecoderLayer(nn.Module): if self.enable_dp_attention: self.tp_rank = get_tensor_model_parallel_rank() self.tp_size = get_tensor_model_parallel_world_size() - self.tp_group = get_tp_group().device_group + self.tp_group = get_tp_group() if not global_server_args_dict["disable_mla"]: self.self_attn = DeepseekV2AttentionMLA( config=config,