init
This commit is contained in:
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,24 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
def all_gather_into_tensor(self, input_: torch.Tensor, dim: int = -1, output_tensor: torch.Tensor = None) -> torch.Tensor:
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
input_size = input_.size()
|
||||
# NOTE: we have to use concat-style all-gather here,
|
||||
# stack-style all-gather has compatibility issues with
|
||||
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
|
||||
output_size = (input_size[0] * self.world_size, ) + input_size[1:]
|
||||
# Allocate output tensor.
|
||||
# [N,] => [N*world_size], 1D Tensor
|
||||
if output_tensor is None:
|
||||
output_tensor = torch.empty(output_size,
|
||||
dtype=input_.dtype,
|
||||
device=input_.device)
|
||||
# print("o tensor is:", output_tensor.shape, "i tensor is:", input_.shape, input_size)
|
||||
# All-gather.
|
||||
dist.all_gather_into_tensor(output_tensor,
|
||||
input_,
|
||||
group=self.device_group)
|
||||
return output_tensor
|
||||
Reference in New Issue
Block a user