[CPU] Fix fallback allgather issue (#8041)
This commit is contained in:
@@ -650,17 +650,19 @@ class GroupCoordinator:
|
|||||||
output_size, dtype=input_.dtype, device=input_.device
|
output_size, dtype=input_.dtype, device=input_.device
|
||||||
)
|
)
|
||||||
|
|
||||||
if input_.is_cpu:
|
# All-gather.
|
||||||
if is_shm_available(input_.dtype, self.world_size, self.local_size):
|
if input_.is_cpu and is_shm_available(
|
||||||
|
input_.dtype, self.world_size, self.local_size
|
||||||
|
):
|
||||||
return torch.ops.sgl_kernel.shm_allgather(input_, dim)
|
return torch.ops.sgl_kernel.shm_allgather(input_, dim)
|
||||||
else:
|
|
||||||
|
if input_.is_cpu:
|
||||||
torch.distributed.all_gather_into_tensor(
|
torch.distributed.all_gather_into_tensor(
|
||||||
output_tensor, input_, group=self.device_group
|
output_tensor, input_, group=self.device_group
|
||||||
)
|
)
|
||||||
return output_tensor
|
else:
|
||||||
|
|
||||||
# All-gather.
|
|
||||||
self.all_gather_into_tensor(output_tensor, input_)
|
self.all_gather_into_tensor(output_tensor, input_)
|
||||||
|
|
||||||
# Reshape
|
# Reshape
|
||||||
output_tensor = output_tensor.reshape((world_size,) + input_size)
|
output_tensor = output_tensor.reshape((world_size,) + input_size)
|
||||||
output_tensor = output_tensor.movedim(0, dim)
|
output_tensor = output_tensor.movedim(0, dim)
|
||||||
|
|||||||
Reference in New Issue
Block a user