Support (1 <= dp < tp) in the dp attention in DeepEP (#4770)

Co-authored-by: Cheng Wan <cwan39@gatech.edu>
This commit is contained in:
tarinkk
2025-03-27 20:09:35 -04:00
committed by GitHub
parent 98a2cfa9b2
commit 7f19e083c1
10 changed files with 238 additions and 47 deletions

View File

@@ -5,7 +5,7 @@ import logging
import os
from contextlib import contextmanager
from functools import wraps
from typing import Callable, List, Optional, TypeVar, Union
from typing import Any, Callable, List, Optional, TypeVar, Union
import torch
import torch.distributed as dist

View File

@@ -439,6 +439,15 @@ class GroupCoordinator:
else:
torch.distributed.all_reduce(input_, group=self.device_group)
def reduce_scatter(
self,
output: torch.Tensor,
input_list: List[torch.Tensor],
) -> None:
# TODO(ch-wan): support other backends
torch.distributed.reduce_scatter(output, input_list, group=self.device_group)
return output
def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
@@ -456,11 +465,23 @@ class GroupCoordinator:
output, input, group_name=self.unique_name
)
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
def all_gather(
self,
input_: torch.Tensor,
dim: int = -1,
tensor_list: List[torch.Tensor] = None,
) -> torch.Tensor:
world_size = self.world_size
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
if tensor_list is not None:
# TODO(ch-wan): support other backends
return torch.distributed.all_gather(
tensor_list, input_, group=self.device_group
)
assert (
-input_.dim() <= dim < input_.dim()
), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"