This commit is contained in:
2026-04-02 04:53:13 +00:00
parent 80932c96e5
commit 24df76db9d
1987 changed files with 447445 additions and 0 deletions

View File

@@ -0,0 +1,24 @@
import torch
import torch.distributed as dist
def all_gather_into_tensor(self, input_: torch.Tensor, dim: int = -1, output_tensor: torch.Tensor = None) -> torch.Tensor:
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
input_size = input_.size()
# NOTE: we have to use concat-style all-gather here,
# stack-style all-gather has compatibility issues with
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
output_size = (input_size[0] * self.world_size, ) + input_size[1:]
# Allocate output tensor.
# [N,] => [N*world_size], 1D Tensor
if output_tensor is None:
output_tensor = torch.empty(output_size,
dtype=input_.dtype,
device=input_.device)
# print("o tensor is:", output_tensor.shape, "i tensor is:", input_.shape, input_size)
# All-gather.
dist.all_gather_into_tensor(output_tensor,
input_,
group=self.device_group)
return output_tensor