Fix allgather ops inside cuda graphs (#3709)

This commit is contained in:
Nicolas Castet
2025-02-25 10:39:10 -06:00
committed by GitHub
parent c0bb9eb3b3
commit 127998cc41
2 changed files with 41 additions and 7 deletions

View File

@@ -139,6 +139,27 @@ if supports_custom_op():
fake_impl=outplace_all_reduce_fake,
)
def reg_all_gather_into_tensor(
output: torch.Tensor, input: torch.Tensor, group_name: str
) -> None:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
group._all_gather_into_tensor(output, input)
def reg_all_gather_into_tensor_fake(
output: torch.Tensor, input: torch.Tensor, group_name: str
) -> None:
pass
direct_register_custom_op(
op_name="reg_all_gather_into_tensor",
op_func=reg_all_gather_into_tensor,
mutates_args=[],
fake_impl=reg_all_gather_into_tensor_fake,
)
class GroupCoordinator:
"""
@@ -414,6 +435,23 @@ class GroupCoordinator:
else:
torch.distributed.all_reduce(input_, group=self.device_group)
def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.all_gather(output, input)
else:
torch.distributed.all_gather_into_tensor(
output, input, group=self.device_group
)
def all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
if not supports_custom_op():
self._all_gather_into_tensor(output, input)
else:
torch.ops.sglang.reg_all_gather_into_tensor(
output, input, group_name=self.unique_name
)
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
world_size = self.world_size
# Bypass the function if we are using only 1 GPU.
@@ -441,9 +479,7 @@ class GroupCoordinator:
output_size, dtype=input_.dtype, device=input_.device
)
# All-gather.
torch.distributed.all_gather_into_tensor(
output_tensor, input_, group=self.device_group
)
self.all_gather_into_tensor(output_tensor, input_)
# Reshape
output_tensor = output_tensor.reshape((world_size,) + input_size)
output_tensor = output_tensor.movedim(0, dim)