[CPU] Fix fallback allgather issue (#8041)
This commit is contained in:
@@ -650,17 +650,19 @@ class GroupCoordinator:
|
||||
output_size, dtype=input_.dtype, device=input_.device
|
||||
)
|
||||
|
||||
if input_.is_cpu:
|
||||
if is_shm_available(input_.dtype, self.world_size, self.local_size):
|
||||
return torch.ops.sgl_kernel.shm_allgather(input_, dim)
|
||||
else:
|
||||
torch.distributed.all_gather_into_tensor(
|
||||
output_tensor, input_, group=self.device_group
|
||||
)
|
||||
return output_tensor
|
||||
|
||||
# All-gather.
|
||||
self.all_gather_into_tensor(output_tensor, input_)
|
||||
if input_.is_cpu and is_shm_available(
|
||||
input_.dtype, self.world_size, self.local_size
|
||||
):
|
||||
return torch.ops.sgl_kernel.shm_allgather(input_, dim)
|
||||
|
||||
if input_.is_cpu:
|
||||
torch.distributed.all_gather_into_tensor(
|
||||
output_tensor, input_, group=self.device_group
|
||||
)
|
||||
else:
|
||||
self.all_gather_into_tensor(output_tensor, input_)
|
||||
|
||||
# Reshape
|
||||
output_tensor = output_tensor.reshape((world_size,) + input_size)
|
||||
output_tensor = output_tensor.movedim(0, dim)
|
||||
|
||||
Reference in New Issue
Block a user