[BugFix]Fix bugs when initializing communication groups with dp on 300I Duo (#1478)

### What this PR does / why we need it?
This PR fixes a bug that use broadcast with cpu_group when running dp.
The `broadcast310p` patch will take effects for both cpu_group and
device group, but we only need it for device group. Hence a wrapper is
added to allow cpu_group use native torch broadcast and it solves the
bug.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
With this PR, DP on 310p runs normally and generates reasonable answers.

Signed-off-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
This commit is contained in:
Angazenn
2025-06-28 16:07:52 +08:00
committed by GitHub
parent 2cf9c4c3a2
commit 5c53cbaf2a

View File

@@ -77,20 +77,28 @@ class NullHandle:
def communication_adaptation_310p():
def broadcast310p(tensor, src, group=None, async_op=False):
rank = torch.distributed.get_rank(group)
world_size = torch.distributed.get_world_size(group)
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
tensor_list[rank] = tensor
torch.distributed.all_gather(tensor_list, tensor, group=group)
tensor[...] = tensor_list[src]
if async_op:
return NullHandle()
else:
return None
def broadcast310p_wrapper(fn):
torch.distributed.broadcast = broadcast310p
torch.distributed.distributed_c10d.broadcast = broadcast310p
def broadcast310p(tensor, src, group=None, async_op=False):
if tensor.device == torch.device('cpu'):
return fn(tensor, src, group, async_op)
rank = torch.distributed.get_rank(group)
world_size = torch.distributed.get_world_size(group)
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
tensor_list[rank] = tensor
torch.distributed.all_gather(tensor_list, tensor, group=group)
tensor[...] = tensor_list[src]
if async_op:
return NullHandle()
else:
return None
return broadcast310p
torch.distributed.broadcast = broadcast310p_wrapper(
torch.distributed.broadcast)
torch.distributed.distributed_c10d.broadcast = broadcast310p_wrapper(
torch.distributed.distributed_c10d.broadcast)
def all_reduce_wrapper_310p(fn):