Support (1 <= dp < tp) in the dp attention in DeepEP (#4770)
Co-authored-by: Cheng Wan <cwan39@gatech.edu>
This commit is contained in:
@@ -439,6 +439,15 @@ class GroupCoordinator:
|
||||
else:
|
||||
torch.distributed.all_reduce(input_, group=self.device_group)
|
||||
|
||||
def reduce_scatter(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
input_list: List[torch.Tensor],
|
||||
) -> None:
|
||||
# TODO(ch-wan): support other backends
|
||||
torch.distributed.reduce_scatter(output, input_list, group=self.device_group)
|
||||
return output
|
||||
|
||||
def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
|
||||
pynccl_comm = self.pynccl_comm
|
||||
if pynccl_comm is not None and not pynccl_comm.disabled:
|
||||
@@ -456,11 +465,23 @@ class GroupCoordinator:
|
||||
output, input, group_name=self.unique_name
|
||||
)
|
||||
|
||||
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
def all_gather(
|
||||
self,
|
||||
input_: torch.Tensor,
|
||||
dim: int = -1,
|
||||
tensor_list: List[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
world_size = self.world_size
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
if tensor_list is not None:
|
||||
# TODO(ch-wan): support other backends
|
||||
return torch.distributed.all_gather(
|
||||
tensor_list, input_, group=self.device_group
|
||||
)
|
||||
|
||||
assert (
|
||||
-input_.dim() <= dim < input_.dim()
|
||||
), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
|
||||
|
||||
Reference in New Issue
Block a user